Revert D30752939: [pytorch][PR] nvfuser update
authorEli Uriegas <eliuriegas@fb.com>
Thu, 16 Sep 2021 00:37:10 +0000 (17:37 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 16 Sep 2021 00:38:47 +0000 (17:38 -0700)
Test Plan: revert-hammer

Differential Revision:
D30752939 (https://github.com/pytorch/pytorch/commit/cfaecaf40bd6cabd3f4e0ef0d8c7252655349b61)

Original commit changeset: ce122e80f01b

fbshipit-source-id: 57685df8f9946032a06eff1de8a3d1498500d2d2

181 files changed:
CMakeLists.txt
aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/core/interned_strings.h
benchmarks/cpp/nvfuser/CMakeLists.txt [deleted file]
benchmarks/cpp/nvfuser/batch_norm.cpp [deleted file]
benchmarks/cpp/nvfuser/bert.cpp [deleted file]
benchmarks/cpp/nvfuser/broadcast.cpp [deleted file]
benchmarks/cpp/nvfuser/gelu_backward.cpp [deleted file]
benchmarks/cpp/nvfuser/heuristic_cache.cpp [deleted file]
benchmarks/cpp/nvfuser/heuristic_lookup.cpp [deleted file]
benchmarks/cpp/nvfuser/instance_norm.cpp [deleted file]
benchmarks/cpp/nvfuser/layer_norm.cpp [deleted file]
benchmarks/cpp/nvfuser/lstm_cell.cpp [deleted file]
benchmarks/cpp/nvfuser/main.cpp [deleted file]
benchmarks/cpp/nvfuser/reduction.cpp [deleted file]
benchmarks/cpp/nvfuser/scale_bias_relu.cpp [deleted file]
benchmarks/cpp/nvfuser/softmax.cpp [deleted file]
benchmarks/cpp/nvfuser/utils.cpp [deleted file]
benchmarks/cpp/nvfuser/utils.h [deleted file]
caffe2/CMakeLists.txt
cmake/Summary.cmake
test/cpp/jit/CMakeLists.txt
test/cpp/jit/test_gpu.cpp
test/cpp/jit/test_gpu_shift.cpp [deleted file]
test/cpp/jit/test_gpu_validator.h [deleted file]
test/test_jit_cuda_fuser.py
tools/build_variables.bzl
torch/csrc/jit/codegen/cuda/arith.cpp
torch/csrc/jit/codegen/cuda/arith.h
torch/csrc/jit/codegen/cuda/codegen.cpp
torch/csrc/jit/codegen/cuda/codegen.h
torch/csrc/jit/codegen/cuda/compute_at.cpp
torch/csrc/jit/codegen/cuda/compute_at.h
torch/csrc/jit/codegen/cuda/compute_at_map.cpp [deleted file]
torch/csrc/jit/codegen/cuda/compute_at_map.h [deleted file]
torch/csrc/jit/codegen/cuda/disjoint_set.h [deleted file]
torch/csrc/jit/codegen/cuda/dispatch.cpp
torch/csrc/jit/codegen/cuda/dispatch.h
torch/csrc/jit/codegen/cuda/executor.cpp
torch/csrc/jit/codegen/cuda/executor.h
torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp
torch/csrc/jit/codegen/cuda/executor_kernel_arg.h
torch/csrc/jit/codegen/cuda/executor_launch_params.cpp
torch/csrc/jit/codegen/cuda/executor_launch_params.h
torch/csrc/jit/codegen/cuda/executor_utils.cpp
torch/csrc/jit/codegen/cuda/executor_utils.h
torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
torch/csrc/jit/codegen/cuda/expr_evaluator.h
torch/csrc/jit/codegen/cuda/fusion.cpp
torch/csrc/jit/codegen/cuda/fusion.h
torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp [deleted file]
torch/csrc/jit/codegen/cuda/fusion_segmenter.h [deleted file]
torch/csrc/jit/codegen/cuda/graph_fuser.cpp
torch/csrc/jit/codegen/cuda/index_compute.cpp
torch/csrc/jit/codegen/cuda/index_compute.h
torch/csrc/jit/codegen/cuda/index_reference_replay.cpp [deleted file]
torch/csrc/jit/codegen/cuda/index_reference_replay.h [deleted file]
torch/csrc/jit/codegen/cuda/instrumentation.cpp
torch/csrc/jit/codegen/cuda/instrumentation.h
torch/csrc/jit/codegen/cuda/interface.cpp
torch/csrc/jit/codegen/cuda/interface.h
torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
torch/csrc/jit/codegen/cuda/ir_base_nodes.h
torch/csrc/jit/codegen/cuda/ir_cloner.cpp
torch/csrc/jit/codegen/cuda/ir_cloner.h
torch/csrc/jit/codegen/cuda/ir_graphviz.cpp
torch/csrc/jit/codegen/cuda/ir_graphviz.h
torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
torch/csrc/jit/codegen/cuda/ir_iostream.cpp
torch/csrc/jit/codegen/cuda/ir_iostream.h
torch/csrc/jit/codegen/cuda/ir_nodes.cpp
torch/csrc/jit/codegen/cuda/ir_printer.h
torch/csrc/jit/codegen/cuda/ir_utils.cpp [deleted file]
torch/csrc/jit/codegen/cuda/ir_utils.h
torch/csrc/jit/codegen/cuda/iter_visitor.cpp
torch/csrc/jit/codegen/cuda/iter_visitor.h
torch/csrc/jit/codegen/cuda/kernel.cpp
torch/csrc/jit/codegen/cuda/kernel.h
torch/csrc/jit/codegen/cuda/kernel_cache.cpp
torch/csrc/jit/codegen/cuda/kernel_cache.h
torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp [deleted file]
torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h [deleted file]
torch/csrc/jit/codegen/cuda/kernel_ir.cpp
torch/csrc/jit/codegen/cuda/kernel_ir.h
torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp
torch/csrc/jit/codegen/cuda/kernel_ir_builder.h
torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp
torch/csrc/jit/codegen/cuda/kernel_ir_printer.h
torch/csrc/jit/codegen/cuda/lower2device.cpp
torch/csrc/jit/codegen/cuda/lower2device.h
torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp
torch/csrc/jit/codegen/cuda/lower_alias_memory.h
torch/csrc/jit/codegen/cuda/lower_allocation.cpp [deleted file]
torch/csrc/jit/codegen/cuda/lower_allocation.h [deleted file]
torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp [deleted file]
torch/csrc/jit/codegen/cuda/lower_expr_sort.h [deleted file]
torch/csrc/jit/codegen/cuda/lower_index.cpp
torch/csrc/jit/codegen/cuda/lower_index.h
torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp
torch/csrc/jit/codegen/cuda/lower_insert_syncs.h
torch/csrc/jit/codegen/cuda/lower_loops.cpp
torch/csrc/jit/codegen/cuda/lower_loops.h
torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp [deleted file]
torch/csrc/jit/codegen/cuda/lower_magic_zero.h [deleted file]
torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp [deleted file]
torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h [deleted file]
torch/csrc/jit/codegen/cuda/lower_predicate.cpp [deleted file]
torch/csrc/jit/codegen/cuda/lower_predicate.h [deleted file]
torch/csrc/jit/codegen/cuda/lower_shift.cpp [deleted file]
torch/csrc/jit/codegen/cuda/lower_shift.h [deleted file]
torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp
torch/csrc/jit/codegen/cuda/lower_thread_predicate.h
torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp [deleted file]
torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h [deleted file]
torch/csrc/jit/codegen/cuda/lower_unroll.cpp
torch/csrc/jit/codegen/cuda/lower_unroll.h
torch/csrc/jit/codegen/cuda/lower_utils.cpp
torch/csrc/jit/codegen/cuda/lower_utils.h
torch/csrc/jit/codegen/cuda/lower_validation.cpp
torch/csrc/jit/codegen/cuda/lower_validation.h
torch/csrc/jit/codegen/cuda/manager.cpp
torch/csrc/jit/codegen/cuda/mutator.cpp
torch/csrc/jit/codegen/cuda/ops/all_ops.h [deleted file]
torch/csrc/jit/codegen/cuda/ops/composite.cpp [deleted file]
torch/csrc/jit/codegen/cuda/ops/composite.h [deleted file]
torch/csrc/jit/codegen/cuda/ops/normalization.cpp [deleted file]
torch/csrc/jit/codegen/cuda/ops/normalization.h [deleted file]
torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp [deleted file]
torch/csrc/jit/codegen/cuda/parallel_dimension_map.h [deleted file]
torch/csrc/jit/codegen/cuda/parallel_type_bitmap.cpp [deleted file]
torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h [deleted file]
torch/csrc/jit/codegen/cuda/parser.cpp
torch/csrc/jit/codegen/cuda/parser.h
torch/csrc/jit/codegen/cuda/partition.cpp
torch/csrc/jit/codegen/cuda/partition.h
torch/csrc/jit/codegen/cuda/predicate_compute.cpp
torch/csrc/jit/codegen/cuda/predicate_compute.h
torch/csrc/jit/codegen/cuda/register_interface.cpp
torch/csrc/jit/codegen/cuda/root_domain_map.cpp [deleted file]
torch/csrc/jit/codegen/cuda/root_domain_map.h [deleted file]
torch/csrc/jit/codegen/cuda/runtime/block_reduction.cu
torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu [deleted file]
torch/csrc/jit/codegen/cuda/runtime/block_sync_default.cu [deleted file]
torch/csrc/jit/codegen/cuda/runtime/broadcast.cu
torch/csrc/jit/codegen/cuda/runtime/fp16_support.cu
torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu
torch/csrc/jit/codegen/cuda/runtime/helpers.cu
torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu
torch/csrc/jit/codegen/cuda/runtime/tensor.cu
torch/csrc/jit/codegen/cuda/runtime/welford.cu [deleted file]
torch/csrc/jit/codegen/cuda/scheduler.cpp [new file with mode: 0644]
torch/csrc/jit/codegen/cuda/scheduler.h [new file with mode: 0644]
torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h [deleted file]
torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp [deleted file]
torch/csrc/jit/codegen/cuda/scheduler/normalization.h [deleted file]
torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp [deleted file]
torch/csrc/jit/codegen/cuda/scheduler/pointwise.h [deleted file]
torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h [deleted file]
torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp [deleted file]
torch/csrc/jit/codegen/cuda/scheduler/reduction.h [deleted file]
torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h [deleted file]
torch/csrc/jit/codegen/cuda/scheduler/registry.cpp [deleted file]
torch/csrc/jit/codegen/cuda/scheduler/registry.h [deleted file]
torch/csrc/jit/codegen/cuda/scheduler/utils.cpp [deleted file]
torch/csrc/jit/codegen/cuda/scheduler/utils.h [deleted file]
torch/csrc/jit/codegen/cuda/shape_inference.cpp
torch/csrc/jit/codegen/cuda/tensor_view.cpp
torch/csrc/jit/codegen/cuda/transform_iter.cpp
torch/csrc/jit/codegen/cuda/transform_iter.h
torch/csrc/jit/codegen/cuda/transform_replay.cpp
torch/csrc/jit/codegen/cuda/transform_replay.h
torch/csrc/jit/codegen/cuda/transform_rfactor.cpp
torch/csrc/jit/codegen/cuda/type.cpp
torch/csrc/jit/codegen/cuda/type.h
torch/csrc/jit/codegen/cuda/utils.cpp [deleted file]
torch/csrc/jit/codegen/cuda/utils.h
torch/csrc/jit/runtime/autodiff.cpp
torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp
torch/csrc/jit/runtime/profiling_record.cpp
torch/testing/_internal/jit_utils.py

index c63aedc..0956b6a 100644 (file)
@@ -195,9 +195,6 @@ cmake_dependent_option(
     USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
     "USE_CUDNN" OFF)
 cmake_dependent_option(
-    BUILD_NVFUSER_BENCHMARK "Build C++ binaries for nvfuser benchmarks" ON
-    "USE_CUDA;BUILD_TEST" OFF)
-cmake_dependent_option(
   USE_WHOLE_CUDNN "Use whole-library linking for cuDNN" OFF
     "USE_STATIC_CUDNN" OFF)
 cmake_dependent_option(
index 80aed8a..d766c69 100644 (file)
@@ -208,8 +208,6 @@ _(aten, avg_pool3d_forward) \
 _(aten, baddbmm) \
 _(aten, bartlett_window) \
 _(aten, batch_norm) \
-_(aten, _batch_norm_impl_index) \
-_(aten, _batch_norm_impl_index_backward) \
 _(aten, bernoulli) \
 _(aten, bilinear) \
 _(aten, binary_cross_entropy) \
@@ -351,7 +349,6 @@ _(aten, full_like) \
 _(aten, gather) \
 _(aten, gcd) \
 _(aten, gelu) \
-_(aten, gelu_backward) \
 _(aten, geometric) \
 _(aten, geqrf) \
 _(aten, get_device) \
@@ -521,8 +518,6 @@ _(aten, narrow) \
 _(aten, narrow_copy) \
 _(aten, native_batch_norm) \
 _(aten, native_batch_norm_backward) \
-_(aten, native_layer_norm) \
-_(aten, native_layer_norm_backward) \
 _(aten, native_clone) \
 _(aten, native_get_device) \
 _(aten, native_norm) \
index b4c4319..7ed3cf8 100644 (file)
@@ -42,7 +42,6 @@ namespace c10 {
   _(prim, CudaFusionGroup)           \
   _(prim, CudaFusionGuard)           \
   _(prim, FunctionalGraph)           \
-  _(prim, add_optional)              \
   _(prim, DifferentiableGraph)       \
   _(prim, TensorExprGroup)           \
   _(prim, StaticSubgraph)            \
diff --git a/benchmarks/cpp/nvfuser/CMakeLists.txt b/benchmarks/cpp/nvfuser/CMakeLists.txt
deleted file mode 100644 (file)
index 1024c40..0000000
+++ /dev/null
@@ -1,18 +0,0 @@
-if(USE_CUDA)
-  add_executable(nvfuser_bench
-    batch_norm.cpp
-    bert.cpp
-    broadcast.cpp
-    gelu_backward.cpp
-    heuristic_lookup.cpp
-    instance_norm.cpp
-    layer_norm.cpp
-    lstm_cell.cpp
-    reduction.cpp
-    softmax.cpp
-    scale_bias_relu.cpp
-    utils.cpp
-    main.cpp)
-
-  target_link_libraries(nvfuser_bench PRIVATE torch_library benchmark)
-endif()
diff --git a/benchmarks/cpp/nvfuser/batch_norm.cpp b/benchmarks/cpp/nvfuser/batch_norm.cpp
deleted file mode 100644 (file)
index 7d57f15..0000000
+++ /dev/null
@@ -1,296 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-//------------------------------------------------------------------------------
-
-static void setupBatchNorm(Fusion* fusion, DataType dtype) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  FusionGuard fg(fusion);
-
-  const bool kTraining = true;
-  const float kMomentum = 0.1;
-  const float kEps = 1e-5;
-
-  // setup fusion
-  auto input = makeContigTensor(4, dtype);
-  auto weight = makeContigTensor(1, dtype);
-  auto bias = makeContigTensor(1, dtype);
-  auto running_mean = makeContigTensor(1, DataType::Float);
-  auto running_var = makeContigTensor(1, DataType::Float);
-
-  fusion->addInput(input);
-  fusion->addInput(weight);
-  fusion->addInput(bias);
-  fusion->addInput(running_mean);
-  fusion->addInput(running_var);
-
-  if (dtype == DataType::Half) {
-    input = castOp(DataType::Float, input);
-    weight = castOp(DataType::Float, weight);
-    bias = castOp(DataType::Float, bias);
-  }
-
-  auto momentum_ptr = new Double(kMomentum);
-  auto eps_ptr = new Double(kEps);
-
-  auto result = batch_norm(
-      input,
-      weight,
-      bias,
-      running_mean,
-      running_var,
-      kTraining,
-      momentum_ptr,
-      eps_ptr);
-
-  auto output = result.output;
-
-  if (dtype == DataType::Half) {
-    output = castOp(DataType::Half, output);
-  }
-
-  fusion->addOutput(output);
-}
-
-static void NvFuserScheduler_BatchNorm(
-    benchmark::State& benchmark_state,
-    FusionExecutorCache* fusion_executor_cache,
-    DataType dtype) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  const bool kTraining = true;
-  const float kMomentum = 0.1;
-  const float kEps = 1e-5;
-
-  std::vector<int64_t> input_shape{
-      benchmark_state.range(0),
-      benchmark_state.range(1),
-      benchmark_state.range(2),
-      benchmark_state.range(2)};
-
-  // inputs
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  auto fp32_options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor at_x = at::randn(input_shape, options);
-  at::Tensor at_weight = at::ones({input_shape[1]}, options);
-  at::Tensor at_bias = at::zeros({input_shape[1]}, options);
-  at::Tensor at_run_mean = at::zeros({input_shape[1]}, fp32_options);
-  at::Tensor at_run_var = at::ones({input_shape[1]}, fp32_options);
-  std::vector<c10::IValue> aten_inputs(
-      {at_x, at_weight, at_bias, at_run_mean, at_run_var});
-
-  runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
-
-  benchmark_state.SetBytesProcessed(
-      (int64_t(benchmark_state.iterations()) *
-       (2 * (at_x.numel() + at_weight.numel() + at_bias.numel())) *
-       int64_t(dataTypeSize(dtype))) +
-      (2 * (at_run_mean.numel() + at_run_var.numel()) *
-       int64_t(dataTypeSize(DataType::Float))));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_BatchNorm(
-    benchmark::State& benchmark_state,
-    DataType dtype) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  const float kMomentum = 0.1;
-  const float kEps = 1e-5;
-  std::vector<int64_t> input_shape{
-      benchmark_state.range(0),
-      benchmark_state.range(1),
-      benchmark_state.range(2),
-      benchmark_state.range(2)};
-
-  // inputs
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  auto fp32_options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor at_x = at::randn(input_shape, options);
-  at::Tensor at_weight = at::ones({input_shape[1]}, options);
-  at::Tensor at_bias = at::zeros({input_shape[1]}, options);
-  at::Tensor at_running_mean = at::zeros({input_shape[1]}, fp32_options);
-  at::Tensor at_running_var = at::ones({input_shape[1]}, fp32_options);
-
-  auto ato_weight = c10::optional<at::Tensor>(at_weight);
-  auto ato_bias = c10::optional<at::Tensor>(at_bias);
-  auto ato_running_mean = c10::optional<at::Tensor>(at_running_mean);
-  auto ato_running_var = c10::optional<at::Tensor>(at_running_var);
-
-  auto output = at::batch_norm(
-      at_x,
-      ato_weight,
-      ato_bias,
-      ato_running_mean,
-      ato_running_var,
-      true,
-      kMomentum,
-      kEps,
-      true);
-  cudaDeviceSynchronize();
-
-  for (auto _ : benchmark_state) {
-    CudaKernelTimer timer;
-    auto output = at::batch_norm(
-        at_x,
-        ato_weight,
-        ato_bias,
-        ato_running_mean,
-        ato_running_var,
-        true,
-        kMomentum,
-        kEps,
-        true);
-    benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
-    cudaDeviceSynchronize();
-  }
-  benchmark_state.SetBytesProcessed(
-      (int64_t(benchmark_state.iterations()) *
-       (2 * (at_x.numel() + at_weight.numel() + at_bias.numel())) *
-       int64_t(dataTypeSize(dtype))) +
-      (2 * (at_running_mean.numel() + at_running_var.numel()) *
-       int64_t(dataTypeSize(DataType::Float))));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_BatchNorm_fp32(benchmark::State& benchmark_state) {
-  Baseline_BatchNorm(benchmark_state, DataType::Float);
-}
-
-static void Baseline_BatchNorm_fp16(benchmark::State& benchmark_state) {
-  Baseline_BatchNorm(benchmark_state, DataType::Half);
-}
-
-//------------------------------------------------------------------------------
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_BatchNorm_fp32,
-    setupBatchNorm,
-    NvFuserScheduler_BatchNorm,
-    DataType::Float);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{32, 32}, {64, 512}, {8, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{64, 128}, {64, 128}, {8, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{128, 128}, {128, 512}, {8, 128}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{16, 64}, {2, 4}, {128, 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_BatchNorm_fp16,
-    setupBatchNorm,
-    NvFuserScheduler_BatchNorm,
-    DataType::Half);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{32, 32}, {64, 512}, {8, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{64, 128}, {64, 128}, {8, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{128, 128}, {128, 512}, {8, 128}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{16, 64}, {2, 4}, {128, 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-BENCHMARK(Baseline_BatchNorm_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{32, 32}, {64, 512}, {8, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{64, 128}, {64, 128}, {8, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{128, 128}, {128, 512}, {8, 128}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{16, 64}, {2, 4}, {128, 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{32, 32}, {64, 512}, {8, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{64, 128}, {64, 128}, {8, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{128, 128}, {128, 512}, {8, 128}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{16, 64}, {2, 4}, {128, 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
diff --git a/benchmarks/cpp/nvfuser/bert.cpp b/benchmarks/cpp/nvfuser/bert.cpp
deleted file mode 100644 (file)
index bec916f..0000000
+++ /dev/null
@@ -1,746 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include <sstream>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-// Return reduction tensor view and output of reduction
-static void setupDivMaxSoftmaxDropoutForward(Fusion* fusion, DataType dtype) {
-  FusionGuard fg(fusion);
-
-  bool is_fp16 = dtype == DataType::Half;
-
-  TensorView* tv0 = TensorViewBuilder()
-                        .ndims(4)
-                        .dtype(dtype)
-                        .contiguity({true, false, false, true})
-                        .shape({-1, 1, 1, -1})
-                        .build();
-  TensorView* tv1 = makeContigTensor(4, dtype);
-
-  fusion->addInput(tv0);
-  fusion->addInput(tv1);
-
-  // TODO: should be input
-  auto d16 = new Double(1.0);
-
-  if (is_fp16) {
-    tv0 = castOp(DataType::Float, tv0);
-    tv1 = castOp(DataType::Float, tv1);
-  }
-
-  auto tv2 = div(tv1, d16);
-  auto tv3 = add(tv2, tv0);
-
-  auto tv10 = softmax(tv3, 3);
-  auto dropout_tvs = dropout(tv10, new Double(0.9));
-  auto tv12 = dropout_tvs.mask;
-  auto tv14 = dropout_tvs.output;
-
-  if (is_fp16) {
-    tv14 = castOp(DataType::Half, tv14);
-    tv10 = castOp(DataType::Half, tv10);
-    tv3 = castOp(DataType::Half, tv3);
-  }
-
-  fusion->addOutput(tv14);
-  fusion->addOutput(tv12);
-  fusion->addOutput(tv10);
-  fusion->addOutput(tv3);
-}
-
-static void setupDivMaxSoftmaxDropoutBackward(Fusion* fusion, DataType dtype) {
-  TensorView* tv0 = makeContigTensor(4, dtype);
-  // Strangely tv1 isn't used anywhere, need to come back to that...
-  TensorView* tv1 = makeContigTensor(4, dtype);
-  TensorView* tv2 = makeContigTensor(4, dtype);
-  TensorView* tv3 = makeContigTensor(4, DataType::Bool);
-
-  fusion->addInput(tv0);
-  fusion->addInput(tv1);
-  fusion->addInput(tv2);
-  fusion->addInput(tv3);
-
-  bool is_fp16 = dtype == DataType::Half;
-  if (is_fp16) {
-    tv0 = castOp(DataType::Float, tv0);
-    tv1 = castOp(DataType::Float, tv1);
-    tv2 = castOp(DataType::Float, tv2);
-  }
-
-  // TODO: should be inputs
-  auto d32 = new Double(1.0);
-  // fusion->addInput(d32);
-  auto d33 = new Double(2.0);
-  // fusion->addInput(d33);
-
-  auto tv4 = mul(tv2, tv3);
-  auto tv5 = mul(tv4, d33);
-  auto tv6 = mul(tv5, tv0);
-  auto tv7 = sum(tv6, {-1});
-  auto tv8 = broadcast(tv7, {false, false, false, true});
-  auto tv9 = mul(tv0, tv8);
-  auto tv10 = sub(tv6, tv9);
-  auto tv11 = div(tv10, d32);
-
-  if (is_fp16) {
-    tv10 = castOp(DataType::Half, tv10);
-    tv11 = castOp(DataType::Half, tv11);
-  }
-
-  fusion->addOutput(tv11);
-  fusion->addOutput(tv10);
-}
-
-static void MagicScheduler_DivMaxSoftDropFwd(
-    benchmark::State& benchmark_state,
-    DataType dtype) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto w = benchmark_state.range(0);
-  auto x = benchmark_state.range(1);
-  auto y = benchmark_state.range(2);
-  auto z = benchmark_state.range(3);
-
-  setupDivMaxSoftmaxDropoutForward(&fusion, dtype);
-
-  auto tvs = ir_utils::allTvs(&fusion);
-
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({w, 1, 1, z}, options);
-  at::Tensor t1 = at::randn({w, x, y, z}, options);
-
-  std::vector<c10::IValue> at_inputs = {t0, t1};
-  std::vector<at::Tensor> cg_outputs;
-
-  auto norm_params = getNormalizationHeuristics(&fusion, at_inputs);
-  TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
-  scheduleNormalization(&fusion, norm_params.value());
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.setMeasureKernelTimeFlag(true);
-  // Sync everything up before we start
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    CudaKernelTimer timer;
-    cg_outputs = fe.runFusion({t0, t1}, norm_params.value().lparams);
-    benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
-  }
-  // Sync everything up before we're finished, don't want to run ahead on the
-  // cpu while benchmarking.
-  cudaDeviceSynchronize();
-
-  int64_t bytes = 0;
-  for (auto tensor : std::vector<at::Tensor>({t0, t1})) {
-    bytes += tensor.numel() *
-        (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
-  }
-
-  for (auto tensor : cg_outputs) {
-    bytes += tensor.numel() *
-        (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
-  }
-
-  benchmark_state.SetBytesProcessed(
-      bytes * int64_t(benchmark_state.iterations()));
-}
-
-static void MagicScheduler_DivMaxSoftDropBwd(
-    benchmark::State& benchmark_state,
-    DataType dtype) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto w = benchmark_state.range(0);
-  auto x = benchmark_state.range(1);
-  auto y = benchmark_state.range(2);
-  auto z = benchmark_state.range(3);
-
-  setupDivMaxSoftmaxDropoutBackward(&fusion, dtype);
-
-  auto tvs = ir_utils::allTvs(&fusion);
-
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({w, x, y, z}, options);
-  at::Tensor t1 = at::randn({w, x, y, z}, options);
-  at::Tensor t2 = at::randn({w, x, y, z}, options);
-  at::Tensor t3 = at::randn({w, x, y, z}, options).round().to(at::kBool);
-
-  std::vector<c10::IValue> at_inputs = {t0, t1, t2, t3};
-  std::vector<at::Tensor> cg_outputs;
-
-  auto norm_params = getNormalizationHeuristics(&fusion, at_inputs);
-  TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
-  scheduleNormalization(&fusion, norm_params.value());
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.setMeasureKernelTimeFlag(true);
-  // Sync everything up before we start
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    CudaKernelTimer timer;
-    cg_outputs = fe.runFusion({t0, t1, t2, t3}, norm_params.value().lparams);
-    benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
-  }
-  // Sync everything up before we're finished, don't want to run ahead on the
-  // cpu while benchmarking.
-  cudaDeviceSynchronize();
-
-  int64_t bytes = 0;
-  // Some reason t1 isn't used, ignore it.
-  for (auto tensor : std::vector<at::Tensor>({t0, t2, t3})) {
-    bytes += tensor.numel() *
-        (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
-  }
-
-  for (auto tensor : cg_outputs) {
-    bytes += tensor.numel() *
-        (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
-  }
-
-  benchmark_state.SetBytesProcessed(
-      bytes * int64_t(benchmark_state.iterations()));
-}
-
-static void setupBiasDropoutAddLayernormFwd(Fusion* fusion, DataType dtype) {
-  FusionGuard fg(fusion);
-
-  bool is_fp16 = dtype == DataType::Half;
-
-  TensorView* tv0 = makeContigTensor(1, dtype);
-  TensorView* tv1 = makeContigTensor(1, dtype);
-  TensorView* tv2 = makeContigTensor(3, dtype);
-  TensorView* tv3 = makeContigTensor(3, dtype);
-  TensorView* tv4 = makeContigTensor(1, dtype);
-
-  fusion->addInput(tv0);
-  fusion->addInput(tv1);
-  fusion->addInput(tv2);
-  fusion->addInput(tv3);
-  fusion->addInput(tv4);
-
-  if (is_fp16) {
-    tv0 = castOp(DataType::Float, tv0);
-    tv1 = castOp(DataType::Float, tv1);
-    tv2 = castOp(DataType::Float, tv2);
-    tv3 = castOp(DataType::Float, tv3);
-    tv4 = castOp(DataType::Float, tv4);
-  }
-
-  auto tv5 = broadcast(tv4, {true, true, false});
-  auto tv6 = add(tv3, tv5);
-  auto dropout_outs = dropout(tv6, new Double(0.9));
-
-  auto tv8 = dropout_outs.output;
-  auto tv10 = dropout_outs.mask;
-
-  auto tv11 = add(tv10, tv2);
-
-  auto layer_norm_outs = layer_norm(tv11, 1, tv0, tv1, new Double(1e-5));
-  auto tv14 = layer_norm_outs.output;
-  auto tv21 = layer_norm_outs.mean;
-  auto tv26 = layer_norm_outs.invstd;
-
-  if (is_fp16) {
-    tv11 = castOp(DataType::Half, tv11);
-    tv14 = castOp(DataType::Half, tv14);
-    tv21 = castOp(DataType::Half, tv21);
-    tv26 = castOp(DataType::Half, tv26);
-  }
-
-  fusion->addOutput(tv8);
-  fusion->addOutput(tv11);
-  fusion->addOutput(tv14);
-  fusion->addOutput(tv21);
-  fusion->addOutput(tv26);
-}
-
-static void MagicScheduler_BiasDropoutAddLayernormFwd(
-    benchmark::State& benchmark_state,
-    DataType dtype) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto x = benchmark_state.range(0);
-  auto y = benchmark_state.range(1);
-  auto z = benchmark_state.range(2);
-
-  setupBiasDropoutAddLayernormFwd(&fusion, dtype);
-
-  auto tvs = ir_utils::allTvs(&fusion);
-
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({z}, options);
-  at::Tensor t1 = at::randn({z}, options);
-  at::Tensor t2 = at::randn({x, y, z}, options);
-  at::Tensor t3 = at::randn({x, y, z}, options);
-  at::Tensor t4 = at::randn({z}, options);
-
-  std::vector<c10::IValue> at_inputs = {t0, t1, t2, t3, t4};
-  std::vector<at::Tensor> cg_outputs;
-
-  auto norm_params = getNormalizationHeuristics(&fusion, at_inputs);
-  TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
-  scheduleNormalization(&fusion, norm_params.value());
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.setMeasureKernelTimeFlag(true);
-  // Sync everything up before we start
-
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    CudaKernelTimer timer;
-    cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
-    benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
-  }
-  // Sync everything up before we're finished, don't want to run ahead on the
-  // cpu while benchmarking.
-  cudaDeviceSynchronize();
-
-  int64_t bytes = 0;
-  for (auto inp : at_inputs) {
-    auto tensor = inp.toTensor();
-    bytes += tensor.numel() *
-        (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
-  }
-
-  for (auto tensor : cg_outputs) {
-    bytes += tensor.numel() *
-        (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
-  }
-
-  benchmark_state.SetBytesProcessed(
-      bytes * int64_t(benchmark_state.iterations()));
-}
-
-static void MagicScheduler_fp32_BiasDropoutAddLayernormFwd(
-    benchmark::State& benchmark_state) {
-  MagicScheduler_BiasDropoutAddLayernormFwd(benchmark_state, DataType::Float);
-}
-
-static void setupBiasDropoutAddLayernormBwd1(Fusion* fusion, DataType dtype) {
-  FusionGuard fg(fusion);
-
-  bool is_fp16 = dtype == DataType::Half;
-
-  TensorView* tv1 = makeContigTensor(3, dtype);
-  TensorView* tv2 = makeContigTensor(3, dtype);
-  TensorView* tv3 = TensorViewBuilder()
-                        .ndims(3)
-                        .dtype(dtype)
-                        .contiguity({true, true, true})
-                        .shape({-1, -1, 1})
-                        .build();
-  TensorView* tv4 = TensorViewBuilder()
-                        .ndims(3)
-                        .dtype(dtype)
-                        .contiguity({true, true, true})
-                        .shape({-1, -1, 1})
-                        .build();
-
-  fusion->addInput(tv1);
-  fusion->addInput(tv2);
-  fusion->addInput(tv3);
-  fusion->addInput(tv4);
-
-  if (is_fp16) {
-    tv1 = castOp(DataType::Float, tv1);
-    tv2 = castOp(DataType::Float, tv2);
-    tv3 = castOp(DataType::Float, tv3);
-    tv4 = castOp(DataType::Float, tv4);
-  }
-
-  auto tv7 = sub(tv2, tv3);
-  auto tv8 = mul(tv7, tv4);
-  auto tv24 = sum(tv1, {0, 1});
-  auto tv22 = mul(tv1, tv8);
-  auto tv23 = sum(tv22, {0, 1});
-
-  if (is_fp16) {
-    tv24 = castOp(DataType::Half, tv24);
-    tv23 = castOp(DataType::Half, tv23);
-    tv8 = castOp(DataType::Half, tv8);
-  }
-
-  fusion->addOutput(tv24);
-  fusion->addOutput(tv23);
-  fusion->addOutput(tv8);
-}
-
-static void MagicScheduler_BiasDropoutAddLayernormBwd1(
-    benchmark::State& benchmark_state,
-    DataType dtype) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto x = benchmark_state.range(0);
-  auto y = benchmark_state.range(1);
-  auto z = benchmark_state.range(2);
-
-  setupBiasDropoutAddLayernormBwd1(&fusion, dtype);
-
-  auto tvs = ir_utils::allTvs(&fusion);
-
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({x, y, z}, options);
-  at::Tensor t1 = at::randn({x, y, z}, options);
-  at::Tensor t2 = at::randn({x, y, 1}, options);
-  at::Tensor t3 = at::randn({x, y, 1}, options);
-
-  std::vector<c10::IValue> at_inputs = {t0, t1, t2, t3};
-  std::vector<at::Tensor> cg_outputs;
-
-  auto norm_params = getNormalizationHeuristics(&fusion, at_inputs);
-  TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
-  scheduleNormalization(&fusion, norm_params.value());
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.setMeasureKernelTimeFlag(true);
-  // Sync everything up before we start
-
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    clearL2Cache();
-    cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
-    benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
-  }
-  // Sync everything up before we're finished, don't want to run ahead on the
-  // cpu while benchmarking.
-  cudaDeviceSynchronize();
-
-  int64_t bytes = 0;
-  for (auto inp : at_inputs) {
-    auto tensor = inp.toTensor();
-    bytes += tensor.numel() *
-        (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
-  }
-
-  for (auto tensor : cg_outputs) {
-    bytes += tensor.numel() *
-        (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
-  }
-
-  benchmark_state.SetBytesProcessed(
-      bytes * int64_t(benchmark_state.iterations()));
-}
-
-static void setupBiasDropoutAddLayernormBwd2(Fusion* fusion, DataType dtype) {
-  FusionGuard fg(fusion);
-
-  bool is_fp16 = dtype == DataType::Half;
-
-  TensorView* tv4 = TensorViewBuilder()
-                        .ndims(3)
-                        .dtype(dtype)
-                        .contiguity({true, true, true})
-                        .shape({-1, -1, 1})
-                        .build();
-  TensorView* tv5 = makeContigTensor(1, dtype);
-  TensorView* tv1 = makeContigTensor(3, dtype);
-  TensorView* tv8 = makeContigTensor(3, dtype);
-
-  fusion->addInput(tv4);
-  fusion->addInput(tv5);
-  fusion->addInput(tv1);
-  fusion->addInput(tv8);
-
-  if (is_fp16) {
-    tv4 = castOp(DataType::Float, tv4);
-    tv5 = castOp(DataType::Float, tv5);
-    tv1 = castOp(DataType::Float, tv1);
-    tv8 = castOp(DataType::Float, tv8);
-  }
-  auto d36 = mul(new Double(1.0), tv1->axis(2)->extent());
-  auto d47 = unaryOp(UnaryOpType::Reciprocal, d36);
-
-  auto tv9 = broadcast(tv5, {true, true, false});
-  auto tv10 = mul(tv1, tv9);
-  auto tv14 = mul(tv10, tv8);
-  auto tv15 = sum(tv14, {2});
-  auto tv16 = broadcast(tv15, {false, false, true});
-  auto tv17 = mul(tv8, tv16);
-  auto tv12 = sum(tv10, {2});
-  auto tv13 = broadcast(tv12, {false, false, true});
-  auto tv11 = mul(d36, tv10);
-  auto tv18 = sub(tv11, tv13);
-  auto tv20 = mul(d47, tv4);
-  auto tv19 = sub(tv18, tv17);
-  auto tv21 = mul(tv20, tv19);
-
-  if (is_fp16) {
-    tv21 = castOp(DataType::Half, tv21);
-  }
-
-  fusion->addOutput(tv21);
-}
-
-static void MagicScheduler_BiasDropoutAddLayernormBwd2(
-    benchmark::State& benchmark_state,
-    DataType dtype) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto x = benchmark_state.range(0);
-  auto y = benchmark_state.range(1);
-  auto z = benchmark_state.range(2);
-
-  setupBiasDropoutAddLayernormBwd2(&fusion, dtype);
-
-  auto tvs = ir_utils::allTvs(&fusion);
-
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
-  at::Tensor t4 = at::randn({x, y, 1}, options);
-  at::Tensor t5 = at::randn({z}, options);
-  at::Tensor t1 = at::randn({x, y, z}, options);
-  at::Tensor t8 = at::randn({x, y, z}, options);
-
-  std::vector<c10::IValue> at_inputs = {t4, t5, t1, t8};
-  std::vector<at::Tensor> cg_outputs;
-
-  auto norm_params = getNormalizationHeuristics(&fusion, at_inputs);
-  TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
-  scheduleNormalization(&fusion, norm_params.value());
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.setMeasureKernelTimeFlag(true);
-  // Sync everything up before we start
-
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    CudaKernelTimer timer;
-    cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
-    benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
-  }
-  // Sync everything up before we're finished, don't want to run ahead on the
-  // cpu while benchmarking.
-  cudaDeviceSynchronize();
-
-  int64_t bytes = 0;
-  for (auto inp : at_inputs) {
-    auto tensor = inp.toTensor();
-    bytes += tensor.numel() *
-        (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
-  }
-
-  for (auto tensor : cg_outputs) {
-    bytes += tensor.numel() *
-        (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
-  }
-
-  benchmark_state.SetBytesProcessed(
-      bytes * int64_t(benchmark_state.iterations()));
-}
-
-static void setupBiasDropoutAddLayernormBwd3(Fusion* fusion, DataType dtype) {
-  FusionGuard fg(fusion);
-
-  bool is_fp16 = dtype == DataType::Half;
-
-  TensorView* tv0 = makeContigTensor(3, dtype);
-  TensorView* tv21 = makeContigTensor(3, dtype);
-
-  fusion->addInput(tv0);
-  fusion->addInput(tv21);
-
-  if (is_fp16) {
-    tv0 = castOp(DataType::Float, tv0);
-    tv21 = castOp(DataType::Float, tv21);
-  }
-
-  // Uncertain this is the right value, but going for it anyways
-  auto d34 = div(new Double(1.0), tv0->axis(2)->extent());
-
-  auto tv25 = mul(tv21, tv0);
-  auto tv26 = mul(tv25, d34);
-  auto tv27 = sum(tv26, {0, 1});
-
-  if (is_fp16) {
-    tv26 = castOp(DataType::Half, tv27);
-    tv27 = castOp(DataType::Half, tv27);
-  }
-
-  fusion->addOutput(tv26);
-  fusion->addOutput(tv27);
-}
-
-static void MagicScheduler_BiasDropoutAddLayernormBwd3(
-    benchmark::State& benchmark_state,
-    DataType dtype) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto x = benchmark_state.range(0);
-  auto y = benchmark_state.range(1);
-  auto z = benchmark_state.range(2);
-
-  setupBiasDropoutAddLayernormBwd3(&fusion, dtype);
-
-  auto tvs = ir_utils::allTvs(&fusion);
-
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({x, y, z}, options);
-  at::Tensor t21 = at::randn({x, y, z}, options);
-
-  std::vector<c10::IValue> at_inputs = {t0, t21};
-  std::vector<at::Tensor> cg_outputs;
-
-  auto norm_params = getNormalizationHeuristics(&fusion, at_inputs);
-  TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
-  scheduleNormalization(&fusion, norm_params.value());
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.setMeasureKernelTimeFlag(true);
-  // Sync everything up before we start
-
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    CudaKernelTimer timer;
-    cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
-    benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
-  }
-  // Sync everything up before we're finished, don't want to run ahead on the
-  // cpu while benchmarking.
-  cudaDeviceSynchronize();
-
-  int64_t bytes = 0;
-  for (auto inp : at_inputs) {
-    auto tensor = inp.toTensor();
-    bytes += tensor.numel() *
-        (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
-  }
-
-  for (auto tensor : cg_outputs) {
-    bytes += tensor.numel() *
-        (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
-  }
-
-  benchmark_state.SetBytesProcessed(
-      bytes * int64_t(benchmark_state.iterations()));
-}
-
-//------------------------------------------------------------------------------
-
-static void DivMaxSoftDropFwd_fp32(benchmark::State& benchmark_state) {
-  MagicScheduler_DivMaxSoftDropFwd(benchmark_state, DataType::Float);
-}
-
-static void DivMaxSoftDropBwd_fp32(benchmark::State& benchmark_state) {
-  MagicScheduler_DivMaxSoftDropBwd(benchmark_state, DataType::Float);
-}
-
-static void DivMaxSoftDropFwd_fp16(benchmark::State& benchmark_state) {
-  MagicScheduler_DivMaxSoftDropFwd(benchmark_state, DataType::Half);
-}
-
-static void DivMaxSoftDropBwd_fp16(benchmark::State& benchmark_state) {
-  MagicScheduler_DivMaxSoftDropBwd(benchmark_state, DataType::Half);
-}
-
-static void BiasDropoutAddLayernormBwd1_fp32(
-    benchmark::State& benchmark_state) {
-  MagicScheduler_BiasDropoutAddLayernormBwd1(benchmark_state, DataType::Float);
-}
-
-// Use full ampere wave here
-static void BiasDropoutAddLayernormBwd1_tf32(
-    benchmark::State& benchmark_state) {
-  MagicScheduler_BiasDropoutAddLayernormBwd1(benchmark_state, DataType::Float);
-}
-
-static void BiasDropoutAddLayernormBwd2_fp32(
-    benchmark::State& benchmark_state) {
-  MagicScheduler_BiasDropoutAddLayernormBwd2(benchmark_state, DataType::Float);
-}
-
-static void BiasDropoutAddLayernormBwd3_fp32(
-    benchmark::State& benchmark_state) {
-  MagicScheduler_BiasDropoutAddLayernormBwd3(benchmark_state, DataType::Float);
-}
-
-//------------------------------------------------------------------------------
-
-BENCHMARK(DivMaxSoftDropFwd_fp32)
-    ->RangeMultiplier(8)
-    ->Ranges({{8, 8}, {16, 16}, {128, 128}, {128, 128}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(DivMaxSoftDropBwd_fp32)
-    ->RangeMultiplier(8)
-    ->Ranges({{8, 8}, {16, 16}, {128, 128}, {128, 128}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(DivMaxSoftDropFwd_fp16)
-    ->RangeMultiplier(8)
-    ->Ranges({{8, 8}, {16, 16}, {128, 128}, {128, 128}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(DivMaxSoftDropBwd_fp16)
-    ->RangeMultiplier(8)
-    ->Ranges({{8, 8}, {16, 16}, {128, 128}, {128, 128}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(BiasDropoutAddLayernormBwd1_fp32)
-    ->RangeMultiplier(2)
-    ->Ranges({{32, 1024}, {128, 128}, {1024, 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-// Use full ampere wave here
-BENCHMARK(BiasDropoutAddLayernormBwd1_tf32)
-    ->RangeMultiplier(2)
-    ->Ranges({{32, 1024}, {128, 128}, {864, 864}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(BiasDropoutAddLayernormBwd2_fp32)
-    ->Ranges({{32, 1024}, {128, 128}, {1024, 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(BiasDropoutAddLayernormBwd3_fp32)
-    ->Ranges({{32, 1024}, {128, 128}, {1024, 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
diff --git a/benchmarks/cpp/nvfuser/broadcast.cpp b/benchmarks/cpp/nvfuser/broadcast.cpp
deleted file mode 100644 (file)
index ac8d392..0000000
+++ /dev/null
@@ -1,217 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include <sstream>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-// Return broadcast tensor view and output of broadcast
-static void setupBroadcast(Fusion* fusion, DataType dtype, int bcast_axis) {
-  FusionGuard fg(fusion);
-
-  bool is_fp16 = dtype == DataType::Half;
-
-  TensorView* tv0 = makeContigTensor(2, dtype);
-  TensorView* tv1 = makeContigTensor(1, dtype);
-
-  fusion->addInput(tv0);
-  fusion->addInput(tv1);
-
-  std::vector<bool> bcast_pattern(2, false);
-  bcast_pattern[bcast_axis] = true;
-
-  if (is_fp16) {
-    tv0 = castOp(DataType::Float, tv0);
-    tv1 = castOp(DataType::Float, tv1);
-  }
-
-  TensorView* tv2 = broadcast(tv1, bcast_pattern);
-  TensorView* tv3 = add(tv0, tv2);
-
-  if (is_fp16) {
-    tv3 = castOp(DataType::Half, tv3);
-  }
-
-  fusion->addOutput(tv3);
-}
-
-static void NvFuserScheduler_Broadcast(
-    benchmark::State& benchmark_state,
-    FusionExecutorCache* fusion_executor_cache,
-    DataType dtype,
-    int bcast_dim) {
-  auto bcast_size = benchmark_state.range(0);
-  auto iter_size = benchmark_state.range(1);
-
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
-  at::Tensor t0 =
-      (bcast_dim ? at::randn({iter_size, bcast_size}, options)
-                 : at::randn({bcast_size, iter_size}, options));
-
-  at::Tensor t1 = at::randn({iter_size}, options);
-
-  fusion_executor_cache->profile(true);
-  fusion_executor_cache->runFusionWithInputs({t0, t1});
-
-  auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
-  auto executor_instance = compile_log.fusion_executor;
-  TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
-  TORCH_INTERNAL_ASSERT(compile_log.launch_constraints.has_value());
-  auto params = toString(compile_log.pointwise_params.value());
-  auto lparams = toString(compile_log.launch_constraints.value());
-
-  benchmark_state.SetLabel(params + lparams);
-
-  fusion_executor_cache->profile(false);
-  executor_instance->setMeasureKernelTimeFlag(true);
-  // Sync everything up before we start
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    auto cg_outputs = fusion_executor_cache->runFusionWithInputs({t0, t1});
-    benchmark_state.SetIterationTime(
-        executor_instance->kernelTimeMs() / 1000.0);
-    clearL2Cache();
-  }
-  // Sync everything up before we're finished, don't want to run ahead on the
-  // cpu while benchmarking.
-  cudaDeviceSynchronize();
-
-  benchmark_state.SetBytesProcessed(
-      int64_t(benchmark_state.iterations()) *
-      (iter_size * bcast_size * 2 + iter_size) * int64_t(dataTypeSize(dtype)));
-}
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Broadcast_Outer_fp32,
-    setupBroadcast,
-    NvFuserScheduler_Broadcast,
-    DataType::Float,
-    0);
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Broadcast_Outer_fp16,
-    setupBroadcast,
-    NvFuserScheduler_Broadcast,
-    DataType::Half,
-    0);
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Broadcast_Inner_fp32,
-    setupBroadcast,
-    NvFuserScheduler_Broadcast,
-    DataType::Float,
-    1);
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Broadcast_Inner_fp16,
-    setupBroadcast,
-    NvFuserScheduler_Broadcast,
-    DataType::Half,
-    1);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp32)
-    ->RangeMultiplier(8)
-    ->Ranges({{1, 1024 * 1024}, {160, 320}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp32)
-    ->RangeMultiplier(8)
-    ->Ranges({{32768, 64 * 1024 * 1024}, {2, 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp32)
-    ->RangeMultiplier(8)
-    ->Ranges({{2, 16}, {32768, 64 * 1024 * 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp16)
-    ->RangeMultiplier(8)
-    ->Ranges({{1, 1024 * 1024}, {160, 320}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp16)
-    ->RangeMultiplier(8)
-    ->Ranges({{32768, 64 * 1024 * 1024}, {2, 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp16)
-    ->RangeMultiplier(8)
-    ->Ranges({{2, 16}, {32768, 64 * 1024 * 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp32)
-    ->RangeMultiplier(8)
-    ->Ranges({{1, 1024 * 1024}, {160, 320}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp32)
-    ->RangeMultiplier(8)
-    ->Ranges({{32768, 64 * 1024 * 1024}, {2, 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp32)
-    ->RangeMultiplier(8)
-    ->Ranges({{2, 16}, {32768, 64 * 1024 * 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp16)
-    ->RangeMultiplier(8)
-    ->Ranges({{1, 1024 * 1024}, {160, 320}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp16)
-    ->RangeMultiplier(8)
-    ->Ranges({{32768, 64 * 1024 * 1024}, {2, 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp16)
-    ->RangeMultiplier(8)
-    ->Ranges({{2, 16}, {32768, 64 * 1024 * 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
diff --git a/benchmarks/cpp/nvfuser/gelu_backward.cpp b/benchmarks/cpp/nvfuser/gelu_backward.cpp
deleted file mode 100644 (file)
index 9d53d9c..0000000
+++ /dev/null
@@ -1,244 +0,0 @@
-
-// Based on NVFuserTest.FusionBiasGeluBwd_CUDA
-
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-static void setupFusion(Fusion* fusion) {
-  FusionGuard fg(fusion);
-
-  const float k_079 = 0.79788456;
-  const float k_004 = 0.044715;
-  const float k_010 = 0.1070322243;
-
-  // gradient tensor
-  auto t0 = makeContigTensor(3, DataType::Half);
-  fusion->addInput(t0);
-
-  auto t1 = castOp(DataType::Float, t0);
-
-  // bias tensor
-  auto t2 = makeContigTensor(1, DataType::Half);
-  fusion->addInput(t2);
-
-  auto t3 = castOp(DataType::Float, t2);
-
-  // input tensor
-  auto t4 = makeContigTensor(3, DataType::Half);
-  fusion->addInput(t4);
-
-  auto t5 = castOp(DataType::Float, t4);
-  auto t6 = broadcast(t3, {true, true, false});
-  auto t7 = add(t6, t5);
-  auto t8 = mul(t7, new Double(k_079));
-  auto t9 = mul(t7, new Double(k_004));
-  auto t10 = mul(t9, t7);
-  auto t11 = add(t10, new Int(1));
-  auto t12 = mul(t8, t11);
-  auto t13 = unaryOp(UnaryOpType::Tanh, t12);
-  auto t14 = mul(t7, new Double(0.5));
-  auto t15 = mul(t13, t13);
-  auto t16 = unaryOp(UnaryOpType::Neg, t15);
-  auto t17 = add(t16, new Int(1));
-  auto t18 = mul(t7, new Double(k_010));
-  auto t19 = mul(t18, t7);
-  auto t20 = add(t19, new Double(k_079));
-  auto t21 = mul(t17, t20);
-  auto t22 = mul(t14, t21);
-  auto t23 = add(t13, new Int(1));
-  auto t24 = mul(t23, new Double(0.5));
-  auto t25 = add(t22, t24);
-  auto t26 = mul(t25, t1);
-
-  // Save float output for validation
-  fusion->addOutput(t26);
-  auto t27 = castOp(DataType::Half, t26);
-  fusion->addOutput(t27);
-}
-
-static std::vector<c10::IValue> setupInputs() {
-  at::manual_seed(0);
-
-  auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
-  std::vector<int64_t> input_shape{6, 512, 4096};
-  std::vector<int64_t> bias_shape{4096};
-  auto at_input = at::randn(input_shape, options);
-  auto at_bias = at::randn(bias_shape, options);
-  auto at_grad = at::randn(input_shape, options);
-
-  return {at_grad, at_bias, at_input};
-}
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_SetupFusion(benchmark::State& benchmark_state) {
-  for (auto _ : benchmark_state) {
-    Fusion fusion;
-    setupFusion(&fusion);
-  }
-}
-
-BENCHMARK(GeluBackward_SetupFusion)->Unit(benchmark::kMicrosecond);
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_AutoSchedule(benchmark::State& benchmark_state) {
-  for (auto _ : benchmark_state) {
-    // Setup (not included in the measurement)
-    benchmark_state.PauseTiming();
-    Fusion fusion;
-    setupFusion(&fusion);
-    std::vector<c10::IValue> inputs = setupInputs();
-    benchmark_state.ResumeTiming();
-
-    // Auto-schedule
-    schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-  }
-}
-
-BENCHMARK(GeluBackward_AutoSchedule)->Unit(benchmark::kMicrosecond);
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_Lower(benchmark::State& benchmark_state) {
-  constexpr int kHiddenFeatures = 512;
-  constexpr int kBatchSize = 64;
-
-  Fusion fusion;
-
-  // setup fusion
-  setupFusion(&fusion);
-
-  // inputs
-  std::vector<c10::IValue> inputs = setupInputs();
-
-  schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
-  for (auto _ : benchmark_state) {
-    GpuLower gpu_lower(&fusion);
-  }
-}
-
-BENCHMARK(GeluBackward_Lower)->Unit(benchmark::kMillisecond);
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_Compile(benchmark::State& benchmark_state) {
-  Fusion fusion;
-
-  // setup fusion
-  setupFusion(&fusion);
-
-  // inputs
-  std::vector<c10::IValue> inputs = setupInputs();
-
-  schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
-  for (auto _ : benchmark_state) {
-    FusionExecutor executor;
-    executor.compileFusion(&fusion);
-  }
-}
-
-BENCHMARK(GeluBackward_Compile)->Unit(benchmark::kMillisecond);
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_RunFusion(benchmark::State& benchmark_state) {
-  Fusion fusion;
-
-  // setup fusion
-  setupFusion(&fusion);
-
-  // inputs
-  std::vector<c10::IValue> inputs = setupInputs();
-
-  // outputs
-  std::vector<at::Tensor> outputs;
-
-  auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
-  FusionExecutor executor;
-  executor.compileFusion(&fusion);
-
-  cudaDeviceSynchronize();
-
-  for (auto _ : benchmark_state) {
-    outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
-    cudaDeviceSynchronize();
-    clearL2Cache();
-  }
-}
-
-BENCHMARK(GeluBackward_RunFusion)->Unit(benchmark::kMicrosecond);
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_RunFusion_GpuOnly(benchmark::State& benchmark_state) {
-  Fusion fusion;
-
-  // setup fusion
-  setupFusion(&fusion);
-
-  // inputs
-  std::vector<c10::IValue> inputs = setupInputs();
-
-  // outputs
-  std::vector<at::Tensor> outputs;
-
-  auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
-  FusionExecutor executor;
-  executor.setMeasureKernelTimeFlag(true);
-  executor.compileFusion(&fusion);
-
-  cudaDeviceSynchronize();
-
-  for (auto _ : benchmark_state) {
-    outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
-    benchmark_state.SetIterationTime(executor.kernelTimeMs() / 1000.0);
-    clearL2Cache();
-  }
-}
-
-BENCHMARK(GeluBackward_RunFusion_GpuOnly)
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_RunFusion_CpuOnly(benchmark::State& benchmark_state) {
-  Fusion fusion;
-
-  // setup fusion
-  setupFusion(&fusion);
-
-  // inputs
-  std::vector<c10::IValue> inputs = setupInputs();
-
-  // outputs
-  std::vector<at::Tensor> outputs;
-
-  auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
-  FusionExecutor executor;
-  executor.setExecuteKernelFlag(false);
-  executor.compileFusion(&fusion);
-
-  for (auto _ : benchmark_state) {
-    outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
-  }
-}
-
-BENCHMARK(GeluBackward_RunFusion_CpuOnly)->Unit(benchmark::kMicrosecond);
diff --git a/benchmarks/cpp/nvfuser/heuristic_cache.cpp b/benchmarks/cpp/nvfuser/heuristic_cache.cpp
deleted file mode 100644 (file)
index 22b8ec4..0000000
+++ /dev/null
@@ -1,177 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
-// but unknown sizes
-TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
-  return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
-}
-
-// Make a non-contiguous tensor of compile-time known sizes
-TensorView* makeConcreteTensor(
-    std::vector<int64_t> shape,
-    DataType dtype = DataType::Float) {
-  return TensorViewBuilder().shape(shape).dtype(dtype).build();
-}
-
-static auto getLayerBackwardNormRuntime(
-    std::unique_ptr<Fusion> fusion_ptr,
-    std::unique_ptr<FusionExecutorCache>& fec,
-    std::vector<at::IValue>& aten_inputs,
-    std::vector<int64_t>& shape,
-    std::vector<int64_t>& norm_shape) {
-  Fusion& fusion = *fusion_ptr.get();
-
-  const size_t kM = shape.size();
-  const size_t kN = norm_shape.size();
-  const size_t kOuterNumDims = kM - kN;
-
-  std::vector<int64_t> outer_shape;
-  for (size_t idx = 0; idx < kOuterNumDims; ++idx) {
-    outer_shape.push_back(shape[idx]);
-  }
-  for (size_t idx = kOuterNumDims; idx < kM; ++idx) {
-    outer_shape.push_back(1);
-  }
-
-  auto grad_out = makeSymbolicTensor(shape.size());
-  auto input = makeSymbolicTensor(shape.size());
-  auto mean = makeConcreteTensor(outer_shape);
-  auto rstd = makeConcreteTensor(outer_shape);
-  auto weight = makeSymbolicTensor(norm_shape.size());
-  auto bias = makeSymbolicTensor(norm_shape.size());
-  fusion.addInput(grad_out);
-  fusion.addInput(input);
-  fusion.addInput(mean);
-  fusion.addInput(rstd);
-  fusion.addInput(weight);
-  fusion.addInput(bias);
-
-  auto grads = layer_norm_backward(
-      grad_out,
-      input,
-      norm_shape,
-      mean,
-      rstd,
-      weight,
-      bias,
-      {true, true, true});
-
-  fusion.addOutput(grads.grad_input);
-  fusion.addOutput(grads.grad_weight);
-  fusion.addOutput(grads.grad_bias);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_grad_out = at::randn(shape, options);
-  at::Tensor aten_input = at::randn(shape, options);
-  at::Tensor aten_weight = at::randn(norm_shape, options);
-  at::Tensor aten_bias = at::randn(norm_shape, options);
-  auto at_weight = c10::optional<at::Tensor>(aten_weight);
-  auto at_bias = c10::optional<at::Tensor>(aten_bias);
-
-  const float kEps = 1e-5;
-  auto aten_results =
-      at::native_layer_norm(aten_input, norm_shape, at_weight, at_bias, kEps);
-  auto aten_output = std::get<0>(aten_results);
-  auto aten_mean = std::get<1>(aten_results);
-  auto aten_rstd = std::get<2>(aten_results);
-
-  fec = std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
-  aten_inputs = {
-      aten_grad_out, aten_input, aten_mean, aten_rstd, aten_weight, aten_bias};
-  auto cg_outputs = fec->runFusionWithInputs(aten_inputs);
-
-  return fec->getMostRecentKernelRuntime();
-}
-
-static void LayerNormBackward_HeuristicLookup(
-    benchmark::State& benchmark_state) {
-  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
-  FusionGuard fg(fusion_ptr.get());
-
-  // PreAllocate
-  std::unique_ptr<FusionExecutorCache> fec;
-  std::vector<at::IValue> aten_inputs;
-
-  std::vector<int64_t> shape{20, 100, 35, 67};
-  std::vector<int64_t> norm_shape{67};
-
-  auto runtime = getLayerBackwardNormRuntime(
-      std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);
-  TORCH_INTERNAL_ASSERT(
-      runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
-
-  for (auto _ : benchmark_state) {
-    // Setup (not included in the measurement)
-    runtime->getMaybeHeuristicsFor(aten_inputs);
-  }
-}
-
-static auto getLayerForwardNormRuntime(
-    std::unique_ptr<Fusion> fusion_ptr,
-    std::unique_ptr<FusionExecutorCache>& fec,
-    std::vector<at::IValue>& aten_inputs,
-    std::vector<int64_t>& shape,
-    std::vector<int64_t>& norm_shape) {
-  Fusion& fusion = *fusion_ptr.get();
-
-  const float kEps = 1e-5;
-  Double* eps_ptr = new Double(kEps);
-
-  auto input = makeSymbolicTensor(shape.size());
-  fusion.addInput(input);
-
-  auto result = layer_norm(input, norm_shape, nullptr, nullptr, eps_ptr);
-
-  fusion.addOutput(result.output);
-  fusion.addOutput(result.mean);
-  fusion.addOutput(result.invstd);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn(shape, options);
-
-  fec = std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
-  aten_inputs = {aten_input};
-  auto cg_outputs = fec->runFusionWithInputs(aten_inputs);
-
-  return fec->getMostRecentKernelRuntime();
-}
-
-static void LayerNormForward_HeuristicLookup(
-    benchmark::State& benchmark_state) {
-  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
-  FusionGuard fg(fusion_ptr.get());
-
-  // PreAllocate
-  std::unique_ptr<FusionExecutorCache> fec;
-  std::vector<at::IValue> aten_inputs;
-
-  std::vector<int64_t> shape{20, 100, 35, 67};
-  std::vector<int64_t> norm_shape{67};
-
-  auto runtime = getLayerForwardNormRuntime(
-      std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);
-  TORCH_INTERNAL_ASSERT(
-      runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
-
-  for (auto _ : benchmark_state) {
-    // Setup (not included in the measurement)
-    runtime->getMaybeHeuristicsFor(aten_inputs);
-  }
-}
-
-BENCHMARK(LayerNormBackward_HeuristicLookup)->Unit(benchmark::kMicrosecond);
-BENCHMARK(LayerNormForward_HeuristicLookup)->Unit(benchmark::kMicrosecond);
diff --git a/benchmarks/cpp/nvfuser/heuristic_lookup.cpp b/benchmarks/cpp/nvfuser/heuristic_lookup.cpp
deleted file mode 100644 (file)
index 22b8ec4..0000000
+++ /dev/null
@@ -1,177 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
-// but unknown sizes
-TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
-  return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
-}
-
-// Make a non-contiguous tensor of compile-time known sizes
-TensorView* makeConcreteTensor(
-    std::vector<int64_t> shape,
-    DataType dtype = DataType::Float) {
-  return TensorViewBuilder().shape(shape).dtype(dtype).build();
-}
-
-static auto getLayerBackwardNormRuntime(
-    std::unique_ptr<Fusion> fusion_ptr,
-    std::unique_ptr<FusionExecutorCache>& fec,
-    std::vector<at::IValue>& aten_inputs,
-    std::vector<int64_t>& shape,
-    std::vector<int64_t>& norm_shape) {
-  Fusion& fusion = *fusion_ptr.get();
-
-  const size_t kM = shape.size();
-  const size_t kN = norm_shape.size();
-  const size_t kOuterNumDims = kM - kN;
-
-  std::vector<int64_t> outer_shape;
-  for (size_t idx = 0; idx < kOuterNumDims; ++idx) {
-    outer_shape.push_back(shape[idx]);
-  }
-  for (size_t idx = kOuterNumDims; idx < kM; ++idx) {
-    outer_shape.push_back(1);
-  }
-
-  auto grad_out = makeSymbolicTensor(shape.size());
-  auto input = makeSymbolicTensor(shape.size());
-  auto mean = makeConcreteTensor(outer_shape);
-  auto rstd = makeConcreteTensor(outer_shape);
-  auto weight = makeSymbolicTensor(norm_shape.size());
-  auto bias = makeSymbolicTensor(norm_shape.size());
-  fusion.addInput(grad_out);
-  fusion.addInput(input);
-  fusion.addInput(mean);
-  fusion.addInput(rstd);
-  fusion.addInput(weight);
-  fusion.addInput(bias);
-
-  auto grads = layer_norm_backward(
-      grad_out,
-      input,
-      norm_shape,
-      mean,
-      rstd,
-      weight,
-      bias,
-      {true, true, true});
-
-  fusion.addOutput(grads.grad_input);
-  fusion.addOutput(grads.grad_weight);
-  fusion.addOutput(grads.grad_bias);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_grad_out = at::randn(shape, options);
-  at::Tensor aten_input = at::randn(shape, options);
-  at::Tensor aten_weight = at::randn(norm_shape, options);
-  at::Tensor aten_bias = at::randn(norm_shape, options);
-  auto at_weight = c10::optional<at::Tensor>(aten_weight);
-  auto at_bias = c10::optional<at::Tensor>(aten_bias);
-
-  const float kEps = 1e-5;
-  auto aten_results =
-      at::native_layer_norm(aten_input, norm_shape, at_weight, at_bias, kEps);
-  auto aten_output = std::get<0>(aten_results);
-  auto aten_mean = std::get<1>(aten_results);
-  auto aten_rstd = std::get<2>(aten_results);
-
-  fec = std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
-  aten_inputs = {
-      aten_grad_out, aten_input, aten_mean, aten_rstd, aten_weight, aten_bias};
-  auto cg_outputs = fec->runFusionWithInputs(aten_inputs);
-
-  return fec->getMostRecentKernelRuntime();
-}
-
-static void LayerNormBackward_HeuristicLookup(
-    benchmark::State& benchmark_state) {
-  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
-  FusionGuard fg(fusion_ptr.get());
-
-  // PreAllocate
-  std::unique_ptr<FusionExecutorCache> fec;
-  std::vector<at::IValue> aten_inputs;
-
-  std::vector<int64_t> shape{20, 100, 35, 67};
-  std::vector<int64_t> norm_shape{67};
-
-  auto runtime = getLayerBackwardNormRuntime(
-      std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);
-  TORCH_INTERNAL_ASSERT(
-      runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
-
-  for (auto _ : benchmark_state) {
-    // Setup (not included in the measurement)
-    runtime->getMaybeHeuristicsFor(aten_inputs);
-  }
-}
-
-static auto getLayerForwardNormRuntime(
-    std::unique_ptr<Fusion> fusion_ptr,
-    std::unique_ptr<FusionExecutorCache>& fec,
-    std::vector<at::IValue>& aten_inputs,
-    std::vector<int64_t>& shape,
-    std::vector<int64_t>& norm_shape) {
-  Fusion& fusion = *fusion_ptr.get();
-
-  const float kEps = 1e-5;
-  Double* eps_ptr = new Double(kEps);
-
-  auto input = makeSymbolicTensor(shape.size());
-  fusion.addInput(input);
-
-  auto result = layer_norm(input, norm_shape, nullptr, nullptr, eps_ptr);
-
-  fusion.addOutput(result.output);
-  fusion.addOutput(result.mean);
-  fusion.addOutput(result.invstd);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn(shape, options);
-
-  fec = std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
-  aten_inputs = {aten_input};
-  auto cg_outputs = fec->runFusionWithInputs(aten_inputs);
-
-  return fec->getMostRecentKernelRuntime();
-}
-
-static void LayerNormForward_HeuristicLookup(
-    benchmark::State& benchmark_state) {
-  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
-  FusionGuard fg(fusion_ptr.get());
-
-  // PreAllocate
-  std::unique_ptr<FusionExecutorCache> fec;
-  std::vector<at::IValue> aten_inputs;
-
-  std::vector<int64_t> shape{20, 100, 35, 67};
-  std::vector<int64_t> norm_shape{67};
-
-  auto runtime = getLayerForwardNormRuntime(
-      std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);
-  TORCH_INTERNAL_ASSERT(
-      runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
-
-  for (auto _ : benchmark_state) {
-    // Setup (not included in the measurement)
-    runtime->getMaybeHeuristicsFor(aten_inputs);
-  }
-}
-
-BENCHMARK(LayerNormBackward_HeuristicLookup)->Unit(benchmark::kMicrosecond);
-BENCHMARK(LayerNormForward_HeuristicLookup)->Unit(benchmark::kMicrosecond);
diff --git a/benchmarks/cpp/nvfuser/instance_norm.cpp b/benchmarks/cpp/nvfuser/instance_norm.cpp
deleted file mode 100644 (file)
index 1d1dd4a..0000000
+++ /dev/null
@@ -1,221 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-static void setupInstanceNorm(Fusion* fusion, DataType dtype) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  FusionGuard fg(fusion);
-
-  auto input = makeContigTensor(4, dtype);
-  auto weight = makeContigTensor(1, dtype);
-  auto bias = makeContigTensor(1, dtype);
-  auto running_mean = makeContigTensor(1, DataType::Float);
-  auto running_var = makeContigTensor(1, DataType::Float);
-
-  fusion->addInput(input);
-  fusion->addInput(weight);
-  fusion->addInput(bias);
-  fusion->addInput(running_mean);
-  fusion->addInput(running_var);
-
-  if (dtype == DataType::Half) {
-    input = castOp(DataType::Float, input);
-    weight = castOp(DataType::Float, weight);
-    bias = castOp(DataType::Float, bias);
-  }
-
-  const bool kTraining = true;
-  const float kMomentum = 0.1;
-  const float kEps = 1e-5;
-  auto momentum_ptr = new Double(kMomentum);
-  auto eps_ptr = new Double(kEps);
-
-  auto norm = instance_norm(
-      input,
-      weight,
-      bias,
-      running_mean,
-      running_var,
-      kTraining,
-      momentum_ptr,
-      eps_ptr);
-
-  auto output = unaryOp(UnaryOpType::Relu, norm.output);
-
-  if (dtype == DataType::Half) {
-    output = castOp(DataType::Half, output);
-  }
-
-  fusion->addOutput(output);
-}
-
-//------------------------------------------------------------------------------
-
-static void NvFuserScheduler_InstanceNorm(
-    benchmark::State& benchmark_state,
-    FusionExecutorCache* fusion_executor_cache,
-    DataType dtype) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  std::vector<int64_t> input_shape{
-      benchmark_state.range(0),
-      benchmark_state.range(2),
-      benchmark_state.range(1),
-      benchmark_state.range(1)};
-
-  // inputs
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  auto fp32_options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor at_x = at::randn(input_shape, options);
-  at::Tensor at_weight = at::ones({input_shape[1]}, options);
-  at::Tensor at_bias = at::zeros({input_shape[1]}, options);
-  at::Tensor at_mean = at::zeros({input_shape[1]}, fp32_options);
-  at::Tensor at_var = at::ones({input_shape[1]}, fp32_options);
-
-  std::vector<c10::IValue> aten_inputs = {
-      at_x, at_weight, at_bias, at_mean, at_var};
-  std::vector<at::Tensor> outputs;
-
-  runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
-
-  const size_t kSize =
-      input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
-  const size_t kChannels = input_shape[1];
-
-  // Read: x, weight, bias, running_mean, running_var
-  // Write: y, running_mean, running_var
-  benchmark_state.SetBytesProcessed(
-      benchmark_state.iterations() *
-      ((kChannels * 2 + kSize * 2) * dataTypeSize(dtype) +
-       (kChannels * 2 * 2) * dataTypeSize(DataType::Float)));
-}
-
-static void Baseline_InstanceNorm(
-    benchmark::State& benchmark_state,
-    DataType dtype) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  std::vector<int64_t> input_shape{
-      benchmark_state.range(0),
-      benchmark_state.range(2),
-      benchmark_state.range(1),
-      benchmark_state.range(1)};
-  const float kMomentum = 0.1;
-  const float kEps = 1e-5;
-  const auto aten_dtype = data_type_to_aten(dtype);
-
-  at::manual_seed(0);
-  auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
-  auto fp32_options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor at_x = at::randn(input_shape, options);
-  at::Tensor at_weight = at::ones({input_shape[1]}, options);
-  at::Tensor at_bias = at::zeros({input_shape[1]}, options);
-  at::Tensor at_mean = at::zeros({input_shape[1]}, fp32_options);
-  at::Tensor at_var = at::ones({input_shape[1]}, fp32_options);
-
-  auto ato_weight = c10::optional<at::Tensor>(at_weight);
-  auto ato_bias = c10::optional<at::Tensor>(at_bias);
-  auto ato_running_mean = c10::optional<at::Tensor>(at_mean);
-  auto ato_running_var = c10::optional<at::Tensor>(at_var);
-
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    CudaKernelTimer timer;
-
-    auto norm = at::instance_norm(
-        at_x,
-        ato_weight,
-        ato_bias,
-        ato_running_mean,
-        ato_running_var,
-        true,
-        kMomentum,
-        kEps,
-        false);
-    auto output = at::relu(norm);
-
-    benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
-    cudaDeviceSynchronize();
-    clearL2Cache();
-    cudaDeviceSynchronize();
-  }
-
-  const size_t kSize =
-      input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
-  const size_t kChannels = input_shape[1];
-
-  // Read: x, weight, bias, running_mean, running_var
-  // Write: y, running_mean, running_var
-  benchmark_state.SetBytesProcessed(
-      benchmark_state.iterations() *
-      ((kChannels * 2 + kSize * 2) * dataTypeSize(dtype) +
-       (kChannels * 2 * 2) * dataTypeSize(DataType::Float)));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_InstanceNorm_fp32(benchmark::State& benchmark_state) {
-  Baseline_InstanceNorm(benchmark_state, DataType::Float);
-}
-
-static void Baseline_InstanceNorm_fp16(benchmark::State& benchmark_state) {
-  Baseline_InstanceNorm(benchmark_state, DataType::Half);
-}
-
-//------------------------------------------------------------------------------
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_fp32_InstanceNorm,
-    setupInstanceNorm,
-    NvFuserScheduler_InstanceNorm,
-    DataType::Float);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_fp32_InstanceNorm)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8}, {640, 640}, {64, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_fp16_InstanceNorm,
-    setupInstanceNorm,
-    NvFuserScheduler_InstanceNorm,
-    DataType::Half);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_fp16_InstanceNorm)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8}, {640, 640}, {64, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-//------------------------------------------------------------------------------
-
-BENCHMARK(Baseline_InstanceNorm_fp32)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8}, {640, 640}, {64, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_InstanceNorm_fp16)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8}, {640, 640}, {64, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-//------------------------------------------------------------------------------
diff --git a/benchmarks/cpp/nvfuser/layer_norm.cpp b/benchmarks/cpp/nvfuser/layer_norm.cpp
deleted file mode 100644 (file)
index 5bbe76f..0000000
+++ /dev/null
@@ -1,161 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-//------------------------------------------------------------------------------
-
-static void setupLayerNorm(Fusion* fusion, DataType dtype) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  FusionGuard fg(fusion);
-
-  const int kReductionAxis = 1;
-  const float kEps = 1e-5;
-
-  Double* eps_ptr = new Double(kEps);
-
-  // setup fusion
-  auto input = makeContigTensor(2, dtype);
-  auto weight = makeContigTensor(1, dtype);
-  auto bias = makeContigTensor(1, dtype);
-
-  fusion->addInput(input);
-  fusion->addInput(weight);
-  fusion->addInput(bias);
-
-  if (dtype == DataType::Half) {
-    input = castOp(DataType::Float, input);
-    weight = castOp(DataType::Float, weight);
-    bias = castOp(DataType::Float, bias);
-  }
-
-  auto layer_norm_results = layer_norm(input, 1, weight, bias, eps_ptr);
-
-  auto output = layer_norm_results.output;
-
-  if (dtype == DataType::Half) {
-    output = castOp(DataType::Half, output);
-  }
-
-  fusion->addOutput(output);
-}
-
-static void NvFuserScheduler_LayerNorm(
-    benchmark::State& benchmark_state,
-    FusionExecutorCache* fusion_executor_cache,
-    DataType dtype) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  std::vector<int64_t> input_shape{656, benchmark_state.range(0)};
-  const float kEps = 1e-5;
-
-  // inputs
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  at::Tensor input = at::randn(input_shape, options);
-  at::Tensor weight = at::randn({input_shape[1]}, options);
-  at::Tensor bias = at::randn({input_shape[1]}, options);
-
-  std::vector<c10::IValue> aten_inputs({input, weight, bias});
-
-  runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
-
-  benchmark_state.SetBytesProcessed(
-      int64_t(benchmark_state.iterations()) *
-      (2 * input.numel() + weight.numel() + bias.numel()) *
-      int64_t(dataTypeSize(dtype)));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_LayerNorm(
-    benchmark::State& benchmark_state,
-    DataType dtype) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  std::vector<int64_t> input_shape{656, benchmark_state.range(0)};
-  const int kReductionAxis = 1;
-  std::vector<int64_t> norm_shape;
-  for (int idx = kReductionAxis; idx < input_shape.size(); ++idx) {
-    norm_shape.push_back(input_shape[idx]);
-  }
-
-  // inputs
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  at::Tensor input = at::randn(input_shape, options);
-  at::Tensor weight = at::randn({input_shape[1]}, options);
-  at::Tensor bias = at::randn({input_shape[1]}, options);
-
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    CudaKernelTimer timer;
-    auto output = at::layer_norm(input, norm_shape, weight, bias);
-    benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
-    cudaDeviceSynchronize();
-    clearL2Cache();
-    cudaDeviceSynchronize();
-  }
-}
-
-static void Baseline_LayerNorm_fp32(benchmark::State& benchmark_state) {
-  Baseline_LayerNorm(benchmark_state, DataType::Float);
-}
-
-static void Baseline_LayerNorm_fp16(benchmark::State& benchmark_state) {
-  Baseline_LayerNorm(benchmark_state, DataType::Half);
-}
-
-//------------------------------------------------------------------------------
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_fp32_LayerNorm,
-    setupLayerNorm,
-    NvFuserScheduler_LayerNorm,
-    DataType::Float);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_fp32_LayerNorm)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8 << 12}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_fp16_LayerNorm,
-    setupLayerNorm,
-    NvFuserScheduler_LayerNorm,
-    DataType::Half);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_fp16_LayerNorm)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8 << 12}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-BENCHMARK(Baseline_LayerNorm_fp32)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8 << 12}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_LayerNorm_fp16)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8 << 12}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
diff --git a/benchmarks/cpp/nvfuser/lstm_cell.cpp b/benchmarks/cpp/nvfuser/lstm_cell.cpp
deleted file mode 100644 (file)
index e6bffc6..0000000
+++ /dev/null
@@ -1,261 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-// TODO: add LSTM function to composite operations
-// Function Signature: cy, hy = lstm(x, cx)
-static void setupFusion(Fusion* fusion) {
-  FusionGuard fg(fusion);
-
-  TensorView* tvs[16];
-  for (size_t i = 0; i < 16; i++) {
-    tvs[i] = makeContigTensor(2, DataType::Float);
-    fusion->addInput(tvs[i]);
-  }
-
-  const auto cx = makeContigTensor(2, DataType::Float);
-  fusion->addInput(cx);
-
-  const auto in_x = add(add(add(tvs[0], tvs[1]), tvs[2]), tvs[3]);
-  const auto forget_x = add(add(add(tvs[4], tvs[5]), tvs[6]), tvs[7]);
-  const auto cell_x = add(add(add(tvs[8], tvs[9]), tvs[10]), tvs[11]);
-  const auto out_x = add(add(add(tvs[12], tvs[13]), tvs[14]), tvs[15]);
-  auto lstm_result = lstm(cx, in_x, forget_x, cell_x, out_x);
-
-  fusion->addOutput(lstm_result.cell);
-  fusion->addOutput(lstm_result.hidden);
-}
-
-static std::vector<c10::IValue> setupInputs(
-    int hidden_features,
-    int batch_size) {
-  at::manual_seed(0);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const at::Tensor large_tensor0 =
-      at::randn({batch_size, hidden_features * 4}, options);
-  const at::Tensor large_tensor1 =
-      at::randn({batch_size, hidden_features * 4}, options);
-  const at::Tensor large_tensor2 =
-      at::randn({batch_size, hidden_features * 4}, options);
-  const at::Tensor large_tensor3 =
-      at::randn({batch_size, hidden_features * 4}, options);
-
-  const auto chunked0 = large_tensor0.chunk(4, 1);
-  const auto chunked1 = large_tensor1.chunk(4, 1);
-  const auto chunked2 = large_tensor2.chunk(4, 1);
-  const auto chunked3 = large_tensor3.chunk(4, 1);
-
-  std::vector<c10::IValue> inputs;
-  inputs.insert(inputs.end(), chunked0.begin(), chunked0.end());
-  inputs.insert(inputs.end(), chunked1.begin(), chunked1.end());
-  inputs.insert(inputs.end(), chunked2.begin(), chunked2.end());
-  inputs.insert(inputs.end(), chunked3.begin(), chunked3.end());
-
-  const auto at_cx = at::randn({batch_size, hidden_features}, options);
-  inputs.push_back(at_cx);
-
-  return inputs;
-}
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_SetupFusion(benchmark::State& benchmark_state) {
-  for (auto _ : benchmark_state) {
-    Fusion fusion;
-    setupFusion(&fusion);
-  }
-}
-
-BENCHMARK(LstmCell_SetupFusion)->Unit(benchmark::kMicrosecond);
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_AutoSchedule(benchmark::State& benchmark_state) {
-  constexpr int kHiddenFeatures = 512;
-  constexpr int kBatchSize = 64;
-
-  for (auto _ : benchmark_state) {
-    // Setup (not included in the measurement)
-    benchmark_state.PauseTiming();
-    Fusion fusion;
-    setupFusion(&fusion);
-    std::vector<c10::IValue> inputs = setupInputs(kHiddenFeatures, kBatchSize);
-    benchmark_state.ResumeTiming();
-
-    // Auto-schedule
-    schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-  }
-}
-
-BENCHMARK(LstmCell_AutoSchedule)->Unit(benchmark::kMicrosecond);
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_Lower(benchmark::State& benchmark_state) {
-  constexpr int kHiddenFeatures = 512;
-  constexpr int kBatchSize = 64;
-
-  Fusion fusion;
-
-  // setup fusion
-  setupFusion(&fusion);
-
-  // inputs
-  std::vector<c10::IValue> inputs = setupInputs(kHiddenFeatures, kBatchSize);
-
-  schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
-  for (auto _ : benchmark_state) {
-    GpuLower gpu_lower(&fusion);
-  }
-}
-
-BENCHMARK(LstmCell_Lower)->Unit(benchmark::kMillisecond);
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_Compile(benchmark::State& benchmark_state) {
-  constexpr int kHiddenFeatures = 512;
-  constexpr int kBatchSize = 64;
-
-  Fusion fusion;
-
-  // setup fusion
-  setupFusion(&fusion);
-
-  // inputs
-  std::vector<c10::IValue> inputs = setupInputs(kHiddenFeatures, kBatchSize);
-
-  schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
-  for (auto _ : benchmark_state) {
-    FusionExecutor executor;
-    executor.compileFusion(&fusion);
-  }
-}
-
-BENCHMARK(LstmCell_Compile)->Unit(benchmark::kMillisecond);
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_RunFusion(
-    benchmark::State& benchmark_state,
-    int hidden_features,
-    int batch_size) {
-  Fusion fusion;
-
-  // setup fusion
-  setupFusion(&fusion);
-
-  // inputs
-  std::vector<c10::IValue> inputs = setupInputs(hidden_features, batch_size);
-
-  // outputs
-  std::vector<at::Tensor> outputs;
-
-  auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
-  FusionExecutor executor;
-  executor.compileFusion(&fusion);
-
-  cudaDeviceSynchronize();
-
-  for (auto _ : benchmark_state) {
-    outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
-    cudaDeviceSynchronize();
-  }
-}
-
-BENCHMARK_CAPTURE(LstmCell_RunFusion, Small, 512, 64)
-    ->Unit(benchmark::kMicrosecond);
-
-BENCHMARK_CAPTURE(LstmCell_RunFusion, Medium, 1024, 128)
-    ->Unit(benchmark::kMicrosecond);
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_RunFusion_GpuOnly(
-    benchmark::State& benchmark_state,
-    int hidden_features,
-    int batch_size) {
-  Fusion fusion;
-
-  // setup fusion
-  setupFusion(&fusion);
-
-  // inputs
-  std::vector<c10::IValue> inputs = setupInputs(hidden_features, batch_size);
-
-  // outputs
-  std::vector<at::Tensor> outputs;
-
-  auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
-  FusionExecutor executor;
-  executor.setMeasureKernelTimeFlag(true);
-  executor.compileFusion(&fusion);
-
-  cudaDeviceSynchronize();
-
-  for (auto _ : benchmark_state) {
-    outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
-    benchmark_state.SetIterationTime(executor.kernelTimeMs() / 1000.0);
-    cudaDeviceSynchronize();
-    clearL2Cache();
-    cudaDeviceSynchronize();
-  }
-}
-
-BENCHMARK_CAPTURE(LstmCell_RunFusion_GpuOnly, Small, 512, 64)
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK_CAPTURE(LstmCell_RunFusion_GpuOnly, Medium, 1024, 128)
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_RunFusion_CpuOnly(
-    benchmark::State& benchmark_state,
-    int hidden_features,
-    int batch_size) {
-  Fusion fusion;
-
-  // setup fusion
-  setupFusion(&fusion);
-
-  // inputs
-  std::vector<c10::IValue> inputs = setupInputs(hidden_features, batch_size);
-
-  // outputs
-  std::vector<at::Tensor> outputs;
-
-  auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
-  FusionExecutor executor;
-  executor.setExecuteKernelFlag(false);
-  executor.compileFusion(&fusion);
-
-  for (auto _ : benchmark_state) {
-    outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
-  }
-}
-
-BENCHMARK_CAPTURE(LstmCell_RunFusion_CpuOnly, Small, 512, 64)
-    ->Unit(benchmark::kMicrosecond);
-
-BENCHMARK_CAPTURE(LstmCell_RunFusion_CpuOnly, Medium, 1024, 128)
-    ->Unit(benchmark::kMicrosecond);
diff --git a/benchmarks/cpp/nvfuser/main.cpp b/benchmarks/cpp/nvfuser/main.cpp
deleted file mode 100644 (file)
index 71fefa0..0000000
+++ /dev/null
@@ -1,3 +0,0 @@
-#include <benchmark/benchmark.h>
-
-BENCHMARK_MAIN();
diff --git a/benchmarks/cpp/nvfuser/reduction.cpp b/benchmarks/cpp/nvfuser/reduction.cpp
deleted file mode 100644 (file)
index 7e6ab7b..0000000
+++ /dev/null
@@ -1,213 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include <sstream>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-// Return reduction tensor view and output of reduction
-static void setupReduction(Fusion* fusion, DataType dtype, int red_axis) {
-  FusionGuard fg(fusion);
-
-  bool is_fp16 = dtype == DataType::Half;
-
-  TensorView* tv0 = makeContigTensor(2, dtype);
-  fusion->addInput(tv0);
-
-  TensorView* tv0_cast = tv0;
-  if (is_fp16) {
-    tv0_cast = castOp(DataType::Float, tv0);
-  }
-
-  TensorView* tv1 = sum(tv0_cast, {red_axis});
-
-  TensorView* tv1_cast = tv1;
-  if (is_fp16) {
-    tv1_cast = castOp(DataType::Half, tv1);
-  }
-
-  fusion->addOutput(tv1_cast);
-
-  TensorView* output_of_reduction = nullptr;
-  if (is_fp16) {
-    output_of_reduction = tv1_cast;
-  }
-}
-
-static void NvFuserScheduler_Reduction(
-    benchmark::State& benchmark_state,
-    FusionExecutorCache* fusion_executor_cache,
-    DataType dtype,
-    int reduction_dim) {
-  auto reduction_size = benchmark_state.range(0);
-  auto iter_size = benchmark_state.range(1);
-
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  at::Tensor aten_input =
-      (reduction_dim ? at::randn({iter_size, reduction_size}, options)
-                     : at::randn({reduction_size, iter_size}, options));
-
-  fusion_executor_cache->profile(true);
-  fusion_executor_cache->runFusionWithInputs({aten_input});
-
-  auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
-  auto executor_instance = compile_log.fusion_executor;
-  TORCH_INTERNAL_ASSERT(compile_log.reduction_params.has_value());
-  TORCH_INTERNAL_ASSERT(compile_log.launch_constraints.has_value());
-  auto rparams = toString(compile_log.reduction_params.value());
-  auto lparams = toString(compile_log.launch_constraints.value());
-
-  benchmark_state.SetLabel(rparams + lparams);
-
-  fusion_executor_cache->profile(false);
-  executor_instance->setMeasureKernelTimeFlag(true);
-  // Sync everything up before we start
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    auto cg_outputs = fusion_executor_cache->runFusionWithInputs({aten_input});
-    benchmark_state.SetIterationTime(
-        executor_instance->kernelTimeMs() / 1000.0);
-    clearL2Cache();
-  }
-  // Sync everything up before we're finished, don't want to run ahead on the
-  // cpu while benchmarking.
-  cudaDeviceSynchronize();
-
-  benchmark_state.SetBytesProcessed(
-      int64_t(benchmark_state.iterations()) *
-      (iter_size * reduction_size + iter_size) * int64_t(dataTypeSize(dtype)));
-}
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Reduction_Outer_fp32,
-    setupReduction,
-    NvFuserScheduler_Reduction,
-    DataType::Float,
-    0);
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Reduction_Outer_fp16,
-    setupReduction,
-    NvFuserScheduler_Reduction,
-    DataType::Half,
-    0);
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Reduction_Inner_fp32,
-    setupReduction,
-    NvFuserScheduler_Reduction,
-    DataType::Float,
-    1);
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Reduction_Inner_fp16,
-    setupReduction,
-    NvFuserScheduler_Reduction,
-    DataType::Half,
-    1);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp32)
-    ->RangeMultiplier(8)
-    ->Ranges({{1, 1024 * 1024}, {160, 320}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{32768, 128 * 1024 * 1024}, {2, 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{2, 16}, {32768, 128 * 1024 * 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp32)
-    ->RangeMultiplier(2)
-    ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp16)
-    ->RangeMultiplier(8)
-    ->Ranges({{1, 1024 * 1024}, {160, 320}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{32768, 128 * 1024 * 1024}, {2, 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{2, 16}, {32768, 128 * 1024 * 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp16)
-    ->RangeMultiplier(2)
-    ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp32)
-    ->RangeMultiplier(8)
-    ->Ranges({{1, 1024 * 1024}, {160, 320}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{32768, 128 * 1024 * 1024}, {2, 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp32)
-    ->RangeMultiplier(4)
-    ->Ranges({{2, 16}, {32768, 128 * 1024 * 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp32)
-    ->RangeMultiplier(2)
-    ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp16)
-    ->RangeMultiplier(8)
-    ->Ranges({{1, 1024 * 1024}, {160, 320}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{32768, 128 * 1024 * 1024}, {2, 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp16)
-    ->RangeMultiplier(4)
-    ->Ranges({{2, 16}, {32768, 128 * 1024 * 1024}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp16)
-    ->RangeMultiplier(2)
-    ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
diff --git a/benchmarks/cpp/nvfuser/scale_bias_relu.cpp b/benchmarks/cpp/nvfuser/scale_bias_relu.cpp
deleted file mode 100644 (file)
index 6a294ba..0000000
+++ /dev/null
@@ -1,410 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-static void setupSBR(Fusion* fusion, DataType dtype) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  FusionGuard fg(fusion);
-
-  const size_t kNumberOfDims = 4;
-
-  std::vector<int64_t> bcast_shape(kNumberOfDims, 1);
-  bcast_shape[bcast_shape.size() - 1] = -1;
-
-  std::vector<bool> bcast_contig(kNumberOfDims, false);
-  bcast_contig[bcast_contig.size() - 1] = true;
-
-  auto x = makeContigTensor(kNumberOfDims, dtype);
-
-  auto scale = TensorViewBuilder()
-                   .contiguity(bcast_contig)
-                   .shape(bcast_shape)
-                   .dtype(dtype)
-                   .build();
-
-  auto bias = TensorViewBuilder()
-                  .contiguity(bcast_contig)
-                  .shape(bcast_shape)
-                  .dtype(dtype)
-                  .build();
-
-  fusion->addInput(x);
-  fusion->addInput(scale);
-  fusion->addInput(bias);
-
-  if (dtype == DataType::Half) {
-    x = castOp(DataType::Float, x);
-    scale = castOp(DataType::Float, scale);
-    bias = castOp(DataType::Float, bias);
-  }
-
-  auto scale_bias = add(mul(x, scale), bias);
-  auto scale_bias_relu = unaryOp(UnaryOpType::Relu, scale_bias);
-
-  if (dtype == DataType::Half) {
-    scale_bias_relu = castOp(DataType::Half, scale_bias_relu);
-  }
-  fusion->addOutput(scale_bias_relu);
-}
-
-static void setupSBRNorm(Fusion* fusion, DataType dtype) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-  FusionGuard fg(fusion);
-
-  const size_t kNumberOfDims = 4;
-
-  auto x = makeContigTensor(kNumberOfDims, dtype);
-  auto weight = makeContigTensor(1, dtype);
-  auto bias = makeContigTensor(1, dtype);
-  auto mean = makeContigTensor(1, dtype);
-  auto var = makeContigTensor(1, dtype);
-
-  fusion->addInput(x);
-  fusion->addInput(weight);
-  fusion->addInput(bias);
-  fusion->addInput(mean);
-  fusion->addInput(var);
-
-  std::vector<bool> broadcast_mask(kNumberOfDims, true);
-  broadcast_mask[broadcast_mask.size() - 1] = false;
-
-  if (dtype == DataType::Half) {
-    x = castOp(DataType::Float, x);
-    weight = castOp(DataType::Float, weight);
-    bias = castOp(DataType::Float, bias);
-    mean = castOp(DataType::Float, mean);
-    var = castOp(DataType::Float, var);
-  }
-
-  auto rsqrt = unaryOp(UnaryOpType::Rsqrt, var);
-  auto this_scale = mul(weight, rsqrt);
-  auto this_bias = mul(sub(bias, mean), this_scale);
-
-  auto bcast_scale = broadcast(this_scale, broadcast_mask);
-  auto bcast_bias = broadcast(this_bias, broadcast_mask);
-
-  auto scale_bias = add(mul(x, bcast_scale), bcast_bias);
-  auto scale_bias_relu = unaryOp(UnaryOpType::Relu, scale_bias);
-
-  if (dtype == DataType::Half) {
-    scale_bias_relu = castOp(DataType::Half, scale_bias_relu);
-  }
-
-  fusion->addOutput(scale_bias_relu);
-}
-
-//------------------------------------------------------------------------------
-
-static void NvFuserScheduler_SBR(
-    benchmark::State& benchmark_state,
-    FusionExecutorCache* fusion_executor_cache,
-    DataType dtype) {
-  // N, H, W, C format
-  std::vector<int64_t> input_shape{
-      benchmark_state.range(0),
-      benchmark_state.range(1),
-      benchmark_state.range(1),
-      benchmark_state.range(2)};
-  std::vector<int64_t> bcast_shape{1, 1, 1, -1};
-
-  // inputs
-  at::manual_seed(0);
-  std::vector<int64_t> static_bcast_shape{1, 1, 1, benchmark_state.range(2)};
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  at::Tensor at_x = at::randn(input_shape, options);
-  at::Tensor at_scale = at::ones(static_bcast_shape, options);
-  at::Tensor at_bias = at::zeros(static_bcast_shape, options);
-
-  // inputs
-  std::vector<c10::IValue> aten_inputs = {at_x, at_scale, at_bias};
-
-  fusion_executor_cache->profile(true);
-  fusion_executor_cache->runFusionWithInputs(aten_inputs);
-
-  auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
-  auto executor_instance = compile_log.fusion_executor;
-  TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
-  TORCH_INTERNAL_ASSERT(compile_log.launch_constraints.has_value());
-  auto params = toString(compile_log.pointwise_params.value());
-  auto lparams = toString(compile_log.launch_constraints.value());
-
-  benchmark_state.SetLabel(params + lparams);
-  benchmark_state.SetLabel(lparams);
-
-  fusion_executor_cache->profile(false);
-  executor_instance->setMeasureKernelTimeFlag(true);
-  // Sync everything up before we start
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
-    benchmark_state.SetIterationTime(
-        executor_instance->kernelTimeMs() / 1000.0);
-    clearL2Cache();
-  }
-
-  // Sync everything up before we're finished, don't want to run ahead on the
-  // cpu while benchmarking.
-  cudaDeviceSynchronize();
-
-  const size_t size =
-      input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
-  const size_t channels = input_shape[3];
-  benchmark_state.SetBytesProcessed(
-      int64_t(benchmark_state.iterations()) * (channels * 2 + size * 2) *
-      int64_t(dataTypeSize(dtype)));
-}
-
-static void Baseline_SBR(benchmark::State& benchmark_state, DataType dtype) {
-  // N, H, W, C format
-  std::vector<int64_t> input_shape{
-      benchmark_state.range(0),
-      benchmark_state.range(1),
-      benchmark_state.range(1),
-      benchmark_state.range(2)};
-  std::vector<int64_t> bcast_shape{benchmark_state.range(2)};
-
-  // inputs
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  at::Tensor at_x = at::randn(input_shape, options);
-  at::Tensor at_y = at::randn(input_shape, options);
-  at::Tensor at_scale = at::ones(bcast_shape, options);
-  at::Tensor at_bias = at::zeros(bcast_shape, options);
-
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    CudaKernelTimer timer;
-
-    auto scale = at::mul(at_x, at_scale);
-    auto bias = at::add(scale, at_bias);
-    auto output = at::relu(bias);
-
-    benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
-    cudaDeviceSynchronize();
-    clearL2Cache();
-    cudaDeviceSynchronize();
-  }
-
-  const size_t size =
-      input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
-  const size_t channels = input_shape[3];
-  benchmark_state.SetBytesProcessed(
-      int64_t(benchmark_state.iterations()) * (channels * 2 + size * 2) *
-      int64_t(dataTypeSize(dtype)));
-}
-
-//------------------------------------------------------------------------------
-
-static void NvFuserScheduler_SBR_Norm(
-    benchmark::State& benchmark_state,
-    FusionExecutorCache* fusion_executor_cache,
-    DataType dtype) {
-  // N, H, W, C format
-  std::vector<int64_t> input_shape{
-      benchmark_state.range(0),
-      benchmark_state.range(1),
-      benchmark_state.range(1),
-      benchmark_state.range(2)};
-  std::vector<int64_t> bcast_shape{benchmark_state.range(2)};
-
-  // inputs
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  at::Tensor at_x = at::randn(input_shape, options);
-  at::Tensor at_weight = at::ones(bcast_shape, options);
-  at::Tensor at_bias = at::zeros(bcast_shape, options);
-  at::Tensor at_mean = at::zeros(bcast_shape, options);
-  at::Tensor at_var = at::ones(bcast_shape, options);
-
-  // inputs
-  std::vector<c10::IValue> aten_inputs = {
-      at_x, at_weight, at_bias, at_mean, at_var};
-
-  fusion_executor_cache->profile(true);
-  fusion_executor_cache->runFusionWithInputs(aten_inputs);
-
-  auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
-  auto executor_instance = compile_log.fusion_executor;
-  TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
-  TORCH_INTERNAL_ASSERT(compile_log.launch_constraints.has_value());
-  auto params = toString(compile_log.pointwise_params.value());
-  auto lparams = toString(compile_log.launch_constraints.value());
-
-  benchmark_state.SetLabel(params + lparams);
-
-  fusion_executor_cache->profile(false);
-  executor_instance->setMeasureKernelTimeFlag(true);
-  // Sync everything up before we start
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
-    benchmark_state.SetIterationTime(
-        executor_instance->kernelTimeMs() / 1000.0);
-    clearL2Cache();
-  }
-
-  // Sync everything up before we're finished, don't want to run ahead on the
-  // cpu while benchmarking.
-  cudaDeviceSynchronize();
-
-  const size_t size =
-      input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
-  const size_t channels = input_shape[3];
-  benchmark_state.SetBytesProcessed(
-      int64_t(benchmark_state.iterations()) * (channels * 4 + size * 2) *
-      int64_t(dataTypeSize(dtype)));
-}
-
-static void Baseline_SBR_Norm(
-    benchmark::State& benchmark_state,
-    DataType dtype) {
-  // N, H, W, C format
-  std::vector<int64_t> input_shape{
-      benchmark_state.range(0),
-      benchmark_state.range(1),
-      benchmark_state.range(1),
-      benchmark_state.range(2)};
-  std::vector<int64_t> bcast_shape{1, 1, 1, benchmark_state.range(2)};
-
-  // inputs
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  at::Tensor at_x = at::randn(input_shape, options);
-  at::Tensor at_weight = at::ones(bcast_shape, options);
-  at::Tensor at_bias = at::zeros(bcast_shape, options);
-  at::Tensor at_mean = at::zeros(bcast_shape, options);
-  at::Tensor at_var = at::ones(bcast_shape, options);
-
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    CudaKernelTimer timer;
-
-    auto this_scale = at::mul(at_weight, at::rsqrt(at_var));
-    auto this_bias = at::mul(at::sub(at_bias, at_mean), this_scale);
-
-    auto scale = at::mul(at_x, this_scale);
-    auto bias = at::add(scale, this_bias);
-    auto output = at::relu(bias);
-
-    benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
-    cudaDeviceSynchronize();
-  }
-
-  const size_t size =
-      input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
-  const size_t channels = input_shape[3];
-  benchmark_state.SetBytesProcessed(
-      int64_t(benchmark_state.iterations()) * (channels * 4 + size * 2) *
-      int64_t(dataTypeSize(dtype)));
-}
-
-//------------------------------------------------------------------------------
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_SBR_fp32,
-    setupSBR,
-    NvFuserScheduler_SBR,
-    DataType::Float);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_SBR_fp32)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8}, {640, 640}, {64, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_SBR_fp16,
-    setupSBR,
-    NvFuserScheduler_SBR,
-    DataType::Half);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_SBR_fp16)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8}, {640, 640}, {64, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_SBR_Norm_fp32,
-    setupSBRNorm,
-    NvFuserScheduler_SBR_Norm,
-    DataType::Float);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_SBR_Norm_fp32)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8}, {640, 640}, {64, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_SBR_Norm_fp16,
-    setupSBRNorm,
-    NvFuserScheduler_SBR_Norm,
-    DataType::Half);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_SBR_Norm_fp16)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8}, {640, 640}, {64, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-static void Baseline_SBR_fp32(benchmark::State& benchmark_state) {
-  Baseline_SBR(benchmark_state, DataType::Float);
-}
-
-BENCHMARK(Baseline_SBR_fp32)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8}, {640, 640}, {64, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-static void Baseline_SBR_fp16(benchmark::State& benchmark_state) {
-  Baseline_SBR(benchmark_state, DataType::Half);
-}
-
-BENCHMARK(Baseline_SBR_fp16)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8}, {640, 640}, {64, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-static void Baseline_SBR_Norm_fp32(benchmark::State& benchmark_state) {
-  Baseline_SBR_Norm(benchmark_state, DataType::Float);
-}
-
-BENCHMARK(Baseline_SBR_Norm_fp32)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8}, {640, 640}, {64, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-static void Baseline_SBR_Norm_fp16(benchmark::State& benchmark_state) {
-  Baseline_SBR_Norm(benchmark_state, DataType::Half);
-}
-
-BENCHMARK(Baseline_SBR_Norm_fp16)
-    ->RangeMultiplier(2)
-    ->Ranges({{8, 8}, {640, 640}, {64, 256}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
diff --git a/benchmarks/cpp/nvfuser/softmax.cpp b/benchmarks/cpp/nvfuser/softmax.cpp
deleted file mode 100644 (file)
index 9d0cf9b..0000000
+++ /dev/null
@@ -1,525 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-//------------------------------------------------------------------------------
-
-static void setupSoftmax(
-    Fusion* fusion,
-    DataType dtype,
-    const int reduction_axis) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  FusionGuard fg(fusion);
-  // setup fusion
-  auto input = makeContigTensor(2, dtype);
-  fusion->addInput(input);
-
-  if (dtype == DataType::Half) {
-    input = castOp(DataType::Float, input);
-  }
-
-  auto output = softmax(input, reduction_axis);
-
-  if (dtype == DataType::Half) {
-    output = castOp(DataType::Half, output);
-  }
-
-  fusion->addOutput(output);
-}
-
-static void NvFuserScheduler_Softmax(
-    benchmark::State& benchmark_state,
-    FusionExecutorCache* fusion_executor_cache,
-    DataType dtype,
-    const int reduction_axis) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  std::vector<int64_t> input_shape{
-      benchmark_state.range(1), benchmark_state.range(0)};
-
-  // inputs
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn(input_shape, options);
-  std::vector<c10::IValue> aten_inputs({aten_input});
-
-  runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
-
-  benchmark_state.SetBytesProcessed(
-      int64_t(benchmark_state.iterations()) *
-      (2 * aten_input.numel() * int64_t(dataTypeSize(dtype))));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_Softmax(
-    benchmark::State& benchmark_state,
-    DataType dtype) {
-  std::vector<int64_t> input_shape{
-      benchmark_state.range(1), benchmark_state.range(0)};
-  const int kReductionAxis = benchmark_state.range(2);
-
-  // inputs
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn(input_shape, options);
-
-  cudaDeviceSynchronize();
-  for (auto _ : benchmark_state) {
-    CudaKernelTimer timer;
-    auto output = at::_softmax(aten_input, kReductionAxis, false);
-    benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
-    cudaDeviceSynchronize();
-    clearL2Cache();
-    cudaDeviceSynchronize();
-  }
-
-  benchmark_state.SetBytesProcessed(
-      int64_t(benchmark_state.iterations()) *
-      (2 * aten_input.numel() * int64_t(dataTypeSize(dtype))));
-}
-
-static void Baseline_Softmax_fp32(benchmark::State& benchmark_state) {
-  Baseline_Softmax(benchmark_state, DataType::Float);
-}
-
-static void Baseline_Softmax_fp16(benchmark::State& benchmark_state) {
-  Baseline_Softmax(benchmark_state, DataType::Half);
-}
-
-//------------------------------------------------------------------------------
-
-static void setupSoftmaxDropout(
-    Fusion* fusion,
-    DataType dtype,
-    const int kReductionAxis) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  FusionGuard fg(fusion);
-
-  constexpr int kHiddenSize = 768;
-  constexpr int kNumAttentionHeads = 12;
-  constexpr int kAttentionHeadSize = kHiddenSize / kNumAttentionHeads;
-  constexpr float kDropoutProbability = 0.9;
-  constexpr float kScale = 1.0f / kDropoutProbability;
-
-  // setup fusion
-  auto attention_scores = makeContigTensor(4, dtype);
-  auto attention_mask = makeContigTensor(4, dtype);
-
-  Double* divisor = new Double();
-
-  fusion->addInput(attention_scores);
-  fusion->addInput(attention_mask);
-  fusion->addInput(divisor);
-
-  if (dtype == DataType::Half) {
-    attention_scores = castOp(DataType::Float, attention_scores);
-    attention_mask = castOp(DataType::Float, attention_mask);
-  }
-
-  attention_scores = div(attention_scores, divisor);
-  attention_scores = add(attention_scores, attention_mask);
-  auto attention_probs = softmax(attention_scores, kReductionAxis);
-  auto prob = new Double(kDropoutProbability);
-  auto scale = new Double(kScale);
-  auto dropout_results = dropout(attention_probs, prob, scale);
-  auto output = dropout_results.output;
-
-  if (dtype == DataType::Half) {
-    attention_scores = castOp(DataType::Half, attention_scores);
-    attention_probs = castOp(DataType::Half, attention_probs);
-    output = castOp(DataType::Half, output);
-  }
-
-  fusion->addOutput(attention_scores);
-  fusion->addOutput(attention_probs);
-  fusion->addOutput(output);
-
-  fusion->addOutput(dropout_results.mask);
-}
-
-static void NvFuserScheduler_SoftmaxDropout(
-    benchmark::State& benchmark_state,
-    FusionExecutorCache* fusion_executor_cache,
-    DataType dtype,
-    const int kReductionAxis) {
-  TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
-  // reduce across 1, [256, 12, 100, 8]
-  std::vector<int64_t> input_shape{256, 12, 100, benchmark_state.range(0)};
-
-  constexpr int kHiddenSize = 768;
-  constexpr int kNumAttentionHeads = 12;
-  constexpr int kAttentionHeadSize = kHiddenSize / kNumAttentionHeads;
-  constexpr float kDropoutProbability = 0.9;
-  constexpr float kScale = 1.0f / kDropoutProbability;
-
-  // inputs
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  at::Tensor at_scores = at::randn(input_shape, options);
-  at::Tensor at_mask = at::randn(input_shape, options);
-  std::vector<c10::IValue> aten_inputs(
-      {at_scores, at_mask, sqrt(kAttentionHeadSize)});
-
-  runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
-
-  // 5 dtype: attention_scores + attention_mask + attention_scores_out +
-  // attention_probs_out + output
-  // 1 bool: dropout_results.mask
-  // All the same size
-  benchmark_state.SetBytesProcessed(
-      int64_t(benchmark_state.iterations()) * 5 * at_scores.numel() *
-          int64_t(dataTypeSize(dtype)) +
-      // bool mask
-      int64_t(benchmark_state.iterations()) * at_scores.numel() *
-          int64_t(dataTypeSize(DataType::Bool)));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_Softmax_Dropout(
-    benchmark::State& benchmark_state,
-    const int kReductionAxis,
-    DataType dtype) {
-  std::vector<int64_t> input_shape{256, 12, 100, benchmark_state.range(0)};
-
-  constexpr int kHiddenSize = 768;
-  constexpr int kNumAttentionHeads = 12;
-  constexpr float kDropoutProbability = 0.1;
-  constexpr int kAttentionHeadSize = kHiddenSize / kNumAttentionHeads;
-
-  // inputs
-  at::manual_seed(0);
-  auto options =
-      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-  at::Tensor attention_scores = at::randn(input_shape, options);
-  at::Tensor at_y = at::randn(input_shape, options);
-
-  cudaDeviceSynchronize();
-
-  for (auto _ : benchmark_state) {
-    // Create
-    CudaKernelTimer timer;
-
-    // Run
-    attention_scores = attention_scores / sqrt(kAttentionHeadSize);
-    attention_scores = attention_scores + at_y;
-    auto attention_probs =
-        at::_softmax(attention_scores, kReductionAxis, false);
-    attention_probs = at::dropout(attention_probs, kDropoutProbability, true);
-
-    // Record
-    benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
-    cudaDeviceSynchronize();
-    clearL2Cache();
-    cudaDeviceSynchronize();
-  }
-
-  // 5 dtype: attention_scores + attention_mask + attention_scores_out +
-  // attention_probs_out + output
-  // 1 bool: dropout_results.mask
-  // All the same size
-  benchmark_state.SetBytesProcessed(
-      int64_t(benchmark_state.iterations()) * 5 * attention_scores.numel() *
-          int64_t(dataTypeSize(dtype)) +
-      // bool mask
-      int64_t(benchmark_state.iterations()) * attention_scores.numel() *
-          int64_t(dataTypeSize(DataType::Bool)));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_Softmax_Dropout_Inner_fp32(
-    benchmark::State& benchmark_state) {
-  Baseline_Softmax_Dropout(benchmark_state, 3, DataType::Float);
-}
-
-static void Baseline_Softmax_Dropout_Outer_fp32(
-    benchmark::State& benchmark_state) {
-  Baseline_Softmax_Dropout(benchmark_state, 1, DataType::Float);
-}
-
-static void Baseline_Softmax_Dropout_Inner_fp16(
-    benchmark::State& benchmark_state) {
-  Baseline_Softmax_Dropout(benchmark_state, 3, DataType::Half);
-}
-
-static void Baseline_Softmax_Dropout_Outer_fp16(
-    benchmark::State& benchmark_state) {
-  Baseline_Softmax_Dropout(benchmark_state, 1, DataType::Half);
-}
-
-//------------------------------------------------------------------------------
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Softmax_Outer_fp32,
-    setupSoftmax,
-    NvFuserScheduler_Softmax,
-    DataType::Float,
-    0);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Outer_fp32)
-    ->RangeMultiplier(2)
-    ->Ranges({{656, 656}, {8, 8 << 12}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Softmax_Inner_fp32,
-    setupSoftmax,
-    NvFuserScheduler_Softmax,
-    DataType::Float,
-    1);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Inner_fp32)
-    ->RangeMultiplier(2)
-    ->Ranges({{656, 656}, {8, 8 << 12}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Softmax_Outer_fp16,
-    setupSoftmax,
-    NvFuserScheduler_Softmax,
-    DataType::Half,
-    0);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Outer_fp16)
-    ->RangeMultiplier(2)
-    ->Ranges({{656, 656}, {8, 8 << 12}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Softmax_Inner_fp16,
-    setupSoftmax,
-    NvFuserScheduler_Softmax,
-    DataType::Half,
-    1);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Inner_fp16)
-    ->RangeMultiplier(2)
-    ->Ranges({{656, 656}, {8, 8 << 12}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Softmax_Dropout_Inner_fp32,
-    setupSoftmaxDropout,
-    NvFuserScheduler_SoftmaxDropout,
-    DataType::Float,
-    3);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Dropout_Inner_fp32)
-    ->Arg(8)
-    ->Arg(16)
-    ->Arg(24)
-    ->Arg(32)
-    ->Arg(40)
-    ->Arg(48)
-    ->Arg(56)
-    ->Arg(64)
-    ->Arg(72)
-    ->Arg(80)
-    ->Arg(88)
-    ->Arg(96)
-    ->Arg(104)
-    ->Arg(112)
-    ->Arg(120)
-    ->Arg(128)
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Softmax_Dropout_Outer_fp32,
-    setupSoftmaxDropout,
-    NvFuserScheduler_SoftmaxDropout,
-    DataType::Float,
-    1);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Dropout_Outer_fp32)
-    ->Arg(8)
-    ->Arg(16)
-    ->Arg(24)
-    ->Arg(32)
-    ->Arg(40)
-    ->Arg(48)
-    ->Arg(56)
-    ->Arg(64)
-    ->Arg(72)
-    ->Arg(80)
-    ->Arg(88)
-    ->Arg(96)
-    ->Arg(104)
-    ->Arg(112)
-    ->Arg(120)
-    ->Arg(128)
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Softmax_Dropout_Inner_fp16,
-    setupSoftmaxDropout,
-    NvFuserScheduler_SoftmaxDropout,
-    DataType::Half,
-    3);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Dropout_Inner_fp16)
-    ->Arg(8)
-    ->Arg(16)
-    ->Arg(24)
-    ->Arg(32)
-    ->Arg(40)
-    ->Arg(48)
-    ->Arg(56)
-    ->Arg(64)
-    ->Arg(72)
-    ->Arg(80)
-    ->Arg(88)
-    ->Arg(96)
-    ->Arg(104)
-    ->Arg(112)
-    ->Arg(120)
-    ->Arg(128)
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
-    NvFuserScheduler_Softmax_Dropout_Outer_fp16,
-    setupSoftmaxDropout,
-    NvFuserScheduler_SoftmaxDropout,
-    DataType::Half,
-    1);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Dropout_Outer_fp16)
-    ->Arg(8)
-    ->Arg(16)
-    ->Arg(24)
-    ->Arg(32)
-    ->Arg(40)
-    ->Arg(48)
-    ->Arg(56)
-    ->Arg(64)
-    ->Arg(72)
-    ->Arg(80)
-    ->Arg(88)
-    ->Arg(96)
-    ->Arg(104)
-    ->Arg(112)
-    ->Arg(120)
-    ->Arg(128)
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-BENCHMARK(Baseline_Softmax_fp32)
-    ->RangeMultiplier(2)
-    ->Ranges({{656, 656}, {8, 8 << 12}, {0, 1}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_Softmax_fp16)
-    ->RangeMultiplier(2)
-    ->Ranges({{656, 656}, {8, 8 << 12}, {0, 1}})
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_Softmax_Dropout_Inner_fp32)
-    ->Arg(8)
-    ->Arg(16)
-    ->Arg(24)
-    ->Arg(32)
-    ->Arg(40)
-    ->Arg(48)
-    ->Arg(56)
-    ->Arg(64)
-    ->Arg(72)
-    ->Arg(80)
-    ->Arg(88)
-    ->Arg(96)
-    ->Arg(104)
-    ->Arg(112)
-    ->Arg(120)
-    ->Arg(128)
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_Softmax_Dropout_Outer_fp32)
-    ->Arg(8)
-    ->Arg(16)
-    ->Arg(24)
-    ->Arg(32)
-    ->Arg(40)
-    ->Arg(48)
-    ->Arg(56)
-    ->Arg(64)
-    ->Arg(72)
-    ->Arg(80)
-    ->Arg(88)
-    ->Arg(96)
-    ->Arg(104)
-    ->Arg(112)
-    ->Arg(120)
-    ->Arg(128)
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_Softmax_Dropout_Inner_fp16)
-    ->Arg(8)
-    ->Arg(16)
-    ->Arg(24)
-    ->Arg(32)
-    ->Arg(40)
-    ->Arg(48)
-    ->Arg(56)
-    ->Arg(64)
-    ->Arg(72)
-    ->Arg(80)
-    ->Arg(88)
-    ->Arg(96)
-    ->Arg(104)
-    ->Arg(112)
-    ->Arg(120)
-    ->Arg(128)
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
-
-BENCHMARK(Baseline_Softmax_Dropout_Outer_fp16)
-    ->Arg(8)
-    ->Arg(16)
-    ->Arg(24)
-    ->Arg(32)
-    ->Arg(40)
-    ->Arg(48)
-    ->Arg(56)
-    ->Arg(64)
-    ->Arg(72)
-    ->Arg(80)
-    ->Arg(88)
-    ->Arg(96)
-    ->Arg(104)
-    ->Arg(112)
-    ->Arg(120)
-    ->Arg(128)
-    ->Unit(benchmark::kMicrosecond)
-    ->UseManualTime();
diff --git a/benchmarks/cpp/nvfuser/utils.cpp b/benchmarks/cpp/nvfuser/utils.cpp
deleted file mode 100644 (file)
index 54ffda5..0000000
+++ /dev/null
@@ -1,154 +0,0 @@
-#include "utils.h"
-
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <sstream>
-
-using namespace torch::jit::fuser::cuda;
-
-std::string toString(ReductionParams rparams) {
-  std::stringstream ss;
-  if (rparams.fastest_dim) {
-    ss << "/Fastest dim";
-  } else {
-    ss << "/Slow dim";
-  }
-  if (rparams.cross_grid) {
-    ss << "/cross grid";
-  }
-  if (rparams.cross_block) {
-    ss << "/cross block";
-  }
-  if (rparams.multiple_reds_per_blk) {
-    ss << "/multiple reductions per block ";
-  }
-  if (rparams.loop_unroll > 1) {
-    ss << (rparams.vectorize ? "/Vectorize " : "/Unroll ")
-       << (rparams.reduction_unroll ? "reduction dim " : "iter dim ")
-       << rparams.loop_unroll;
-  }
-  if (rparams.batches_per_block > 1) {
-    ss << "/batches per block " << rparams.batches_per_block << " ";
-  }
-  if (rparams.persistent_kernel) {
-    ss << "/persistent";
-  }
-
-  if (rparams.split_grid_dim) {
-    ss << "/split grid dim";
-  }
-  return ss.str();
-}
-
-std::string toString(PointwiseParams params) {
-  std::stringstream ss;
-  if (params.break_point) {
-    ss << "2D Schedule at " << params.break_point << "/";
-    if (params.split_block) {
-      ss << " Split block into y-dim/";
-    }
-    if (params.split_grid_y_dim) {
-      ss << " Split y grid dim/";
-    }
-  } else {
-    ss << "1D"
-       << "/";
-  }
-  if (params.inner_factor > 1) {
-    if (params.vectorize) {
-      ss << "Vectorize, Factor: " << params.inner_factor;
-    } else {
-      ss << "Unroll, Factor: " << params.inner_factor;
-    }
-  }
-  return ss.str();
-}
-
-std::string toString(LaunchParams lparams) {
-  std::stringstream ss;
-  lparams.toString();
-  ss << "/Launch_Parameters["
-     << "block(" << lparams.bdimz() << "/" << lparams.bdimy() << "/"
-     << lparams.bdimx() << ")/grid(" << lparams.gdimz() << "/"
-     << lparams.gdimy() << "/" << lparams.gdimx() << ")/" << lparams.smem()
-     << "]";
-  return ss.str();
-}
-
-void clearL2Cache() {
-  torch::NoGradGuard no_grad;
-  auto l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize;
-  auto options =
-      torch::TensorOptions().dtype(torch::kFloat32).device(at::kCUDA, 0);
-
-  auto l2_elems = l2_cache_size / 4;
-  torch::Tensor t0 = torch::empty(l2_elems, options);
-  torch::Tensor t1 = torch::clone(t0);
-};
-
-TensorView* makeContigTensor(size_t ndims, DataType dtype) {
-  return TensorViewBuilder()
-      .ndims(ndims)
-      .dtype(dtype)
-      .contiguity(std::vector<bool>(ndims, true))
-      .build();
-}
-
-void runBenchmarkIterations(
-    benchmark::State& benchmark_state,
-    FusionExecutorCache* fusion_executor_cache,
-    std::vector<c10::IValue>& aten_inputs) {
-  fusion_executor_cache->runFusionWithInputs(aten_inputs);
-  bool segmented =
-      fusion_executor_cache->getMostRecentKernelRuntime()->isSegmented();
-
-  if (!segmented) {
-    fusion_executor_cache->profile(true);
-    fusion_executor_cache->runFusionWithInputs(aten_inputs);
-    auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
-    auto executor_instance = compile_log.fusion_executor;
-    TORCH_INTERNAL_ASSERT(compile_log.reduction_params.has_value());
-    TORCH_INTERNAL_ASSERT(compile_log.launch_constraints.has_value());
-    auto rparams = toString(compile_log.reduction_params.value());
-    auto lparams = toString(compile_log.launch_constraints.value());
-    benchmark_state.SetLabel(rparams + lparams);
-    executor_instance->setMeasureKernelTimeFlag(true);
-
-    // Sync everything up before we start
-    cudaDeviceSynchronize();
-    for (auto _ : benchmark_state) {
-      auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
-      benchmark_state.SetIterationTime(
-          executor_instance->kernelTimeMs() / 1000.0);
-      clearL2Cache();
-    }
-    // Sync everything up before we're finished, don't want to run ahead on the
-    // cpu while benchmarking.
-    cudaDeviceSynchronize();
-  } else {
-    // Segmented
-    // Sync everything up before we start
-    {
-      // Compile/warmup
-      auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
-    }
-    cudaDeviceSynchronize();
-    CudaKernelTimer timer;
-    for (auto _ : benchmark_state) {
-      timer.restart();
-      auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
-      benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
-      clearL2Cache();
-    }
-    // Sync everything up before we're finished, don't want to run ahead on the
-    // cpu while benchmarking.
-    cudaDeviceSynchronize();
-  }
-}
-
-namespace executorCache {
-thread_local ExecutorMap executor_map_;
-ExecutorMap& getGlobalMap() {
-  return executor_map_;
-}
-} // namespace executorCache
diff --git a/benchmarks/cpp/nvfuser/utils.h b/benchmarks/cpp/nvfuser/utils.h
deleted file mode 100644 (file)
index b4a2f3a..0000000
+++ /dev/null
@@ -1,187 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <ATen/cuda/CUDAContext.h>
-#include <torch/torch.h>
-
-#include <cuda_runtime.h>
-
-using namespace torch::jit::fuser::cuda;
-
-std::string toString(ReductionParams rparams);
-std::string toString(PointwiseParams params);
-std::string toString(LaunchParams lparams);
-
-// Run benchmark iterations with provided inputs. If not segmented, report
-// kernel time from the runtime, as well as heuristic parameters. If segmented
-// use timers. Make sure to clear L2 between iterations.
-void runBenchmarkIterations(
-    benchmark::State& benchmark_state,
-    FusionExecutorCache* fusion_executor_cache,
-    std::vector<c10::IValue>& aten_inputs);
-
-void clearL2Cache();
-
-// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
-// but unknown sizes. Taken from test_gpu.cpp
-TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float);
-
-class CudaKernelTimer {
- public:
-  CudaKernelTimer() {
-    // Setup
-    cudaEventCreate(&start_event);
-    cudaEventCreate(&finish_event);
-    cudaEventRecord(start_event);
-  }
-
-  ~CudaKernelTimer() {
-    cudaEventDestroy(start_event);
-    cudaEventDestroy(finish_event);
-  }
-
-  void restart() {
-    cudaEventRecord(start_event);
-  }
-
-  float elapsed() {
-    // Record
-    cudaEventRecord(finish_event);
-    cudaEventSynchronize(start_event);
-    cudaEventSynchronize(finish_event);
-    cudaEventElapsedTime(&kernel_time_ms_, start_event, finish_event);
-    return kernel_time_ms_;
-  }
-
- private:
-  // Create
-  float kernel_time_ms_ = 0;
-  cudaEvent_t start_event = {};
-  cudaEvent_t finish_event = {};
-};
-
-namespace executorCache {
-using ExecutorPtr = std::unique_ptr<FusionExecutorCache>;
-using ExecutorMap = std::unordered_map<std::string, ExecutorPtr>;
-ExecutorMap& getGlobalMap();
-} // namespace executorCache
-
-//! Utility to manage FusionExecutorCache instances for
-//!  all defined benchmarks
-class BenchmarkGraph : public benchmark::Fixture {
- public:
-  using SetupFusionFunction = std::function<void(Fusion*)>;
-  using SetupFusionMap = std::unordered_map<std::string, SetupFusionFunction>;
-
-  virtual std::string graphName() = 0;
-  virtual SetupFusionFunction setupFusion() = 0;
-
-  FusionExecutorCache* getExecutorCache() {
-    auto& executor_ = getExecutorCacheMap()[graphName()];
-    TORCH_INTERNAL_ASSERT(executor_);
-    return executor_.get();
-  }
-
-  void SetUp(const ::benchmark::State& state) {
-    auto& executor_ = getExecutorCacheMap()[graphName()];
-    // Makes sure same graph hasn't been compiled before
-    if (!executor_) {
-      auto fusion_ptr = std::make_unique<Fusion>();
-      FusionGuard(fusion_ptr.get());
-      setupFusion()(fusion_ptr.get());
-      getExecutorCacheMap()[graphName()] =
-          std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
-    }
-  }
-
-  void TearDown(const ::benchmark::State& state) {}
-
- protected:
-  static executorCache::ExecutorMap& getExecutorCacheMap() {
-    return executorCache::getGlobalMap();
-  }
-};
-
-#define NVFUSER_TO_STRING_HELPER(n) std::string(#n)
-#define NVFUSER_TO_STRING(n) NVFUSER_TO_STRING_HELPER(n)
-
-//! NVFUSER_BENCHMARK_RUN utility usage:
-//!  This utility helps create and manage FusionExecutorCaches and tries to use
-//!  the caching
-//! mechanism in NVFuser to avoid re-compilation.
-//!
-//!  There are two macros in this utility: NVFUSER_BENCHMARK_DEFINE, and
-//!  NVFUSER_BENCHMARK_RUN,
-//! and user needs to supply two functions SETUP_FUSION and RUN_FUSION, with
-//! following signatures:
-//!
-//!  SETUP_FUSION(Fusion* , args...);
-//!  RUN_FUSION(benchmark::State&, FusionExecutorCache* , args...);
-//!
-//!  where args... are additional arguments, and they need to be the same for
-//!  SETUP_FUSION and RUN_FUSION.
-//!
-//!  SETUP_FUSION is called once in each definition of benchmark to build the
-//!  fusionIR graph
-//!
-//!  RUN_FUSION is just like the normal benchmark instance, except that a
-//!  FusionExecutorCache
-//!   will be provided for scheduling, running and timing the fusion runs. It is
-//!   called once in each benchmark instance. For example:
-//!   NVFUSER_BENCHMARK_RUN(my_benchmark)
-//!    ->RangeMultiplier(2)
-//!    ->Ranges({{1, 4})
-//!  Calls RUN_FUSION 3 times.
-//!
-//!  To register a benchmark, the API is:
-//!
-//!  NVFUSER_BENCHMARK_DEFINE(my_benchmark,SETUP_FUSION,RUN_FUSION,args...);
-//!
-//!    where my_benchmark is any unique name given for this benchmark,
-//!      SETUP_FUSION, RUN_FUSION as described above,
-//!      args... is the arg list supplied to both setup_fusion and run_fusion
-//!
-//!  each NVFUSER_BENCHMARK_DEFINE registers a benchmark with a single
-//!  FusionExecutorCache, i.e. a single fusion graph, and multiple benchmark
-//!  data points can be registered like:
-//!
-//!  NVFUSER_BENCHMARK_RUN(my_benchmark)
-//!    ->Ranges({{1,2}});
-//!
-//!  NVFUSER_BENCHMARK_RUN(my_benchmark)
-//!    ->Ranges({{3,4}});
-//!
-//!  All datapoints will use the same FusionExecutorCache so recompilation is
-//!  avoided as much as possible.
-
-#define NVFUSER_BENCHMARK_DEFINE(                                       \
-    BENCHMARK_NAME, SETUP_FUSION, RUN_FUSION, ...)                      \
-  class BENCHMARK_NAME##___GRAPH : public BenchmarkGraph {              \
-   public:                                                              \
-    std::string graphName() {                                           \
-      return NVFUSER_TO_STRING(BENCHMARK_NAME##___GRAPH);               \
-    }                                                                   \
-    SetupFusionFunction setupFusion() {                                 \
-      return [](Fusion* fusion) { SETUP_FUSION(fusion, __VA_ARGS__); }; \
-    }                                                                   \
-  };                                                                    \
-  BENCHMARK_DEFINE_F(BENCHMARK_NAME##___GRAPH, BENCHMARK_NAME)          \
-  (benchmark::State & benchmark_state) {                                \
-    RUN_FUSION(                                                         \
-        benchmark_state,                                                \
-        BENCHMARK_NAME##___GRAPH::getExecutorCache(),                   \
-        __VA_ARGS__);                                                   \
-  }
-
-#define NVFUSER_BENCHMARK_RUN(BENCHMARK_NAME) \
-  BENCHMARK_REGISTER_F(BENCHMARK_NAME##___GRAPH, BENCHMARK_NAME)
index 087615f..3c2fb83 100644 (file)
@@ -936,15 +936,12 @@ if(USE_CUDA OR USE_ROCM)
   # The list of NVFUSER runtime files
   list(APPEND NVFUSER_RUNTIME_FILES
     ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/block_reduction.cu
-    ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu
-    ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/block_sync_default.cu
     ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/broadcast.cu
     ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/fp16_support.cu
     ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/grid_reduction.cu
     ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/helpers.cu
     ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/random_numbers.cu
     ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/tensor.cu
-    ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/welford.cu
     ${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh
     ${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/detail/UnpackRaw.cuh
   )
@@ -1667,10 +1664,6 @@ if(BUILD_TENSOREXPR_BENCHMARK)
   add_subdirectory(${TORCH_ROOT}/benchmarks/cpp/tensorexpr ${CMAKE_BINARY_DIR}/tensorexpr_bench)
 endif()
 
-if(BUILD_NVFUSER_BENCHMARK)
-  add_subdirectory(${TORCH_ROOT}/benchmarks/cpp/nvfuser ${CMAKE_BINARY_DIR}/nvfuser_bench)
-endif()
-
 if(BUILD_CPP_BENCHMARKS)
   add_subdirectory(${TORCH_ROOT}/benchmarks/cpp ${PROJECT_BINARY_DIR}/bin)
 endif()
index 63d08d2..99c41f2 100644 (file)
@@ -29,7 +29,6 @@ function(caffe2_print_configuration_summary)
   message(STATUS "  BUILD_CAFFE2_MOBILE   : ${BUILD_CAFFE2_MOBILE}")
   message(STATUS "  BUILD_STATIC_RUNTIME_BENCHMARK: ${BUILD_STATIC_RUNTIME_BENCHMARK}")
   message(STATUS "  BUILD_TENSOREXPR_BENCHMARK: ${BUILD_TENSOREXPR_BENCHMARK}")
-  message(STATUS "  BUILD_NVFUSER_BENCHMARK: ${BUILD_NVFUSER_BENCHMARK}")
   message(STATUS "  BUILD_BINARY          : ${BUILD_BINARY}")
   message(STATUS "  BUILD_CUSTOM_PROTOBUF : ${BUILD_CUSTOM_PROTOBUF}")
   if(${CAFFE2_LINK_LOCAL_PROTOBUF})
index d398e78..8bd37a1 100644 (file)
@@ -74,7 +74,6 @@ set(JIT_TEST_SRCS
 
 if(USE_CUDA)
   list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu.cpp)
-  list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu_shift.cpp)
 endif()
 
 add_executable(test_jit
index 9daa571..1a0ee7b 100644 (file)
@@ -3,28 +3,19 @@
 
 #include <torch/csrc/jit/codegen/cuda/arith.h>
 #include <torch/csrc/jit/codegen/cuda/codegen.h>
-#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
 #include <torch/csrc/jit/codegen/cuda/executor.h>
 #include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
 #include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
 #include <torch/csrc/jit/codegen/cuda/interface.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_graphviz.h>
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
 #include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
 #include <torch/csrc/jit/codegen/cuda/mutator.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
+#include <torch/csrc/jit/codegen/cuda/scheduler.h>
 #include <torch/csrc/jit/codegen/cuda/transform_replay.h>
 #include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
 
 #include <torch/csrc/jit/codegen/cuda/parser.h>
 #include <torch/csrc/jit/ir/irparser.h>
 
-#include "test_gpu_validator.h"
-
 #include <ATen/cuda/Exceptions.h>
 #include <c10/cuda/CUDAStream.h>
 
-#include <algorithm>
 #include <iostream>
 
 // Tests go in torch::jit
@@ -45,60 +33,62 @@ namespace torch {
 namespace jit {
 
 using namespace torch::jit::fuser::cuda;
-using namespace at::indexing;
 
 namespace {
 
-// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
-// but unknown sizes
-TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) {
-  return TensorViewBuilder()
-      .ndims(ndims)
-      .dtype(dtype)
-      .contiguity(std::vector<bool>(ndims, true))
-      .build();
+TensorView* makeContigTensor(int nDims, DataType dtype = DataType::Float) {
+  std::vector<IterDomain*> dom;
+  for (int i = 0; i < nDims; i++)
+    dom.push_back(new IterDomain(new Int(0), new Int()));
+  std::vector<bool> contig(dom.size(), true);
+  return new TensorView(new TensorDomain(dom, contig), dtype);
 }
 
-// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
-// but unknown sizes
-TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
-  return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
+TensorView* makeDummyTensor(int nDims, DataType dtype = DataType::Float) {
+  // We can uncomment the below statement to test all tests with contiguous
+  // tensors. return makeContigTensor(nDims, dtype);
+  std::vector<IterDomain*> dom;
+  for (int i = 0; i < nDims; i++)
+    dom.push_back(new IterDomain(new Int(0), new Int()));
+  return new TensorView(new TensorDomain(dom), dtype);
 }
 
-// Make a non-contiguous tensor of compile-time known sizes
 TensorView* makeConcreteTensor(
-    std::vector<int64_t> shape,
+    std::vector<int> sizes,
+    DataType dtype = DataType::Float) {
+  // We can uncomment the below statement to test all tests with contiguous
+  // tensors. return makeContigTensor(nDims, dtype);
+  std::vector<IterDomain*> dom;
+  for (int size : sizes) {
+    if (size >= 0) {
+      dom.push_back(new IterDomain(new Int(0), new Int(size)));
+    } else {
+      dom.push_back(new IterDomain(new Int(0), new Int()));
+    }
+  }
+  return new TensorView(new TensorDomain(dom), dtype);
+}
+
+TensorView* makeTensorWithContig(
+    int nDims,
+    std::vector<bool> contig_info,
     DataType dtype = DataType::Float) {
-  return TensorViewBuilder().shape(shape).dtype(dtype).build();
+  std::vector<IterDomain*> dom;
+  for (int i = 0; i < nDims; i++)
+    dom.push_back(new IterDomain(new Int(0), new Int()));
+  return new TensorView(new TensorDomain(dom, contig_info), dtype);
 }
 
 void checkIntValue(
-    ExpressionEvaluator& evaluator,
+    StatefulExpressionEvaluator& evaluator,
     Val* val,
     Int::ScalarType expected_value) {
   TORCH_CHECK(val->isAnInt());
-  const auto actual_value = evaluator.evaluate(val);
-  TORCH_CHECK(actual_value.has_value());
-  TORCH_CHECK(actual_value.value() == expected_value);
-}
-
-void checkIntValue(
-    kir::ExpressionEvaluator& evaluator,
-    const kir::Val* val,
-    kir::Int::ScalarType expected_value) {
-  const auto actual_value = evaluator.evaluate(val);
+  const auto actual_value = evaluator.inferValue(val);
   TORCH_CHECK(actual_value.has_value());
   TORCH_CHECK(actual_value.value() == expected_value);
 }
 
-bool isPredicated(TensorView* tv, GpuLower& gpulw) {
-  auto parent_scope = gpulw.lowerValue(tv)->definition()->parentScope();
-  if (parent_scope->isA<kir::IfThenElse>()) {
-    return !parent_scope->predicate()->value()->isConst();
-  }
-  return true;
-};
-
 } // namespace
 
 // 1. Test cases are void() functions.
@@ -118,13 +108,13 @@ TEST(NVFuserTest, IrGraphGenerator_CUDA) {
                    .empty());
 
   // Construct an interesting IR
-  TensorView* tv0 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
 
-  TensorView* tv2 = add(tv0, new Double(3.141));
+  TensorView* tv2 = add(tv0, new Float(3.141));
   TensorView* tv3 = broadcast(tv0, {false, true, false, true});
-  TensorView* tv4 = reductionOp(BinaryOpType::Add, {2}, new Double(0), tv3);
-  TensorView* tv5 = clamp(tv4, new Double(0.f), new Double(1.f));
+  TensorView* tv4 = reductionOp(BinaryOpType::Add, {2}, new Float(0), tv3);
+  TensorView* tv5 = clamp(tv4, new Float(0.f), new Float(1.f));
   TensorView* tv6 = add(tv2, tv2);
 
   // Another checkpoint before adding outputs
@@ -164,14 +154,14 @@ TEST(NVFuserTest, FusionDispatch_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  Double* f = new Double{2.f};
+  Float* f = new Float{2.f};
   std::stringstream ss1, ss2, ss3;
   ss1 << f;
   ss2 << static_cast<Val*>(f);
   ss3 << static_cast<Statement*>(f);
   TORCH_CHECK(
       ss1.str().compare(ss2.str()) == 0 && ss1.str().compare(ss3.str()) == 0,
-      "Error with dispatch system where results differ by passing Double* vs Val* vs Statement*.");
+      "Error with dispatch system where results differ by passing Float* vs Val* vs Statement*.");
 }
 
 // Evaluate basic scalar operations with constant values
@@ -179,7 +169,7 @@ TEST(NVFuserTest, FusionExprEvalConstants_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  ExpressionEvaluator evaluator(&fusion);
+  StatefulExpressionEvaluator evaluator(&fusion);
 
   auto* a = new Int(7);
   auto* b = new Int(3);
@@ -196,7 +186,7 @@ TEST(NVFuserTest, FusionExprEvalBindings_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  ExpressionEvaluator evaluator(&fusion);
+  StatefulExpressionEvaluator evaluator(&fusion);
 
   auto* a = new Int();
   auto* b = new Int();
@@ -205,17 +195,17 @@ TEST(NVFuserTest, FusionExprEvalBindings_CUDA) {
   auto* e = new Int(0);
 
   // trying to evaluate before binding should give empty results
-  TORCH_CHECK(!evaluator.evaluate(a).has_value());
-  TORCH_CHECK(!evaluator.evaluate(d).has_value());
+  TORCH_CHECK(!evaluator.inferValue(a).has_value());
+  TORCH_CHECK(!evaluator.inferValue(d).has_value());
 
-  evaluator.bind(a, 7);
-  evaluator.bind(b, 3);
+  evaluator.safeBind(a, 7);
+  evaluator.safeBind(b, 3);
 
   // can't bind to the results of expressions
-  ASSERT_ANY_THROW(evaluator.bind(c, 100));
+  ASSERT_ANY_THROW(evaluator.safeBind(c, 100));
 
   // can't bind to concrete values
-  ASSERT_ANY_THROW(evaluator.bind(e, 100));
+  ASSERT_ANY_THROW(evaluator.safeBind(e, 100));
 
   checkIntValue(evaluator, c, 10);
   checkIntValue(evaluator, sub(a, b), 4);
@@ -224,10 +214,10 @@ TEST(NVFuserTest, FusionExprEvalBindings_CUDA) {
   checkIntValue(evaluator, d, -4);
 
   // Reset evaluation context
-  evaluator = ExpressionEvaluator(&fusion);
+  evaluator = StatefulExpressionEvaluator(&fusion);
 
-  evaluator.bind(a, 2);
-  evaluator.bind(b, 5);
+  evaluator.safeBind(a, 2);
+  evaluator.safeBind(b, 5);
 
   checkIntValue(evaluator, c, 7);
   checkIntValue(evaluator, sub(a, b), -3);
@@ -242,13 +232,13 @@ TEST(NVFuserTest, FusionExprEvalBasic_CUDA) {
   FusionGuard fg(&fusion);
 
   // Create a non-trivial IR
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = makeDummyTensor(2);
 
   fusion.addInput(tv0);
   fusion.addInput(tv1);
 
-  TensorView* tv2 = add(tv1, new Double(2.0));
+  TensorView* tv2 = add(tv1, new Float(2.0));
   TensorView* tv3 = add(tv0, tv2);
 
   fusion.addOutput(tv3);
@@ -265,7 +255,7 @@ TEST(NVFuserTest, FusionExprEvalBasic_CUDA) {
   tv3->axis(-1)->parallelize(ParallelType::TIDx);
 
   // 1. Create an evaluator
-  ExpressionEvaluator evaluator(&fusion);
+  StatefulExpressionEvaluator evaluator(&fusion);
 
   // 2. Bind values
   //
@@ -275,21 +265,21 @@ TEST(NVFuserTest, FusionExprEvalBasic_CUDA) {
   //  (ex. `tv0->getRootDomain()[0]->extent()`
   //   instead of `tv0->axis(0)->extent()`)
   //
-  evaluator.bind(tv0->getRootDomain()[0]->extent(), 6);
-  evaluator.bind(tv0->getRootDomain()[1]->extent(), 128);
-  evaluator.bind(tv1->getRootDomain()[0]->extent(), 6);
-  evaluator.bind(tv1->getRootDomain()[1]->extent(), 128);
+  evaluator.safeBind(tv0->getRootDomain()[0]->extent(), 6);
+  evaluator.safeBind(tv0->getRootDomain()[1]->extent(), 128);
+  evaluator.safeBind(tv1->getRootDomain()[0]->extent(), 6);
+  evaluator.safeBind(tv1->getRootDomain()[1]->extent(), 128);
 
   // 3. Evaluate and check result values
   TORCH_CHECK(tv2->domain()->nDims() == 3);
-  checkIntValue(evaluator, tv2->axis(0)->extent(), 2);
-  checkIntValue(evaluator, tv2->axis(1)->extent(), 4);
-  checkIntValue(evaluator, tv2->axis(2)->extent(), 128);
+  checkIntValue(evaluator, tv2->axis(0)->rawExtent(), 2);
+  checkIntValue(evaluator, tv2->axis(1)->rawExtent(), 4);
+  checkIntValue(evaluator, tv2->axis(2)->rawExtent(), 128);
 
   TORCH_CHECK(tv3->domain()->nDims() == 3);
-  checkIntValue(evaluator, tv3->axis(0)->extent(), 2);
-  checkIntValue(evaluator, tv3->axis(1)->extent(), 4);
-  checkIntValue(evaluator, tv3->axis(2)->extent(), 128);
+  checkIntValue(evaluator, tv3->axis(0)->rawExtent(), 2);
+  checkIntValue(evaluator, tv3->axis(1)->rawExtent(), 4);
+  checkIntValue(evaluator, tv3->axis(2)->rawExtent(), 128);
 }
 
 // Evaluate expressions in a more complex IR
@@ -297,12 +287,12 @@ TEST(NVFuserTest, FusionExprEvalComplex_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
 
-  TensorView* tv1 = mul(tv0, new Double(-1.0));
-  TensorView* tv2 = add(tv0, new Double(3.0));
-  TensorView* tv3 = mul(tv0, new Double(2.0));
+  TensorView* tv1 = mul(tv0, new Float(-1.0));
+  TensorView* tv2 = add(tv0, new Float(3.0));
+  TensorView* tv3 = mul(tv0, new Float(2.0));
   TensorView* tv4 = add(tv2, tv1);
   TensorView* tv5 = add(tv4, tv3);
   TensorView* tv6 = add(tv0, tv3);
@@ -316,32 +306,32 @@ TEST(NVFuserTest, FusionExprEvalComplex_CUDA) {
   tv5->merge(0);
 
   // 1. Create an evaluator
-  ExpressionEvaluator evaluator(&fusion);
+  StatefulExpressionEvaluator evaluator(&fusion);
 
   // 2. Bind values
-  evaluator.bind(tv0->getRootDomain()[0]->extent(), 129);
-  evaluator.bind(tv0->getRootDomain()[1]->extent(), 127);
+  evaluator.safeBind(tv0->getRootDomain()[0]->extent(), 129);
+  evaluator.safeBind(tv0->getRootDomain()[1]->extent(), 127);
 
   // Evaluate and check extent values
   TORCH_CHECK(tv0->domain()->nDims() == 2);
-  checkIntValue(evaluator, tv0->axis(0)->extent(), 129);
-  checkIntValue(evaluator, tv0->axis(1)->extent(), 127);
+  checkIntValue(evaluator, tv0->axis(0)->rawExtent(), 129);
+  checkIntValue(evaluator, tv0->axis(1)->rawExtent(), 127);
 
   TORCH_CHECK(tv3->domain()->nDims() == 2);
-  checkIntValue(evaluator, tv3->axis(0)->extent(), 129);
-  checkIntValue(evaluator, tv3->axis(1)->extent(), 127);
+  checkIntValue(evaluator, tv3->axis(0)->rawExtent(), 129);
+  checkIntValue(evaluator, tv3->axis(1)->rawExtent(), 127);
 
   TORCH_CHECK(tv4->domain()->nDims() == 2);
-  checkIntValue(evaluator, tv4->axis(0)->extent(), 129);
-  checkIntValue(evaluator, tv4->axis(1)->extent(), 127);
+  checkIntValue(evaluator, tv4->axis(0)->rawExtent(), 129);
+  checkIntValue(evaluator, tv4->axis(1)->rawExtent(), 127);
 
   TORCH_CHECK(tv5->domain()->nDims() == 1);
-  checkIntValue(evaluator, tv5->axis(0)->extent(), 16383);
+  checkIntValue(evaluator, tv5->axis(0)->rawExtent(), 16383);
 
   TORCH_CHECK(tv6->domain()->nDims() == 3);
-  checkIntValue(evaluator, tv6->axis(0)->extent(), 26);
-  checkIntValue(evaluator, tv6->axis(1)->extent(), 5);
-  checkIntValue(evaluator, tv6->axis(2)->extent(), 127);
+  checkIntValue(evaluator, tv6->axis(0)->rawExtent(), 26);
+  checkIntValue(evaluator, tv6->axis(1)->rawExtent(), 5);
+  checkIntValue(evaluator, tv6->axis(2)->rawExtent(), 127);
 }
 
 // Evaluate expressions post lowering
@@ -350,13 +340,13 @@ TEST(NVFuserTest, FusionExprEvalPostLower_CUDA) {
   FusionGuard fg(&fusion);
 
   // Create a non-trivial IR
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = makeDummyTensor(2);
 
   fusion.addInput(tv0);
   fusion.addInput(tv1);
 
-  TensorView* tv2 = add(tv1, new Double(2.0));
+  TensorView* tv2 = add(tv1, new Float(2.0));
   TensorView* tv3 = add(tv0, tv2);
 
   fusion.addOutput(tv3);
@@ -372,101 +362,36 @@ TEST(NVFuserTest, FusionExprEvalPostLower_CUDA) {
   tv2->axis(-1)->parallelize(ParallelType::TIDx);
   tv3->axis(-1)->parallelize(ParallelType::TIDx);
 
-  auto* bid_x = add(tv3->axis(0)->extent(), new Int(0));
-  auto* tid_x = add(tv3->axis(-1)->extent(), new Int(0));
+  auto* bid_x = add(tv3->axis(0)->rawExtent(), new Int(0));
+  auto* tid_x = add(tv3->axis(-1)->rawExtent(), new Int(0));
 
   // Lower
   GpuLower gpulw(&fusion);
 
   // 1. Create an evaluation context
-  ExpressionEvaluator evaluator(&fusion);
+  StatefulExpressionEvaluator evaluator(&fusion);
 
   // 2. Bind values
-  evaluator.bind(tv0->getRootDomain()[0]->extent(), 6);
-  evaluator.bind(tv0->getRootDomain()[1]->extent(), 128);
-  evaluator.bind(tv1->getRootDomain()[0]->extent(), 6);
-  evaluator.bind(tv1->getRootDomain()[1]->extent(), 128);
+  evaluator.safeBind(tv0->getRootDomain()[0]->extent(), 6);
+  evaluator.safeBind(tv0->getRootDomain()[1]->extent(), 128);
+  evaluator.safeBind(tv1->getRootDomain()[0]->extent(), 6);
+  evaluator.safeBind(tv1->getRootDomain()[1]->extent(), 128);
 
   // 3. Evaluate and check result values
   TORCH_CHECK(tv2->domain()->nDims() == 3);
-  checkIntValue(evaluator, tv2->axis(0)->extent(), 2);
-  checkIntValue(evaluator, tv2->axis(1)->extent(), 4);
-  checkIntValue(evaluator, tv2->axis(2)->extent(), 128);
+  checkIntValue(evaluator, tv2->axis(0)->rawExtent(), 2);
+  checkIntValue(evaluator, tv2->axis(1)->rawExtent(), 4);
+  checkIntValue(evaluator, tv2->axis(2)->rawExtent(), 128);
 
   TORCH_CHECK(tv3->domain()->nDims() == 3);
-  checkIntValue(evaluator, tv3->axis(0)->extent(), 2);
-  checkIntValue(evaluator, tv3->axis(1)->extent(), 4);
-  checkIntValue(evaluator, tv3->axis(2)->extent(), 128);
+  checkIntValue(evaluator, tv3->axis(0)->rawExtent(), 2);
+  checkIntValue(evaluator, tv3->axis(1)->rawExtent(), 4);
+  checkIntValue(evaluator, tv3->axis(2)->rawExtent(), 128);
 
   checkIntValue(evaluator, bid_x, 2);
   checkIntValue(evaluator, tid_x, 128);
 }
 
-// Kernel IR: Evaluate basic scalar operations with constant values
-TEST(NVFuserTest, KernelExprEvalConstants_CUDA) {
-  kir::Kernel kernel;
-  kir::IrBuilder ir_builder(&kernel);
-
-  auto a = ir_builder.create<kir::Int>(7);
-  auto b = ir_builder.create<kir::Int>(3);
-  auto c = ir_builder.subExpr(a, b);
-  auto d = ir_builder.divExpr(a, b);
-  auto e = ir_builder.mulExpr(c, d);
-
-  kir::ExpressionEvaluator evaluator;
-
-  checkIntValue(evaluator, ir_builder.negExpr(a), -7);
-  checkIntValue(evaluator, ir_builder.addExpr(a, b), 10);
-  checkIntValue(evaluator, ir_builder.negExpr(e), -8);
-  checkIntValue(evaluator, ir_builder.modExpr(a, b), 1);
-  checkIntValue(evaluator, ir_builder.ceilDivExpr(a, b), 3);
-}
-
-// Kernel IR: Evaluate basic scalar operations with bound values
-TEST(NVFuserTest, KernelExprEvalBindings_CUDA) {
-  kir::Kernel kernel;
-  kir::IrBuilder ir_builder(&kernel);
-
-  kir::ExpressionEvaluator evaluator;
-
-  auto a = ir_builder.create<kir::Int>(c10::nullopt);
-  auto b = ir_builder.create<kir::Int>(c10::nullopt);
-  auto c = ir_builder.addExpr(a, b);
-  auto d = ir_builder.negExpr(ir_builder.ceilDivExpr(c, b));
-  auto e = ir_builder.create<kir::Int>(0);
-
-  // trying to evaluate before binding should give empty results
-  TORCH_CHECK(!evaluator.evaluate(a).has_value());
-  TORCH_CHECK(!evaluator.evaluate(d).has_value());
-
-  evaluator.bind(a, 7);
-  evaluator.bind(b, 3);
-
-  // can't bind to the results of expressions
-  ASSERT_ANY_THROW(evaluator.bind(c, 100));
-
-  // can't bind to concrete values
-  ASSERT_ANY_THROW(evaluator.bind(e, 100));
-
-  checkIntValue(evaluator, c, 10);
-  checkIntValue(evaluator, ir_builder.subExpr(a, b), 4);
-  checkIntValue(evaluator, ir_builder.modExpr(a, b), 1);
-  checkIntValue(evaluator, ir_builder.ceilDivExpr(a, b), 3);
-  checkIntValue(evaluator, d, -4);
-
-  // Reset the evaluation context
-  evaluator = kir::ExpressionEvaluator();
-
-  evaluator.bind(a, 2);
-  evaluator.bind(b, 5);
-
-  checkIntValue(evaluator, c, 7);
-  checkIntValue(evaluator, ir_builder.subExpr(a, b), -3);
-  checkIntValue(evaluator, ir_builder.modExpr(a, b), 2);
-  checkIntValue(evaluator, ir_builder.ceilDivExpr(a, b), 1);
-  checkIntValue(evaluator, d, -2);
-}
-
 TEST(NVFuserTest, FusionClear_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
@@ -474,13 +399,13 @@ TEST(NVFuserTest, FusionClear_CUDA) {
   // 1. Create a dummy IR
 
   {
-    TensorView* tv0 = makeSymbolicTensor(2);
-    TensorView* tv1 = makeSymbolicTensor(2);
+    TensorView* tv0 = makeDummyTensor(2);
+    TensorView* tv1 = makeDummyTensor(2);
 
     fusion.addInput(tv0);
     fusion.addInput(tv1);
 
-    TensorView* tv2 = add(tv1, new Double(2.0));
+    TensorView* tv2 = add(tv1, new Float(2.0));
     TensorView* tv3 = add(tv0, tv2);
 
     fusion.addOutput(tv3);
@@ -498,20 +423,22 @@ TEST(NVFuserTest, FusionClear_CUDA) {
 
   fusion.clear();
 
-  TORCH_CHECK(fusion.unordered_exprs().empty());
+  TORCH_CHECK(fusion.exprs().empty());
   TORCH_CHECK(fusion.vals().empty());
 
   TORCH_CHECK(fusion.inputs().empty());
   TORCH_CHECK(fusion.outputs().empty());
 
   TORCH_CHECK(!fusion.hasReduction());
+  TORCH_CHECK(!fusion.hasBlockReduction());
+  TORCH_CHECK(!fusion.hasGridReduction());
 
   // 3. Rebuild the IR
 
   {
-    TensorView* tv0 = makeSymbolicTensor(3);
-    TensorView* tv1 = makeSymbolicTensor(3);
-    TensorView* tv2 = add(tv1, new Double(2.0));
+    TensorView* tv0 = makeDummyTensor(3);
+    TensorView* tv1 = makeDummyTensor(3);
+    TensorView* tv2 = add(tv1, new Float(2.0));
     TensorView* tv3 = add(tv0, tv2);
 
     fusion.addInput(tv0);
@@ -552,9 +479,9 @@ TEST(NVFuserTest, FusionCopy_CUDA) {
   {
     FusionGuard fg(&original_fusion);
 
-    auto tv0 = makeSymbolicTensor(3);
-    auto tv1 = makeSymbolicTensor(3);
-    auto tv2 = add(tv1, new Double(2.0));
+    auto tv0 = makeDummyTensor(3);
+    auto tv1 = makeDummyTensor(3);
+    auto tv2 = add(tv1, new Float(2.0));
     auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2);
 
     original_fusion.addInput(tv0);
@@ -626,9 +553,9 @@ TEST(NVFuserTest, FusionMove_CUDA) {
   {
     FusionGuard fg(&fusion);
 
-    auto tv0 = makeSymbolicTensor(3);
-    auto tv1 = makeSymbolicTensor(3);
-    auto tv2 = add(tv1, new Double(2.0));
+    auto tv0 = makeDummyTensor(3);
+    auto tv1 = makeDummyTensor(3);
+    auto tv2 = add(tv1, new Float(2.0));
     auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2);
 
     fusion.addInput(tv0);
@@ -661,7 +588,7 @@ TEST(NVFuserTest, FusionMove_CUDA) {
   //    standard library containers:
   //    https://en.cppreference.com/w/cpp/utility/move
   //
-  TORCH_CHECK(fusion.unordered_exprs().empty());
+  TORCH_CHECK(fusion.exprs().empty());
   TORCH_CHECK(fusion.vals().empty());
   TORCH_CHECK(fusion.inputs().empty());
   TORCH_CHECK(fusion.outputs().empty());
@@ -695,22 +622,22 @@ TEST(NVFuserTest, FusionSimpleArith_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  Double* d1 = new Double(1.f);
-  Double* d2 = new Double{2.f};
-  Double* d3 = new Double();
+  Float* f1 = new Float(1.f);
+  Float* f2 = new Float{2.f};
+  Float* f3 = new Float();
 
   // Disrupt the fusion to make sure guard works well
   {
     Fusion fusion2;
     FusionGuard fg(&fusion2);
 
-    Double* d1 = new Double(1.f);
-    Double* d2 = new Double(2.f);
-    add(d1, d2);
+    Float* f1 = new Float(1.f);
+    Float* f2 = new Float(2.f);
+    add(f1, f2);
     ss2 << fusion2;
   }
 
-  new BinaryOp(BinaryOpType::Add, d3, d1, d2);
+  new BinaryOp(BinaryOpType::Add, f3, f1, f2);
   ss1 << fusion;
 
   TORCH_CHECK(
@@ -722,24 +649,54 @@ TEST(NVFuserTest, FusionSimpleTypePromote_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  Double* d4 = new Double{4.f};
+  Float* f4 = new Float{4.f};
+  Int* i1 = new Int{3};
+  auto f5 = add(f4, i1);
+
+  TORCH_CHECK(f5->getDataType() == DataType::Float);
+}
+
+class ZeroMutator : public OptOutMutator {
+ public:
+  Statement* mutate(Float* f) {
+    if (f->isConst() && *(f->value()) == 1.0)
+      return new Float(0.0);
+    return f;
+  }
+  void mutate(Fusion* f) {
+    OptOutMutator::mutate(f);
+  }
+};
+
+TEST(NVFuserTest, FusionMutator_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
+
+  Float* f4 = new Float{1.f};
   Int* i1 = new Int{3};
-  auto d5 = add(d4, i1);
+  Val* f5 = add(f4, i1);
+  ZeroMutator mutator;
+  mutator.mutate(&fusion);
+  Val* lhs = static_cast<BinaryOp*>(fusion.origin(f5))->lhs();
+  TORCH_CHECK(
+      lhs->getValType().value() == ValType::Scalar &&
+      lhs->getDataType().value() == DataType::Float);
+  Float* flhs = static_cast<Float*>(lhs);
 
-  TORCH_CHECK(d5->getDataType() == DataType::Double);
+  TORCH_CHECK(flhs->value().value() == 0.f);
 }
 
 TEST(NVFuserTest, FusionRegister_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
-  Double* v1 = new Double{1.f};
-  Double* v2 = new Double{2.f};
+  Float* v1 = new Float{1.f};
+  Float* v2 = new Float{2.f};
   Val* v3 = binaryOp(BinaryOpType::Add, v1, v2);
   Val* v4 = binaryOp(BinaryOpType::Add, v1, v2);
   TORCH_CHECK(v1->name() + 1 == v2->name());
   TORCH_CHECK(v2->name() + 1 == v3->name());
   TORCH_CHECK(v3->name() + 1 == v4->name());
-  TORCH_CHECK(v3->definition()->name() + 1 == v4->definition()->name());
+  TORCH_CHECK(fusion.origin(v3)->name() + 1 == fusion.origin(v4)->name());
 }
 
 // dummy expr with 2 outputs only for toposort test.
@@ -768,57 +725,63 @@ TEST(NVFuserTest, FusionTopoSort_CUDA) {
   // e1: v4     =   add(v3, v2)
   // e2: v5     =   add(v2, v4)
   // e3: v6     =   add(v5, v5)
-  Double* v0 = new Double{1.f};
-  Double* v1 = new Double{2.f};
-  Double* v2 = new Double();
-  Double* v3 = new Double();
-  Double* v4 = new Double();
-  Double* v5 = new Double();
-  Double* v6 = new Double();
-
-  std::vector<Val*> inputs = {v0, v1};
-  for (auto val : inputs) {
-    fusion.addInput(val);
-  }
+  Float* v0 = new Float{1.f};
+  Float* v1 = new Float{2.f};
+  Float* v2 = new Float();
+  Float* v3 = new Float();
+  Float* v4 = new Float();
+  Float* v5 = new Float();
+  Float* v6 = new Float();
 
   Expr* e0 = new DummyExpr(v3, v2, v1, v0);
   Expr* e1 = new BinaryOp(BinaryOpType::Add, v4, v3, v2);
   Expr* e2 = new BinaryOp(BinaryOpType::Add, v5, v2, v4);
   Expr* e3 = new BinaryOp(BinaryOpType::Add, v6, v5, v5);
 
+  std::vector<Expr*> exprs = fusion.exprs();
+
+  TORCH_CHECK(exprs.size() == 4);
+  TORCH_CHECK(exprs[0] == e0);
+  TORCH_CHECK(exprs[1] == e1);
+  TORCH_CHECK(exprs[2] == e2);
+  TORCH_CHECK(exprs[3] == e3);
+
   fusion.addOutput(v2);
-  fusion.addOutput(v3);
-  auto exprs = fusion.exprs();
-  TORCH_CHECK(exprs.size() == 1, "Found ", exprs.size(), " but expecting 1");
+  exprs = fusion.exprs(true);
+  TORCH_CHECK(exprs.size() == 1);
   TORCH_CHECK(exprs[0] == e0);
 
   fusion.addOutput(v5);
-  exprs = fusion.exprs();
-  TORCH_CHECK(exprs.size() == 3, "Found ", exprs.size(), " but expecting 3");
+  exprs = fusion.exprs(true);
   TORCH_CHECK(exprs[0] == e0);
   TORCH_CHECK(exprs[1] == e1);
   TORCH_CHECK(exprs[2] == e2);
 
   fusion.addOutput(v4);
-  exprs = fusion.exprs();
-  TORCH_CHECK(exprs.size() == 3, "Found ", exprs.size(), " but expecting 3");
+  exprs = fusion.exprs(true);
+  TORCH_CHECK(exprs[0] == e0);
+  TORCH_CHECK(exprs[1] == e1);
+  TORCH_CHECK(exprs[2] == e2);
+
+  fusion.addOutput(v3);
+  exprs = fusion.exprs(true);
   TORCH_CHECK(exprs[0] == e0);
   TORCH_CHECK(exprs[1] == e1);
   TORCH_CHECK(exprs[2] == e2);
 
   fusion.addOutput(v6);
-  exprs = fusion.exprs();
-  TORCH_CHECK(exprs.size() == 4, "Found ", exprs.size(), " but expecting 4");
+  exprs = fusion.exprs(true);
+  TORCH_CHECK(exprs.size() == 4);
   TORCH_CHECK(exprs[0] == e0);
   TORCH_CHECK(exprs[1] == e1);
   TORCH_CHECK(exprs[2] == e2);
   TORCH_CHECK(exprs[3] == e3);
 
-  TORCH_CHECK(v2->definition()->name() == 0);
-  TORCH_CHECK(v3->definition()->name() == 0);
-  TORCH_CHECK(v4->definition()->name() == 1);
-  TORCH_CHECK(v5->definition()->name() == 2);
-  TORCH_CHECK(v6->definition()->name() == 3);
+  TORCH_CHECK(fusion.origin(v2)->name() == 0);
+  TORCH_CHECK(fusion.origin(v3)->name() == 0);
+  TORCH_CHECK(fusion.origin(v4)->name() == 1);
+  TORCH_CHECK(fusion.origin(v5)->name() == 2);
+  TORCH_CHECK(fusion.origin(v6)->name() == 3);
 }
 
 TEST(NVFuserTest, FusionTensor_CUDA) {
@@ -889,9 +852,9 @@ TEST(NVFuserTest, FusionFilterVals_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
-  auto tv1 = makeSymbolicTensor(1);
-  auto scalar0 = new Double(0);
+  auto tv0 = makeDummyTensor(1);
+  auto tv1 = makeDummyTensor(1);
+  auto scalar0 = new Float(0);
   auto scalar1 = new Int(0);
   auto scalar2 = new Int(1);
 
@@ -904,9 +867,9 @@ TEST(NVFuserTest, FusionFilterVals_CUDA) {
   TORCH_CHECK(tvs[0] == tv0);
   TORCH_CHECK(tvs[1] == tv1);
 
-  std::vector<Double*> floats(
-      ir_utils::filterByType<Double>(vals).begin(),
-      ir_utils::filterByType<Double>(vals).end());
+  std::vector<Float*> floats(
+      ir_utils::filterByType<Float>(vals).begin(),
+      ir_utils::filterByType<Float>(vals).end());
   TORCH_CHECK(floats.size() == 1);
   TORCH_CHECK(floats[0] == scalar0);
 
@@ -927,11 +890,11 @@ TEST(NVFuserTest, FusionTVSplit_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv = makeSymbolicTensor(3);
+  TensorView* tv = makeDummyTensor(3);
 
   tv = tv->split(2, 2);
   TORCH_CHECK(tv->nDims() == 4);
-  Expr* outer = tv->axis(2)->extent()->definition();
+  Expr* outer = tv->axis(2)->extent()->getOrigin();
 
   TORCH_CHECK(
       outer->getExprType().value() == ExprType::BinaryOp &&
@@ -953,10 +916,10 @@ TEST(NVFuserTest, FusionTVMerge_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv = makeSymbolicTensor(3);
+  TensorView* tv = makeDummyTensor(3);
 
   tv = tv->merge(1);
-  Expr* axisOp = tv->axis(1)->extent()->definition();
+  Expr* axisOp = tv->axis(1)->extent()->getOrigin();
 
   TORCH_CHECK(
       tv->nDims() == 2 && axisOp->getExprType() == ExprType::BinaryOp &&
@@ -979,7 +942,7 @@ TEST(NVFuserTest, FusionTVReorder_CUDA) {
 
   std::unordered_map<int, int> swap{{0, 2}, {2, 0}};
 
-  auto tv = makeSymbolicTensor(3);
+  auto tv = makeDummyTensor(3);
   std::vector<IterDomain*> ref;
   ref = std::vector<IterDomain*>(
       tv->domain()->domain().begin(), tv->domain()->domain().end());
@@ -988,7 +951,7 @@ TEST(NVFuserTest, FusionTVReorder_CUDA) {
   for (int i = 0; i < (int)tv->nDims(); i++)
     TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1)));
 
-  tv = makeSymbolicTensor(3);
+  tv = makeDummyTensor(3);
   ref = std::vector<IterDomain*>(
       tv->domain()->domain().begin(), tv->domain()->domain().end());
 
@@ -996,7 +959,7 @@ TEST(NVFuserTest, FusionTVReorder_CUDA) {
   for (int i = 0; i < (int)tv->nDims(); i++)
     TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1)));
 
-  tv = makeSymbolicTensor(3);
+  tv = makeDummyTensor(3);
   ref = std::vector<IterDomain*>(
       tv->domain()->domain().begin(), tv->domain()->domain().end());
 
@@ -1005,7 +968,7 @@ TEST(NVFuserTest, FusionTVReorder_CUDA) {
   for (int i = 1; i < (int)tv->nDims(); i++)
     TORCH_CHECK(ref[i - 1]->sameAs(tv->axis(i)));
 
-  tv = makeSymbolicTensor(3);
+  tv = makeDummyTensor(3);
   ref = std::vector<IterDomain*>(
       tv->domain()->domain().begin(), tv->domain()->domain().end());
   tv->reorder(swap);
@@ -1018,15 +981,15 @@ TEST(NVFuserTest, FusionEquality_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  Double* fval1 = new Double();
-  Double* fval1_copy = fval1;
-  Double* fval2 = new Double();
-  Double* fone = new Double(1.0);
+  Float* fval1 = new Float();
+  Float* fval1_copy = fval1;
+  Float* fval2 = new Float();
+  Float* fone = new Float(1.0);
 
   TORCH_CHECK(fval1->sameAs(fval1_copy));
   TORCH_CHECK(!fval1->sameAs(fval2));
   TORCH_CHECK(!fone->sameAs(fval1));
-  TORCH_CHECK(fone->sameAs(new Double(1.0)));
+  TORCH_CHECK(fone->sameAs(new Float(1.0)));
 
   Int* ival1 = new Int();
   Int* ival1_copy = ival1;
@@ -1038,14 +1001,14 @@ TEST(NVFuserTest, FusionEquality_CUDA) {
   TORCH_CHECK(!ione->sameAs(ival1));
   TORCH_CHECK(ione->sameAs(new Int(1)));
 
-  BinaryOp* add1 = new BinaryOp(BinaryOpType::Add, new Double(), fval1, ival1);
+  BinaryOp* add1 = new BinaryOp(BinaryOpType::Add, new Float(), fval1, ival1);
   BinaryOp* add1_copy =
-      new BinaryOp(BinaryOpType::Add, new Double(), fval1, ival1);
-  BinaryOp* sub1 = new BinaryOp(BinaryOpType::Sub, new Double(), fval1, ival1);
+      new BinaryOp(BinaryOpType::Add, new Float(), fval1, ival1);
+  BinaryOp* sub1 = new BinaryOp(BinaryOpType::Sub, new Float(), fval1, ival1);
 
-  UnaryOp* neg1 = new UnaryOp(UnaryOpType::Neg, new Double(), fval1);
-  UnaryOp* neg2 = new UnaryOp(UnaryOpType::Neg, new Double(), fval2);
-  UnaryOp* neg1_copy = new UnaryOp(UnaryOpType::Neg, new Double(), fval1);
+  UnaryOp* neg1 = new UnaryOp(UnaryOpType::Neg, new Float(), fval1);
+  UnaryOp* neg2 = new UnaryOp(UnaryOpType::Neg, new Float(), fval2);
+  UnaryOp* neg1_copy = new UnaryOp(UnaryOpType::Neg, new Float(), fval1);
 
   TORCH_CHECK(add1->sameAs(add1_copy));
   TORCH_CHECK(!add1->sameAs(sub1));
@@ -1059,79 +1022,73 @@ TEST(NVFuserTest, FusionDependency_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  Double* d0 = new Double(0.f);
-  Double* d1 = new Double(1.f);
-  auto d2 = add(d0, d1);
-
-  auto d3 = add(d2, d2);
-
-  Double* d4 = new Double(4.f);
-  Double* d5 = new Double(5.f);
-  auto d6 = add(d4, d5);
-
-  Double* d7 = new Double(7.f);
-  Double* d8 = new Double(8.f);
-  auto d9 = add(d7, d8);
-
-  auto d10 = add(d6, d9);
-
-  auto d11 = add(d3, d10);
-
-  TORCH_CHECK(DependencyCheck::isDependencyOf(d0, d11));
-  TORCH_CHECK(DependencyCheck::isDependencyOf(d1, d11));
-  TORCH_CHECK(DependencyCheck::isDependencyOf(d2, d11));
-  TORCH_CHECK(DependencyCheck::isDependencyOf(d3, d11));
-  TORCH_CHECK(DependencyCheck::isDependencyOf(d6, d11));
-  TORCH_CHECK(DependencyCheck::isDependencyOf(d9, d11));
-  TORCH_CHECK(DependencyCheck::isDependencyOf(d0, d2));
-  TORCH_CHECK(DependencyCheck::isDependencyOf(d2, d3));
-  TORCH_CHECK(DependencyCheck::isDependencyOf(d4, d6));
-  TORCH_CHECK(DependencyCheck::isDependencyOf(d8, d10));
-
-  TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d0));
-  TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d1));
-  TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d2));
-  TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d3));
-  TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d4));
-  TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d5));
-  TORCH_CHECK(!DependencyCheck::isDependencyOf(d2, d0));
-  TORCH_CHECK(!DependencyCheck::isDependencyOf(d3, d2));
-  TORCH_CHECK(!DependencyCheck::isDependencyOf(d6, d4));
-  TORCH_CHECK(!DependencyCheck::isDependencyOf(d10, d8));
-
-  auto dep_chain = DependencyCheck::getSingleDependencyChain(d0, d11);
-  TORCH_CHECK(dep_chain.back() == d11);
+  Float* f0 = new Float(0.f);
+  Float* f1 = new Float(1.f);
+  auto f2 = add(f0, f1);
+
+  auto f3 = add(f2, f2);
+
+  Float* f4 = new Float(4.f);
+  Float* f5 = new Float(5.f);
+  auto f6 = add(f4, f5);
+
+  Float* f7 = new Float(7.f);
+  Float* f8 = new Float(8.f);
+  auto f9 = add(f7, f8);
+
+  auto f10 = add(f6, f9);
+
+  auto f11 = add(f3, f10);
+
+  TORCH_CHECK(DependencyCheck::isDependencyOf(f0, f11));
+  TORCH_CHECK(DependencyCheck::isDependencyOf(f1, f11));
+  TORCH_CHECK(DependencyCheck::isDependencyOf(f2, f11));
+  TORCH_CHECK(DependencyCheck::isDependencyOf(f3, f11));
+  TORCH_CHECK(DependencyCheck::isDependencyOf(f6, f11));
+  TORCH_CHECK(DependencyCheck::isDependencyOf(f9, f11));
+  TORCH_CHECK(DependencyCheck::isDependencyOf(f0, f2));
+  TORCH_CHECK(DependencyCheck::isDependencyOf(f2, f3));
+  TORCH_CHECK(DependencyCheck::isDependencyOf(f4, f6));
+  TORCH_CHECK(DependencyCheck::isDependencyOf(f8, f10));
+
+  TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f0));
+  TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f1));
+  TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f2));
+  TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f3));
+  TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f4));
+  TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f5));
+  TORCH_CHECK(!DependencyCheck::isDependencyOf(f2, f0));
+  TORCH_CHECK(!DependencyCheck::isDependencyOf(f3, f2));
+  TORCH_CHECK(!DependencyCheck::isDependencyOf(f6, f4));
+  TORCH_CHECK(!DependencyCheck::isDependencyOf(f10, f8));
+
+  auto dep_chain = DependencyCheck::getSingleDependencyChain(f0, f11);
+  TORCH_CHECK(dep_chain.back() == f11);
   dep_chain.pop_back();
-  TORCH_CHECK(dep_chain.back() == d3);
+  TORCH_CHECK(dep_chain.back() == f3);
   dep_chain.pop_back();
-  TORCH_CHECK(dep_chain.back() == d2);
+  TORCH_CHECK(dep_chain.back() == f2);
   dep_chain.pop_back();
 
-  dep_chain = DependencyCheck::getSingleDependencyChain(d6, d11);
-  TORCH_CHECK(dep_chain.back() == d11);
+  dep_chain = DependencyCheck::getSingleDependencyChain(f6, f11);
+  TORCH_CHECK(dep_chain.back() == f11);
   dep_chain.pop_back();
-  TORCH_CHECK(dep_chain.back() == d10);
+  TORCH_CHECK(dep_chain.back() == f10);
   dep_chain.pop_back();
 
-  dep_chain = DependencyCheck::getSingleDependencyChain(d4, d11);
-  TORCH_CHECK(dep_chain.back() == d11);
+  dep_chain = DependencyCheck::getSingleDependencyChain(f4, f11);
+  TORCH_CHECK(dep_chain.back() == f11);
   dep_chain.pop_back();
-  TORCH_CHECK(dep_chain.back() == d10);
+  TORCH_CHECK(dep_chain.back() == f10);
   dep_chain.pop_back();
-  TORCH_CHECK(dep_chain.back() == d6);
+  TORCH_CHECK(dep_chain.back() == f6);
   dep_chain.pop_back();
 
-  dep_chain = DependencyCheck::getSingleDependencyChain(d11, d2);
+  dep_chain = DependencyCheck::getSingleDependencyChain(f11, f2);
   TORCH_CHECK(dep_chain.empty());
 }
 
 TEST(NVFuserTest, FusionParser_CUDA) {
-  // This test may not pass if using a custom block sync as there may
-  // be additional calls. Skip the test as it's not specifically
-  // relevant with block synchronizatin.
-  if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
-    return;
-  }
   auto g = std::make_shared<Graph>();
   const auto graph0_string = R"IR(
     graph(%0 : Float(2, strides=[1]),
@@ -1156,43 +1113,38 @@ TEST(NVFuserTest, FusionParser_CUDA) {
   auto fusion = parseJitIR(g);
   FusionGuard fg(fusion.get());
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  // Avoid vectorization here as those kernels can't be lowered twice at the
-  // moment
   at::Tensor input1 = at::randn({16}, options);
   at::Tensor input2 = at::randn({16}, options);
-  auto lparams = schedulePointwise(fusion.get(), {input1, input2});
+  scheduleFusion(fusion.get(), {input1, input2});
 
   // CONSIDER:
   // 1. this can be moved to a dedicated "golden" file
   // 2. use a fuzzy compare (ignore non-significant whitespaces for example)
   const std::string expected_kernel = R"(
 __global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3) {
-  if ((((((((((nvfuser_index_t)blockIdx.x) * 1) + (1 - 1)) * 1) + (1 - 1)) * 128) + ((nvfuser_index_t)threadIdx.x)) < T0.size[0])) {
-    constexpr nvfuser_index_t ki169 = 0;
-    float T5[1];
-    constexpr nvfuser_index_t ki203 = 0;
-    T5[ki203] = 0;
-    constexpr nvfuser_index_t ki194 = 0;
-    T5[ki194]
-       = T1[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki169) * 1) + ki194) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)];
-    float T4[1];
-    constexpr nvfuser_index_t ki209 = 0;
-    T4[ki209] = 0;
-    constexpr nvfuser_index_t ki189 = 0;
-    T4[ki189]
-       = T0[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki169) * 1) + ki189) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)];
-    float T6[1];
-    constexpr nvfuser_index_t ki178 = 0;
-    float T2[1];
-    T2[0]
-      = T4[ki178]
-      * T5[ki178];
-    T6[ki178]
-      = T2[0]
-      * T4[ki178];
-    constexpr nvfuser_index_t ki171 = 0;
-    T3[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki169) * 1) + ki171) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]
-       = T6[ki171];
+  float T2[1];
+  if ((((((blockIdx.x * 1) + (1 - 1)) * 128) + threadIdx.x) < T0.size[0])) {
+    for(size_t i6 = 0; i6 < 1; ++i6) {
+      T2[i6]
+        = T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)]
+        * T1[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
+      T3[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)]
+        = T2[i6]
+        * T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
+    }
+  } else {
+    for(size_t i6 = 0; i6 < 1; ++i6) {
+      if ((((((blockIdx.x * 1) + i6) * 128) + threadIdx.x) < T0.size[0])) {
+        T2[i6]
+          = T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)]
+          * T1[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
+      }
+      if ((((((blockIdx.x * 1) + i6) * 128) + threadIdx.x) < T0.size[0])) {
+        T3[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)]
+          = T2[i6]
+          * T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
+      }
+    }
   }
 }
 )";
@@ -1206,23 +1158,12 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Te
         << " \n ========= EXPECTED ========= \n"
         << expected_kernel << "\n========= ACTUAL ========== \n"
         << actual_kernel << "\n=================" << std::endl;
-    auto it = std::mismatch(
-        expected_kernel.begin(),
-        expected_kernel.end(),
-        actual_kernel.begin(),
-        actual_kernel.end());
-    std::string actual_mismatched_snippet(it.second, actual_kernel.end());
-    actual_mismatched_snippet = actual_mismatched_snippet.substr(0, 10);
-    std::string expected_mismatched_snippet(it.first, expected_kernel.end());
-    expected_mismatched_snippet = expected_mismatched_snippet.substr(0, 10);
-    std::cerr << "First mismatch found at: " << actual_mismatched_snippet
-              << ", expected: " << expected_mismatched_snippet << std::endl;
     TORCH_CHECK(false);
   }
 
   FusionExecutor fe;
   fe.compileFusion(fusion.get());
-  auto outputs = fe.runFusion({input1, input2}, lparams);
+  auto outputs = fe.runFusion({input1, input2});
   at::Tensor output_ref = input1 * input2 * input1;
   TORCH_CHECK(output_ref.equal(outputs[0]));
 }
@@ -1248,7 +1189,7 @@ TEST(NVFuserTest, FusionForLoop_CUDA) {
   auto ID0 = new kir::IterDomain(new IterDomain(new Int(0), new Int(8)));
 
   TensorView* TV2 = add(TV0, TV1);
-  BinaryOp* op = static_cast<BinaryOp*>(TV2->definition();
+  BinaryOp* op = static_cast<BinaryOp*>(TV2->getOrigin());
   fusion.addOutput(TV2);
 
   auto fl = new kir::ForLoop(new kir::Int(c10::nullopt), ID0, {op});
@@ -1268,53 +1209,15 @@ TEST(NVFuserTest, FusionForLoop_CUDA) {
 #endif
 }
 
-TEST(NVFuserTest, FusionOuterSplit_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(3);
-
-  new BinaryOp(BinaryOpType::Add, tv0, new Double(0.0), new Double(1.0));
-  TensorView* tv1 = add(tv0, new Double(2.0));
-  TensorView* tv2 = add(tv1, new Double(3.0));
-  fusion.addOutput(tv2);
-
-  //[I0, I1, I2]
-  tv2->split(-1, 4, false);
-  //[I0, I1, I2o{4}, I2i]
-  tv2->merge(0);
-  tv2->merge(0);
-  //[I0*I1*I2o{4}, I2i]
-  tv2->split(0, 2);
-  //[I0*I1*I2o{4}o, I0*I1*I2o{4}i{2}, I2i]
-  tv2->reorder({{0, 1}, {1, 0}});
-  // I0*I1*I2o{4}i{2}, [I0*I1*I2o{4}o, I2i]
-
-  tv0->computeAt(tv2, -1);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor output = at::empty({2, 6, 32}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({}, {output});
-
-  at::Tensor output_ref = at::zeros_like(output, options);
-  output_ref = output_ref + 0.0 + 1.0 + 2.0 + 3.0;
-
-  TORCH_CHECK(output_ref.equal(output));
-}
-
 TEST(NVFuserTest, FusionCodeGen_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(3);
+  TensorView* tv0 = makeDummyTensor(3);
 
-  new BinaryOp(BinaryOpType::Add, tv0, new Double(0.0), new Double(1.0));
-  TensorView* tv1 = add(tv0, new Double(2.0));
-  TensorView* tv2 = add(tv1, new Double(3.0));
+  new BinaryOp(BinaryOpType::Add, tv0, new Float(0.0), new Float(1.0));
+  TensorView* tv1 = add(tv0, new Float(2.0));
+  TensorView* tv2 = add(tv1, new Float(3.0));
   fusion.addOutput(tv2);
 
   //[I0, I1, I2]
@@ -1347,9 +1250,9 @@ TEST(NVFuserTest, FusionCodeGen2_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(3);
-  TensorView* tv1 = makeSymbolicTensor(3);
-  TensorView* tv2 = add(tv1, new Double(2.0));
+  TensorView* tv0 = makeDummyTensor(3);
+  TensorView* tv1 = makeDummyTensor(3);
+  TensorView* tv2 = add(tv1, new Float(2.0));
   TensorView* tv3 = add(tv0, tv2);
 
   fusion.addInput(tv0);
@@ -1401,7 +1304,7 @@ TEST(NVFuserTest, FusionSimplePWise_CUDA) {
 
   // Do math with it, it returns a `Val*` but can be static_casted back to
   // TensorView
-  TensorView* tv2 = add(tv1, new Double(2.0));
+  TensorView* tv2 = add(tv1, new Float(2.0));
   TensorView* tv3 = add(tv0, tv2);
 
   // Register your outputs
@@ -1447,8 +1350,8 @@ TEST(NVFuserTest, FusionExecKernel_CUDA) {
   FusionGuard fg(&fusion);
 
   // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = makeDummyTensor(2);
 
   // Register your inputs
   fusion.addInput(tv0);
@@ -1456,7 +1359,7 @@ TEST(NVFuserTest, FusionExecKernel_CUDA) {
 
   // Do math with it, it returns a `Val*` but can be static_casted back to
   // TensorView
-  TensorView* tv2 = add(tv1, new Double(2.0));
+  TensorView* tv2 = add(tv1, new Float(2.0));
   TensorView* tv3 = add(tv0, tv2);
 
   // Register your outputs
@@ -1508,13 +1411,13 @@ TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
 
-  TensorView* tv1 = mul(tv0, new Double(0.5));
-  TensorView* tv2 = mul(tv1, new Double(-1.0));
-  TensorView* tv3 = add(tv1, new Double(3.0));
-  TensorView* tv4 = mul(tv1, new Double(2.0));
+  TensorView* tv1 = mul(tv0, new Float(0.5));
+  TensorView* tv2 = mul(tv1, new Float(-1.0));
+  TensorView* tv3 = add(tv1, new Float(3.0));
+  TensorView* tv4 = mul(tv1, new Float(2.0));
   TensorView* tv5 = add(tv3, tv2);
 
   TensorView* tv6 = add(tv5, tv4);
@@ -1532,20 +1435,13 @@ TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) {
 
   tv0->computeAt(tv7, 1);
 
-  GpuLower gpulw(&fusion);
-
-  // The this-position of the last tensor should be zero.
-  TORCH_CHECK(
-      tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 &&
-      tv7->getMaxProducerPosition() == 1);
-  TORCH_CHECK(
-      tv7->nDims() == 3 && tv6->getComputeAtPosition() == 0 &&
-      tv6->getMaxProducerPosition() == 1);
-  // The position of every other tensor should be 1.
-  for (auto tv : {tv1, tv2, tv3, tv4, tv5}) {
-    TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1);
-    TORCH_CHECK(gpulw.caLoopMap().areMapped(tv7->axis(0), tv->axis(0)));
-  }
+  TORCH_CHECK(tv1->hasComputeAt() && tv1->nDims() == 3);
+  TORCH_CHECK(tv2->getComputeAtView() == tv5 && tv2->nDims() == 3);
+  TORCH_CHECK(tv3->getComputeAtView() == tv5 && tv3->nDims() == 3);
+  TORCH_CHECK(tv4->hasComputeAt() && tv4->nDims() == 3);
+  TORCH_CHECK(tv5->getComputeAtView() == tv6 && tv5->nDims() == 3);
+  TORCH_CHECK(tv6->getComputeAtView() == tv7 && tv6->nDims() == 3);
+  TORCH_CHECK(!tv7->hasComputeAt());
 
   for (Val* val : fusion.vals()) {
     if (!fusion.hasInput(val) &&
@@ -1558,9 +1454,9 @@ TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) {
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  at::Tensor aten_input = at::randn({129, 127}, options);
+  at::Tensor t0 = at::randn({129, 127}, options);
 
-  auto t1 = aten_input.mul({0.5});
+  auto t1 = t0.mul({0.5});
   auto t2 = t1.mul({-1.0});
   auto t3 = t1.add({3.0});
   auto t4 = t1.mul({2.0});
@@ -1568,16 +1464,15 @@ TEST(NVFuserTest, FusionAdvancedComputeAt1_CUDA) {
   auto t6 = t5.add(t4);
   auto t7 = t1.add(t4);
 
-  std::vector<at::Tensor> aten_outputs = {t6, t7};
-  std::vector<at::Tensor> cg_outputs = {
-      at::empty_like(aten_input, options), at::empty_like(aten_input, options)};
+  at::Tensor kernel_tv6 = at::empty_like(t0, options);
+  at::Tensor kernel_tv7 = at::empty_like(t0, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  fe.runFusion({aten_input}, cg_outputs);
+  fe.runFusion({t0}, {kernel_tv6, kernel_tv7});
 
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
+  TORCH_CHECK(at::allclose(kernel_tv6, t6));
+  TORCH_CHECK(at::allclose(kernel_tv7, t7));
 }
 
 TEST(NVFuserTest, FusionAdvancedComputeAt2_CUDA) {
@@ -1591,12 +1486,12 @@ TEST(NVFuserTest, FusionAdvancedComputeAt2_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
 
-  TensorView* tv1 = mul(tv0, new Double(-1.0));
-  TensorView* tv2 = add(tv0, new Double(3.0));
-  TensorView* tv3 = mul(tv0, new Double(2.0));
+  TensorView* tv1 = mul(tv0, new Float(-1.0));
+  TensorView* tv2 = add(tv0, new Float(3.0));
+  TensorView* tv3 = mul(tv0, new Float(2.0));
   TensorView* tv4 = add(tv2, tv1);
 
   TensorView* tv5 = add(tv4, tv3);
@@ -1625,22 +1520,21 @@ TEST(NVFuserTest, FusionAdvancedComputeAt2_CUDA) {
   }
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({129, 127}, options);
+  at::Tensor t0 = at::randn({129, 127}, options);
 
-  auto t1 = input.mul({-1.0});
-  auto t2 = input.add({3.0});
-  auto t3 = input.mul({2.0});
+  auto t1 = t0.mul({-1.0});
+  auto t2 = t0.add({3.0});
+  auto t3 = t0.mul({2.0});
   auto t4 = t2.add(t1);
   auto t5 = t4.add(t3);
   auto t6 = t5.add(t3);
 
-  std::vector<at::Tensor> aten_outputs = {t5, t6};
-
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input});
+  auto outputs = fe.runFusion({t0});
 
-  testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__);
+  TORCH_CHECK(at::allclose(outputs[0], t5));
+  TORCH_CHECK(at::allclose(outputs[1], t6));
 }
 
 TEST(NVFuserTest, FusionAdvancedComputeAt3_CUDA) {
@@ -1650,13 +1544,13 @@ TEST(NVFuserTest, FusionAdvancedComputeAt3_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(4);
+  TensorView* tv0 = makeDummyTensor(4);
   fusion.addInput(tv0);
 
-  TensorView* tv1 = makeSymbolicTensor(4);
+  TensorView* tv1 = makeDummyTensor(4);
   fusion.addInput(tv1);
 
-  TensorView* tv2 = mul(tv1, new Double(.979361));
+  TensorView* tv2 = mul(tv1, new Float(.979361));
   TensorView* tv3 = mul(tv2, tv0);
 
   fusion.addOutput(tv3);
@@ -1687,18 +1581,15 @@ TEST(NVFuserTest, FusionAdvancedComputeAt3_CUDA) {
   at::Tensor t1 = at::rand_like(t0, options);
 
   auto t2 = t1.mul({0.979361});
-  auto aten_output = t2.mul(t0);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
+  auto t3 = t2.mul(t0);
 
-  at::Tensor cg_output = at::empty_like(t0, options);
+  at::Tensor kernel_tv3 = at::empty_like(t0, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  fe.runFusion(aten_inputs, {cg_output});
+  fe.runFusion({t0, t1}, {kernel_tv3});
 
-  testValidate(
-      &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  TORCH_CHECK(at::allclose(kernel_tv3, t3));
 }
 
 TEST(NVFuserTest, FusionAdvancedComputeAt4_CUDA) {
@@ -1709,16 +1600,16 @@ TEST(NVFuserTest, FusionAdvancedComputeAt4_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(4);
+  TensorView* tv0 = makeDummyTensor(4);
   fusion.addInput(tv0);
 
-  TensorView* tv1 = makeSymbolicTensor(4);
+  TensorView* tv1 = makeDummyTensor(4);
   fusion.addInput(tv1);
 
-  TensorView* tv2 = makeSymbolicTensor(4);
+  TensorView* tv2 = makeDummyTensor(4);
   fusion.addInput(tv2);
 
-  TensorView* tv3 = makeSymbolicTensor(4);
+  TensorView* tv3 = makeDummyTensor(4);
   fusion.addInput(tv3);
 
   TensorView* tv4 = sub(tv2, tv3);
@@ -1758,16 +1649,13 @@ TEST(NVFuserTest, FusionAdvancedComputeAt4_CUDA) {
 
   auto t4 = t2.sub(t3);
   auto t5 = t1.add(t4);
-  auto aten_output = t5.sub(t0);
-
-  std::vector<IValue> aten_inputs = {t0, t1, t2, t3};
+  auto t6 = t5.sub(t0);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  auto outputs = fe.runFusion({t0, t1, t2, t3});
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  TORCH_CHECK(at::allclose(outputs[0], t6));
 }
 
 TEST(NVFuserTest, FusionAdvancedComputeAt5_CUDA) {
@@ -1778,11 +1666,11 @@ TEST(NVFuserTest, FusionAdvancedComputeAt5_CUDA) {
   FusionGuard fg(&fusion);
 
   // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
-  TensorView* tv1 = makeSymbolicTensor(2);
+  TensorView* tv1 = makeDummyTensor(2);
   fusion.addInput(tv1);
-  TensorView* tv2 = add(tv0, new Double(2.0));
+  TensorView* tv2 = add(tv0, new Float(2.0));
   TensorView* tv3 = mul(tv1, tv2);
   fusion.addOutput(tv3);
 
@@ -1798,27 +1686,24 @@ TEST(NVFuserTest, FusionAdvancedComputeAt5_CUDA) {
   at::Tensor t1 = at::rand_like(t0, options);
 
   auto t2 = t0.add(2.0);
-  auto aten_output = t1.mul(t2);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
+  auto t3 = t1.mul(t2);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  auto outputs = fe.runFusion({t0, t1});
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  TORCH_CHECK(at::allclose(outputs[0], t3));
 }
 
 TEST(NVFuserTest, FusionAdvancedComputeAt6_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
-  TensorView* tv1 = makeSymbolicTensor(2);
+  TensorView* tv1 = makeDummyTensor(2);
   fusion.addInput(tv1);
-  TensorView* tv2 = add(tv0, new Double(2.0));
+  TensorView* tv2 = add(tv0, new Float(2.0));
   TensorView* tv3 = mul(tv1, tv2);
   fusion.addOutput(tv3);
 
@@ -1837,209 +1722,208 @@ TEST(NVFuserTest, FusionAdvancedComputeAt6_CUDA) {
   at::Tensor t1 = at::rand_like(t0, options);
 
   auto t2 = t0.add(2.0);
-  auto aten_output = t1.mul(t2);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
+  auto t3 = t1.mul(t2);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  auto outputs = fe.runFusion({t0, t1});
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  TORCH_CHECK(at::allclose(outputs[0], t3));
 }
 
-TEST(NVFuserTest, FusionAdvancedComputeAt7_CUDA) {
+TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) {
+  // tv1 = tv0 * 0.5
+  // tv2 = tv1 * -1
+  // tv3 = tv2 * -2
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
+  TensorView* tv0 = makeDummyTensor(1);
   fusion.addInput(tv0);
 
-  auto tv1 = add(tv0, new Double(1.0));
-
-  auto tv2 = makeSymbolicTensor(1);
-  fusion.addInput(tv2);
-
-  auto tv3 = add(tv2, new Double(3.0));
-
-  auto tv4 = add(tv1, tv3);
-  fusion.addOutput(tv4);
-
-  auto tv5 = broadcast(tv1, {false, true});
+  TensorView* tv1 = mul(tv0, new Float(0.5));
+  TensorView* tv2 = mul(tv1, new Float(-1.0));
+  TensorView* tv3 = mul(tv1, new Float(-2.0));
+  fusion.addOutput(tv2);
+  fusion.addOutput(tv3);
 
-  auto tv6 = makeSymbolicTensor(2);
-  fusion.addInput(tv6);
+  // This computeAt will affect tv2 as well, even though tv2 is not in
+  // the data-flow path between tv1 and tv3. The reason is that tv1 is
+  // now computed at tv3, so tv2 must also be computed at the same
+  // location. Overall, what will happen is basically we merge
+  // expressions of all tensors and compute them in a single loop
+  // nest.
+  TensorView* computeAtTarget = tv3;
+  computeAtTarget->split(0, 128);
+  tv1->computeAt(computeAtTarget, 1);
 
-  auto tv7 = mul(tv5, tv6);
+  TensorView* affected_tensors[] = {tv1, tv2, tv3};
+  for (auto tv : affected_tensors) {
+    TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
+  }
 
-  fusion.addOutput(tv7);
+  // Note that tv2 is also computed at tv3.
+  TORCH_CHECK(tv1->getComputeAtView() == computeAtTarget);
+  TORCH_CHECK(tv2->getComputeAtView() == tv3);
+  TORCH_CHECK(!tv3->hasComputeAt());
 
-  tv7->split(1, 2);
-  tv7->merge(0);
-  tv7->split(0, 4);
-  tv7->split(0, 128);
+  computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
+  for (auto tv : affected_tensors) {
+    tv->axis(-1)->parallelize(ParallelType::TIDx);
+  }
 
-  tv7->axis(0)->parallelize(ParallelType::BIDx);
-  tv7->axis(1)->parallelize(ParallelType::TIDx);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  tv0->computeAt(tv7, 1);
-  auto tv5_domain = tv5->domain()->domain();
+  at::Tensor t0 = at::randn({1000}, options);
 
-  // These computeAt transformations should not affect the TV5 domain
-  tv0->computeAt(tv4, -1);
-  tv2->computeAt(tv4, -1);
+  auto t1 = t0 * 0.5;
+  auto t2 = t1 * -1.0;
+  auto t3 = t1 * -2.0;
 
-  auto tv5_domain_current = tv5->domain()->domain();
-  TORCH_CHECK(tv5_domain == tv5_domain_current, "Invalid TV5 domain");
+  at::Tensor kernel_tv2 = at::empty_like(t0, options);
+  at::Tensor kernel_tv3 = at::empty_like(t0, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
+  fe.runFusion({t0}, {kernel_tv2, kernel_tv3});
 
-  const int numel_x = 100;
-  const int numel_y = 200;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto t0 = at::randn({numel_x}, options);
-  auto t2 = at::randn({numel_x}, options);
-  auto t6 = at::randn({numel_x, numel_y}, options);
-
-  auto t1 = t0.add(1.0);
-  auto t3 = t2.add(3.0);
-  auto t4 = t1.add(t3);
-  auto t5 = t1.unsqueeze(1);
-  auto t7 = t5.mul(t6);
-
-  std::vector<IValue> aten_inputs = {t0, t2, t6};
-  std::vector<at::Tensor> aten_outputs = {t4, t7};
-
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
+  TORCH_CHECK(at::allclose(kernel_tv2, t2));
+  TORCH_CHECK(at::allclose(kernel_tv3, t3));
 }
 
-TEST(NVFuserTest, FusionAdvancedComputeAt8_CUDA) {
+// Similar to ComputeAtMultiConsumers, but with a common consumer.
+TEST(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) {
+  // tv1 = tv0 * 0.5
+  // tv2 = tv1 * -1
+  // tv3 = tv2 * -2
+  // tv4 = tv2 + tv3
+  // tv5 = tv4 * 5
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
+  TensorView* tv0 = makeDummyTensor(1);
   fusion.addInput(tv0);
 
-  auto tv1 = add(tv0, new Double(1.0));
-
-  auto tv2 = makeSymbolicTensor(1);
-  fusion.addInput(tv2);
+  TensorView* tv1 = mul(tv0, new Float(0.5));
+  TensorView* tv2 = mul(tv1, new Float(-1.0));
+  TensorView* tv3 = mul(tv1, new Float(-2.0));
+  TensorView* tv4 = add(tv2, tv3);
+  TensorView* tv5 = mul(tv4, new Float(5.0));
+  fusion.addOutput(tv3);
+  fusion.addOutput(tv4);
+  fusion.addOutput(tv5);
 
-  auto tv3 = add(tv2, new Double(3.0));
+  // Computing tv1 at tv3. This will affect tv2 as discussed in
+  // ComplexComputeAt1. Additionally, in this case, notice that tv4 is
+  // the common consumer of tv2 and tv3, so they are computed at
+  // tv4. The indirect propagation of the computeAt should stop at the
+  // common consumer, and no further change should occur. More
+  // specifically, tv4 and tv5 should not have a computeAt tensor.
+  TensorView* computeAtTarget = tv3;
+  computeAtTarget->split(0, 128);
+  tv1->computeAt(computeAtTarget, 1);
 
-  auto tv4 = add(tv1, tv3);
-  fusion.addOutput(tv4);
+  TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4};
+  for (auto tv : affected_tensors) {
+    TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
+  }
 
-  auto tv5 = broadcast(tv1, {false, true});
+  TORCH_CHECK(tv1->getComputeAtView() == computeAtTarget);
+  TORCH_CHECK(tv2->getComputeAtView() == tv4);
+  TORCH_CHECK(tv3->getComputeAtView() == tv4);
+  TORCH_CHECK(!tv4->hasComputeAt());
+  TORCH_CHECK(!tv5->hasComputeAt());
 
-  auto tv6 = makeSymbolicTensor(2);
-  fusion.addInput(tv6);
+  computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
 
-  auto tv7 = mul(tv5, tv6);
+  for (auto tv : affected_tensors) {
+    tv->axis(-1)->parallelize(ParallelType::TIDx);
+  }
 
-  fusion.addOutput(tv7);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  tv7->split(1, 2);
-  tv7->merge(0);
-  tv7->split(0, 128, false);
-  tv7->split(0, 4, false);
+  at::Tensor t0 = at::randn({1000}, options);
 
-  tv7->axis(0)->parallelize(ParallelType::BIDx);
-  tv7->axis(1)->parallelize(ParallelType::TIDx);
+  auto t1 = t0 * 0.5;
+  auto t2 = t1 * -1.0;
+  auto t3 = t1 * -2.0;
+  auto t4 = t2 + t3;
+  auto t5 = t4 * 5.0;
 
-  // Reverse computeAt structure from previous test
-  tv0->computeAt(tv4, -1);
-  tv2->computeAt(tv4, -1);
-  tv0->computeAt(tv7, -1);
+  at::Tensor kernel_tv3 = at::empty_like(t0, options);
+  at::Tensor kernel_tv4 = at::empty_like(t0, options);
+  at::Tensor kernel_tv5 = at::empty_like(t0, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
+  fe.runFusion({t0}, {kernel_tv3, kernel_tv4, kernel_tv5});
 
-  const int numel_x = 100;
-  const int numel_y = 200;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto t0 = at::randn({numel_x}, options);
-  auto t2 = at::randn({numel_x}, options);
-  auto t6 = at::randn({numel_x, numel_y}, options);
-
-  auto t1 = t0.add(1.0);
-  auto t3 = t2.add(3.0);
-  auto t4 = t1.add(t3);
-  auto t5 = t1.unsqueeze(1);
-  auto t7 = t5.mul(t6);
-
-  std::vector<IValue> aten_inputs = {t0, t2, t6};
-  std::vector<at::Tensor> aten_outputs = {t4, t7};
-
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
+  TORCH_CHECK(at::allclose(kernel_tv3, t3));
+  TORCH_CHECK(at::allclose(kernel_tv4, t4));
+  TORCH_CHECK(at::allclose(kernel_tv5, t5));
 }
 
-TEST(NVFuserTest, FusionAdvancedComputeWith1_CUDA) {
-  // Case 1
+TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) {
   // tv1 = tv0 * 0.5
   // tv2 = tv1 * -1
-  // tv3 = tv1 + 3
-  // tv4 = tv1 * 2
-  // tv5 = tv3 + tv2
-  // tv6 = tv5 + tv4
-  // tv7 = tv1 + tv4
+  // tv3 = tv2 * -1
+  // tv4 = tv1 + 4
+  // tv5 = tv3 + tv4
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
 
-  TensorView* tv1 = mul(tv0, new Double(0.5));
-  TensorView* tv2 = mul(tv1, new Double(-1.0));
-  TensorView* tv3 = add(tv1, new Double(3.0));
-  TensorView* tv4 = mul(tv1, new Double(2.0));
-  TensorView* tv5 = add(tv3, tv2);
-
-  TensorView* tv6 = add(tv5, tv4);
-  TensorView* tv7 = add(tv1, tv4);
-
-  fusion.addOutput(tv6);
-  fusion.addOutput(tv7);
-
-  // Lets setup to actually run
-  tv0->merge(0);
-  tv0->split(0, 128);
-  tv0->split(0, 4);
+  TensorView* tv1 = mul(tv0, new Float(0.5));
+  TensorView* tv2 = mul(tv1, new Float(-1.0));
+  TensorView* tv3 = mul(tv2, new Float(-1.0));
+  TensorView* tv4 = add(tv1, new Float(4.0));
+  TensorView* tv5 = add(tv3, tv4);
 
-  tv0->axis(0)->parallelize(ParallelType::BIDx);
+  fusion.addOutput(tv5);
 
-  tv0->computeWith(tv7, 1);
+  TensorView* computeAtTarget = tv3;
 
-  GpuLower gpulw(&fusion);
+  computeAtTarget->merge(0);
+  computeAtTarget->split(0, 128);
+  computeAtTarget->split(0, 4);
 
-  // The this-position of the last tensor should be zero.
-  TORCH_CHECK(
-      tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 &&
-      tv7->getMaxProducerPosition() == 1);
-  TORCH_CHECK(
-      tv7->nDims() == 3 && tv6->getComputeAtPosition() == 0 &&
-      tv6->getMaxProducerPosition() == 1);
+  computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
 
-  // The position of every other tensor should be 1.
-  for (auto tv : {tv1, tv2, tv3, tv4, tv5}) {
-    TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1);
-    TORCH_CHECK(gpulw.caLoopMap().areMapped(tv7->axis(0), tv->axis(0)));
-  }
+  // This computeAt will affect all tensors including tv3, tv4 and
+  // tv5, even though it appears to impact only tv1 and tv2. The
+  // reason is that tv1 is now computed at tv3, so tv4 must also be
+  // computed at the same location. Similarly, the consumer of tv4,
+  // tv5, must also be computed at the same location. Overall, what
+  // will happen is basically we merge expressions of all tensors and
+  // compute them in a single loop nest. Internally, this will be
+  // realized by making all tensors, except for those in the path
+  // between tv1 and tv3, computed at tv5, which we call the common
+  // consumer.
+  tv1->computeAt(computeAtTarget, 1);
 
+  // All tensors should have the same dimenionality as the target
   for (Val* val : fusion.vals()) {
-    if (!fusion.hasInput(val) &&
-        val->getValType().value() == ValType::TensorView) {
-      TensorView* tv = static_cast<TensorView*>(val);
+    if (fusion.hasInput(val) ||
+        val->getValType().value() != ValType::TensorView) {
+      continue;
+    }
+    TensorView* tv = val->as<TensorView>();
+    TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
+  }
+
+  TORCH_CHECK(tv1->getComputeAtView() == tv2);
+  TORCH_CHECK(tv2->getComputeAtView() == tv3);
+  // tv3 and tv4 are computed at tv5
+  TORCH_CHECK(tv3->getComputeAtView() == tv5);
+  TORCH_CHECK(tv4->getComputeAtView() == tv5);
+  TORCH_CHECK(!tv5->hasComputeAt());
+
+  for (Val* val : fusion.vals()) {
+    if (!fusion.hasInput(val) &&
+        val->getValType().value() == ValType::TensorView) {
+      TensorView* tv = val->as<TensorView>();
       tv->axis(1)->parallelize(ParallelType::Unroll);
       tv->axis(-1)->parallelize(ParallelType::TIDx);
     }
@@ -2047,10254 +1931,2090 @@ TEST(NVFuserTest, FusionAdvancedComputeWith1_CUDA) {
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  at::Tensor aten_input = at::randn({129, 127}, options);
+  at::Tensor t0 = at::randn({129, 127}, options);
 
-  auto t1 = aten_input.mul({0.5});
+  auto t1 = t0.mul({0.5});
   auto t2 = t1.mul({-1.0});
-  auto t3 = t1.add({3.0});
-  auto t4 = t1.mul({2.0});
-  auto t5 = t3.add(t2);
-  auto t6 = t5.add(t4);
-  auto t7 = t1.add(t4);
+  auto t3 = t2.mul({-1.0});
+  auto t4 = t1.add({4.0});
+  auto t5 = t3 + t4;
 
-  std::vector<at::Tensor> aten_outputs = {t6, t7};
-  std::vector<at::Tensor> cg_outputs = {
-      at::empty_like(aten_input, options), at::empty_like(aten_input, options)};
+  at::Tensor kernel_tv5 = at::empty_like(t0, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  fe.runFusion({aten_input}, cg_outputs);
+  fe.runFusion({t0}, {kernel_tv5});
 
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
+  TORCH_CHECK(at::allclose(kernel_tv5, t5));
 }
 
-TEST(NVFuserTest, FusionAdvancedComputeWith2_CUDA) {
-  // Case 2
-  // tv1 = tv0 * -1
-  // tv2 = tv0 + 3
-  // tv3 = tv0 * 2
-  // tv4 = tv2 + tv1
-  // tv5 = tv4 + tv3
-  // tv6 = tv5 + tv3
+// Similar to the above common consumer test but adds an additional
+// tensor that has no common consumer with the other tensors.
+TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) {
+  // tv1 = tv0 * 0.5
+  // tv2 = tv1 * -1
+  // tv3 = tv2 * -1
+  // tv4 = tv1 + 4
+  // tv5 = tv2 + tv3
+  // tv6 = tv1 + 6
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
 
-  TensorView* tv1 = mul(tv0, new Double(-1.0));
-  TensorView* tv2 = add(tv0, new Double(3.0));
-  TensorView* tv3 = mul(tv0, new Double(2.0));
-  TensorView* tv4 = add(tv2, tv1);
-
-  TensorView* tv5 = add(tv4, tv3);
-  TensorView* tv6 = add(tv5, tv3);
+  TensorView* tv1 = mul(tv0, new Float(0.5));
+  TensorView* tv2 = mul(tv1, new Float(-1.0));
+  TensorView* tv3 = mul(tv2, new Float(-1.0));
+  TensorView* tv4 = add(tv1, new Float(4.0));
+  TensorView* tv5 = add(tv3, tv4);
+  TensorView* tv6 = add(tv1, new Float(6.0));
 
   fusion.addOutput(tv5);
   fusion.addOutput(tv6);
 
-  // Lets setup to actually run
-  tv0->merge(0);
-  tv0->split(0, 128);
-  tv0->split(0, 4);
+  TensorView* computeAtTarget = tv3;
+
+  computeAtTarget->merge(0);
+  computeAtTarget->split(0, 128);
+  computeAtTarget->split(0, 4);
+
+  computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
+
+  // This will have the same impact on the tensors except for tv5 and
+  // tv6. tv6 does not have any common consumer with the computeAt
+  // target, but since it uses tv1, it must be also computed at the
+  // same location as the other impacted tensors. We can either make
+  // tv5 computed at tv6 or tv6 computed at tv5. In this case, tv5
+  // should be computed at tv6 just because the current implementation
+  // orders the computeAt relationship based on the order in which
+  // tensors are specified as outputs.
+
+  tv1->computeAt(computeAtTarget, 1);
+
+  // All tensors should have the same dimenionality as the target
+  for (Val* val : fusion.vals()) {
+    if (fusion.hasInput(val) ||
+        val->getValType().value() != ValType::TensorView) {
+      continue;
+    }
+    TensorView* tv = val->as<TensorView>();
+    TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
+  }
 
-  tv0->axis(0)->parallelize(ParallelType::BIDx);
+  TORCH_CHECK(tv1->getComputeAtView() == tv2);
+  TORCH_CHECK(tv2->getComputeAtView() == tv3);
 
-  tv0->computeWith(tv6, 1);
+  // tv3 and tv4 are computed at tv5
+  TORCH_CHECK(tv3->getComputeAtView() == tv5);
+  TORCH_CHECK(tv4->getComputeAtView() == tv5);
+
+  // tv5 should be computed at tv6 since tv5 is added as an output
+  // before tv6. If we call fusion.addOutput(tv6) first, tv6 should be
+  // computed at tv5.
+  TORCH_CHECK(tv5->getComputeAtView() == tv6);
+  TORCH_CHECK(!tv6->hasComputeAt());
 
   for (Val* val : fusion.vals()) {
     if (!fusion.hasInput(val) &&
         val->getValType().value() == ValType::TensorView) {
-      TensorView* tv = static_cast<TensorView*>(val);
-
+      TensorView* tv = val->as<TensorView>();
       tv->axis(1)->parallelize(ParallelType::Unroll);
       tv->axis(-1)->parallelize(ParallelType::TIDx);
     }
   }
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({129, 127}, options);
 
-  auto t1 = input.mul({-1.0});
-  auto t2 = input.add({3.0});
-  auto t3 = input.mul({2.0});
-  auto t4 = t2.add(t1);
-  auto t5 = t4.add(t3);
-  auto t6 = t5.add(t3);
+  at::Tensor t0 = at::randn({129, 127}, options);
 
-  std::vector<at::Tensor> aten_outputs = {t5, t6};
+  auto t1 = t0.mul({0.5});
+  auto t2 = t1.mul({-1.0});
+  auto t3 = t2.mul({-1.0});
+  auto t4 = t1.add({4.0});
+  auto t5 = t3 + t4;
+  auto t6 = t1.add({6.0});
+
+  at::Tensor kernel_tv5 = at::empty_like(t0, options);
+  at::Tensor kernel_tv6 = at::empty_like(t0, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input});
+  fe.runFusion({t0}, {kernel_tv5, kernel_tv6});
 
-  testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__);
+  TORCH_CHECK(at::allclose(kernel_tv5, t5));
+  TORCH_CHECK(at::allclose(kernel_tv6, t6));
 }
 
-TEST(NVFuserTest, FusionAdvancedComputeWith3_CUDA) {
-  // Case 3
-  // T2 = T1 * 0.979361
-  // T3 = T2 * T0
+// Similar to ComputeAtCommonConsumer1 but with an addtiona ltensor
+// that does not have data dependency with the consumer.
+TEST(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) {
+  // tv1 = tv0 * 0.5
+  // tv2 = tv1 * -1
+  // tv3 = tv1 * -2
+  // tv4 = tv2 + tv3
+  // tv5 = tv4 * 5
+  // tv6 = tv1 * 6
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(4);
+  TensorView* tv0 = makeDummyTensor(1);
   fusion.addInput(tv0);
 
-  TensorView* tv1 = makeSymbolicTensor(4);
-  fusion.addInput(tv1);
-
-  TensorView* tv2 = mul(tv1, new Double(.979361));
-  TensorView* tv3 = mul(tv2, tv0);
-
+  TensorView* tv1 = mul(tv0, new Float(0.5));
+  TensorView* tv2 = mul(tv1, new Float(-1.0));
+  TensorView* tv3 = mul(tv1, new Float(-2.0));
+  TensorView* tv4 = add(tv2, tv3);
+  TensorView* tv5 = mul(tv4, new Float(5.0));
+  // Notice that tv6 is not a consumer of tv4.
+  TensorView* tv6 = mul(tv1, new Float(6.0));
   fusion.addOutput(tv3);
+  fusion.addOutput(tv4);
+  fusion.addOutput(tv5);
+  fusion.addOutput(tv6);
 
-  // Lets setup to actually run
-  while (tv0->nDims() > 1)
-    tv0->merge(0);
-  tv0->split(0, 128);
-  tv0->split(0, 4);
-
-  while (tv1->nDims() > 1)
-    tv1->merge(0);
-  tv1->split(0, 128);
-  tv1->split(0, 4);
+  TensorView* computeAtTarget = tv3;
+  computeAtTarget->split(0, 128);
+  tv1->computeAt(computeAtTarget, 1);
 
-  tv0->computeWith(tv3, 1);
-  tv1->computeWith(tv3, 1);
+  TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4, tv6};
+  for (auto tv : affected_tensors) {
+    TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
+  }
 
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
+  TORCH_CHECK(tv1->getComputeAtView() == computeAtTarget);
+  TORCH_CHECK(tv2->getComputeAtView() == tv4);
+  TORCH_CHECK(tv3->getComputeAtView() == tv4);
+  TORCH_CHECK(tv4->getComputeAtView() == tv5);
+  TORCH_CHECK(tv5->getComputeAtView() == tv6);
+  TORCH_CHECK(!tv6->hasComputeAt());
 
-  for (Val* val : fusion.vals()) {
-    if (!fusion.hasInput(val) &&
-        val->getValType().value() == ValType::TensorView) {
-      TensorView* tv = static_cast<TensorView*>(val);
+  computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
 
-      tv->axis(1)->parallelize(ParallelType::Unroll);
-      tv->axis(-1)->parallelize(ParallelType::TIDx);
-    }
+  for (auto tv : affected_tensors) {
+    tv->axis(-1)->parallelize(ParallelType::TIDx);
   }
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({129, 127, 63, 65}, options);
-  at::Tensor t1 = at::rand_like(t0, options);
 
-  auto t2 = t1.mul({0.979361});
-  auto aten_output = t2.mul(t0);
+  at::Tensor t0 = at::randn({1000}, options);
 
-  std::vector<IValue> aten_inputs = {t0, t1};
+  auto t1 = t0 * 0.5;
+  auto t2 = t1 * -1.0;
+  auto t3 = t1 * -2.0;
+  auto t4 = t2 + t3;
+  auto t5 = t4 * 5.0;
+  auto t6 = t1 * 6.0;
 
-  at::Tensor cg_output = at::empty_like(t0, options);
+  at::Tensor kernel_tv3 = at::empty_like(t0, options);
+  at::Tensor kernel_tv4 = at::empty_like(t0, options);
+  at::Tensor kernel_tv5 = at::empty_like(t0, options);
+  at::Tensor kernel_tv6 = at::empty_like(t0, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  fe.runFusion(aten_inputs, {cg_output});
+  fe.runFusion({t0}, {kernel_tv3, kernel_tv4, kernel_tv5, kernel_tv6});
 
-  testValidate(
-      &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  TORCH_CHECK(at::allclose(kernel_tv3, t3));
+  TORCH_CHECK(at::allclose(kernel_tv4, t4));
+  TORCH_CHECK(at::allclose(kernel_tv5, t5));
+  TORCH_CHECK(at::allclose(kernel_tv6, t6));
 }
 
-TEST(NVFuserTest, FusionAdvancedComputeWith4_CUDA) {
-  // Case 4
-  // T4 = T2 - T3
-  // T5 = T1 + T4
-  // T6 = T5 - T0
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+namespace {
 
-  TensorView* tv0 = makeSymbolicTensor(4);
-  fusion.addInput(tv0);
+void checkConcretized(
+    TensorView* v0,
+    int a0,
+    TensorView* v1,
+    int a1,
+    bool should_concretize) {
+  if (should_concretize) {
+    TORCH_CHECK(
+        IterDomain::concretizeDomain(v0->axis(a0))->sameAs(v1->axis(a1)));
+  } else {
+    TORCH_CHECK(
+        !IterDomain::concretizeDomain(v0->axis(a0))->sameAs(v1->axis(a1)));
+  }
+}
 
-  TensorView* tv1 = makeSymbolicTensor(4);
-  fusion.addInput(tv1);
+} // namespace
 
-  TensorView* tv2 = makeSymbolicTensor(4);
-  fusion.addInput(tv2);
+TEST(NVFuserTest, FusionBCastConcretizeBasic_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  TensorView* tv3 = makeSymbolicTensor(4);
-  fusion.addInput(tv3);
+  // tv0: [I I]
+  TensorView* tv0 = makeDummyTensor(2);
 
-  TensorView* tv4 = sub(tv2, tv3);
-  TensorView* tv5 = add(tv1, tv4);
-  TensorView* tv6 = sub(tv5, tv0);
+  // tv1: [I I I]
+  TensorView* tv1 = makeDummyTensor(3);
 
-  fusion.addOutput(tv6);
-  std::vector<TensorView*> tvs = {tv0, tv1, tv2};
-  for (auto tv : tvs) {
-    // Lets setup to actually run
-    while (tv->nDims() > 1) {
-      tv->merge(0);
-    }
-    tv->split(0, 128);
-    tv->split(0, 4);
-    tv->computeWith(tv6, 1);
-  }
+  fusion.addInput(tv0);
+  fusion.addInput(tv1);
 
-  tv6->axis(0)->parallelize(ParallelType::BIDx);
+  // tv2*: [B I I]
+  auto tv2_0 = broadcast(tv0, {true, false, false});
+  auto tv2_1 = broadcast(tv0, {true, false, false});
+  auto tv2 = add(tv2_0, tv2_1);
 
-  for (Val* val : fusion.vals()) {
-    if (!fusion.hasInput(val) &&
-        val->getValType().value() == ValType::TensorView) {
-      TensorView* tv = static_cast<TensorView*>(val);
+  // tv3: [I I I]
+  auto tv3 = add(tv2, tv1);
 
-      tv->axis(1)->parallelize(ParallelType::Unroll);
-      tv->axis(-1)->parallelize(ParallelType::TIDx);
-    }
-  }
+  fusion.addOutput(tv3);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({129, 127, 63, 65}, options);
-  at::Tensor t1 = at::rand_like(t0, options);
-  at::Tensor t2 = at::rand_like(t0, options);
-  at::Tensor t3 = at::rand_like(t0, options);
+  checkConcretized(tv2, 0, tv1, 0, true);
+  checkConcretized(tv2_0, 0, tv1, 0, true);
+  checkConcretized(tv2_1, 0, tv1, 0, true);
+  checkConcretized(tv2_0, 1, tv1, 0, false);
+  checkConcretized(tv2_0, 0, tv1, 1, false);
+}
 
-  auto t4 = t2.sub(t3);
-  auto t5 = t1.add(t4);
-  auto aten_output = t5.sub(t0);
+TEST(NVFuserTest, FusionBCastConcretizeRfactor_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  std::vector<IValue> aten_inputs = {t0, t1, t2, t3};
+  // both tv0 and tv1 = [I, I]
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = makeDummyTensor(2);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  //[B,I,I]
+  auto tv2 = broadcast(tv1, {true, false, false});
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+  //[B,I,R]
+  auto tv3 = sum(tv2, {2});
 
-TEST(NVFuserTest, FusionAdvancedComputeWith5_CUDA) {
-  // Case 5
-  // tv2 = tv0 + 2.0
-  // tv3 = tv1 * tv2
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  auto tv5 = add(tv3, tv1);
 
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
   fusion.addInput(tv0);
-  TensorView* tv1 = makeSymbolicTensor(2);
   fusion.addInput(tv1);
-  TensorView* tv2 = add(tv0, new Double(2.0));
-  TensorView* tv3 = mul(tv1, tv2);
-  fusion.addOutput(tv3);
-
-  tv2->merge(0);
-  tv2->split(-1, 8);
-  tv2->split(-1, 4);
-
-  tv2->computeWith(tv3, 1);
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
+  fusion.addOutput(tv5);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({63, 65}, options);
-  at::Tensor t1 = at::rand_like(t0, options);
+  // scheduling:
+  //[B,I,R0,R1=128], root = [B,I,R]
+  tv3->split(2, 128);
 
-  auto t2 = t0.add(2.0);
-  auto aten_output = t1.mul(t2);
+  // root=[B,I,Irf], rfactor=[B,I,Irf,Rrf]
+  auto tv4 = tv3->rFactor({3});
 
-  std::vector<IValue> aten_inputs = {t0, t1};
+  checkConcretized(tv2, 0, tv5, 0, true);
+  checkConcretized(tv4, 0, tv5, 0, true);
+  checkConcretized(tv3, 0, tv5, 0, true);
+}
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+namespace {
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+void checkIdProvedEquivalent(
+    TensorView* v0,
+    int a0,
+    TensorView* v1,
+    int a1,
+    bool should_prove) {
+  if (should_prove) {
+    TORCH_CHECK(IterDomain::proveEquivalent(v0->axis(a0), v1->axis(a1)));
+  } else {
+    TORCH_CHECK(!IterDomain::proveEquivalent(v0->axis(a0), v1->axis(a1)));
+  }
 }
 
-TEST(NVFuserTest, FusionAdvancedComputeWith6_CUDA) {
+} // namespace
+
+TEST(NVFuserTest, FusionProveIdEqBasic_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = makeDummyTensor(2);
+  TensorView* tv2 = makeDummyTensor(3);
+
   fusion.addInput(tv0);
-  TensorView* tv1 = makeSymbolicTensor(2);
   fusion.addInput(tv1);
-  TensorView* tv2 = add(tv0, new Double(2.0));
-  TensorView* tv3 = mul(tv1, tv2);
-  fusion.addOutput(tv3);
+  auto tv3 = broadcast(tv0, {true, false, false});
+  auto tv4 = broadcast(tv1, {false, true, false});
+  auto tv5 = add(tv3, tv4);
+  fusion.addOutput(tv5);
 
-  tv2->merge(0);
-  tv2->split(-1, 8);
-  tv2->split(-1, 4);
-  tv3->merge(0);
-  tv3->split(-1, 8);
+  checkIdProvedEquivalent(tv0, 0, tv4, 1, true);
+  checkIdProvedEquivalent(tv1, 0, tv4, 0, true);
+  checkIdProvedEquivalent(tv1, 1, tv0, 1, true);
+  checkIdProvedEquivalent(tv0, 0, tv5, 1, true);
+  checkIdProvedEquivalent(tv1, 1, tv5, 2, true);
+  checkIdProvedEquivalent(tv0, 0, tv1, 0, false);
+  checkIdProvedEquivalent(tv0, 1, tv1, 0, false);
+  checkIdProvedEquivalent(tv0, 0, tv1, 1, false);
+}
 
-  tv2->computeWith(tv3, 1);
+TEST(NVFuserTest, FusionProveIdEqRfactor_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
+  // [I,I]
+  TensorView* tv0 = makeDummyTensor(2);
+  // [I,I,I]
+  TensorView* tv1 = makeDummyTensor(3);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({63, 65}, options);
-  at::Tensor t1 = at::rand_like(t0, options);
+  //[I,I,R]
+  auto tv2 = sum(tv1, {2});
 
-  auto t2 = t0.add(2.0);
-  auto aten_output = t1.mul(t2);
+  auto tv5 = add(tv2, tv0);
+
+  fusion.addInput(tv0);
+  fusion.addInput(tv1);
+  fusion.addOutput(tv5);
 
-  std::vector<IValue> aten_inputs = {t0, t1};
+  // scheduling:
+  //[B,I,R0,R1=128], root = [B,I,R]
+  tv2->split(2, 128);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  // root=[B,I,Irf], rfactor=[B,I,Irf,Rrf]
+  auto tv3 = tv2->rFactor({3});
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  checkIdProvedEquivalent(tv1, 0, tv0, 0, true);
+  checkIdProvedEquivalent(tv2, 0, tv0, 0, true);
+  checkIdProvedEquivalent(tv3, 0, tv0, 0, true);
 }
 
-TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) {
-  // tv1 = tv0 * 0.5
-  // tv2 = tv1 * -1
-  // tv3 = tv2 * -2
+TEST(NVFuserTest, FusionScalarInputs_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(1);
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
+  TensorView* tv1 = makeDummyTensor(2);
+  fusion.addInput(tv1);
 
-  TensorView* tv1 = mul(tv0, new Double(0.5));
-  TensorView* tv2 = mul(tv1, new Double(-1.0));
-  TensorView* tv3 = mul(tv1, new Double(-2.0));
-  fusion.addOutput(tv2);
-  fusion.addOutput(tv3);
-
-  // This computeAt will affect tv2 as well, even though tv2 is not in
-  // the data-flow path between tv1 and tv3. The reason is that tv1 is
-  // now computed at tv3, so tv2 must also be computed at the same
-  // location. Overall, what will happen is basically we merge
-  // expressions of all tensors and compute them in a single loop
-  // nest.
-  TensorView* computeAtTarget = tv3;
-  computeAtTarget->split(0, 128);
-  tv1->computeAt(computeAtTarget, 1);
+  Float* f0 = new Float();
+  fusion.addInput(f0);
+  Float* f1 = new Float();
+  fusion.addInput(f1);
+  Float* f2 = new Float();
+  fusion.addInput(f2);
+  Float* f3 = new Float();
+  fusion.addInput(f3);
+  Val* f4 = mul(f0, f1);
+  Val* f5 = sub(f2, f3);
+
+  TensorView* tv2 = sub(tv1, f4);
+  TensorView* tv3 = add(tv0, f5);
+  TensorView* tv4 = mul(tv3, tv2);
 
-  TensorView* affected_tensors[] = {tv1, tv2, tv3};
-  for (auto tv : affected_tensors) {
-    TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
-  }
+  fusion.addOutput(tv4);
 
-  GpuLower gpulw(&fusion);
+  // Lets setup to actually run
+  while (tv4->nDims() > 1)
+    tv4->merge(0);
+  tv4->split(0, 128);
+  tv4->split(0, 4);
 
-  TORCH_CHECK(tv1->getComputeAtPosition() == 1);
-  TORCH_CHECK(
-      tv2->getComputeAtPosition() == 0 && tv2->getMaxProducerPosition() == 1);
-  TORCH_CHECK(
-      tv3->getComputeAtPosition() == 0 && tv3->getMaxProducerPosition() == 1);
+  tv0->computeAt(tv4, 1);
+  tv1->computeAt(tv4, 1);
 
-  // Note that tv2 is also computed at tv3.
-  for (auto tv : {tv1, tv2}) {
-    TORCH_CHECK(
-        gpulw.caLoopMap().areMapped(tv->axis(0), computeAtTarget->axis(0)));
-  }
+  tv4->axis(0)->parallelize(ParallelType::BIDx);
 
-  TORCH_CHECK(tv3->getComputeAtPosition() == 0);
+  for (Val* val : fusion.vals()) {
+    if (!fusion.hasInput(val) &&
+        val->getValType().value() == ValType::TensorView) {
+      TensorView* tv = static_cast<TensorView*>(val);
 
-  computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
-  for (auto tv : affected_tensors) {
-    tv->axis(-1)->parallelize(ParallelType::TIDx);
+      tv->axis(1)->parallelize(ParallelType::Unroll);
+      tv->axis(-1)->parallelize(ParallelType::TIDx);
+    }
   }
 
+  // f4 = f0 * f1
+  // f5 = f2 - f3
+  // t2 = t1 - f4
+  // t3 = t0 + f5
+  // t4 = t3 * t2
+
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  at::Tensor aten_input = at::randn({1000}, options);
+  float fl0 = 0.1;
+  float fl1 = -0.2;
+  float fl2 = 0.3;
+  float fl3 = -0.4;
+  float fl4 = fl0 * fl1;
+  float fl5 = fl2 - fl3;
+
+  at::Tensor t0 = at::randn({129, 127}, options);
+  at::Tensor t1 = at::rand_like(t0, options);
 
-  auto t1 = aten_input * 0.5;
-  auto t2 = t1 * -1.0;
-  auto t3 = t1 * -2.0;
+  auto t2 = t1.sub(fl4);
+  auto t3 = t0.add(fl5);
+  auto t4 = t3.mul(t2);
 
-  std::vector<at::Tensor> aten_outputs = {t2, t3};
+  at::Tensor kernel_tv4 = at::empty_like(t0, options);
 
-  std::vector<at::Tensor> cg_outputs = {
-      at::empty_like(aten_input, options), at::empty_like(aten_input, options)};
+  at::Scalar test(fl0);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  fe.runFusion({aten_input}, cg_outputs);
+  fe.runFusion(
+      {t0,
+       t1,
+       at::Scalar(fl0),
+       at::Scalar(fl1),
+       at::Scalar(fl2),
+       at::Scalar(fl3)},
+      {kernel_tv4});
 
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
+  TORCH_CHECK(at::allclose(kernel_tv4, t4));
 }
 
-// Similar to ComputeAtMultiConsumers, but with a common consumer.
-TEST(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) {
-  // tv1 = tv0 * 0.5
-  // tv2 = tv1 * -1
-  // tv3 = tv2 * -2
-  // tv4 = tv2 + tv3
-  // tv5 = tv4 * 5
+TEST(NVFuserTest, FusionLoopUnroll_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(1);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(3);
+  TensorView* tv1 = makeDummyTensor(3);
+
+  // Register your inputs
   fusion.addInput(tv0);
+  fusion.addInput(tv1);
 
-  TensorView* tv1 = mul(tv0, new Double(0.5));
-  TensorView* tv2 = mul(tv1, new Double(-1.0));
-  TensorView* tv3 = mul(tv1, new Double(-2.0));
-  TensorView* tv4 = add(tv2, tv3);
-  TensorView* tv5 = mul(tv4, new Double(5.0));
-  fusion.addOutput(tv3);
-  fusion.addOutput(tv4);
-  fusion.addOutput(tv5);
+  // Do math with it, it returns a `Val*` but can be static_casted back to
+  // TensorView
+  TensorView* tv2 = add(tv1, new Float(2.0));
+  TensorView* tv3 = add(tv0, tv2);
 
-  // Computing tv1 at tv3. This will affect tv2 as discussed in
-  // ComplexComputeAt1. Additionally, in this case, notice that tv4 is
-  // the common consumer of tv2 and tv3, so they are computed at
-  // tv4. The indirect propagation of the computeAt should stop at the
-  // common consumer, and no further change should occur. More
-  // specifically, the computeAT position of tv4 and tv5 should be zero.
-  TensorView* computeAtTarget = tv3;
-  computeAtTarget->split(0, 128);
-  tv1->computeAt(computeAtTarget, 1);
+  // Register your outputs
+  fusion.addOutput(tv3);
 
-  TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4};
-  for (auto tv : affected_tensors) {
-    TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
-  }
+  int block_size = 16;
 
-  TORCH_CHECK(tv1->getComputeAtPosition() == 1);
-  TORCH_CHECK(tv2->getComputeAtPosition() == 1);
-  TORCH_CHECK(tv3->getComputeAtPosition() == 1);
-  TORCH_CHECK(tv4->getComputeAtPosition() == 0);
-  TORCH_CHECK(tv5->getComputeAtPosition() == 0);
+  tv3->merge(0, 1);
+  tv3->merge(0, 1);
 
-  computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
+  tv3->split(0, block_size);
+  tv3->split(0, 4);
 
-  for (auto tv : affected_tensors) {
-    tv->axis(-1)->parallelize(ParallelType::TIDx);
-  }
+  // For all inputs, computeAt the output inline, temporaries should be squeezed
+  // between them
+  tv0->computeAt(tv3, 1);
+  tv1->computeAt(tv3, 1);
 
-  // Transform tv5 to make it look like the rest
-  tv5->split(0, 128);
-  tv5->axis(1)->parallelize(ParallelType::TIDx);
-  tv5->axis(0)->parallelize(ParallelType::BIDx);
+  // Parallelize
+  tv2->axis(1)->parallelize(ParallelType::Unroll);
+  tv3->axis(1)->parallelize(ParallelType::Unroll);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  tv3->axis(0)->parallelize(ParallelType::BIDx);
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  at::Tensor aten_input = at::randn({1000}, options);
-
-  auto t1 = aten_input * 0.5;
-  auto t2 = t1 * -1.0;
-  auto t3 = t1 * -2.0;
-  auto t4 = t2 + t3;
-  auto t5 = t4 * 5.0;
-
-  std::vector<at::Tensor> aten_outputs = {t3, t4, t5};
-  std::vector<at::Tensor> cg_outputs = {
-      at::empty_like(aten_input, options),
-      at::empty_like(aten_input, options),
-      at::empty_like(aten_input, options)};
+  at::Tensor input0 = at::rand({129, 13, 3}, options);
+  at::Tensor input1 = at::rand({129, 13, 3}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  fe.runFusion({aten_input}, cg_outputs);
+  auto outputs = fe.runFusion({input0, input1});
 
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
+  TORCH_CHECK(outputs[0].equal(input0.add(input1.add(2.0))));
 }
 
-TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) {
-  // tv1 = tv0 * 0.5
-  // tv2 = tv1 * -1
-  // tv3 = tv2 * -1
-  // tv4 = tv1 + 4
-  // tv5 = tv3 + tv4
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = mul(tv0, new Double(0.5));
-  TensorView* tv2 = mul(tv1, new Double(-1.0));
-  TensorView* tv3 = mul(tv2, new Double(-1.0));
-  TensorView* tv4 = add(tv1, new Double(4.0));
-  TensorView* tv5 = add(tv3, tv4);
-
-  fusion.addOutput(tv5);
-
-  TensorView* computeAtTarget = tv3;
-
-  computeAtTarget->merge(0);
-  computeAtTarget->split(0, 128);
-  computeAtTarget->split(0, 4);
+/*
+ * Helper function for single op testing that generates a codegen operand
+ */
 
-  computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
+Val* gen_jit_operand(std::pair<ValType, DataType> desc) {
+  if (desc.first == ValType::TensorView) {
+    return makeDummyTensor(2, desc.second);
+  } else if (desc.first == ValType::Scalar) {
+    if (desc.second == DataType::Float)
+      return new Float();
+    else if (desc.second == DataType::Int)
+      return new Int();
+    else
+      TORCH_CHECK(false, "Not currently supported type", desc.first);
+  } else {
+    TORCH_CHECK(false, "Not currently supported type", desc.first);
+  }
+  return nullptr;
+}
 
-  // This computeAt will affect all tensors including tv3, tv4 and
-  // tv5, even though it appears to impact only tv1 and tv2. The
-  // reason is that tv1 is now computed at tv3, so tv4 must also be
-  // computed at the same location. Similarly, the consumer of tv4,
-  // tv5, must also be computed at the same location. Overall, what
-  // will happen is basically we merge expressions of all tensors and
-  // compute them in a single loop nest. Internally, this will be
-  // realized by making all tensors, except for those in the path
-  // between tv1 and tv3, computed at tv5, which we call the common
-  // consumer.
-  tv1->computeAt(computeAtTarget, 1);
+/*
+ * Helper function for single op testing that generates an ATen operand
+ */
 
-  // All tensors should have the same dimenionality as the target
-  for (Val* val : fusion.vals()) {
-    if (fusion.hasInput(val) ||
-        val->getValType().value() != ValType::TensorView) {
-      continue;
-    }
-    TensorView* tv = val->as<TensorView>();
-    TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
-    if (tv == tv5) {
-      TORCH_CHECK(tv->getComputeAtPosition() == 0);
+IValue gen_aten_operand(
+    std::pair<ValType, DataType> desc,
+    int blocks,
+    int threads,
+    bool rand) {
+  if (desc.first == ValType::TensorView) {
+    if (desc.second == DataType::Float) {
+      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+      if (rand)
+        return IValue(at::rand({blocks, threads}, options));
+      else
+        return IValue(at::empty({blocks, threads}, options));
+    } else if (desc.second == DataType::Half) {
+      auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
+      if (rand)
+        return IValue(at::rand({blocks, threads}, options));
+      else
+        return IValue(at::empty({blocks, threads}, options));
+    } else if (desc.second == DataType::Bool) {
+      if (rand) {
+        auto options =
+            at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+        return IValue(at::rand({blocks, threads}, options).to(at::kBool));
+      } else {
+        auto options =
+            at::TensorOptions().dtype(at::kBool).device(at::kCUDA, 0);
+        return IValue(at::empty({blocks, threads}, options));
+      }
     } else {
-      TORCH_CHECK(tv->getComputeAtPosition() == 1);
-    }
-  }
-
-  for (auto tv : ir_utils::filterByType<TensorView>(fusion.vals())) {
-    if (!fusion.hasInput(tv)) {
-      tv->axis(1)->parallelize(ParallelType::Unroll);
-      tv->axis(-1)->parallelize(ParallelType::TIDx);
+      TORCH_CHECK("Not currently supported type", desc.second)
     }
+  } else if (desc.first == ValType::Scalar) {
+    if (desc.second == DataType::Float)
+      return IValue(at::Scalar(1.f));
+    else if (desc.second == DataType::Int)
+      return IValue(at::Scalar(1));
+    else
+      TORCH_CHECK("Not currently supported type", desc.first);
+  } else {
+    TORCH_CHECK("Not currently supported type", desc.first);
   }
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor aten_input = at::randn({129, 127}, options);
-
-  auto t1 = aten_input.mul({0.5});
-  auto t2 = t1.mul({-1.0});
-  auto t3 = t2.mul({-1.0});
-  auto t4 = t1.add({4.0});
-  auto aten_output = t3 + t4;
-
-  at::Tensor cg_output = at::empty_like(aten_input, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({aten_input}, {cg_output});
-
-  testValidate(
-      &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
+  return nullptr;
 }
 
-// Similar to the above common consumer test but adds an additional
-// tensor that has no common consumer with the other tensors.
-TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) {
-  // tv1 = tv0 * 0.5
-  // tv2 = tv1 * -1
-  // tv3 = tv2 * -1
-  // tv4 = tv1 + 4
-  // tv5 = tv2 + tv3
-  // tv6 = tv1 + 6
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = mul(tv0, new Double(0.5));
-  TensorView* tv2 = mul(tv1, new Double(-1.0));
-  TensorView* tv3 = mul(tv2, new Double(-1.0));
-  TensorView* tv4 = add(tv1, new Double(4.0));
-  TensorView* tv5 = add(tv3, tv4);
-  TensorView* tv6 = add(tv1, new Double(6.0));
-
-  fusion.addOutput(tv5);
-  fusion.addOutput(tv6);
-
-  TensorView* computeAtTarget = tv3;
-
-  computeAtTarget->merge(0);
-  computeAtTarget->split(0, 128);
-  computeAtTarget->split(0, 4);
-
-  computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
-
-  // This will have the same impact on the tensors except for tv5 and
-  // tv6. tv6 does not have any common consumer with the computeAt
-  // target, but since it uses tv1, it must be also computed at the
-  // same location as the other impacted tensors. We can either make
-  // tv5 computed at tv6 or tv6 computed at tv5. In this case, tv5
-  // should be computed at tv6 just because the current implementation
-  // orders the computeAt relationship based on the order in which
-  // tensors are specified as outputs.
-
-  tv1->computeAt(computeAtTarget, 1);
-
-  // All tensors should have the same dimenionality as the target
-  for (auto tv : ir_utils::filterByType<TensorView>(fusion.vals())) {
-    if (fusion.hasInput(tv)) {
-      continue;
-    }
-    TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
-    if (tv == tv5 || tv == tv6) {
-      TORCH_CHECK(tv->getComputeAtPosition() == 0);
-      TORCH_CHECK(tv->getMaxProducerPosition() == 1);
-    } else {
-      TORCH_CHECK(tv->getComputeAtPosition() == 1);
-    }
-  }
-
-  for (Val* val : fusion.vals()) {
-    if (!fusion.hasInput(val) &&
-        val->getValType().value() == ValType::TensorView) {
-      TensorView* tv = val->as<TensorView>();
-      tv->axis(1)->parallelize(ParallelType::Unroll);
-      tv->axis(-1)->parallelize(ParallelType::TIDx);
-    }
-  }
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor aten_input = at::randn({129, 127}, options);
-
-  auto t1 = aten_input.mul({0.5});
-  auto t2 = t1.mul({-1.0});
-  auto t3 = t2.mul({-1.0});
-  auto t4 = t1.add({4.0});
-  auto t5 = t3 + t4;
-  auto t6 = t1.add({6.0});
-
-  std::vector<at::Tensor> aten_outputs = {t5, t6};
-  std::vector<at::Tensor> cg_outputs = {
-      at::empty_like(aten_input, options), at::empty_like(aten_input, options)};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({aten_input}, cg_outputs);
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-// Similar to ComputeAtCommonConsumer1 but with an addtiona ltensor
-// that does not have data dependency with the consumer.
-TEST(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) {
-  // tv1 = tv0 * 0.5
-  // tv2 = tv1 * -1
-  // tv3 = tv1 * -2
-  // tv4 = tv2 + tv3
-  // tv5 = tv4 * 5
-  // tv6 = tv1 * 6
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = mul(tv0, new Double(0.5));
-  TensorView* tv2 = mul(tv1, new Double(-1.0));
-  TensorView* tv3 = mul(tv1, new Double(-2.0));
-  TensorView* tv4 = add(tv2, tv3);
-  TensorView* tv5 = mul(tv4, new Double(5.0));
-  // Notice that tv6 is not a consumer of tv4.
-  TensorView* tv6 = mul(tv1, new Double(6.0));
-  fusion.addOutput(tv3);
-  fusion.addOutput(tv4);
-  fusion.addOutput(tv5);
-  fusion.addOutput(tv6);
-
-  TensorView* computeAtTarget = tv3;
-  computeAtTarget->split(0, 128);
-  tv1->computeAt(computeAtTarget, 1);
-
-  TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4, tv5, tv6};
-  for (auto tv : affected_tensors) {
-    TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
-    if (tv == tv6 || tv == tv5) {
-      TORCH_CHECK(tv->getComputeAtPosition() == 0);
-    } else {
-      TORCH_CHECK(tv->getComputeAtPosition() == 1);
-    }
-  }
-
-  computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
-
-  for (auto tv : affected_tensors) {
-    tv->axis(-1)->parallelize(ParallelType::TIDx);
-  }
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor aten_input = at::randn({1000}, options);
-
-  auto t1 = aten_input * 0.5;
-  auto t2 = t1 * -1.0;
-  auto t3 = t1 * -2.0;
-  auto t4 = t2 + t3;
-  auto t5 = t4 * 5.0;
-  auto t6 = t1 * 6.0;
-
-  std::vector<at::Tensor> aten_outputs = {t3, t4, t5, t6};
-  std::vector<at::Tensor> cg_outputs = {
-      at::empty_like(aten_input, options),
-      at::empty_like(aten_input, options),
-      at::empty_like(aten_input, options),
-      at::empty_like(aten_input, options)};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({aten_input}, cg_outputs);
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-namespace {
-
-void checkConcretized(
-    TensorView* v0,
-    int a0,
-    TensorView* v1,
-    int a1,
-    bool should_concretize) {
-  if (should_concretize) {
-    TORCH_CHECK(
-        IterDomain::concretizeDomain(v0->axis(a0))->sameAs(v1->axis(a1)));
-  } else {
-    TORCH_CHECK(
-        !IterDomain::concretizeDomain(v0->axis(a0))->sameAs(v1->axis(a1)));
-  }
-}
-
-} // namespace
-
-TEST(NVFuserTest, FusionBCastConcretizeBasic_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // tv0: [I I]
-  TensorView* tv0 = makeSymbolicTensor(2);
-
-  // tv1: [I I I]
-  TensorView* tv1 = makeSymbolicTensor(3);
-
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  // tv2*: [B I I]
-  auto tv2_0 = broadcast(tv0, {true, false, false});
-  auto tv2_1 = broadcast(tv0, {true, false, false});
-  auto tv2 = add(tv2_0, tv2_1);
-
-  // tv3: [I I I]
-  auto tv3 = add(tv2, tv1);
-
-  fusion.addOutput(tv3);
-
-  checkConcretized(tv2, 0, tv1, 0, true);
-  checkConcretized(tv2_0, 0, tv1, 0, true);
-  checkConcretized(tv2_1, 0, tv1, 0, true);
-  checkConcretized(tv2_0, 1, tv1, 0, false);
-  checkConcretized(tv2_0, 0, tv1, 1, false);
-}
-
-TEST(NVFuserTest, FusionBCastConcretizeRfactor_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // both tv0 and tv1 = [I, I]
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
-
-  //[B,I,I]
-  auto tv2 = broadcast(tv1, {true, false, false});
-
-  //[B,I,R]
-  auto tv3 = sum(tv2, {2});
-
-  auto tv5 = add(tv3, tv1);
-
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addOutput(tv5);
-
-  // scheduling:
-  //[B,I,R0,R1=128], root = [B,I,R]
-  tv3->split(2, 128);
-
-  // root=[B,I,Irf], rfactor=[B,I,Irf,Rrf]
-  auto tv4 = tv3->rFactor({3});
-
-  checkConcretized(tv2, 0, tv5, 0, true);
-  checkConcretized(tv4, 0, tv5, 0, true);
-  checkConcretized(tv3, 0, tv5, 0, true);
-}
-
-namespace {
-
-void checkIdMapped(
-    ComputeAtRootDomainMap& root_map,
-    TensorView* v0,
-    IterDomain* id0,
-    TensorView* v1,
-    IterDomain* id1,
-    bool should_map) {
-  if (should_map) {
-    TORCH_CHECK(
-        root_map.canMap(v0->domain(), id0, v1->domain(), id1),
-        "Should be mappable: ",
-        id0,
-        " of ",
-        v0,
-        " and ",
-        id1,
-        " of ",
-        v1);
-  } else {
-    TORCH_CHECK(
-        !root_map.canMap(v0->domain(), id0, v1->domain(), id1),
-        "Should not be mappable: ",
-        id0,
-        " of ",
-        v0,
-        " and ",
-        id1,
-        " of ",
-        v1);
-  }
-}
-
-void checkIdMapped(
-    TensorView* v0,
-    const std::vector<IterDomain*>& root0,
-    const std::vector<bool> should_map0,
-    TensorView* v1,
-    const std::vector<IterDomain*>& root1,
-    const std::vector<bool> should_map1) {
-  ComputeAtRootDomainMap map;
-  map.build();
-  TORCH_INTERNAL_ASSERT(root0.size() == should_map0.size());
-  TORCH_INTERNAL_ASSERT(root1.size() == should_map1.size());
-  size_t idx0 = 0;
-  for (size_t i = 0; i < root0.size(); ++i) {
-    size_t idx1 = 0;
-    for (size_t j = 0; j < root1.size(); ++j) {
-      if (should_map0[i] && should_map1[j] && idx0 == idx1) {
-        checkIdMapped(map, v0, root0[i], v1, root1[j], true);
-      } else {
-        checkIdMapped(map, v0, root0[i], v1, root1[j], false);
-      }
-      if (should_map1[j])
-        ++idx1;
-    }
-    if (should_map0[i])
-      ++idx0;
-  }
-}
-
-void checkIdMapped(
-    TensorView* v0,
-    const std::vector<IterDomain*>& root0,
-    TensorView* v1,
-    const std::vector<IterDomain*>& root1) {
-  checkIdMapped(
-      v0,
-      root0,
-      std::vector<bool>(root0.size(), true),
-      v1,
-      root1,
-      std::vector<bool>(root1.size(), true));
-}
-
-} // namespace
-
-TEST(NVFuserTest, FusionRootMappingBasic_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
-
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  auto tv3 = broadcast(tv0, {true, false, false});
-  auto tv4 = broadcast(tv1, {false, true, false});
-  auto tv5 = add(tv3, tv4);
-  fusion.addOutput(tv5);
-
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, true},
-      tv4,
-      tv4->getRootDomain(),
-      {false, true, true});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, true},
-      tv4,
-      tv4->getRootDomain(),
-      {true, false, true});
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {false, true},
-      tv1,
-      tv1->getRootDomain(),
-      {false, true});
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, true},
-      tv5,
-      tv5->getRootDomain(),
-      {false, true, true});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, true},
-      tv5,
-      tv5->getRootDomain(),
-      {true, false, true});
-  checkIdMapped(tv3, tv3->getRootDomain(), tv4, tv4->getRootDomain());
-  checkIdMapped(tv3, tv3->getRootDomain(), tv5, tv5->getRootDomain());
-  checkIdMapped(tv4, tv4->getRootDomain(), tv5, tv5->getRootDomain());
-}
-
-TEST(NVFuserTest, FusionRootMappingRfactor_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // [I,I]
-  TensorView* tv0 = makeSymbolicTensor(2);
-  // [I,I,I]
-  TensorView* tv1 = makeSymbolicTensor(3);
-
-  //[I,I,R]
-  auto tv2 = sum(tv1, {2});
-  auto tv3 = add(tv2, tv0);
-
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addOutput(tv3);
-
-  // scheduling:
-  //[B,I,R0,R1=128], root = [B,I,R]
-  tv2->split(2, 128);
-
-  // root=[B,I,Irf], rfactor=[B,I,Irf,Rrf]
-  auto tv4 = tv2->rFactor({3});
-
-  checkIdMapped(tv1, tv1->getRootDomain(), tv4, tv4->getRootDomain());
-  checkIdMapped(
-      tv4,
-      tv4->getRFactorDomain(),
-      {true, true, true, false},
-      tv2,
-      tv2->getRootDomain(),
-      {true, true, true});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, true, false},
-      tv2,
-      tv2->getRootDomain(),
-      {true, true, false});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, true, false},
-      tv3,
-      tv3->getRootDomain(),
-      {true, true});
-  checkIdMapped(
-      tv2,
-      tv2->getRootDomain(),
-      {true, true, false},
-      tv3,
-      tv3->getRootDomain(),
-      {true, true});
-  checkIdMapped(tv0, tv0->getRootDomain(), tv3, tv3->getRootDomain());
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, true},
-      tv1,
-      tv1->getRootDomain(),
-      {true, true, false});
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, true},
-      tv2,
-      tv2->getRootDomain(),
-      {true, true, false});
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, true},
-      tv4,
-      tv4->getRFactorDomain(),
-      {true, true, false, false});
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, true},
-      tv4,
-      tv4->getRootDomain(),
-      {true, true, false});
-}
-
-TEST(NVFuserTest, FusionRootMappingReductionDependency1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  auto tv1 = sum(tv0, {1});
-  auto tv2 = broadcast(tv1, {false, true});
-  fusion.addOutput(tv2);
-
-  // The second dimension cannot be mapped as it would require recomputation.
-  checkIdMapped(tv0, tv0->getRootDomain(), tv1, tv1->getRootDomain());
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, false},
-      tv2,
-      tv2->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, false},
-      tv2,
-      tv2->getRootDomain(),
-      {true, false});
-}
-
-TEST(NVFuserTest, FusionRootMappingReductionDependency2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  auto tv1 = sum(tv0, {1});
-  auto tv2 = broadcast(tv1, {false, true});
-  auto tv3 = add(tv0, tv2);
-  fusion.addOutput(tv3);
-
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, false},
-      tv1,
-      tv1->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, false},
-      tv2,
-      tv2->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, false},
-      tv3,
-      tv3->getRootDomain(),
-      {true, false});
-  checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain());
-}
-
-TEST(NVFuserTest, FusionRootMappingReductionDependency3_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  auto tv1 = sum(tv0, {1});
-  auto tv2 = broadcast(tv1, {false, true});
-  fusion.addOutput(tv2);
-
-  tv1->split(-1, 4);
-  auto tv3 = tv1->rFactor({-2});
-
-  checkIdMapped(tv0, tv0->getRootDomain(), tv3, tv3->getRootDomain());
-  checkIdMapped(
-      tv3,
-      tv3->getMaybeRFactorDomain(),
-      {true, false, true},
-      tv1,
-      tv1->getRootDomain(),
-      {true, true});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, false},
-      tv2,
-      tv2->getRootDomain(),
-      {true, false});
-}
-
-TEST(NVFuserTest, FusionRootMappingReductionDependency4_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  auto tv1 = sum(tv0, {1});
-  auto tv2 = broadcast(tv1, {false, true});
-  auto tv3 = add(tv0, tv2);
-  fusion.addOutput(tv3);
-
-  tv1->split(-1, 4);
-  auto tv4 = tv1->rFactor({-2});
-
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, false},
-      tv4,
-      tv4->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv4,
-      tv4->getMaybeRFactorDomain(),
-      {true, false, true},
-      tv1,
-      tv1->getRootDomain(),
-      {true, true});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, false},
-      tv2,
-      tv2->getRootDomain(),
-      {true, false});
-  checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain());
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, false},
-      tv2,
-      tv2->getRootDomain(),
-      {true, false});
-}
-
-// Reproducer of issue #749
-TEST(NVFuserTest, FusionRootMappingReductionDependency5_CUDA_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = sum(tv1, {1});
-  auto tv3 = broadcast(tv2, {false, true});
-  auto tv4 = add(tv0, tv3);
-  auto tv5 = add(tv4, tv1);
-  fusion.addOutput(tv5);
-
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, false},
-      tv1,
-      tv1->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, false},
-      tv2,
-      tv2->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv2,
-      tv2->getRootDomain(),
-      {true, false},
-      tv3,
-      tv3->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv3,
-      tv3->getRootDomain(),
-      {true, true},
-      tv4,
-      tv4->getRootDomain(),
-      {true, true});
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, false},
-      tv4,
-      tv4->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv4,
-      tv4->getRootDomain(),
-      {true, true},
-      tv5,
-      tv5->getRootDomain(),
-      {true, true});
-}
-
-// Similar to RootMappingReductionDependency5 but with rFactor
-TEST(NVFuserTest, FusionRootMappingReductionDependency6_CUDA_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = sum(tv1, {1});
-  auto tv3 = broadcast(tv2, {false, true});
-  auto tv4 = add(tv0, tv3);
-  auto tv5 = add(tv4, tv1);
-  fusion.addOutput(tv5);
-
-  tv2->split(1, 4);
-  auto tv6 = tv2->rFactor({-1});
-
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, false},
-      tv1,
-      tv1->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, false},
-      tv6,
-      tv6->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv6,
-      tv6->getMaybeRFactorDomain(),
-      {true, true, false},
-      tv2,
-      tv2->getRootDomain(),
-      {true, true});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, false},
-      tv2,
-      tv2->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv2,
-      tv2->getRootDomain(),
-      {true, false},
-      tv3,
-      tv3->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv3,
-      tv3->getRootDomain(),
-      {true, true},
-      tv4,
-      tv4->getRootDomain(),
-      {true, true});
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true, false},
-      tv4,
-      tv4->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv4,
-      tv4->getRootDomain(),
-      {true, true},
-      tv5,
-      tv5->getRootDomain(),
-      {true, true});
-}
-
-TEST(NVFuserTest, FusionRootMappingMultipleBroadcast_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(1);
-  auto tv1 = broadcast(tv0, {false, true});
-  auto tv2 = broadcast(tv0, {true, false});
-  auto tv3 = add(tv1, tv2);
-  fusion.addOutput(tv3);
-
-  // tv0 cannot be mapped with the consumers as it would mean its only
-  // domain would be mapped to both the first and second domains of
-  // the two consumers, thus computing tv0 at both corresponding loops.
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {false},
-      tv1,
-      tv1->getRootDomain(),
-      {false, false});
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {false},
-      tv2,
-      tv2->getRootDomain(),
-      {false, false});
-  checkIdMapped(tv1, tv1->getRootDomain(), tv3, tv3->getRootDomain());
-  checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain());
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {false},
-      tv3,
-      tv3->getRootDomain(),
-      {false, false});
-}
-
-TEST(NVFuserTest, FusionRootMappingMultipleBroadcastWithNoCommonConsumer_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(1);
-  auto tv1 = broadcast(tv0, {false, true});
-  auto tv2 = broadcast(tv0, {true, false});
-  fusion.addOutput(tv1);
-  fusion.addOutput(tv2);
-
-  // If there is no common consumer, there is no recomputation constraint.
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true},
-      tv1,
-      tv1->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true},
-      tv2,
-      tv2->getRootDomain(),
-      {false, true});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, false},
-      tv2,
-      tv2->getRootDomain(),
-      {false, true});
-}
-
-TEST(NVFuserTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  auto tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
-  auto tv2 = makeSymbolicTensor(2);
-  fusion.addInput(tv2);
-  auto tv3 = broadcast(tv0, {false, true});
-  auto tv4 = add(tv1, tv3);
-  fusion.addOutput(tv4);
-  auto tv5 = add(tv2, tv3);
-  fusion.addOutput(tv5);
-
-  // Broadcast domains can be used with multiple domains with
-  // different sizes. In this test, the broadcast domain of tv3 has
-  // two consumers, tv4 and tv5, which may have different sizes. Each
-  // of the consumers is used with the broadcast domain of tv3, but
-  // the two consumers may not have the same size, it is not possible
-  // to map those domains.
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true},
-      tv3,
-      tv3->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true},
-      tv1,
-      tv1->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true},
-      tv2,
-      tv2->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, false},
-      tv2,
-      tv2->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, false},
-      tv3,
-      tv3->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv2,
-      tv2->getRootDomain(),
-      {true, false},
-      tv3,
-      tv3->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv3,
-      tv3->getRootDomain(),
-      {true, false},
-      tv4,
-      tv4->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv3,
-      tv3->getRootDomain(),
-      {true, false},
-      tv5,
-      tv5->getRootDomain(),
-      {true, false});
-  checkIdMapped(
-      tv4,
-      tv4->getRootDomain(),
-      {true, false},
-      tv5,
-      tv5->getRootDomain(),
-      {true, false});
-}
-
-TEST(NVFuserTest, FusionRootMappingBroadcast_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  // tv0[I0]
-  fusion.addInput(tv0);
-  auto tv1 = broadcast(tv0, {true, false});
-  // tv1[B1, I0]
-  auto tv2 = broadcast(tv1, {true, false, false});
-  // tv2[B2, B1, I0]
-  fusion.addOutput(tv2);
-
-  // In this case, tv1 and tv2 has one and two broadcast domains,
-  // respectively. It is the second broadcast domain that is mapped to
-  // the broadcast of tv1.
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true},
-      tv1,
-      tv1->getRootDomain(),
-      {false, true});
-  checkIdMapped(
-      tv1,
-      tv1->getRootDomain(),
-      {true, true},
-      tv2,
-      tv2->getRootDomain(),
-      {false, true, true}); // Not {true, false, true}
-  checkIdMapped(
-      tv0,
-      tv0->getRootDomain(),
-      {true},
-      tv2,
-      tv2->getRootDomain(),
-      {false, false, true});
-}
-
-// Reproducer of issue #723
-TEST(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  auto tv1 = makeSymbolicTensor(2);
-
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  auto tv2 = broadcast(tv0, {true, false});
-  auto tv3 = sum(tv2, {0});
-  auto tv4 = add(tv2, tv1);
-
-  fusion.addOutput(tv3);
-  fusion.addOutput(tv4);
-
-  ComputeAtRootDomainMap map;
-  map.build();
-
-  checkIdMapped(
-      map, tv2, tv2->getRootDomain()[0], tv4, tv4->getRootDomain()[0], true);
-  checkIdMapped(
-      map, tv2, tv2->getRootDomain()[0], tv3, tv3->getRootDomain()[0], true);
-
-  tv2->computeAt(tv4, -1);
-
-  const int x = 11;
-  const int y = 12;
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({x}, options);
-  at::Tensor t1 = at::randn({y, x}, options);
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(aten_inputs);
-
-  auto t3 = t0;
-  auto t4 = t0.unsqueeze(0).expand({y, x}) + t1;
-
-  testValidate(&fusion, outputs, aten_inputs, {t3, t4}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionComputeAtFailDueToRootMapping_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = broadcast(tv1, {true, false});
-  auto tv3 = broadcast(tv1, {false, true});
-  auto tv4 = add(tv2, tv3);
-  fusion.addOutput(tv4);
-
-  // computeAt should fail as there is no valid root mapping.
-  ASSERT_ANY_THROW(tv1->computeAt(tv4, 1));
-}
-
-TEST(NVFuserTest, FusionScalarInputs_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  TensorView* tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
-
-  Double* d0 = new Double();
-  fusion.addInput(d0);
-  Double* d1 = new Double();
-  fusion.addInput(d1);
-  Double* d2 = new Double();
-  fusion.addInput(d2);
-  Double* d3 = new Double();
-  fusion.addInput(d3);
-  Val* d4 = mul(d0, d1);
-  Val* d5 = sub(d2, d3);
-
-  TensorView* tv2 = sub(tv1, d4);
-  TensorView* tv3 = add(tv0, d5);
-  TensorView* tv4 = mul(tv3, tv2);
-
-  fusion.addOutput(tv4);
-
-  // Lets setup to actually run
-  while (tv4->nDims() > 1)
-    tv4->merge(0);
-  tv4->split(0, 128);
-  tv4->split(0, 4);
-
-  tv0->computeAt(tv4, 1);
-  tv1->computeAt(tv4, 1);
-
-  tv4->axis(0)->parallelize(ParallelType::BIDx);
-
-  for (Val* val : fusion.vals()) {
-    if (!fusion.hasInput(val) &&
-        val->getValType().value() == ValType::TensorView) {
-      TensorView* tv = static_cast<TensorView*>(val);
-
-      tv->axis(1)->parallelize(ParallelType::Unroll);
-      tv->axis(-1)->parallelize(ParallelType::TIDx);
-    }
-  }
-
-  // d4 = d0 * d1
-  // d5 = d2 - d3
-  // t2 = t1 - d4
-  // t3 = t0 + d5
-  // t4 = t3 * t2
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  float fl0 = 0.1;
-  float fl1 = -0.2;
-  float fl2 = 0.3;
-  float fl3 = -0.4;
-  float fl4 = fl0 * fl1;
-  float fl5 = fl2 - fl3;
-
-  at::Tensor t0 = at::randn({129, 127}, options);
-  at::Tensor t1 = at::rand_like(t0, options);
-
-  auto t2 = t1.sub(fl4);
-  auto t3 = t0.add(fl5);
-  auto aten_output = t3.mul(t2);
-
-  at::Tensor cg_output = at::empty_like(t0, options);
-
-  at::Scalar test(fl0);
-
-  std::vector<IValue> aten_inputs = {
-      t0,
-      t1,
-      at::Scalar(fl0),
-      at::Scalar(fl1),
-      at::Scalar(fl2),
-      at::Scalar(fl3)};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion(aten_inputs, {cg_output});
-
-  testValidate(
-      &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionLoopUnroll_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(3);
-  TensorView* tv1 = makeSymbolicTensor(3);
-
-  // Register your inputs
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  // Do math with it, it returns a `Val*` but can be static_casted back to
-  // TensorView
-  TensorView* tv2 = add(tv1, new Double(2.0));
-  TensorView* tv3 = add(tv0, tv2);
-
-  // Register your outputs
-  fusion.addOutput(tv3);
-
-  int block_size = 16;
-
-  tv3->merge(0, 1);
-  tv3->merge(0, 1);
-
-  tv3->split(0, block_size);
-  tv3->split(0, 4);
-
-  // For all inputs, computeAt the output inline, temporaries should be squeezed
-  // between them
-  tv0->computeAt(tv3, 1);
-  tv1->computeAt(tv3, 1);
-
-  // Parallelize
-  tv2->axis(1)->parallelize(ParallelType::Unroll);
-  tv3->axis(1)->parallelize(ParallelType::Unroll);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor input0 = at::randn({129, 13, 3}, options);
-  at::Tensor input1 = at::randn({129, 13, 3}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion({input0, input1});
-
-  TORCH_CHECK(outputs[0].equal(input0.add(input1.add(2.0))));
-}
-
-/*
- * Helper function for single op testing that generates a codegen operand
- */
-
-Val* gen_jit_operand(std::pair<ValType, DataType> desc) {
-  if (desc.first == ValType::TensorView) {
-    return makeSymbolicTensor(2, desc.second);
-  } else if (desc.first == ValType::Scalar) {
-    if (desc.second == DataType::Float) {
-      return new Double();
-    } else if (desc.second == DataType::Double) {
-      return new Double();
-    } else if (desc.second == DataType::Int) {
-      return new Int();
-    } else {
-      TORCH_CHECK(false, "Not currently supported type: ", desc.first);
-    }
-  } else {
-    TORCH_CHECK(false, "Not currently supported type: ", desc.first);
-  }
-  return nullptr;
-}
-
-/*
- * Helper function for single op testing that generates an ATen operand
- */
-
-IValue gen_aten_operand(
-    std::pair<ValType, DataType> desc,
-    int blocks,
-    int threads,
-    bool rand) {
-  if (desc.first == ValType::TensorView) {
-    if (desc.second == DataType::Double || desc.second == DataType::Float ||
-        desc.second == DataType::Half) {
-      auto options = at::TensorOptions()
-                         .dtype(data_type_to_aten(desc.second))
-                         .device(at::kCUDA, 0);
-      if (rand) {
-        return IValue(at::rand({blocks, threads}, options));
-      } else {
-        return IValue(at::empty({blocks, threads}, options));
-      }
-    } else if (desc.second == DataType::Int || desc.second == DataType::Int32) {
-      auto dtype = desc.second == DataType::Int32 ? at::kInt : at::kLong;
-      if (rand) {
-        auto options =
-            at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-        return IValue(at::randn({blocks, threads}, options).mul(5).to(dtype));
-      } else {
-        auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
-        return IValue(at::empty({blocks, threads}, options));
-      }
-    } else if (desc.second == DataType::Bool) {
-      if (rand) {
-        auto options =
-            at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-        return IValue(
-            at::rand({blocks, threads}, options).round().to(at::kBool));
-      } else {
-        auto options =
-            at::TensorOptions().dtype(at::kBool).device(at::kCUDA, 0);
-        return IValue(at::empty({blocks, threads}, options));
-      }
-    } else {
-      TORCH_CHECK(false, "Not currently supported type: ", desc.second)
-    }
-  } else if (desc.first == ValType::Scalar) {
-    // IValue scalars can only be double int64 or bool
-    if (desc.second == DataType::Double || desc.second == DataType::Float ||
-        desc.second == DataType::Half) {
-      return IValue(at::Scalar(1.f));
-    } else if (desc.second == DataType::Int) {
-      return IValue(at::Scalar(1));
-    } else {
-      TORCH_CHECK(false, "Not currently supported type: ", desc.first);
-    }
-  } else {
-    TORCH_CHECK(false, "Not currently supported type: ", desc.first);
-  }
-  return nullptr;
-}
-
-/*
- * Templatized Helper Function To generate single Op comparison between the
- * JIT codegen for Cuda and the ATen Library.
- */
-
-using OutputPair = std::pair<ValType, DataType>;
-template <
-    typename AtenFunc,
-    typename JitFunc,
-    typename InputTuple,
-    size_t... NumInputs>
-void test_op(
-    int blocks,
-    int threads,
-    std::string op_str,
-    AtenFunc af,
-    JitFunc jf,
-    OutputPair op,
-    InputTuple it,
-    std::index_sequence<NumInputs...>) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Generate Input JIT function Inputs and add them as Inputs to the Fusion
-  // Graph
-  std::array<Val*, sizeof...(NumInputs)> jit_inputs = {
-      gen_jit_operand(std::get<NumInputs>(it))...};
-  std::for_each(jit_inputs.begin(), jit_inputs.end(), [&fusion](Val* v) {
-    fusion.addInput(v);
-  });
-  TensorView* out =
-      static_cast<TensorView*>(jf(std::get<NumInputs>(jit_inputs)...));
-  fusion.addOutput(out);
-
-  std::for_each(jit_inputs.begin(), jit_inputs.end(), [out](Val* v) {
-    if (v->getValType() == ValType::TensorView)
-      static_cast<TensorView*>(v)->computeAt(out, -1);
-  });
-  out->axis(0)->parallelize(ParallelType::BIDx);
-  out->axis(-1)->parallelize(ParallelType::TIDx);
-
-  std::array<IValue, sizeof...(NumInputs)> aten_inputs = {gen_aten_operand(
-      std::get<NumInputs>(it), blocks, threads, /*rand*/ true)...};
-  const at::ArrayRef<IValue> aten_inputs_ivalues(aten_inputs);
-
-  at::Tensor cg_output =
-      gen_aten_operand(op, blocks, threads, /*rand*/ false).toTensor();
-  std::vector<at::Tensor> output_vect = {cg_output};
-  cudaDeviceSynchronize();
-  if (fusion.isStochastic())
-    at::manual_seed(0);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion(aten_inputs_ivalues, output_vect);
-  cudaDeviceSynchronize();
-
-  if (fusion.isStochastic())
-    at::manual_seed(0);
-  at::Tensor aten_output = af(aten_inputs);
-  cudaDeviceSynchronize(); // This sync shouldn't be necessary;
-
-  std::string op_msg = "Operation " + op_str;
-
-  testValidate(
-      &fusion,
-      {cg_output},
-      aten_inputs,
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      op_msg);
-}
-
-/*
- *  Templatized Helper Function that uses variadic templates to
- *  process a variable length Input Tuple of different Operand Type.
- */
-template <typename AtenFunc, typename JitFunc, typename InputTuple>
-void test_op(
-    int blocks,
-    int threads,
-    std::string op_str,
-    AtenFunc af,
-    JitFunc jf,
-    OutputPair op,
-    InputTuple it) {
-  static constexpr auto size = std::tuple_size<InputTuple>::value;
-  test_op(
-      blocks,
-      threads,
-      op_str,
-      af,
-      jf,
-      op,
-      it,
-      std::make_index_sequence<size>{});
-}
-
-TEST(NVFuserTest, FusionUnaryOps_CUDA) {
-  using OpTuple =
-      std::tuple<at::Tensor (*)(const at::Tensor&), UnaryOpType, std::string>;
-
-  // [Note: explicit tuple type for uniform initialization list]
-  // Tuple type must be explicitly specified for each uniform initialization
-  // list within the vector to make this code compatible with some old env
-  // which we still need to support. eg. gcc 5.4 + cuda 9.2.
-  std::vector<OpTuple> ops{
-      OpTuple{at::abs, UnaryOpType::Abs, "abs"},
-      OpTuple{at::acos, UnaryOpType::Acos, "acos"},
-      OpTuple{at::asin, UnaryOpType::Asin, "asin"},
-      OpTuple{at::atan, UnaryOpType::Atan, "atan"},
-      // There does not appear to be an appropriate ATen function for atanh
-      // OpTuple{at::atanh,      UnaryOpType::Atanh,      "atanh"      },
-      OpTuple{at::ceil, UnaryOpType::Ceil, "ceil"},
-      OpTuple{at::cos, UnaryOpType::Cos, "cos"},
-      OpTuple{at::cosh, UnaryOpType::Cosh, "cosh"},
-      OpTuple{at::erf, UnaryOpType::Erf, "erf"},
-      OpTuple{at::erfc, UnaryOpType::Erfc, "erfc"},
-      OpTuple{at::exp, UnaryOpType::Exp, "exp"},
-      OpTuple{at::expm1, UnaryOpType::Expm1, "expm1"},
-      OpTuple{at::floor, UnaryOpType::Floor, "floor"},
-      OpTuple{at::frac, UnaryOpType::Frac, "frac"},
-      // OpTuple{at::gelu, UnaryOpType::Gelu, "gelu"},
-      OpTuple{at::lgamma, UnaryOpType::Lgamma, "lgamma"},
-      OpTuple{at::log, UnaryOpType::Log, "log"},
-      OpTuple{at::log10, UnaryOpType::Log10, "log10"},
-      OpTuple{at::log1p, UnaryOpType::Log1p, "log1p"},
-      OpTuple{at::log2, UnaryOpType::Log2, "log2"},
-      OpTuple{at::neg, UnaryOpType::Neg, "neg"},
-      OpTuple{at::reciprocal, UnaryOpType::Reciprocal, "reciprocal"},
-      OpTuple{at::relu, UnaryOpType::Relu, "relu"},
-      OpTuple{at::round, UnaryOpType::Round, "round"},
-      OpTuple{at::rsqrt, UnaryOpType::Rsqrt, "rsqrt"},
-      OpTuple{at::sigmoid, UnaryOpType::Sigmoid, "sigmoid"},
-      OpTuple{at::sin, UnaryOpType::Sin, "sin"},
-      OpTuple{at::sinh, UnaryOpType::Sinh, "sinh"},
-      OpTuple{at::sqrt, UnaryOpType::Sqrt, "sqrt"},
-      OpTuple{at::tan, UnaryOpType::Tan, "tan"},
-      OpTuple{at::tanh, UnaryOpType::Tanh, "tanh"},
-      OpTuple{at::trunc, UnaryOpType::Trunc, "trunc"}};
-
-  std::vector<DataType> dtypes = {DataType::Float, DataType::Double};
-
-  for (auto dtype : dtypes) {
-    std::for_each(ops.begin(), ops.end(), [&](OpTuple& op) {
-      test_op(
-          /*blocks*/ 640,
-          /*threads*/ 64,
-          /*name*/ std::get<2>(op),
-          /*Aten Func   */
-          [&op](std::array<IValue, 1>& vals) {
-            return std::get<0>(op)(vals[0].toTensor());
-          },
-          /*JIT  Func   */
-          [&op](Val* in1) -> Val* { return unaryOp(std::get<1>(op), in1); },
-          /*Output      */ std::make_pair(ValType::TensorView, dtype),
-          /*Inputs Tuple*/
-          std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
-    });
-
-    test_op(
-        /*blocks*/ 128,
-        /*threads*/ 64,
-        /*name*/ "rand_like",
-        /*Aten Func   */
-        [](std::array<IValue, 1>& vals) {
-          return at::rand_like(vals[0].toTensor());
-        },
-        /*JIT  Func   */
-        [](Val* in1) -> Val* { return unaryOp(UnaryOpType::RandLike, in1); },
-        /*Output      */ std::make_pair(ValType::TensorView, dtype),
-        /*Inputs Tuple*/
-        std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
-  }
-
-  dtypes = {DataType::Int, DataType::Int32, DataType::Bool};
-  for (auto dtype : dtypes) {
-    test_op(
-        /*blocks*/ 128,
-        /*threads*/ 64,
-        /*name*/ "bitwise_not",
-        /*Aten Func   */
-        [](std::array<IValue, 1>& vals) {
-          return at::bitwise_not(vals[0].toTensor());
-        },
-        /*JIT  Func   */
-        [](Val* in1) -> Val* { return unaryOp(UnaryOpType::Not, in1); },
-        /*Output      */ std::make_pair(ValType::TensorView, dtype),
-        /*Inputs Tuple*/
-        std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
-  }
-}
-
-TEST(NVFuserTest, FusionBinaryOps_CUDA) {
-  using AtenFuncSig = at::Tensor (*)(const at::Tensor&, const at::Tensor&);
-  using OpTuple = std::tuple<AtenFuncSig, BinaryOpType, std::string>;
-
-  // see [Note: explicit tuple type for uniform initialization list]
-  std::vector<OpTuple> logic_ops{
-      OpTuple{at::eq, BinaryOpType::Eq, "eq"},
-      OpTuple{at::ge, BinaryOpType::GE, "ge"},
-      OpTuple{at::gt, BinaryOpType::GT, "gt"},
-      OpTuple{at::le, BinaryOpType::LE, "le"},
-      OpTuple{at::lt, BinaryOpType::LT, "lt"},
-      OpTuple{at::ne, BinaryOpType::NE, "ne"}};
-  std::vector<DataType> dtypes = {DataType::Double, DataType::Float};
-
-  for (auto dtype : dtypes) {
-    std::for_each(logic_ops.begin(), logic_ops.end(), [&](OpTuple& op) {
-      test_op(
-          /*blocks*/ 640,
-          /*threads*/ 64,
-          /*name*/ std::get<2>(op),
-          /*Aten Func   */
-          [&op](std::array<IValue, 2>& vals) {
-            return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor());
-          },
-          /*JIT  Func   */
-          [&op](Val* in1, Val* in2) -> Val* {
-            return binaryOp(std::get<1>(op), in1, in2);
-          },
-          /*Output      */ std::make_pair(ValType::TensorView, DataType::Bool),
-          /*Inputs Tuple*/
-          std::make_tuple(
-              std::make_pair(ValType::TensorView, dtype),
-              std::make_pair(ValType::TensorView, dtype)));
-    });
-
-    // see [Note: explicit tuple type for uniform initialization list]
-    std::vector<OpTuple> math_ops{
-        OpTuple{at::atan2, BinaryOpType::Atan2, "atan2"},
-        OpTuple{at::div, BinaryOpType::Div, "div"},
-        OpTuple{at::fmod, BinaryOpType::Fmod, "fmod"},
-        OpTuple{at::max, BinaryOpType::Max, "max"},
-        OpTuple{at::min, BinaryOpType::Min, "min"},
-        OpTuple{at::mul, BinaryOpType::Mul, "mul"},
-        OpTuple{at::pow, BinaryOpType::Pow, "pow"},
-        // NOTE: Remainder does not match the Aten impl exactly
-        // despite using an identical function.
-        OpTuple{at::remainder, BinaryOpType::Remainder, "remainder"},
-    };
-
-    std::for_each(math_ops.begin(), math_ops.end(), [&](OpTuple& op) {
-      test_op(
-          /*blocks*/ 640,
-          /*threads*/ 64,
-          /*name*/ std::get<2>(op),
-          /*Aten Func   */
-          [&op](std::array<IValue, 2>& vals) {
-            return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor());
-          },
-          /*JIT  Func   */
-          [&op](Val* in1, Val* in2) -> Val* {
-            return binaryOp(std::get<1>(op), in1, in2);
-          },
-          /*Output      */ std::make_pair(ValType::TensorView, dtype),
-          /*Inputs Tuple*/
-          std::make_tuple(
-              std::make_pair(ValType::TensorView, dtype),
-              std::make_pair(ValType::TensorView, dtype)));
-    });
-
-    test_op(
-        /*blocks*/ 640,
-        /*threads*/ 64,
-        /*name*/ "add_alpha",
-        /*Aten Func   */
-        [](std::array<IValue, 3>& vals) {
-          return at::add(
-              vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar());
-        },
-        /*JIT  Func   */ static_cast<Val* (*)(Val*, Val*, Val*)>(&add_alpha),
-        /*Output      */ std::make_pair(ValType::TensorView, dtype),
-        /*Inputs Tuple*/
-        std::make_tuple(
-            std::make_pair(ValType::TensorView, dtype),
-            std::make_pair(ValType::TensorView, dtype),
-            std::make_pair(ValType::Scalar, dtype)));
-
-    test_op(
-        /*blocks*/ 640,
-        /*threads*/ 64,
-        /*name*/ "sub_alpha",
-        /*Aten Func   */
-        [](std::array<IValue, 3>& vals) {
-          return at::sub(
-              vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar());
-        },
-        /*JIT  Func   */ static_cast<Val* (*)(Val*, Val*, Val*)>(&sub_alpha),
-        /*Output      */ std::make_pair(ValType::TensorView, dtype),
-        /*Inputs Tuple*/
-        std::make_tuple(
-            std::make_pair(ValType::TensorView, dtype),
-            std::make_pair(ValType::TensorView, dtype),
-            std::make_pair(ValType::Scalar, dtype)));
-  }
-}
-
-TEST(NVFuserTest, FusionTernaryOps_CUDA) {
-  std::vector<DataType> dtypes = {DataType::Double, DataType::Float};
-
-  for (auto dtype : dtypes) {
-    test_op(
-        /*blocks*/ 640,
-        /*threads*/ 64,
-        /*name*/ "clamp",
-        /*Aten Func   */
-        [](std::array<IValue, 1>& vals) {
-          return at::clamp(vals[0].toTensor(), 0.f, 1.f);
-        },
-        /*JIT  Func   */
-        [&](Val* in1) -> Val* {
-          if (dtype == DataType::Float) {
-            return clamp(in1, new Double(0.f), new Double(1.f));
-          } else {
-            return clamp(in1, new Double(0.f), new Double(1.f));
-          }
-        },
-        /*Output      */ std::make_pair(ValType::TensorView, dtype),
-        /*Inputs Tuple*/
-        std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
-    test_op(
-        /*blocks*/ 640,
-        /*threads*/ 64,
-        /*name*/ "threshold",
-        /*Aten Func   */
-        [](std::array<IValue, 1>& vals) {
-          return at::threshold(vals[0].toTensor(), 0.f, 1.f);
-        },
-        /*JIT  Func   */
-        [&](Val* in1) -> Val* {
-          if (dtype == DataType::Float) {
-            return threshold(in1, new Double(0.f), new Double(1.f));
-          } else {
-            return threshold(in1, new Double(0.f), new Double(1.f));
-          }
-        },
-        /*Output      */ std::make_pair(ValType::TensorView, dtype),
-        /*Inputs Tuple*/
-        std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
-    test_op(
-        /*blocks*/ 640,
-        /*threads*/ 64,
-        /*name*/ "where",
-        /*Aten Func   */
-        [](std::array<IValue, 3>& vals) {
-          return at::where(
-              vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor());
-        },
-        /*JIT  Func   */ static_cast<Val* (*)(Val*, Val*, Val*)>(&where),
-        /*Output      */ std::make_pair(ValType::TensorView, dtype),
-        /*Inputs Tuple*/
-        std::make_tuple(
-            std::make_pair(ValType::TensorView, DataType::Bool),
-            std::make_pair(ValType::TensorView, dtype),
-            std::make_pair(ValType::TensorView, dtype)));
-  }
-}
-
-TEST(NVFuserTest, FusionCompoundOps_CUDA) {
-  std::vector<DataType> dtypes = {DataType::Double, DataType::Float};
-
-  for (auto dtype : dtypes) {
-    test_op(
-        /*blocks*/ 640,
-        /*threads*/ 64,
-        /*name*/ "lerp",
-        /*Aten Func   */
-        [](std::array<IValue, 3>& vals) {
-          return at::lerp(
-              vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor());
-        },
-        /*JIT  Func   */ static_cast<Val* (*)(Val*, Val*, Val*)>(&lerp),
-        /*Output      */ std::make_pair(ValType::TensorView, dtype),
-        /*Inputs Tuple*/
-        std::make_tuple(
-            std::make_pair(ValType::TensorView, dtype),
-            std::make_pair(ValType::TensorView, dtype),
-            std::make_pair(ValType::TensorView, dtype)));
-    test_op(
-        /*blocks*/ 640,
-        /*threads*/ 64,
-        /*name*/ "addcmul",
-        /*Aten Func   */
-        [](std::array<IValue, 4>& vals) {
-          return at::addcmul(
-              vals[0].toTensor(),
-              vals[1].toTensor(),
-              vals[2].toTensor(),
-              vals[3].toScalar());
-        },
-        /*JIT  Func   */
-        static_cast<Val* (*)(Val*, Val*, Val*, Val*)>(&addcmul),
-        /*Output      */ std::make_pair(ValType::TensorView, dtype),
-        /*Inputs Tuple*/
-        std::make_tuple(
-            std::make_pair(ValType::TensorView, dtype),
-            std::make_pair(ValType::TensorView, dtype),
-            std::make_pair(ValType::TensorView, dtype),
-            std::make_pair(ValType::Scalar, dtype)));
-  }
-}
-
-TEST(NVFuserTest, FusionCastOps_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2, DataType::Half);
-
-  TensorView* intrm1 = castOp(DataType::Float, tv0);
-  TensorView* out = castOp(DataType::Half, intrm1);
-
-  fusion.addInput(tv0);
-  fusion.addOutput(out);
-  tv0->computeAt(out, -1);
-
-  out->axis(0)->parallelize(ParallelType::BIDx);
-  out->axis(-1)->parallelize(ParallelType::TIDx);
-
-  auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
-
-  at::Tensor input1 = at::randn({1, 4}, options);
-  at::Tensor ref_output = at::empty_like(input1);
-
-  std::array<IValue, 1> inputs = {input1};
-  const at::ArrayRef<IValue> input_ivalues(inputs);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(input_ivalues);
-
-  ref_output = at::_cast_Half(at::_cast_Double(input1));
-
-  TORCH_CHECK(
-      outputs[0].equal(ref_output),
-      "\nOp Type: -- ",
-      "cast FP16->FP32->FP16",
-      " -- had a mismatch.\n",
-      "\nABS MAX DIFF: ",
-      outputs[0].sub(ref_output).abs().max(),
-      "\n");
-}
-
-// Start off simple, block on the outer dim
-// block stride + thread all reduce + unrolling on inner dim
-TEST(NVFuserTest, FusionReduction1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  // tv1[I0, R1] = tv0[I0, I1]
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
-  tv1->split(1, 128);
-  // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
-  tv1->split(1, 4);
-  // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1]
-
-  TensorView* tv2 = tv1->rFactor({1});
-  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1]
-  // tv1[I0,        R1oi{4},  R1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}]
-
-  TensorView* tv3 = tv1->rFactor({1});
-  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1]
-  // tv3[I0,        R1oi{4}, Ir1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}]
-  // tv1[I0,                  R1i{128}] = tv3[I0,        R1oi{4}, Ir1i{128}]
-
-  // Incrementally, can print in between for debugging
-  tv0->computeAt(tv2, 1);
-  tv2->computeAt(tv3, 1);
-  tv3->computeAt(tv1, 1);
-
-  // Re do it all at once, because why not.
-  tv0->computeAt(tv1, 1);
-
-  tv2->axis(2)->parallelize(ParallelType::Unroll);
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
-  int numel_x = 65000;
-  int numel_y = 1025;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y}, options);
-  at::Tensor cg_output = at::empty({numel_x}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({input}, {cg_output});
-
-  auto aten_output = input.to(at::kDouble).sum({1});
-
-  testValidate(
-      &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReduction2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  // tv1[I0, R1] = tv0[I0, I1]
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-
-  fusion.addOutput(tv1);
-
-  // switches to try some different scenarios. maybe we should iterate on all
-  // permutations.
-  bool bind_bidx = true;
-  bool bind_tidx = true;
-  bool bind_tidy = true;
-  bool bind_unroll = true;
-
-  int numel_x = 1025; // Cannot exceed block dim max size / tidy
-  int numel_y = 129;
-  int tidx = 16;
-  int tidy = 8;
-  int unroll_factor = 4;
-
-  tv1->split(1, tidx);
-  // tv1[I0, R1o, R1i{tidx}] = tv0[I0, I1]
-
-  tv1->split(1, unroll_factor);
-  // tv1[I0, R1oo, R1oi{unroll}, R1i{tidx}] = tv0[I0, I1]
-
-  tv1->split(0, tidy);
-
-  TensorView* tv2 = tv1->rFactor({-3});
-  // tv2[I0,             >R1oo<, Ir1oi{unroll}, Ir1i{tidx}]
-  // tv1[I0o, I0i{tidy},          R1oi{unroll},  R1i{tidx}]
-
-  TensorView* tv3 = tv1->rFactor({-2});
-  // tv2[I0,             >R1oo<, Ir1oi{unroll}, Ir1i{tidx}]
-  // tv3[I0,                      R1oi{unroll}, Ir1i{tidx}]
-  // tv1[I0o, I0i{tidy},                         R1i{tidx}]
-
-  tv0->computeAt(tv1, -2);
-
-  if (bind_unroll)
-    tv2->axis(-2)->parallelize(ParallelType::Unroll);
-  if (bind_bidx)
-    tv1->axis(0)->parallelize(ParallelType::BIDx);
-  if (bind_tidy)
-    tv1->axis(1)->parallelize(ParallelType::TIDy);
-
-  if (bind_tidx) {
-    tv2->axis(-1)->parallelize(ParallelType::TIDx);
-    tv3->axis(-1)->parallelize(ParallelType::TIDx);
-    tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  }
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input});
-
-  auto aten_output = input.to(at::kDouble).sum({1});
-  testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReduction3_CUDA) {
-  // What if Z participates in the reduction with X?
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  // tv1[I0, R1] = tv0[I0, I1]
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-
-  fusion.addOutput(tv1);
-
-  int numel_x = 1025; // Cannot exceed block dim max size / tidy
-  int numel_y = 129;
-  int tidx = 16;
-  int tidz = 8;
-
-  tv1->split(1, tidz);
-  // tv1[I0, R1o, R1i{tidz}] = tv0[I0, I1]
-
-  tv1->split(1, tidx);
-  // tv1[I0, R1oo, R1oi{tidx}, R1i{tidz}] = tv0[I0, I1]
-
-  TensorView* tv2 = tv1->rFactor({-3});
-  // tv2[I0,  >R1oo<, Ir1oi{tidx}, Ir1i{tidz}]
-  // tv1[I0o,          R1oi{tidx},  R1i{tidz}]
-
-  tv0->computeAt(tv1, -3);
-
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv1->axis(-2)->parallelize(ParallelType::TIDx);
-  tv1->axis(-1)->parallelize(ParallelType::TIDz);
-
-  tv2->axis(-2)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDz);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({numel_x, numel_y}, options);
-  at::Tensor cg_output = at::empty({numel_x}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({aten_input}, {cg_output});
-
-  auto aten_output = aten_input.to(at::kDouble).sum({1});
-
-  testValidate(
-      &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReduction4_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
-
-  TensorView* tv2 = add(tv0, tv1);
-  // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1]
-
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  TensorView* tv3 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv2);
-  // tv3[I0, R1] = tv2[I0, I1]
-
-  TensorView* tv4 = makeSymbolicTensor(1);
-  fusion.addInput(tv4);
-
-  // tv5[I0] = tv3[I0, R1] * tv4[I0]
-  TensorView* tv5 = mul(tv3, tv4);
-  fusion.addOutput(tv5);
-
-  int tidx = 16;
-
-  // RFactor the reduction
-  tv3->split(1, tidx);
-  // tv3[I0, R1o, R1i{tidx}] = tv2[I0, I1]
-
-  TensorView* tv6 = tv3->rFactor({-2});
-  // tv6[I0, R1o, iR1i{tidx}] = tv2[I0, I1]
-  // tv3[I0,       R1i{tidx}] = tv3[I0, I1]
-  tv2->computeAt(tv6, 2);
-
-  // Compute at inline with tv5 (only 1D)
-  tv6->computeAt(tv3, 1);
-  tv3->computeAt(tv5, 1);
-
-  tv5->axis(0)->parallelize(ParallelType::BIDx);
-
-  // Intermediate tensors only need this, but doesn't hurt to do on inputs
-  // tv0, 1, 4
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv6->axis(-1)->parallelize(ParallelType::TIDx);
-
-  int numel_x = 1025;
-  int numel_y = 129;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  at::Tensor t1 = at::randn({numel_x, numel_y}, options);
-  at::Tensor t4 = at::randn({numel_x}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({t0, t1, t4});
-
-  auto t2 = t0.add(t1);
-  auto t3 = t2.to(at::kDouble).sum({1});
-  auto aten_output = t3.mul(t4);
-
-  testValidate(
-      &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReduction5_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(3);
-
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-
-  fusion.addOutput(tv1);
-
-  int bidy = 2;
-  int tidy = 4;
-  int tidx = 5;
-
-  int dim1 = 11;
-
-  tv1->split(-2, tidy);
-
-  TensorView* tv2 = tv1->rFactor({-3});
-
-  tv0->computeAt(tv1, 1);
-  tv1->axis(0)->parallelize(ParallelType::BIDy);
-
-  for (auto* val : fusion.vals()) {
-    if (!fusion.hasInput(val) &&
-        val->getValType().value() == ValType::TensorView) {
-      val->as<TensorView>()->axis(-1)->parallelize(ParallelType::TIDx);
-    }
-  }
-
-  tv2->axis(-2)->parallelize(ParallelType::TIDy);
-  tv1->axis(-2)->parallelize(ParallelType::TIDy);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({bidy, dim1, tidx}, options);
-
-  at::Tensor cg_output = at::empty({bidy, tidx}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({input}, {cg_output});
-
-  auto aten_output = input.to(at::kDouble).sum({1});
-  testValidate(
-      &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReduction6_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  const int bdimx = 64;
-  const int bdimy = 8;
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(3);
-  fusion.addInput(tv0);
-
-  // tv1[I0, R1, R2] = tv0[I0, I1, I2]
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
-  tv1->split(2, bdimx);
-  // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
-  tv1->split(1, bdimy);
-  // tv1[I0, R1o, R1i{8}, R2o, R2i{128}] = tv0[I0, I1, I2]
-
-  TensorView* tv2 = tv1->rFactor({3});
-  // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
-  // tv1[I0, R1o, R1i{8},      R2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
-
-  TensorView* tv3 = tv1->rFactor({1});
-  // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
-  // tv3[I0, R1o, I1i{8},      I2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
-  // tv1[I0,      R1i{8},      R2i{128}] = tv3[I0, R1o, I1i{8},      I2i{128}]
-
-  tv3->computeAt(tv1, 1);
-  tv2->computeAt(tv3, 2);
-
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
-
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
-  tv1->axis(-2)->parallelize(ParallelType::TIDy);
-  tv3->axis(-2)->parallelize(ParallelType::TIDy);
-  tv2->axis(-3)->parallelize(ParallelType::TIDy);
-
-  int numel_x = 650;
-  int numel_y = 1000;
-  int numel_z = 4;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y, numel_z}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input});
-
-  auto aten_output = input.to(at::kDouble).sum({1, 2});
-  testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionMultiGridReduction_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  TensorView* tv1 = max(tv0, {0});
-  TensorView* tv2 = sum(tv0, {0});
-
-  fusion.addOutput(tv1);
-  fusion.addOutput(tv2);
-
-  int numel_x = 4;
-  int numel_y = 2;
-
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv1->axis(1)->parallelize(ParallelType::TIDx);
-
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(1)->parallelize(ParallelType::TIDx);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input});
-
-  std::vector<at::Tensor> aten_outputs = {
-      std::get<0>(input.to(at::kDouble).max(0)), input.to(at::kDouble).sum(0)};
-  testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionMultiGridReduction2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = sum(tv0, {0});
-  auto tv2 = sum(tv1, {0});
-  fusion.addOutput(tv2);
-
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv1->axis(1)->parallelize(ParallelType::BIDy);
-  tv2->axis(0)->parallelize(ParallelType::BIDy);
-
-  FusionExecutor fe;
-  ASSERT_ANY_THROW(fe.compileFusion(&fusion));
-}
-
-TEST(NVFuserTest, FusionReductionTFT_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  // tv1[I0, R1] = tv0[I0, I1]
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-
-  fusion.addOutput(tv1);
-
-  int numel_x = 1025;
-  int numel_y = 129;
-  int tidx = 16;
-  int tidy = 8;
-  int tidz = 8;
-
-  tv1->split(1, tidx);
-  // tv1[I0, R1o, R1i{tidx}]
-
-  tv1->split(1, tidz);
-  // tv1[I0, R1oo, R1Oi{tidz}, R1R1i{tidx}]
-
-  tv1->split(0, tidy);
-  // tv1[I0o, I0i, R1oo, R1Oi{tidz}, R1R1i{tidx}]
-
-  TensorView* tv2 = tv1->rFactor({2});
-  // tv2[I0o, I0i, R1oo, I1Oi{tidz}, I11i{tidx}]
-  // tv1[I0o, I0i,       R1Oi{tidz}, R1R1i{tidx}]
-
-  tv2->computeAt(tv1, 2);
-
-  tv1->axis(1)->parallelize(ParallelType::TIDy);
-
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-
-  tv1->axis(-2)->parallelize(ParallelType::TIDz);
-  tv2->axis(-2)->parallelize(ParallelType::TIDz);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y}, options);
-  at::Tensor cg_output = at::empty({numel_x}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({input}, {cg_output});
-
-  auto aten_output = input.to(at::kDouble).sum({1});
-  testValidate(
-      &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReductionOuterSplit_CUDA) {
-  // based off FusionReduction4
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
-
-  TensorView* tv2 = add(tv0, tv1);
-  // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1]
-
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  TensorView* tv3 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv2);
-  // tv3[I0, R1] = tv2[I0, I1]
-
-  TensorView* tv4 = makeSymbolicTensor(1);
-  fusion.addInput(tv4);
-
-  // tv5[I0] = tv3[I0, R1] * tv4[I0]
-  TensorView* tv5 = mul(tv3, tv4);
-  fusion.addOutput(tv5);
-
-  // RFactor the reduction
-  tv3->split(1, 16, false);
-  // tv3[I0, R1o{16}, R1i{tidx}] = tv2[I0, I1]
-
-  TensorView* tv6 = tv3->rFactor({-2});
-  // tv6[I0, R1o{16}, iR1i{tidx}] = tv2[I0, I1]
-  // tv3[I0,           R1i{tidx}] = tv3[I0, I1]
-  tv2->computeAt(tv6, 2);
-
-  // Compute at inline with tv5 (only 1D)
-  tv6->computeAt(tv3, 1);
-  tv3->computeAt(tv5, 1);
-
-  tv5->axis(0)->parallelize(ParallelType::BIDx);
-
-  // Intermediate tensors only need this, but doesn't hurt to do on inputs
-  // tv0, 1, 4
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv6->axis(-1)->parallelize(ParallelType::TIDx);
-
-  int numel_x = 1025;
-  int numel_y = 129;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  at::Tensor t1 = at::randn({numel_x, numel_y}, options);
-  at::Tensor t4 = at::randn({numel_x}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({t0, t1, t4});
-
-  auto t2 = t0.add(t1);
-  auto t3 = t2.to(at::kDouble).sum({1});
-  auto aten_output = t3.mul(t4);
-
-  testValidate(
-      &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionBranches_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
-  TensorView* tv2 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addInput(tv2);
-
-  auto tv3 = add(tv0, new Double(1.0));
-  auto tv4 = add(tv3, tv1);
-  auto tv5 = add(tv3, tv2);
-  auto tv6 = add(tv4, tv5);
-
-  fusion.addOutput(tv6);
-
-  constexpr int x = 63, y = 33;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({x, y}, options);
-  at::Tensor t1 = at::randn({x, y}, options);
-  at::Tensor t2 = at::randn({x, y}, options);
-
-  FusionExecutor fe;
-  tv6->merge(0);
-  tv6->split(0, 128);
-  tv6->split(0, 4);
-
-  tv6->axis(0)->parallelize(ParallelType::BIDx);
-
-  tv0->computeAt(tv6, 1);
-  tv1->computeAt(tv6, 1);
-  tv2->computeAt(tv6, 1);
-
-  tv3->axis(-2)->parallelize(ParallelType::Unroll);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv4->axis(-2)->parallelize(ParallelType::Unroll);
-  tv4->axis(-1)->parallelize(ParallelType::TIDx);
-  tv5->axis(-2)->parallelize(ParallelType::Unroll);
-  tv5->axis(-1)->parallelize(ParallelType::TIDx);
-  tv6->axis(-1)->parallelize(ParallelType::TIDx);
-
-  std::vector<IValue> aten_inputs = {t0, t1, t2};
-
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  auto t3 = t0.add(1.0);
-  auto t4 = t3.add(t1);
-  auto t5 = t3.add(t2);
-  auto aten_output = t4.add(t5);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSimpleBCast1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  TensorView* tv1 = add(tv0, new Double(1.5));
-
-  TensorView* tv2 = makeSymbolicTensor(2);
-  fusion.addInput(tv2);
-  TensorView* tv3 = makeSymbolicTensor(2);
-  fusion.addInput(tv3);
-  TensorView* tv4 = sub(tv2, tv3);
-
-  TensorView* tv5 = broadcast(tv1, {false, false, true});
-  TensorView* tv6 = broadcast(tv4, {true, false, false});
-
-  TensorView* tv7 = add(tv5, tv6);
-  fusion.addOutput(tv7);
-
-  tv7->split(-1, 4);
-  tv7->split(0, 8);
-
-  tv0->computeAt(tv7, -1);
-  tv2->computeAt(tv7, -1);
-
-  tv7->axis(0)->parallelize(ParallelType::BIDx);
-  tv7->axis(-1)->parallelize(ParallelType::TIDx);
-
-  constexpr int x = 63, y = 33, z = 15;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({x, y}, options);
-  at::Tensor t1 = t0.add(1.5);
-
-  at::Tensor t2 = at::randn({y, z}, options);
-  at::Tensor t3 = at::randn({y, z}, options);
-
-  at::Tensor t4 = t2.sub(t3);
-  at::Tensor t5 = t1.unsqueeze(-1).expand({x, y, z});
-
-  at::Tensor t6 = t4.expand({x, y, z});
-
-  at::Tensor aten_output = t5.add(t6);
-
-  std::vector<IValue> aten_inputs = {t0, t2, t3};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSimpleBCast2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  TensorView* tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
-
-  TensorView* tv2 = add(tv0, tv1);
-
-  TensorView* tv3 = broadcast(tv2, {false, false, true});
-
-  TensorView* tv4 = makeSymbolicTensor(2);
-  fusion.addInput(tv4);
-
-  TensorView* tv5 = sub(tv4, new Double(0.1));
-
-  TensorView* tv6 = broadcast(tv5, {true, false, false});
-
-  TensorView* tv7 = add(tv3, tv6);
-
-  fusion.addOutput(tv7);
-
-  tv7->merge(0, 1);
-
-  tv0->computeAt(tv7, -1);
-  tv4->computeAt(tv7, -1);
-
-  tv7->axis(0)->parallelize(ParallelType::BIDx);
-  tv7->axis(-1)->parallelize(ParallelType::TIDx);
-
-  constexpr int x = 63, y = 33, z = 15;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({x, y}, options);
-  at::Tensor t1 = at::randn({x, y}, options);
-  at::Tensor t2 = t0.add(t1);
-  at::Tensor t3 = t2.unsqueeze(-1).expand({x, y, z});
-
-  at::Tensor t4 = at::randn({y, z}, options);
-  at::Tensor t5 = t4.sub(0.1);
-  at::Tensor t6 = t5.expand({x, y, z});
-  at::Tensor aten_output = t3.add(t6);
-
-  at::Tensor cg_output = at::empty({x, y, z}, options);
-
-  std::vector<IValue> aten_inputs = {t0, t1, t4};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion(aten_inputs, {cg_output});
-
-  testValidate(
-      &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSimpleBCast3_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  std::vector<IterDomain*> dom;
-  dom.push_back(new IterDomain(new Int(0), new Int()));
-  dom.push_back(new IterDomain(
-      new Int(0),
-      new Int(1),
-      ParallelType::Serial,
-      IterType::BroadcastWithStride));
-
-  // tv0[I1, B{1}]
-  TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float);
-  fusion.addInput(tv0);
-
-  // tv1[I0, I1, I2]
-  TensorView* tv2 = makeSymbolicTensor(3);
-  fusion.addInput(tv2);
-
-  TensorView* tv3 = add(tv0, tv2);
-
-  fusion.addOutput(tv3);
-
-  tv3->merge(0);
-  tv3->merge(0);
-
-  tv0->computeAt(tv3, -1);
-  tv2->computeAt(tv3, -1);
-
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
-
-  constexpr int x = 2, y = 3, z = 4;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({y, 1}, options);
-  at::Tensor t2 = at::randn({x, y, z}, options);
-  auto aten_output = t0.add(t2);
-
-  std::vector<IValue> aten_inputs = {t0, t2};
-  at::Tensor cg_output = at::empty({x, y, z}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion(aten_inputs, {cg_output});
-
-  testValidate(
-      &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSimpleBCast4_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  std::vector<IterDomain*> dom;
-  dom.push_back(new IterDomain(
-      new Int(0),
-      new Int(1),
-      ParallelType::Serial,
-      IterType::BroadcastWithStride));
-  dom.push_back(new IterDomain(new Int(0), new Int()));
-  TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float);
-
-  TensorView* tv1 = makeSymbolicTensor(3);
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  TensorView* tv3 = add(tv0, tv1);
-
-  tv3->merge(0);
-  tv3->merge(0);
-  tv3->split(0, 128);
-  tv3->split(0, 4);
-
-  fusion.addOutput(tv3);
-
-  tv0->computeAt(tv3, -1);
-  tv1->computeAt(tv3, -1);
-
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-2)->parallelize(ParallelType::Unroll);
-
-  constexpr int x = 63, y = 33, z = 15;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({1, z}, options);
-  at::Tensor t1 = at::randn({x, y, z}, options);
-
-  auto aten_output = t0.add(t1);
-
-  at::Tensor cg_output = at::empty({x, y, z}, options);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion(aten_inputs, {cg_output});
-
-  testValidate(
-      &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSimpleBCast5_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  constexpr int m = 2, k = 3, n = 4;
-
-  auto zero = new Int(0);
-  auto M = new IterDomain(zero, new Int(m));
-  auto K = new IterDomain(zero, new Int(k));
-  auto N = new IterDomain(zero, new Int(n));
-
-  // Set up your input tensor views
-  TensorView* tv0 =
-      new TensorView(new TensorDomain({M, K}, {true, true}), DataType::Float);
-  // Note: IterDomain must not be reused, so K needs to be cloned.
-  TensorView* tv1 = new TensorView(
-      new TensorDomain({K->clone(), N}, {true, true}), DataType::Float);
-
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  TensorView* tv2 = broadcast(tv0, {false, false, true});
-  TensorView* tv3 = broadcast(tv1, {true, false, false});
-
-  TensorView* tv4 = add(tv2, tv3);
-
-  fusion.addOutput(tv4);
-
-  tv4->merge(0);
-  tv4->merge(0);
-
-  tv0->computeAt(tv4, -1);
-  tv1->computeAt(tv4, -1);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({m, k}, options);
-  at::Tensor t1 = at::randn({k, n}, options);
-
-  auto t2 = t0.unsqueeze(-1).expand({m, k, n});
-  auto t3 = t1.expand({m, k, n});
-  auto aten_output = t2.add(t3);
-
-  at::Tensor cg_output = at::empty({m, k, n}, options);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion(aten_inputs, {cg_output});
-
-  testValidate(
-      &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionComplexBCast1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  int x = 2, y = 3, z = 4;
-
-  auto tv0 = makeConcreteTensor({y});
-  auto tv1 = div(tv0, new Double(2.0));
-  auto tv2 = broadcast(tv1, {false, true});
-  auto tv3 = makeConcreteTensor({y, z});
-  auto tv4 = mul(tv2, tv3);
-  auto tv5 = broadcast(tv4, {true, false, false});
-  auto tv6 = makeConcreteTensor({x, y, z});
-  auto tv7 = add(tv5, tv6);
-
-  // tv0[    i1    ] = input
-  // tv1[    i1    ] = tv0/2.0
-  // tv2[    i1, b2] = bcast(tv1)
-  // tv3[    i1, i2] = input
-  // tv4[    i1, i2] = tv2 * tv3
-  // tv5[b0, i1, i2] = bcast(tv4)
-  // tv6[i0, i1, i2] = input
-  // tv7[i0, i1, i2] = tv5 + tv6
-
-  // tv4 = bcast(tv1) * tv3
-  // tv7 = bcast(tv4) + tv6
-
-  fusion.addInput(tv0);
-  fusion.addInput(tv3);
-  fusion.addInput(tv6);
-
-  fusion.addOutput(tv7);
-
-  tv7->merge(0);
-  tv7->merge(0);
-  tv0->computeAt(tv7, -1);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({y}, options);
-  at::Tensor t3 = at::randn({y, z}, options);
-  at::Tensor t6 = at::randn({x, y, z}, options);
-
-  auto t4 = t0.div(2.0).unsqueeze(-1).expand({y, z}) * t3;
-  auto aten_output = t4.unsqueeze(0).expand({x, y, z}) + t6;
-
-  std::vector<IValue> aten_inputs = {t0, t3, t6};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionComplexBCast2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  int x = 2, y = 3, z = 4;
-
-  auto tv0 = makeConcreteTensor({y, z});
-  auto tv1 = div(tv0, new Double(2.0));
-  auto tv2 = sum(tv1, {1});
-  auto tv3 = broadcast(tv2, {true, false});
-  auto tv4 = makeConcreteTensor({x, y});
-  auto tv5 = add(tv3, tv4);
-
-  // tv0[    i1, i2] = input
-  // tv1[    i1, i2] = tv0/2.0
-  // tv2[    i1    ] = sum(tv1, 1)
-  // tv3[b0, i1    ] = bcast(tv2)
-  // tv4[i0, i1    ] = input
-  // tv5[i0, i1    ] = tv3 + tv4
-
-  // tv2 = sum(tv0/2.0, 1)
-  // tv5 = bcast(tv2) + tv4
-
-  fusion.addInput(tv0);
-  fusion.addInput(tv4);
-
-  fusion.addOutput(tv5);
-
-  tv5->merge(0);
-  tv0->computeAt(tv5, -1);
-  tv1->computeAt(tv2, -1);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({y, z}, options);
-  at::Tensor t4 = at::randn({x, y}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({t0, t4});
-
-  auto t1 = t0.div(2.0);
-  auto t2 = t1.to(at::kDouble).sum(1);
-  auto t3 = t2.unsqueeze(0).expand({x, y});
-  auto aten_output = t3.add(t4);
-
-  testValidate(
-      &fusion, {cg_outputs}, {t0, t4}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  int w = 3, x = 4, y = 7, z = 8;
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  auto tv0 = makeSymbolicTensor(3);
-  auto tv1 = makeSymbolicTensor(4);
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  auto tv2 = add(tv0, new Double(1.0));
-  auto tv3 = broadcast(tv2, {true, false, false, false});
-  auto tv4 = add(tv3, tv1);
-
-  fusion.addOutput(tv4);
-
-  tv4->merge(0);
-  tv4->merge(0);
-  tv4->merge(0);
-
-  tv4->split(0, 128);
-  tv4->split(0, 4);
-
-  tv2->computeAt(tv4, 1);
-
-  tv4->axis(0)->parallelize(ParallelType::BIDx);
-  tv4->axis(1)->parallelize(ParallelType::Unroll);
-  tv4->axis(2)->parallelize(ParallelType::TIDx);
-
-  tv3->axis(1)->parallelize(ParallelType::Unroll);
-  tv3->axis(2)->parallelize(ParallelType::TIDx);
-
-  tv2->axis(1)->parallelize(ParallelType::Unroll);
-  tv2->axis(2)->parallelize(ParallelType::TIDx);
-
-  FusionExecutor fe;
-
-  at::Tensor t0 = at::randn({x, y, z}, options);
-  at::Tensor t1 = at::randn({w, x, y, z}, options);
-
-  auto t3 = t0.add(1.0);
-  auto aten_output = t3.add(t1);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  int w = 3, x = 4, y = 7, z = 8;
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  auto tv0 = makeSymbolicTensor(3);
-  auto tv1 = makeSymbolicTensor(4);
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  auto tv2 = add(tv0, new Double(1.0));
-  auto tv3 = broadcast(tv2, {true, false, false, false});
-  auto tv4 = add(tv3, tv1);
-
-  fusion.addOutput(tv4);
-
-  tv4->merge(-2);
-  tv4->merge(-2);
-  tv4->merge(-2);
-
-  tv4->split(0, 128);
-  tv4->split(0, 4);
-
-  tv2->computeAt(tv4, 1);
-
-  tv4->axis(0)->parallelize(ParallelType::BIDx);
-  tv4->axis(1)->parallelize(ParallelType::Unroll);
-  tv4->axis(2)->parallelize(ParallelType::TIDx);
-
-  tv3->axis(1)->parallelize(ParallelType::Unroll);
-  tv3->axis(2)->parallelize(ParallelType::TIDx);
-
-  tv2->axis(1)->parallelize(ParallelType::Unroll);
-  tv2->axis(2)->parallelize(ParallelType::TIDx);
-
-  FusionExecutor fe;
-
-  at::Tensor t0 = at::randn({x, y, z}, options);
-  at::Tensor t1 = at::randn({w, x, y, z}, options);
-
-  auto t3 = t0.add(1.0);
-  auto aten_output = t3.add(t1);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing3_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  int w = 3, x = 4, y = 7, z = 8;
-
-  auto tv0 = makeSymbolicTensor(3);
-  auto tv1 = makeSymbolicTensor(4);
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  auto tv2 = add(tv0, new Double(1.0));
-  auto tv3 = add(tv2, tv1);
-  fusion.addOutput(tv3);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({x, y, z}, options);
-  at::Tensor t1 = at::randn({w, x, y, z}, options);
-
-  auto t2 = t0.add(1.0);
-  auto aten_output = t2.add(t1);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  auto lparams = schedulePointwise(&fusion, aten_inputs);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing4_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeConcreteTensor({10, 20});
-  fusion.addInput(tv0);
-  TensorView* tv1 = makeConcreteTensor({10, 10, 20});
-  fusion.addInput(tv1);
-
-  TensorView* tv2 = add(tv0, new Double(1));
-  TensorView* tv3 = broadcast(tv2, {true, false, false});
-  TensorView* tv4 = add(tv3, tv1);
-  fusion.addOutput(tv4);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({10, 20}, options);
-  at::Tensor t1 = at::randn({10, 10, 20}, options);
-
-  auto t2 = t0.add(1.0);
-  auto aten_output = t2.add(t1);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing5_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  TensorView* tv1 = makeSymbolicTensor(3);
-  fusion.addInput(tv1);
-
-  TensorView* tv2 = add(tv0, new Double(1));
-  TensorView* tv3 = broadcast(tv2, {true, false, true});
-  TensorView* tv4 = add(tv3, tv1);
-  fusion.addOutput(tv4);
-
-  tv3->merge(0)->merge(0)->split(0, 2)->split(0, 3);
-  tv4->merge(0)->merge(0)->split(0, 2)->split(0, 3);
-
-  tv0->computeAt(tv4, 1);
-  tv1->computeAt(tv4, 1);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({7}, options);
-  at::Tensor t1 = at::randn({5, 7, 11}, options);
-
-  auto t2 = t0.add(1.0);
-  auto aten_output = t2.unsqueeze(-1).add(t1);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing6_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  std::vector<int64_t> tensor0_shape{7, 4, 7};
-  std::vector<int64_t> tensor1_shape{4, 7};
-
-  TensorView* tv0 = makeSymbolicTensor(tensor0_shape.size());
-  fusion.addInput(tv0);
-  TensorView* tv1 = makeSymbolicTensor(tensor1_shape.size());
-  fusion.addInput(tv1);
-
-  TensorView* tv2 = add(tv0, tv1);
-  TensorView* tv3 = sum(tv2, {0, 1});
-  fusion.addOutput(tv3);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor input0 = at::randn(tensor0_shape, options);
-  at::Tensor input1 = at::randn(tensor1_shape, options);
-
-  std::vector<int64_t> reduction_axes{0, 1};
-  auto reduction_params = getReductionHeuristics(&fusion, {input0, input1});
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-  scheduleReduction(&fusion, reduction_params.value());
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs =
-      fe.runFusion({input0, input1}, reduction_params.value().lparams);
-
-  auto aten_output = input0.add(input1).to(at::kDouble).sum(reduction_axes);
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {input0, input1},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      reduction_params.value().lparams);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing7_CUDA) {
-  // Might be able to use this one without 6 as the heuristics in 6 may change
-  // and this test is to cover the same issue.
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-
-  auto tv1 = broadcast(tv0, {false, true});
-
-  auto tv2 = makeSymbolicTensor(2);
-  fusion.addInput(tv2);
-
-  auto tv3 = add(tv1, tv2);
-  auto tv4 = sum(tv3, {0, 1});
-  fusion.addOutput(tv4);
-
-  tv4->merge(0, 1);
-  tv4->split(0, 128);
-  tv4->split(0, 4);
-
-  auto tv5 = tv4->rFactor({0, 1});
-
-  tv5->computeAt(tv4, -1);
-  tv0->computeAt(tv5, -1);
-
-  tv4->axis(0)->parallelize(ParallelType::TIDx);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  const int numel_x = 100;
-  const int numel_y = 200;
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto at_t0 = at::randn({numel_x}, options);
-  auto at_t1 = at::randn({numel_x, numel_y}, options);
-
-  auto cg_outputs = fe.runFusion({at_t0, at_t1});
-
-  auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1)
-                         .to(at::kDouble)
-                         .sum();
-
-  testValidate(
-      &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing8_CUDA) {
-  // Same as 7 but with outer splits instead of inner
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-
-  auto tv1 = broadcast(tv0, {false, true});
-
-  auto tv2 = makeSymbolicTensor(2);
-  fusion.addInput(tv2);
-
-  auto tv3 = add(tv1, tv2);
-  auto tv4 = sum(tv3, {0, 1});
-  fusion.addOutput(tv4);
-
-  tv4->merge(0, 1);
-  tv4->split(0, 128, false);
-  tv4->split(0, 4, false);
-
-  auto tv5 = tv4->rFactor({0, 1});
-
-  tv5->computeAt(tv4, -1);
-  tv0->computeAt(tv5, -1);
-
-  tv4->axis(0)->parallelize(ParallelType::TIDx);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  const int numel_x = 100;
-  const int numel_y = 200;
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto at_t0 = at::randn({numel_x}, options);
-  auto at_t1 = at::randn({numel_x, numel_y}, options);
-
-  auto cg_outputs = fe.runFusion({at_t0, at_t1});
-
-  auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1)
-                         .to(at::kDouble)
-                         .sum();
-
-  testValidate(
-      &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing9_CUDA) {
-  // Same as 7 but with outer splits instead of inner
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-
-  auto tv1 = broadcast(tv0, {false, true});
-
-  auto tv2 = mul(tv1, new Double(2));
-  fusion.addOutput(tv2);
-
-  auto tv3 = makeSymbolicTensor(3);
-  fusion.addInput(tv3);
-
-  auto tv4 = add(tv3, tv2);
-  fusion.addOutput(tv4);
-
-  const int numel_x = 200;
-  const int numel_y = 300;
-  const int numel_z = 400;
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto at_t0 = at::randn({numel_y}, options);
-  auto at_t3 = at::randn({numel_x, numel_y, numel_z}, options);
-  std::vector<IValue> aten_inputs = {at_t0, at_t3};
-
-  auto lparams = schedulePointwise(&fusion, aten_inputs);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
-  auto at_t1 = at_t0.unsqueeze(-1);
-  auto at_t2 = at_t1.mul(2.0);
-
-  auto at_t4 = at_t3.add(at_t2);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {at_t2, at_t4}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing10_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeContigTensor(2);
-  TensorView* tv1 = makeContigTensor(2);
-
-  // Register your inputs
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  // Do math with it, it returns a `Val*` but can be static_casted back to
-  // TensorView
-  TensorView* tv2 = add(tv1, new Double(2.0));
-  TensorView* tv3 = add(tv0, tv2);
-
-  // Register your outputs
-  fusion.addOutput(tv3);
-
-  auto tv0_cache = tv0->cache_after();
-  auto tv1_cache = tv1->cache_after();
-
-  std::vector<TensorView*> tvs = {tv0_cache, tv1_cache, tv2, tv3};
-
-  for (auto tv : tvs) {
-    tv->split(1, 2, false);
-    tv->split(1, 1);
-    tv->split(-1, 4);
-    // [I0, 2, 1, I1/2/4, 4]
-    tv->reorder({{1, 2}, {2, 3}, {3, 1}});
-    tv->axis(0)->parallelize(ParallelType::BIDx);
-    tv->axis(1)->parallelize(ParallelType::TIDx);
-  }
-
-  // For all inputs, computeAt the output inline, temporaries should be squeezed
-  // between them
-  tv0->computeAt(tv3, 1);
-  tv1->computeAt(tv3, 1);
-
-  tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize);
-  tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor input1 = at::randn({64, 128}, options);
-  at::Tensor input2 = at::rand_like(input1);
-  at::Tensor output = at::empty_like(input1);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({input1, input2}, {output});
-
-  at::Tensor tv2_ref = input2 + 2.0;
-  at::Tensor output_ref = input1 + tv2_ref;
-
-  TORCH_CHECK(output_ref.equal(output));
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing11_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  int w = 3, x = 4, y = 7, z = 8;
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  auto tv0 = makeSymbolicTensor(4);
-  auto tv1 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  auto tv2 = add(tv1, new Double(1.0));
-  auto tv3 = broadcast(tv2, {true, false, true, true});
-  auto tv4 = add(tv3, tv0);
-
-  fusion.addOutput(tv4);
-
-  tv4->merge(0);
-  tv4->merge(1);
-
-  tv4->split(1, 32);
-  tv4->split(0, 1);
-
-  tv4->reorder({{2, 1}});
-
-  tv2->computeAt(tv4, 3);
-
-  tv2->setMemoryType(MemoryType::Global);
-
-  tv4->axis(0)->parallelize(ParallelType::BIDx);
-  tv4->axis(1)->parallelize(ParallelType::BIDy);
-  tv4->axis(2)->parallelize(ParallelType::Unswitch);
-  tv4->axis(-1)->parallelize(ParallelType::TIDx);
-
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
-  FusionExecutor fe;
-
-  at::Tensor t0 = at::randn({w, x, y, z}, options);
-  at::Tensor t1 = at::randn({x}, options);
-
-  auto t3 = t1.add(1.0).unsqueeze(-1).unsqueeze(-1);
-  auto aten_output = t3.add(t0);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-// Intended to stress the lowering of our code generator
-TEST(NVFuserTest, FusionAdvancedLowering1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeConcreteTensor({9, 5});
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = add(tv0, new Double(1));
-  TensorView* tv2 = add(tv1, new Double(2));
-  TensorView* tv3 = add(tv1, new Double(3));
-  TensorView* tv4 = sum(tv3, {1});
-
-  fusion.addOutput(tv2);
-  fusion.addOutput(tv4);
-
-  tv4->split(1, 4);
-  auto tv5 = tv4->rFactor({2});
-
-  tv1->computeAt(tv5, 2);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(1);
-  at::Tensor aten_input = at::randn({9, 5}, options);
-
-  auto t1 = aten_input.add(1.0);
-  auto t2 = t1.add(2.0);
-  auto t3 = t1.add(3.0);
-  auto t4 = t3.sum(1);
-
-  std::vector<at::Tensor> aten_outputs = {t2, t4};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedLowering2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Progressively broadcast tensors
-  TensorView* tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  TensorView* tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
-  TensorView* tv2 = makeSymbolicTensor(3);
-  fusion.addInput(tv2);
-
-  TensorView* tv3 = add(tv0, new Double(1));
-  TensorView* tv4 = broadcast(tv3, {false, true});
-  TensorView* tv5 = add(tv4, tv1);
-  TensorView* tv6 = add(tv5, tv2);
-
-  fusion.addOutput(tv6);
-
-  // Split inner dimension
-  tv6->split(1, 4);
-  // Merge middle dims with outer dimensions
-  tv6->merge(2);
-  tv6->merge(0);
-
-  // tv6[I0*I1o, I1i*I2]
-
-  // Compute everything inline
-  tv0->computeAt(tv6, -1);
-
-  tv6->axis(0)->parallelize(ParallelType::BIDx);
-  tv6->axis(1)->parallelize(ParallelType::TIDx);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  int x = 13, y = 9, z = 5;
-  at::Tensor t0 = at::randn({y}, options);
-  at::Tensor t1 = at::randn({y, z}, options);
-  at::Tensor t2 = at::randn({x, y, z}, options);
-
-  auto t3 = t0.add(1.0);
-  auto t4 = t3.unsqueeze(-1);
-  auto t5 = t4.add(t1);
-  auto t6 = t5.add(t2);
-
-  std::vector<IValue> aten_inputs = {t0, t1, t2};
-  std::vector<at::Tensor> aten_outputs = {t6};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
-}
-
-// TODO: Complete test
-TEST(NVFuserTest, FusionAdvancedLowering3_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeConcreteTensor({1, -1});
-  auto tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  // [b0, i1]
-  auto tv2 = add(tv0, new Double(2.0));
-
-  // [i0, i1]
-  auto tv3 = add(tv1, new Double(3.0));
-
-  // [b0, i1]
-  auto tv4 = add(tv2, new Double(4.0));
-
-  // [io, i1]
-  auto tv5 = add(tv2, tv3);
-
-  fusion.addOutput(tv4);
-  fusion.addOutput(tv5);
-
-  tv0->computeAt(tv4, -1);
-
-  tv3->setMemoryType(MemoryType::Global);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  int x = 13, y = 9;
-  at::Tensor t0 = at::randn({1, y}, options);
-  at::Tensor t1 = at::randn({x, y}, options);
-
-  auto t4 = t0 + 2 + 4;
-  auto t5 = t0 + 2 + t1 + 3;
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-  std::vector<at::Tensor> aten_outputs = {t4, t5};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
-}
-
-// This excercises indexing with broadcast root axes. Non-broadcast
-// axes need to be preferred when propagating index exprs to root
-// axes. See, e.g., Index::getConsumerIndex_impl.
-TEST(NVFuserTest, FusionAdvancedLowering4_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  auto tv1 = broadcast(tv0, {false, true});
-  auto tv2 = broadcast(tv1, {false, false, true});
-  auto tv3 = makeSymbolicTensor(3);
-  fusion.addInput(tv3);
-  auto tv4 = add(tv2, tv3);
-  fusion.addOutput(tv4);
-
-  tv4->merge(1)->merge(0);
-  tv4->split(0, 8);
-  tv0->computeAt(tv4, 1);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const int bx = 10;
-  const int by = 20;
-  const int bz = 30;
-  at::Tensor t0 = at::randn({bx}, options);
-  at::Tensor t3 = at::randn({bx, by, bz}, options);
-  std::vector<IValue> aten_inputs = {t0, t3};
-
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  auto aten_output =
-      t0.unsqueeze(-1).expand({bx, by}).unsqueeze(-1).expand({bx, by, bz}) + t3;
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedLowering5_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeConcreteTensor({5, 4, 3});
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = makeConcreteTensor({5, 3});
-  fusion.addInput(tv1);
-
-  auto tv2 = broadcast(tv1, {false, true, false});
-
-  auto tv3 = add(tv0, tv2);
-
-  fusion.addOutput(tv3);
-
-  tv2->merge(0);
-  tv1->computeAt(tv2, 1);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(1);
-  at::Tensor t0 = at::randn({5, 4, 3}, options);
-  at::Tensor t1 = at::randn({5, 3}, options);
-  auto t2 = t1.unsqueeze(1);
-  auto t3 = t0 + t2;
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-  std::vector<at::Tensor> aten_outputs = {t3};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
-}
-
-// Test a simple Gemm but also play around with fusion executor features
-TEST(NVFuserTest, FusionSimpleGemm_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2); // M, K
-  TensorView* tv1 = makeSymbolicTensor(2); // K, N
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  TensorView* tv2 = broadcast(tv0, {false, false, true});
-  // tv2[I0, I1, B] = tv0[I0, I1]
-
-  TensorView* tv3 = broadcast(tv1, {true, false, false});
-  // tv3[B, I1, I2] = tv1[I1, I2]
-
-  // tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2]
-  TensorView* tv4 = mul(tv2, tv3);
-  // tv5[I0, R1, I2] = tv4[I0, I1, I2]
-  TensorView* tv5 = sum(tv4, {1});
-  fusion.addOutput(tv5);
-
-  tv5->split(1, 32);
-  // tv5[I0, R1o, R1i{32}, I2]
-
-  auto tv6 = tv5->rFactor({1});
-  // tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2]
-  // tv5[I0,    , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2]
-
-  tv5->split(0, 4);
-  tv5->split(-1, 4);
-  // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
-  // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
-
-  tv0->computeAt(tv5, -1);
-  tv1->computeAt(tv5, -1);
-
-  // tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
-  // tv5[I0o, I0i{4},    , R1i{32}, I2o, I2i{4}]
-  //--> (line symbolizes compute at location)
-  // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o]
-  // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
-  // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
-
-  tv0->computeAt(tv6, -1);
-  tv1->computeAt(tv6, -1);
-  // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
-  // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
-  // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
-
-  tv5->axis(0)->parallelize(ParallelType::BIDz);
-  tv5->axis(1)->parallelize(ParallelType::TIDz);
-
-  tv5->axis(-2)->parallelize(ParallelType::BIDy);
-  tv5->axis(-1)->parallelize(ParallelType::TIDy);
-
-  tv5->axis(2)->parallelize(ParallelType::TIDx);
-  tv6->axis(2)->parallelize(ParallelType::TIDx);
-
-  constexpr int M = 65, K = 33, N = 17;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({M, K}, options);
-  at::Tensor t1 = at::randn({K, N}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  // Lets specify a few bounds in launch params to make sure it works
-  fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4));
-
-  // Make sure bad launch params throws
-  // TODO: Re-enable once we have parallelization validation in.
-  // ASSERT_ANY_THROW(fe.runFusion({t0, t1}, LaunchParams(1, 2, 3, 4, 5, 6)));
-
-  // Don't specify any launch params
-  auto cg_outputs = fe.runFusion({t0, t1});
-
-  auto aten_output = t0.to(at::kDouble).matmul(t1.to(at::kDouble));
-
-  testValidate(
-      &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Softmax with a 1D tensor. Parallelized only with a single thread block.
-TEST(NVFuserTest, FusionSoftmax1D_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  const int tidx = 128;
-  const int dimx = 1000;
-
-  // Set up your input tensor views
-  TensorView* input_tv0 = makeSymbolicTensor(1);
-  fusion.addInput(input_tv0);
-
-  TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0);
-  TensorView* sum_exp_tv2 = sum(exp_tv1, {-1});
-  TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {true});
-
-  // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
-  // computed at sum_exp_rf_tv8.
-  TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0);
-
-  TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3);
-
-  fusion.addOutput(output_tv4);
-
-  bcast_sum_tv3->split(0, tidx);
-
-  sum_exp_tv2->split(-1, tidx);
-  TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2});
-
-  output_tv4->split(-1, tidx);
-
-  exp_tv1->computeAt(sum_exp_rf_tv5, -1);
-  exp_tv1_copy->computeAt(output_tv4, -1);
-
-  TensorView* tensors_to_parallelize[] = {
-      sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5};
-
-  for (auto tv : tensors_to_parallelize) {
-    tv->axis(-1)->parallelize(ParallelType::TIDx);
-  }
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({dimx}, options);
-  at::Tensor cg_output = at::empty({dimx}, options);
-  at::Tensor t3_output = at::empty_like(cg_output, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({t0}, {cg_output});
-
-  auto aten_output = at::_softmax(t0.to(at::kDouble), -1, false);
-
-  testValidate(&fusion, {cg_output}, {t0}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Softmax with a 1D tensor with input normalization.
-TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  const int tidx = 128;
-  const int dimx = 1000;
-
-  // Set up your input tensor views
-  TensorView* input_tv0 = makeSymbolicTensor(1);
-  fusion.addInput(input_tv0);
-
-  // Normalize with the max value before computing exp.
-  TensorView* max_val_tv1 =
-      reductionOp(BinaryOpType::Max, {-1}, new Double(0), input_tv0);
-  TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {true});
-  TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2);
-  TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3);
-  TensorView* sum_exp_tv5 = sum(exp_tv4, {-1});
-  TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {true});
-
-  // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
-  // computed at sum_exp_rf_tv8.
-  TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2);
-  TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy);
-
-  TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6);
-
-  fusion.addOutput(output_tv7);
-  bcast_max_tv2->split(0, tidx);
-  bcast_sum_tv6->split(0, tidx);
-
-  max_val_tv1->split(-1, tidx);
-  TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2});
-
-  sum_exp_tv5->split(-1, tidx);
-  TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2});
-
-  output_tv7->split(-1, tidx);
-
-  sub_tv3->computeAt(sum_exp_rf_tv9, -1);
-  sub_tv3_copy->computeAt(output_tv7, -1);
-
-  TensorView* tensors_to_parallelize[] = {
-      max_val_tv1,
-      bcast_max_tv2,
-      sum_exp_tv5,
-      bcast_sum_tv6,
-      output_tv7,
-      max_val_rf_tv8,
-      sum_exp_rf_tv9};
-
-  for (auto tv : tensors_to_parallelize) {
-    tv->axis(-1)->parallelize(ParallelType::TIDx);
-  }
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({dimx}, options);
-  at::Tensor t3_output = at::empty({dimx}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input});
-
-  auto aten_output = at::_softmax(input.to(at::kDouble), -1, false);
-
-  testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Softmax with a 3D tensor, where the inner-most 3rd dimension is
-// normalized. Pallelized with multiple thread blocks.
-TEST(NVFuserTest, FusionSoftmax3D_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  const int tidx = 32;
-  const int dimx = 32;
-  const int dimy = 16;
-  const int dimz = 130;
-
-  // Set up your input tensor views
-  TensorView* input_tv0 = makeSymbolicTensor(3);
-  fusion.addInput(input_tv0);
-
-  TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0);
-  TensorView* sum_exp_tv2 = sum(exp_tv1, {-1});
-  TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {false, false, true});
-
-  // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
-  // computed at sum_exp_rf_tv8.
-  TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0);
-
-  TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3);
-
-  fusion.addOutput(output_tv4);
-
-  bcast_sum_tv3->split(-1, tidx);
-
-  sum_exp_tv2->split(-1, tidx);
-  TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2});
-
-  output_tv4->split(-1, tidx);
-
-  exp_tv1->computeAt(sum_exp_rf_tv5, -1);
-  exp_tv1_copy->computeAt(output_tv4, -1);
-
-  TensorView* tensors_to_parallelize[] = {
-      sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5};
-
-  for (auto tv : tensors_to_parallelize) {
-    tv->axis(0)->parallelize(ParallelType::BIDx);
-    tv->axis(1)->parallelize(ParallelType::BIDy);
-    tv->axis(-1)->parallelize(ParallelType::TIDx);
-  }
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({dimx, dimy, dimz}, options);
-
-  at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({input}, {cg_output});
-
-  auto aten_output = at::_softmax(input.to(at::kDouble), -1, false);
-
-  testValidate(
-      &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Softmax with a 3D tensor with input normalization.
-TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  const int tidx = 32;
-  const int dimx = 32;
-  const int dimy = 16;
-  const int dimz = 130;
-
-  // Set up your input tensor views
-  TensorView* input_tv0 = makeSymbolicTensor(3);
-  fusion.addInput(input_tv0);
-
-  // Normalize with the max value before computing exp.
-  TensorView* max_val_tv1 =
-      reductionOp(BinaryOpType::Max, {-1}, new Double(0), input_tv0);
-  TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true});
-  TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2);
-  TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3);
-  TensorView* sum_exp_tv5 = sum(exp_tv4, {-1});
-  TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {false, false, true});
-
-  // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
-  // computed at sum_exp_rf_tv8.
-  TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2);
-  TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy);
-
-  TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6);
-
-  fusion.addOutput(output_tv7);
-
-  bcast_max_tv2->split(-1, tidx);
-  bcast_sum_tv6->split(-1, tidx);
-
-  max_val_tv1->split(-1, tidx);
-  TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2});
-
-  sum_exp_tv5->split(-1, tidx);
-  TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2});
-
-  output_tv7->split(-1, tidx);
-
-  sub_tv3->computeAt(sum_exp_rf_tv9, -1);
-  sub_tv3_copy->computeAt(output_tv7, -1);
-
-  TensorView* tensors_to_parallelize[] = {
-      max_val_tv1,
-      bcast_max_tv2,
-      sum_exp_tv5,
-      bcast_sum_tv6,
-      output_tv7,
-      max_val_rf_tv8,
-      sum_exp_rf_tv9};
-
-  for (auto tv : tensors_to_parallelize) {
-    tv->axis(0)->parallelize(ParallelType::BIDx);
-    tv->axis(1)->parallelize(ParallelType::BIDy);
-    tv->axis(-1)->parallelize(ParallelType::TIDx);
-  }
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({dimx, dimy, dimz}, options);
-  at::Tensor t3_output = at::empty({dimx, dimy, dimz}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input});
-
-  auto aten_output = at::_softmax(input.to(at::kDouble), -1, false);
-
-  testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSoftmaxComputeAt_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  auto tv1 = sum(tv0, {1});
-  auto tv2 = broadcast(tv1, {false, true});
-
-  auto tv3 = add(tv0, new Double(1.0));
-
-  auto tv4 = mul(tv2, tv3);
-
-  auto tv5 = sum(tv4, {1});
-  auto tv6 = broadcast(tv5, {false, true});
-
-  auto tv7 = sub(tv6, tv4);
-  fusion.addOutput(tv7);
-
-  tv1->computeAt(tv7, 1);
-  ASSERT_ANY_THROW(tv1->computeAt(tv7, -1));
-}
-
-// Similar to FusionReduction but uses grid reduction
-TEST(NVFuserTest, FusionGridReduction1_CUDA) {
-  const int gdimx = 32;
-  const int bdimx = 128;
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  // tv1[I0, R1] = tv0[I0, I1]
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
-  tv1->split(1, bdimx);
-  // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
-  tv1->split(1, gdimx);
-  // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
-
-  TensorView* tv2 = tv1->rFactor({1});
-  // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
-  // tv1[I0,        R1oi{32},  R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
-
-  // Incrementally, can print in between for debugging
-  tv0->computeAt(tv2, 1);
-  tv2->computeAt(tv1, 1);
-
-  // Re do it all at once, because why not.
-  tv0->computeAt(tv1, 1);
-
-  tv1->axis(0)->parallelize(ParallelType::BIDy);
-  tv1->axis(1)->parallelize(ParallelType::BIDx);
-  tv2->axis(2)->parallelize(ParallelType::BIDx);
-
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
-  // reduced shape for OOM on upstream CI
-  int numel_x = 1000;
-  int numel_y = 65000;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y}, options);
-  at::Tensor cg_output = at::empty({numel_x}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({input}, {cg_output});
-
-  auto aten_output = input.to(at::kDouble).sum({1});
-
-  testValidate(
-      &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Same test as the above but uses BIDy and TIDx for reduction
-TEST(NVFuserTest, FusionGridReduction2_CUDA) {
-  const int gdimy = 32;
-  const int bdimx = 128;
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  // tv1[I0, R1] = tv0[I0, I1]
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
-  tv1->split(1, bdimx);
-  // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
-  tv1->split(1, gdimy);
-  // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
-
-  TensorView* tv2 = tv1->rFactor({1});
-  // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
-  // tv1[I0,        R1oi{32},  R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
-
-  // Incrementally, can print in between for debugging
-  tv0->computeAt(tv2, 1);
-  tv2->computeAt(tv1, 1);
-
-  // Re do it all at once, because why not.
-  tv0->computeAt(tv1, 1);
-
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv1->axis(1)->parallelize(ParallelType::BIDy);
-  tv2->axis(2)->parallelize(ParallelType::BIDy);
-
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
-  // reduced shape for OOM on upstream CI
-  int numel_x = 1000;
-  int numel_y = 65000;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input});
-
-  auto aten_output = input.to(at::kDouble).sum({1});
-
-  testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Same test but uses BIDy and BIDz for reduction. No TID used.
-TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) {
-  // Grid reductions when there aren't any threads are serial reductions
-  // keep these numbers low so our error isn't too high compared to normal cuda
-  // reductions
-  const int gdimz = 15;
-  const int gdimy = 9;
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  // tv1[I0, R1] = tv0[I0, I1]
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
-  tv1->split(1, gdimy);
-  // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
-  tv1->split(1, gdimz);
-  // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
-
-  TensorView* tv2 = tv1->rFactor({1});
-  // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
-  // tv1[I0,        R1oi{32},  R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
-
-  // Incrementally, can print in between for debugging
-  tv0->computeAt(tv2, 1);
-  tv2->computeAt(tv1, 1);
-
-  // Re do it all at once, because why not.
-  tv0->computeAt(tv1, 1);
-
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv1->axis(1)->parallelize(ParallelType::BIDz);
-  tv2->axis(2)->parallelize(ParallelType::BIDz);
-  tv1->axis(-1)->parallelize(ParallelType::BIDy);
-  tv2->axis(-1)->parallelize(ParallelType::BIDy);
-
-  int numel_x = 100;
-  int numel_y = 6500;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y}, options);
-  at::Tensor cg_output = at::empty({numel_x}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({input}, {cg_output});
-
-  auto aten_output = input.to(at::kDouble).sum({1});
-  testValidate(
-      &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0
-TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) {
-  // Grid reductions when there aren't any threads are serial reductions
-  // keep these numbers low so our error isn't too high compared to normal cuda
-  // reductions
-  const int gdimz = 15;
-  const int gdimy = 9;
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  // tv1[R0, I1] = tv0[I0, I1]
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {0}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
-  tv1->split(0, gdimy);
-  // tv1[R0o, R0i{128}, I1] = tv0[I0, I1]
-  tv1->split(0, gdimz);
-  // tv1[R0oo, R0oi{32}, R0i{128}, I1] = tv0[I0, I1]
-
-  TensorView* tv2 = tv1->rFactor({0});
-  // tv2[R0oo, I0oi{32}, I0i{128}, I1] = tv0[I0, I1]
-  // tv1[      R0oi{32}, R0i{128}, I1] = tv2[R0oo, I0oi{32}, I0i{128}, I1]
-
-  // Note that computeAt isn't going to make anything better as there
-  // is no dynamically sized dimension.
-
-  // Map parallelism as [Serial, BIDz, BIDy, BIDx]
-  tv1->axis(-1)->parallelize(ParallelType::BIDx);
-  tv2->axis(-1)->parallelize(ParallelType::BIDx);
-  tv1->axis(-2)->parallelize(ParallelType::BIDy);
-  tv2->axis(-2)->parallelize(ParallelType::BIDy);
-  tv1->axis(-3)->parallelize(ParallelType::BIDz);
-  tv2->axis(-3)->parallelize(ParallelType::BIDz);
-
-  int numel_x = 6500;
-  int numel_y = 100;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input});
-
-  auto aten_output = input.to(at::kDouble).sum({0});
-
-  testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// This is similar to the FusionReduction, but swaps BIDx and TIDx
-TEST(NVFuserTest, FusionGridReduction4_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  const int bdimx = 128;
-  const int gdimx = 1024;
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  // tv1[I0, R1] = tv0[I0, I1]
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
-  tv1->split(1, gdimx);
-  // tv1[I0, R1o, R1i{1024}] = tv0[I0, I1]
-  tv1->split(1, 4);
-  // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1]
-
-  TensorView* tv2 = tv1->rFactor({1});
-  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
-  // tv1[I0,        R1oi{4},  R1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
-
-  TensorView* tv3 = tv1->rFactor({1});
-  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
-  // tv3[I0,        R1oi{4}, Ir1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
-  // tv1[I0,                  R1i{1024}] = tv3[I0,        R1oi{4}, Ir1i{1024}]
-
-  // Incrementally, can print in between for debugging
-  tv0->computeAt(tv2, 1);
-  tv2->computeAt(tv3, 1);
-  tv3->computeAt(tv1, 1);
-
-  // Re do it all at once, because why not.
-  tv0->computeAt(tv1, 1);
-
-  tv2->axis(2)->parallelize(ParallelType::Unroll);
-  tv1->axis(0)->parallelize(ParallelType::TIDx);
-
-  tv1->axis(-1)->parallelize(ParallelType::BIDx);
-  tv2->axis(-1)->parallelize(ParallelType::BIDx);
-  tv3->axis(-1)->parallelize(ParallelType::BIDx);
-
-  int numel_x = bdimx;
-  int numel_y = 65000;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y}, options);
-  at::Tensor cg_output = at::empty({numel_x}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({input}, {cg_output});
-
-  auto aten_output = input.to(at::kDouble).sum({1});
-  testValidate(
-      &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Grid reduction with 2D thread blocks but only TIDx and BIDx are
-// mapped to a reduction dim
-TEST(NVFuserTest, FusionGridReduction5_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  const int bdimx = 64;
-  const int bdimy = 16;
-  const int gdimx = 4;
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  // tv1[I0, R1] = tv0[I0, I1]
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
-  tv1->split(1, bdimx);
-  // tv1[I0, R1o, R1i{64}] = tv0[I0, I1]
-  tv1->split(1, gdimx);
-  // tv1[I0, R1oo, R1oi{4}, R1i{64}] = tv0[I0, I1]
-
-  TensorView* tv2 = tv1->rFactor({1});
-  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}] = tv0[I0, I1]
-  // tv1[I0,        R1oi{4},  R1i{64}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}]
-
-  tv0->computeAt(tv1, 1);
-
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
-  tv1->axis(-2)->parallelize(ParallelType::BIDx);
-  tv2->axis(-2)->parallelize(ParallelType::BIDx);
-
-  tv1->axis(0)->parallelize(ParallelType::TIDy);
-
-  int numel_x = bdimy;
-  int numel_y = 6500;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input});
-
-  auto aten_output = input.to(at::kDouble).sum({1});
-  testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Similar to FusionGridReduction1 but with 3D tensors
-TEST(NVFuserTest, FusionGridReduction6_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(3);
-  fusion.addInput(tv0);
-
-  // tv1[I0, R1, R2] = tv0[I0, I1, I2]
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
-  // Splitting for TID
-  tv1->split(2, 128);
-  // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
-
-  // Splitting for BID
-  tv1->split(1, 128);
-
-  // tv1[I0, R1o, R1i{128}, R2o, R2i{128}] = tv0[I0, I1, I2]
-
-  TensorView* tv2 = tv1->rFactor({3});
-  // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
-  // tv1[I0, R1o, R1i{128},      R2i{128}]
-
-  TensorView* tv3 = tv1->rFactor({1});
-  // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
-  // tv3[I0, R1o, I1i{128},      I2i{128}]
-  // tv1[I0,      R1i{128},      R2i{128}]
-
-  tv3->computeAt(tv1, 1);
-  tv2->computeAt(tv3, 3);
-
-  tv1->axis(0)->parallelize(ParallelType::BIDy);
-
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
-  tv1->axis(-2)->parallelize(ParallelType::BIDx);
-  tv2->axis(-3)->parallelize(ParallelType::BIDx);
-  tv3->axis(-2)->parallelize(ParallelType::BIDx);
-
-  int numel_x = 6500;
-  int numel_y = 200;
-  int numel_z = numel_y;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y, numel_z}, options);
-  at::Tensor cg_output = at::empty({numel_x}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({input}, {cg_output});
-
-  auto aten_output = input.to(at::kDouble).sum({1, 2});
-
-  testValidate(
-      &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// See issue #1049
-TEST(NVFuserTest, FusionGridReduction7_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-
-  auto tv1 = sum(tv0, {0});
-  fusion.addOutput(tv1);
-
-  tv1->split(0, 1000);
-
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv1->axis(1)->parallelize(ParallelType::BIDy);
-
-  const int numel_x = 1;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x}, options);
-  at::Tensor cg_output = at::empty({numel_x}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto out = fe.runFusion({input});
-
-  auto aten_output = input.sum({0});
-
-  testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) {
-  int bid_x = 3;
-  int tid_x = 2;
-  int red_dim = 0;
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  TensorView* tv1 =
-      reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  tv1->split(-1, tid_x);
-  tv1->axis(-2)->parallelize(ParallelType::BIDx);
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({16, bid_x * tid_x}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input});
-
-  auto aten_output = input.to(at::kDouble).sum({red_dim});
-
-  testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSplitBCast_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* input_tv0 = makeSymbolicTensor(3);
-  TensorView* input_tv1 = makeSymbolicTensor(3);
-  fusion.addInput(input_tv0);
-  fusion.addInput(input_tv1);
-
-  TensorView* sum_tv2 =
-      reductionOp(BinaryOpType::Add, {2}, new Double(0), input_tv0);
-  TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true});
-  TensorView* output_tv4 = div(input_tv1, bcast_tv3);
-
-  sum_tv2->split(-1, 32);
-  TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2});
-
-  bcast_tv3->split(-1, 32);
-  output_tv4->split(-1, 32);
-
-  sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx);
-  sum_tv2->axis(0)->parallelize(ParallelType::BIDx);
-  bcast_tv3->axis(0)->parallelize(ParallelType::BIDx);
-  output_tv4->axis(0)->parallelize(ParallelType::BIDx);
-
-  sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy);
-  sum_tv2->axis(1)->parallelize(ParallelType::BIDy);
-  bcast_tv3->axis(1)->parallelize(ParallelType::BIDy);
-  output_tv4->axis(1)->parallelize(ParallelType::BIDy);
-
-  sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx);
-  sum_tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  bcast_tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  output_tv4->axis(-1)->parallelize(ParallelType::TIDx);
-
-  fusion.addOutput(output_tv4);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({32, 32, 128}, options);
-  at::Tensor t1 = at::randn({32, 32, 128}, options);
-  at::Tensor cg_output = at::empty({32, 32, 128}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({t0, t1}, {cg_output});
-}
-
-TEST(NVFuserTest, FusionBCastInnerDim_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  // reduce then broadcast
-  auto tv1 = sum(tv0, {0});
-  auto tv2 = broadcast(tv1, {false, true});
-
-  TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast());
-}
-
-TEST(NVFuserTest, FusionBCastReduce_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-
-  auto tv1 = broadcast(tv0, {true, false, false});
-  auto tv2 = sum(tv1, {1});
-  TORCH_CHECK(
-      tv2->axis(0)->isBroadcast() && tv2->axis(1)->isReduction() &&
-      !tv2->axis(2)->isBroadcast() && !tv2->axis(2)->isReduction());
-}
-
-// Multiple consumer reduction with computeAt
-// https://github.com/csarofeen/pytorch/issues/110
-TEST(NVFuserTest, FusionReductionMultiConsumer_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = unaryOp(UnaryOpType::Exp, tv0);
-  auto tv2 = reductionOp(BinaryOpType::Max, {-1}, new Double(0), tv1);
-  auto tv3 = reductionOp(BinaryOpType::Min, {-1}, new Double(0), tv1);
-  auto tv4 = add(tv2, tv3);
-  fusion.addOutput(tv4);
-  tv1->computeAt(tv2, -1, ComputeAtMode::BestEffort);
-
-  TORCH_CHECK(tv1->getComputeAtPosition() == 2);
-}
-
-TEST(NVFuserTest, FusionComputeAtExprOrder1_CUDA) {
-  for (int i = 0; i < 2; ++i) {
-    Fusion fusion;
-    FusionGuard fg(&fusion);
-
-    // Set up your input tensor views
-    TensorView* tv0 = makeSymbolicTensor(1);
-    fusion.addInput(tv0);
-
-    auto tv1 = add(tv0, new Double(1));
-    auto tv2 = add(tv0, new Double(1));
-    TensorView* tv3 = add(tv1, tv2);
-    // Set outputs tv2 or tv1 and then tv3
-    if (i == 0) {
-      fusion.addOutput(tv2);
-    } else {
-      fusion.addOutput(tv1);
-    }
-    fusion.addOutput(tv3);
-
-    if (i == 0) {
-      tv1->computeAt(tv3, -1);
-    } else {
-      tv2->computeAt(tv3, -1);
-    }
-
-    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-    at::Tensor aten_input = at::randn({100}, options);
-    std::vector<at::Tensor> aten_outputs = {
-        aten_input + 1, (aten_input + 1) * 2};
-
-    FusionExecutor fe;
-    fe.compileFusion(&fusion);
-    auto cg_outputs = fe.runFusion({aten_input});
-
-    testValidate(
-        &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-  }
-}
-
-TEST(NVFuserTest, FusionComputeAtExprOrder2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv0, new Double(1));
-  TensorView* tv3 = add(tv1, tv2);
-  fusion.addOutput(tv3);
-
-  tv3->split(-1, 32);
-
-  tv1->computeAt(tv3, -1);
-  tv2->computeAt(tv3, -2);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({100, 100}, options);
-  auto aten_output = (aten_input + 1) * 2;
-
-  at::Tensor cg_output = at::empty_like(aten_input, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({aten_input}, {cg_output});
-
-  testValidate(
-      &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionComputeAtExprOrder3_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  const size_t dimx = 13;
-  const size_t dimy = 15;
-
-  TensorView* tv0 = makeConcreteTensor({dimx, dimy});
-  fusion.addInput(tv0);
-  TensorView* tv1 = add(tv0, new Double(1));
-  TensorView* tv2 = add(tv1, new Double(2));
-  TensorView* tv3 = add(tv2, new Double(3));
-  TensorView* tv4 = add(tv3, new Double(4));
-  TensorView* tv5 = mul(tv2, tv4);
-  fusion.addOutput(tv5);
-
-  tv1->computeAt(tv2, 2);
-  tv3->computeAt(tv4, 1);
-  tv4->computeAt(tv5, 2);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({dimx, dimy}, options);
-  auto t1 = aten_input.add(1.);
-  auto t2 = t1.add(2.);
-  auto t3 = t2.add(3.);
-  auto t4 = t3.add(4.);
-  auto aten_output = t2.mul(t4);
-
-  torch::jit::fuser::cuda::FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionZeroDimComputeAt_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-
-  auto tv1 = sum(tv0, {0});
-  auto tv2 = add(tv1, new Double(1));
-  fusion.addOutput(tv2);
-  TORCH_CHECK(tv2->nDims() == 0);
-  tv1->computeAt(tv2, 0);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({100}, options);
-  auto aten_output = aten_input.to(at::kDouble).sum() + 1;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionZeroDimBroadcast_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(0);
-  fusion.addInput(tv0);
-
-  auto tv1 = broadcast(tv0, {true, true});
-  TORCH_CHECK(tv1->nDims() == 2);
-
-  TensorView* tv2 = makeSymbolicTensor(2);
-  fusion.addInput(tv2);
-
-  auto tv3 = add(tv1, tv2);
-  auto tv4 = sum(tv3, {0, 1});
-  fusion.addOutput(tv4);
-
-  tv3->computeAt(tv4, -1);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({}, options);
-  at::Tensor t1 = at::randn({10, 10}, options);
-
-  auto aten_output = (t0.unsqueeze(-1).unsqueeze(-1).expand({10, 10}) + t1)
-                         .to(at::kDouble)
-                         .sum();
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-  at::Tensor cg_output = at::empty({}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion(aten_inputs, {cg_output});
-
-  testValidate(
-      &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionZeroDimReduction_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  const int bdimx = 32;
-  const int gdimx = 32;
-
-  TensorView* tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-
-  auto tv1 = sum(tv0, {0});
-  fusion.addOutput(tv1);
-
-  tv1->split(0, bdimx);
-  tv1->split(0, gdimx);
-  auto tv2 = tv1->rFactor({0});
-
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv1->axis(-2)->parallelize(ParallelType::BIDx);
-  tv2->axis(-2)->parallelize(ParallelType::BIDx);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({1000}, options);
-  auto aten_output = aten_input.to(at::kDouble).sum();
-
-  at::Tensor cg_output = at::empty({}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({aten_input}, {cg_output});
-
-  testValidate(
-      &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionBCastAfterReduce_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-  const int tidx = 128;
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  auto tv1 = sum(tv0, {1});
-  auto tv2 = broadcast(tv1, {false, true});
-
-  tv1->split(1, tidx);
-  auto tv3 = tv1->rFactor({-2});
-
-  TensorView* tv4 = makeSymbolicTensor(2);
-  fusion.addInput(tv4);
-
-  auto tv5 = add(tv2, tv4);
-  fusion.addOutput(tv5);
-  tv5->split(1, tidx);
-
-  tv3->computeAt(tv5, 1);
-
-  tv2->split(1, tidx);
-
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv5->axis(-1)->parallelize(ParallelType::TIDx);
-
-  tv5->axis(0)->parallelize(ParallelType::BIDx);
-
-  int x = 63, y = 200;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor t0 = at::randn({x, y}, options);
-  at::Tensor t4 = at::randn({x, y}, options);
-
-  auto t3 = t0.to(at::kDouble).sum({1}).unsqueeze(-1).expand({x, y});
-  auto aten_output = t3.add(t4);
-
-  std::vector<IValue> aten_inputs = {t0, t4};
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({t0, t4});
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionOutputBroadcast_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeConcreteTensor({2, 3});
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = broadcast(tv0, {true, false, true, false, true});
-
-  fusion.addOutput(tv1);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor aten_input = at::randn({2, 3}, options);
-  auto aten_output = aten_input.unsqueeze(2).unsqueeze(1).unsqueeze(0);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReductionKeepDimBasic_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeConcreteTensor({2, 3, 4, 5, 6});
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = sum(tv0, {0, 2, 4}, /*keep_dim=*/true);
-
-  fusion.addOutput(tv1);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor aten_input = at::randn({2, 3, 4, 5, 6}, options);
-  auto aten_output =
-      aten_input.to(at::kDouble).sum({0, 2, 4}, /*keepdim=*/true);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) {
-  constexpr int bid_x = 80;
-  constexpr int tid_x = 4096;
-  constexpr int red_dim = 1;
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeConcreteTensor({bid_x, tid_x});
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = reductionOp(
-      BinaryOpType::Add, {red_dim}, new Double(0), tv0, /*keep_dim=*/true);
-
-  fusion.addOutput(tv1);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor aten_input = at::randn({bid_x, tid_x}, options);
-  auto aten_output =
-      aten_input.to(at::kDouble).sum({red_dim}, /*keepdim=*/true);
-
-  // Apply reduction heuristic
-  auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-  scheduleReduction(&fusion, reduction_params.value());
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto lparams = reduction_params.value().lparams;
-
-  auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-TEST(NVFuserTest, FusionSumTo_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  std::vector<int64_t> tensor_shape{2, 3, 4, 5, 6};
-  std::vector<int64_t> sum_to_shape{1, 5, 6};
-
-  std::vector<int64_t> tensor_shape_ref{2, 3, 4, 5, 6};
-  std::vector<int64_t> sum_to_shape_ref{1, 5, 6};
-
-  std::vector<Int*> sum_to_symb;
-  std::transform(
-      sum_to_shape.begin(),
-      sum_to_shape.end(),
-      std::back_inserter(sum_to_symb),
-      [](int s) -> Int* { return new Int(s); });
-
-  TensorView* tv0 = makeConcreteTensor(tensor_shape);
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = sum_to(tv0, sum_to_symb);
-  fusion.addOutput(tv1);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor aten_input = at::randn(tensor_shape_ref, options);
-  auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  TORCH_CHECK(
-      cg_outputs[0].dim() == sum_to_shape.size(),
-      "sum_to not keeping the final dimension");
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSumToNoop_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  std::vector<int64_t> tensor_shape{4, 5, 6};
-  std::vector<int64_t> sum_to_shape{4, 5, 6};
-
-  std::vector<int64_t> tensor_shape_ref{4, 5, 6};
-  std::vector<int64_t> sum_to_shape_ref{4, 5, 6};
-
-  std::vector<Int*> sum_to_symb;
-  std::transform(
-      sum_to_shape.begin(),
-      sum_to_shape.end(),
-      std::back_inserter(sum_to_symb),
-      [](int s) -> Int* { return new Int(s); });
-
-  TensorView* tv0 = makeConcreteTensor(tensor_shape);
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = sum_to(tv0, sum_to_symb);
-
-  // Dummy operator to avoid tv0 both input and output
-  TensorView* tv2 = add(tv1, new Double(0));
-  fusion.addOutput(tv2);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor aten_input = at::randn(tensor_shape_ref, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto cg_outputs = fe.runFusion({aten_input});
-  auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref);
-
-  TORCH_CHECK(
-      cg_outputs[0].dim() == sum_to_shape.size(),
-      "sum_to not keeping the final dimension");
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReductionScheduler_CUDA) {
-  constexpr int bid_x = 80;
-  constexpr int tid_x = 4096;
-  constexpr int red_dim = 1;
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  TensorView* tv1 =
-      reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor aten_input = at::randn({bid_x, tid_x}, options);
-  auto aten_output = aten_input.to(at::kDouble).sum({red_dim});
-
-  // Apply reduction heuristic
-  auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-  scheduleReduction(&fusion, reduction_params.value());
-
-  auto lparams = reduction_params.value().lparams;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  // no broadcasting needed, omitting the last optional argument;
-  auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-// Simple reduction parallelized on a symbolic size.
-TEST(NVFuserTest, FusionSymbolicReduction_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  // tv1[I0, R1] = tv0[I0, I1]
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  // Interface should just be a direct split with a Parallel type. We can
-  // include the parallelize call if we do this.
-  tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
-  // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
-
-  TensorView* tv2 = tv1->rFactor({1});
-  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}] = tv0[I0, I1]
-  // tv1[I0,        R1oi{4},  R1i{BIDx}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}]
-
-  // Incrementally, can print in between for debugging
-  tv0->computeAt(tv2, 1);
-  tv2->computeAt(tv1, 1);
-
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-
-  int numel_x = 65000;
-  int numel_y = 1025;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({numel_x, numel_y}, options);
-  auto aten_output = aten_input.to(at::kDouble).sum({1});
-
-  // How many threads to use for the block reduction
-  int runtime_threadIdx_dim = 128;
-
-  LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) {
-  const std::vector<int> red_dims = {0, 2};
-  // Copy is because CodeGen requires int and Pytorch requires int64_t
-  // for a vector of reduction dimensions
-  const std::vector<int64_t> red_dims64 = {0, 2};
-  const std::vector<int64_t> tensor_dims_in = {5, 10, 15, 20};
-  const std::vector<int64_t> tensor_dims_out = {10, 20};
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size());
-  fusion.addInput(tv0);
-
-  TensorView* tv1 =
-      reductionOp(BinaryOpType::Add, red_dims, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn(tensor_dims_in, options);
-  auto aten_output = aten_input.to(at::kDouble).sum(red_dims64);
-  at::Tensor cg_output = at::empty(tensor_dims_out, options);
-
-  // Apply reduction heuristic
-  auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-  scheduleReduction(&fusion, reduction_params.value());
-  auto lparams = reduction_params.value().lparams;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({aten_input}, {cg_output}, lparams);
-
-  testValidate(
-      &fusion,
-      {cg_output},
-      {aten_input},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) {
-  const std::vector<int> red_dims = {1, 3};
-  // Copy is because CodeGen requires int and Pytorch requires int64_t
-  // for a vector of reduction dimensions
-  const std::vector<int64_t> red_dims64 = {1, 3};
-  const std::vector<int64_t> tensor_dims_in = {5, 10, 15, 20};
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size());
-  fusion.addInput(tv0);
-
-  TensorView* tv1 =
-      reductionOp(BinaryOpType::Add, red_dims, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn(tensor_dims_in, options);
-  auto aten_output = aten_input.to(at::kDouble).sum(red_dims64);
-
-  auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-  scheduleReduction(&fusion, reduction_params.value());
-  auto lparams = reduction_params.value().lparams;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-TEST(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) {
-  std::vector<DataType> dtypes = {
-      DataType::Double, DataType::Float, DataType::Half};
-  std::vector<int> red_dims;
-
-  // Tried to cut down the number iterations with just
-  // doing every other power of 2.
-  for (int i = 1; i <= 1024 * 1024; i <<= 2) {
-    red_dims.push_back(i);
-  }
-
-  for (auto dtype : dtypes) {
-    at::ScalarType aten_dtype = data_type_to_aten(dtype);
-    for (auto& rdim : red_dims) {
-      Fusion fusion;
-      FusionGuard fg(&fusion);
-
-      bool is_fp16 = dtype == DataType::Half;
-
-      TensorView* tv0 = makeSymbolicTensor(1, dtype);
-      fusion.addInput(tv0);
-
-      TensorView* tv0_cast = tv0;
-      if (is_fp16) {
-        tv0_cast = castOp(DataType::Float, tv0);
-      }
-
-      TensorView* tv1 = sum(tv0_cast, {0});
-
-      TensorView* tv1_cast = tv1;
-      if (is_fp16) {
-        tv1_cast = castOp(DataType::Half, tv1);
-      }
-
-      fusion.addOutput(tv1_cast);
-
-      auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
-
-      at::Tensor aten_input = at::randn({rdim}, options);
-      auto aten_output = aten_input.to(at::kDouble).sum({0});
-
-      auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
-      TORCH_CHECK(reduction_params.has_value(), "Reduction is not found!");
-      scheduleReduction(&fusion, reduction_params.value());
-      auto lparams = reduction_params.value().lparams;
-
-      FusionExecutor fe;
-      fe.compileFusion(&fusion);
-
-      auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
-      testValidate(
-          &fusion,
-          cg_outputs,
-          {aten_input},
-          {aten_output},
-          __LINE__,
-          __FILE__,
-          "",
-          lparams);
-    }
-  }
-}
-
-TEST(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) {
-  std::vector<DataType> dtypes = {
-      DataType::Double, DataType::Float, DataType::Half};
-  std::vector<int> red_axis = {1, 0};
-  std::vector<int> output_dims = {160, 320};
-  std::vector<int> red_dims;
-
-  // Tried to cut down the number iterations with just
-  // doing every other power of 2.
-  for (int i = 1; i <= 1024 * 1024; i <<= 2) {
-    red_dims.push_back(i);
-  }
-
-  for (auto dtype : dtypes) {
-    at::ScalarType aten_dtype = data_type_to_aten(dtype);
-    for (auto& axis : red_axis) {
-      for (auto& odim : output_dims) {
-        for (auto& rdim : red_dims) {
-          Fusion fusion;
-          FusionGuard fg(&fusion);
-
-          bool is_fp16 = dtype == DataType::Half;
-
-          TensorView* tv0 = makeSymbolicTensor(2, dtype);
-          fusion.addInput(tv0);
-
-          TensorView* tv0_cast = tv0;
-          if (is_fp16) {
-            tv0_cast = castOp(DataType::Float, tv0);
-          }
-
-          TensorView* tv1 = sum(tv0_cast, {axis});
-
-          TensorView* tv1_cast = tv1;
-          if (is_fp16) {
-            tv1_cast = castOp(DataType::Half, tv1);
-          }
-
-          fusion.addOutput(tv1_cast);
-
-          auto options =
-              at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
-
-          at::Tensor aten_input =
-              (axis ? at::randn({odim, rdim}, options)
-                    : at::randn({rdim, odim}, options));
-
-          auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
-          TORCH_CHECK(reduction_params.has_value(), "Reduction is not found!");
-          scheduleReduction(&fusion, reduction_params.value());
-          auto lparams = reduction_params.value().lparams;
-
-          FusionExecutor fe;
-          fe.compileFusion(&fusion);
-
-          auto cg_outputs = fe.runFusion({aten_input}, lparams);
-          auto aten_output = aten_input.to(at::kDouble).sum({axis});
-          testValidate(
-              &fusion,
-              cg_outputs,
-              {aten_input},
-              {aten_output},
-              __LINE__,
-              __FILE__,
-              "",
-              lparams);
-        }
-      }
-    }
-  }
-}
-
-TEST(NVFuserTest, FusionCacheBefore_CUDA) {
-  // TVM Cache Write
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = add(tv0, new Double(1.0));
-  TensorView* tv2 = mul(tv1, new Double(3.0));
-  fusion.addInput(tv0);
-  fusion.addOutput(tv2);
-
-  // Before: TV2 = TV1 * 3
-  // After:  TV3 = TV1 * 3;
-  //         TV2 = TV3;
-  TensorView* tv3 = tv2->cache_before();
-
-  constexpr int BSX = 32;
-  tv2->split(-1, BSX);
-  tv0->computeAt(tv2, -1);
-
-  // Thread and Block binding
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
-  constexpr int M = 32, N = 750;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({M, N}, options);
-  at::Tensor aten_output = (aten_input + 1.0) * 3.0;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionCacheAfter_CUDA) {
-  // TVM Cache Read
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = add(tv0, new Double(1.0));
-  TensorView* tv2 = mul(tv1, new Double(3.0));
-  fusion.addInput(tv0);
-  fusion.addOutput(tv2);
-
-  // Before: TV1 = TV0 + 1
-  // After:  TV3 = TV0;
-  //         TV1 = TV3 + 1
-  TensorView* tv3 = tv0->cache_after();
-
-  constexpr int BSX = 32;
-  tv2->split(-1, BSX);
-  tv0->computeAt(tv2, -1);
-
-  // Thread and Block binding
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
-  constexpr int M = 32, N = 457;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({M, N}, options);
-  at::Tensor aten_output = (aten_input + 1.0) * 3.0;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionCacheFork_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = add(tv0, new Double(1.0));
-  TensorView* tv2 = mul(tv1, new Double(3.0));
-  fusion.addInput(tv0);
-  fusion.addOutput(tv1);
-  fusion.addOutput(tv2);
-  // Before:  TV1 = TV0 + 1
-  //          TV2 = TV1 * 1
-  // Output:  TV1, TV2
-
-  // After:   TV1 = TV0 + 1
-  //          TV3 = TV1
-  //          TV2 = TV1 * 1
-  // Output:  TV3, TV2
-
-  // cache_fork !!does not!! automatically apply ComputeAt to the cache
-  auto tv3 = tv1->cache_fork();
-
-  constexpr int BSX = 32;
-  tv2->split(-1, BSX);
-  tv0->computeAt(tv2, -1);
-
-  // Thread and Block binding
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
-  constexpr int M = 32, N = 457;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({M, N}, options);
-  at::Tensor aten_output1 = aten_input + 1.0;
-  at::Tensor aten_output2 = aten_output1 * 3.0;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input},
-      {aten_output1, aten_output2},
-      __LINE__,
-      __FILE__);
-}
-
-TEST(NVFuserTest, FusionCacheIndirect_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
-  TensorView* tv2 = makeSymbolicTensor(2);
-  TensorView* tv3 = makeSymbolicTensor(2);
-  TensorView* tv4 = sub(tv2, tv3);
-  TensorView* tv5 = add(tv1, tv4);
-  TensorView* tv6 = sub(tv5, tv0);
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addInput(tv2);
-  fusion.addInput(tv3);
-  fusion.addOutput(tv6);
-  // t6 = ((t1 + (t2 - t3)) - t0)
-
-  tv5->cache_after();
-  tv5->cache_before();
-
-  // cache_after on inputs placed before schedule
-  constexpr int BSX = 32;
-  tv6->split(-1, BSX);
-  tv2->computeAt(tv6, -1);
-
-  // Thread and Block binding
-  tv6->axis(0)->parallelize(ParallelType::BIDx);
-  tv6->axis(-1)->parallelize(ParallelType::TIDx);
-
-  constexpr int M = 32, N = 810;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({M, N}, options);
-  at::Tensor t1 = at::randn({M, N}, options);
-  at::Tensor t2 = at::randn({M, N}, options);
-  at::Tensor t3 = at::randn({M, N}, options);
-
-  std::vector<IValue> aten_inputs = {t0, t1, t2, t3};
-  at::Tensor aten_output = (t1 + (t2 - t3)) - t0;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionCacheBcast_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Algorithm
-  TensorView* tv0 = makeSymbolicTensor(1); // (M, 1)
-  TensorView* tv1 = broadcast(tv0, {false, true});
-  TensorView* tv2 = makeSymbolicTensor(1); // (1, N)
-  TensorView* tv3 = broadcast(tv2, {true, false});
-  TensorView* tv4 = mul(tv1, tv3);
-  fusion.addInput(tv0);
-  fusion.addInput(tv2);
-  fusion.addOutput(tv4);
-
-  // Case 1
-  tv0->cache_after();
-
-  // Case 2
-  tv1->cache_before();
-
-  // Case 3
-  tv1->cache_after();
-
-  // Case 4
-  TensorView* tv8 = tv4->cache_before();
-
-  constexpr int BSX = 128;
-  tv4->split(0, BSX);
-  tv4->split(-1, BSX);
-  tv4->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
-  // M/BSX, N/BSY, BSX, BSY
-  tv0->computeAt(tv4, 2);
-  tv2->computeAt(tv4, 2);
-  // 0, 1 | 2, 3, 4
-
-  tv4->axis(0)->parallelize(ParallelType::BIDx);
-  tv4->axis(1)->parallelize(ParallelType::BIDy);
-  tv4->axis(-1)->parallelize(ParallelType::TIDx);
-  // Manual Replay on TV3
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv8->axis(-1)->parallelize(ParallelType::TIDx);
-
-  constexpr int M = 92, N = 500;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({M}, options);
-  at::Tensor t1 = at::randn({N}, options);
-  std::vector<IValue> aten_inputs = {t0, t1};
-  at::Tensor aten_output =
-      t0.to(at::kDouble).unsqueeze(1).matmul(t1.to(at::kDouble).unsqueeze(0));
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionCacheMultiConsumer_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(1);
-  TensorView* tv1 = add(tv0, new Double(1));
-  TensorView* tv2 = add(tv1, new Double(2));
-  TensorView* tv3 = add(tv0, new Double(1));
-  TensorView* tv4 = add(tv3, new Double(2));
-
-  fusion.addInput(tv0);
-  fusion.addOutput(tv2);
-  fusion.addOutput(tv4);
-
-  auto tv5 = tv1->cache_before();
-  auto tv6 = tv3->cache_before();
-  tv5->setMemoryType(MemoryType::Shared);
-  tv6->setMemoryType(MemoryType::Shared);
-
-  tv1->computeAt(tv2, -1);
-  tv3->computeAt(tv4, -1);
-
-  // Fails because tensor must be recomputed twice
-  // auto tv7 = tv0->cache_after();
-
-  constexpr int N = 800;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({N}, options);
-  auto aten_output = (aten_input + 1) + 2;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input},
-      {aten_output, aten_output},
-      __LINE__,
-      __FILE__);
-}
-
-TEST(NVFuserTest, FusionSmem_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Algorithm
-  TensorView* tv0 = makeSymbolicTensor(2); // (M, N)
-  TensorView* tv1 = makeSymbolicTensor(2); // (M, N)
-  TensorView* tv2 = mul(tv0, tv1);
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addOutput(tv2);
-
-  // Schedule
-  TensorView* tv3 = tv0->cache_after();
-  TensorView* tv4 = tv1->cache_after();
-  tv3->setMemoryType(MemoryType::Shared);
-  tv4->setMemoryType(MemoryType::Shared);
-
-  constexpr int BSY = 32;
-  constexpr int BSX = 128;
-  tv2->split(0, BSY);
-  tv2->split(2, BSX);
-  // M/BSX, BSX, N/BSX, BSX
-  tv2->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
-  // M/BSX, N/BSX, BSX, BSX
-
-  tv0->computeAt(tv2, 2);
-  tv1->computeAt(tv2, 2);
-
-  // Thread and Block binding
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(1)->parallelize(ParallelType::BIDy);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  // Manual Binding
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv4->axis(-1)->parallelize(ParallelType::TIDx);
-
-  constexpr int M = 128, N = 10240;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({M, N}, options);
-  at::Tensor t1 = at::randn({M, N}, options);
-  at::Tensor aten_output = mul(t0, t1);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({t0, t1});
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-
-  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
-}
-
-TEST(NVFuserTest, FusionSmemReduce_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Algorithm
-  TensorView* tv0 = makeSymbolicTensor(3); // M, K, N
-  TensorView* tv1 = sum(tv0, {1}); // M, R, N
-  fusion.addInput(tv0);
-  fusion.addOutput(tv1);
-
-  TensorView* tv2 = tv0->cache_after();
-  tv2->setMemoryType(MemoryType::Shared);
-
-  // Schedule
-  constexpr int BSX = 32;
-  tv1->split(2, BSX);
-  tv1->split(1, 128);
-  tv1->split(0, BSX);
-  // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
-  tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}});
-  TensorView* tv3 = tv1->rFactor({-2});
-
-  tv0->computeAt(tv1, -2);
-  tv0->computeAt(tv3, -2);
-
-  // Thread and Block binding
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv1->axis(1)->parallelize(ParallelType::BIDy);
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  // Manual Binding
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
-  constexpr int M = 154, K = 45, N = 1524;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({M, K, N}, options);
-  at::Tensor aten_output = sum(aten_input.to(at::kDouble), {1});
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
-}
-
-TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Algorithm
-  TensorView* tv0 = makeSymbolicTensor(2); // (M, K)
-  TensorView* tv1 = makeSymbolicTensor(2); // (K, N)
-  TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
-  TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
-  TensorView* tv4 = mul(tv2, tv3); // M, K, N
-  TensorView* tv5 = sum(tv4, {1}); // M, R, N
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addOutput(tv5);
-
-  // Schedule
-  constexpr int BSX = 16;
-  tv5->split(2, BSX);
-  tv5->split(1, BSX);
-  tv5->split(0, BSX);
-  // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
-  tv5->reorder({{0, 0}, {1, 3}, {2, 2}, {3, 5}, {4, 1}, {5, 4}});
-  // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX
-  TensorView* tv6 = tv5->rFactor({-1});
-
-  tv2->setMemoryType(MemoryType::Shared);
-  tv3->setMemoryType(MemoryType::Shared);
-  tv4->setMemoryType(MemoryType::Shared);
-  tv6->setMemoryType(MemoryType::Shared);
-
-  tv0->computeAt(tv5, 3);
-  tv1->computeAt(tv5, 3);
-
-  // Thread and Block binding
-  tv5->axis(0)->parallelize(ParallelType::BIDx);
-  tv5->axis(1)->parallelize(ParallelType::BIDy);
-  tv5->axis(-2)->parallelize(ParallelType::TIDy);
-  tv5->axis(-1)->parallelize(ParallelType::TIDx);
-  // Manual Binding
-  tv2->axis(-3)->parallelize(ParallelType::TIDy);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv4->axis(-3)->parallelize(ParallelType::TIDy);
-  tv4->axis(-1)->parallelize(ParallelType::TIDx);
-  tv6->axis(-3)->parallelize(ParallelType::TIDy);
-  tv6->axis(-2)->parallelize(ParallelType::TIDx);
-
-  constexpr int M = 154, K = 45, N = 1524;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({M, K}, options);
-  at::Tensor t1 = at::randn({K, N}, options);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-  at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble));
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({t0, t1});
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-
-  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
-}
-
-TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Algorithm
-  TensorView* tv0 = makeSymbolicTensor(2); // (M, K)
-  TensorView* tv1 = makeSymbolicTensor(2); // (K, N)
-  TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
-  TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
-  TensorView* tv4 = mul(tv2, tv3); // M, K, N
-  TensorView* tv5 = sum(tv4, {1}); // M, R, N
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addOutput(tv5);
-
-  // Schedule
-  // Remove reduction axis from tv5
-  // tv6 = (M, R, N)
-  // tv5 = (M, N)
-  TensorView* tv6 = tv5->cache_before();
-
-  constexpr int BSX = 16;
-  tv5->split(1, BSX);
-  tv5->split(0, BSX);
-  // M/BSX, BSX, N/BSX, BSX
-  tv5->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
-  // tv5 = M/BSX, N/BSX, MSX, NSX
-
-  tv6->computeAt(tv5, 2);
-  tv6->computeAt(tv5, 2);
-
-  tv6->split(-1, BSX);
-  // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
-  tv6->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 4}, {4, 2}, {5, 5}});
-  // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX
-  TensorView* tv7 = tv6->rFactor({-1});
-  // tv7 = M/BSX, N/BSX, K/BSXrf, MSX, NSX, KSXr
-  // tv6 = M/BSX, N/BSX, K/BSXr, MSX, NSX
-
-  tv0->computeAt(tv6, 3);
-  tv1->computeAt(tv6, 3);
-
-  tv0->computeAt(tv7, 3);
-  tv1->computeAt(tv7, 3);
-
-  tv2->setMemoryType(MemoryType::Shared);
-  tv3->setMemoryType(MemoryType::Shared);
-  tv4->setMemoryType(MemoryType::Shared);
-  tv6->setMemoryType(MemoryType::Shared);
-  tv7->setMemoryType(MemoryType::Shared);
-  // Memory Type
-
-  // Thread and Block binding
-  tv5->axis(0)->parallelize(ParallelType::BIDx);
-  tv5->axis(1)->parallelize(ParallelType::BIDy);
-  tv5->axis(-2)->parallelize(ParallelType::TIDy);
-  tv5->axis(-1)->parallelize(ParallelType::TIDx);
-  // Manual Binding
-  tv2->axis(-3)->parallelize(ParallelType::TIDy);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv4->axis(-3)->parallelize(ParallelType::TIDy);
-  tv4->axis(-1)->parallelize(ParallelType::TIDx);
-
-  tv7->axis(-3)->parallelize(ParallelType::TIDy);
-  tv7->axis(-2)->parallelize(ParallelType::TIDx);
-
-  tv6->axis(-2)->parallelize(ParallelType::TIDy);
-  tv6->axis(-1)->parallelize(ParallelType::TIDx);
-
-  constexpr int M = 154, K = 45, N = 1524;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({M, K}, options);
-  at::Tensor t1 = at::randn({K, N}, options);
-  at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble));
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-
-  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
-}
-
-TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* x = makeSymbolicTensor(2);
-  fusion.addInput(x);
-  TensorView* max_val =
-      reductionOp(BinaryOpType::Max, {-1}, new Double(FLT_MIN), x); // (M)
-  TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B)
-  TensorView* x_max_sub = sub(x, bcast_max); // (M, N)
-  TensorView* exp = unaryOp(UnaryOpType::Exp, x_max_sub); // (M, N)
-  TensorView* sum_exp = sum(exp, {-1}); // (M, R)
-  TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B)
-  TensorView* softmax = div(exp, bcast_sum); // (M, N)
-  fusion.addOutput(softmax);
-
-  // Read Input into Shared Memory
-  // Load Input + Pwise into shared memory
-  auto cache_x = x->cache_after();
-  cache_x->setMemoryType(MemoryType::Shared);
-  exp->setMemoryType(MemoryType::Shared);
-
-  std::vector<TensorView*> all_tensors(
-      {x,
-       cache_x,
-       max_val,
-       bcast_max,
-       x_max_sub,
-       exp,
-       sum_exp,
-       bcast_sum,
-       softmax});
-
-  auto tidx = new Int();
-  fusion.addInput(tidx);
-
-  for (auto tensor : all_tensors) {
-    tensor->split(-1, tidx);
-  }
-
-  auto sum_exp_rf = sum_exp->rFactor({1});
-  all_tensors.push_back(sum_exp_rf);
-
-  // computeAt
-  x->computeAt(x_max_sub, 1);
-  exp->computeAt(softmax, 1);
-  x_max_sub->computeAt(exp, 2);
-
-  softmax->axis(0)->parallelize(ParallelType::BIDx);
-  for (auto tensor : all_tensors) {
-    tensor->axis(-1)->parallelize(ParallelType::TIDx);
-  }
-
-  const size_t dimx = 1024;
-  const size_t dimy = 4096;
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({dimx, dimy}, options);
-  auto aten_output = at::_softmax(aten_input.to(at::kDouble), -1, false);
-
-  torch::jit::fuser::cuda::FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input, 128});
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input, 128},
-      {aten_output},
-      __LINE__,
-      __FILE__);
-}
-
-TEST(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  const int kReductionAxis = 3;
-  std::vector<int64_t> input_shape{10, 10, 10, 67};
-  TensorView* input = makeSymbolicTensor(input_shape.size());
-  fusion.addInput(input);
-
-  auto output = softmax(input, kReductionAxis);
-
-  fusion.addOutput(output);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn(input_shape, options);
-  auto aten_output =
-      at::_softmax(aten_input.to(at::kDouble), kReductionAxis, false);
-
-  auto reduction_params = getNormalizationHeuristics(&fusion, {aten_input});
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-
-  scheduleNormalization(&fusion, reduction_params.value());
-
-  auto lparams = reduction_params.value().lparams;
-
-  torch::jit::fuser::cuda::FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-TEST(NVFuserTest, FusionMagicSchedulerLayerNormBackward_CUDA) {
-  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
-  Fusion& fusion = *fusion_ptr.get();
-  FusionGuard fg(&fusion);
-
-  std::vector<int64_t> shape{20, 100, 35, 67};
-  std::vector<int64_t> norm_shape{67};
-
-  const size_t kM = shape.size();
-  const size_t kN = norm_shape.size();
-  const size_t kOuterNumDims = kM - kN;
-
-  std::vector<int64_t> outer_shape;
-  for (size_t idx = 0; idx < kOuterNumDims; ++idx) {
-    outer_shape.push_back(shape[idx]);
-  }
-  for (size_t idx = kOuterNumDims; idx < kM; ++idx) {
-    outer_shape.push_back(1);
-  }
-
-  auto grad_out = makeSymbolicTensor(shape.size());
-  auto input = makeSymbolicTensor(shape.size());
-  auto mean = makeConcreteTensor(outer_shape);
-  auto rstd = makeConcreteTensor(outer_shape);
-  auto weight = makeSymbolicTensor(norm_shape.size());
-  auto bias = makeSymbolicTensor(norm_shape.size());
-  fusion.addInput(grad_out);
-  fusion.addInput(input);
-  fusion.addInput(mean);
-  fusion.addInput(rstd);
-  fusion.addInput(weight);
-  fusion.addInput(bias);
-
-  auto grads = layer_norm_backward(
-      grad_out,
-      input,
-      norm_shape,
-      mean,
-      rstd,
-      weight,
-      bias,
-      {true, true, true});
-
-  fusion.addOutput(grads.grad_input);
-  fusion.addOutput(grads.grad_weight);
-  fusion.addOutput(grads.grad_bias);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_grad_out = at::randn(shape, options);
-  at::Tensor aten_input = at::randn(shape, options);
-  at::Tensor aten_weight = at::randn(norm_shape, options);
-  at::Tensor aten_bias = at::randn(norm_shape, options);
-  auto at_weight = c10::optional<at::Tensor>(aten_weight);
-  auto at_bias = c10::optional<at::Tensor>(aten_bias);
-
-  const float kEps = 1e-5;
-  auto aten_results =
-      at::native_layer_norm(aten_input, norm_shape, at_weight, at_bias, kEps);
-  auto aten_output = std::get<0>(aten_results);
-  auto aten_mean = std::get<1>(aten_results);
-  auto aten_rstd = std::get<2>(aten_results);
-
-  FusionExecutorCache fec(std::move(fusion_ptr));
-  std::vector<IValue> aten_inputs = {
-      aten_grad_out, aten_input, aten_mean, aten_rstd, aten_weight, aten_bias};
-  auto cg_outputs = fec.runFusionWithInputs(aten_inputs);
-
-  auto aten_gradients = at::native_layer_norm_backward(
-      aten_grad_out.to(at::kDouble),
-      aten_input.to(at::kDouble),
-      norm_shape,
-      aten_mean.to(at::kDouble),
-      aten_rstd.to(at::kDouble),
-      c10::optional<at::Tensor>(aten_weight.to(at::kDouble)),
-      c10::optional<at::Tensor>(aten_bias.to(at::kDouble)),
-      {true, true, true});
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      aten_inputs,
-      {std::get<0>(aten_gradients),
-       std::get<1>(aten_gradients),
-       std::get<2>(aten_gradients)},
-      __LINE__,
-      __FILE__);
-}
-
-TEST(NVFuserTest, FusionMagicSchedulerLayerNormalization_CUDA) {
-  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
-  Fusion& fusion = *fusion_ptr.get();
-  FusionGuard fg(&fusion);
-
-  const float kEps = 1e-5;
-  Double* eps_ptr = new Double(kEps);
-
-  std::vector<int64_t> input_shape{20, 100, 35, 67};
-  std::vector<int64_t> norm_shape{67};
-
-  auto input = makeSymbolicTensor(input_shape.size());
-  fusion.addInput(input);
-
-  auto result = layer_norm(input, norm_shape, nullptr, nullptr, eps_ptr);
-
-  fusion.addOutput(result.output);
-  fusion.addOutput(result.mean);
-  fusion.addOutput(result.invstd);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn(input_shape, options);
-  c10::optional<at::Tensor> aten_weight = c10::nullopt;
-  c10::optional<at::Tensor> aten_bias = c10::nullopt;
-  auto aten_outputs = at::native_layer_norm(
-      aten_input, norm_shape, aten_weight, aten_bias, kEps);
-
-  // Check reduction axis is same for all reductions
-  // Generate Launch Parameters
-  auto reduction_params = getNormalizationHeuristics(&fusion, {aten_input});
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-
-  scheduleNormalization(&fusion, reduction_params.value());
-  auto lparams = reduction_params.value().lparams;
-
-  torch::jit::fuser::cuda::FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input},
-      {std::get<0>(aten_outputs),
-       std::get<1>(aten_outputs),
-       std::get<2>(aten_outputs)},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-TEST(NVFuserTest, FusionMagicSchedulerBatchNormalization_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
-
-  const float kMomentum = 0.1;
-  const float kEps = 1e-5;
-  const bool kTraining = true;
-  std::vector<int64_t> input_shape{20, 100, 35, 45};
-
-  auto input = makeSymbolicTensor(input_shape.size());
-  auto weight = makeSymbolicTensor(1);
-  auto bias = makeSymbolicTensor(1);
-  auto running_mean = makeSymbolicTensor(1);
-  auto running_var = makeSymbolicTensor(1);
-  fusion->addInput(input);
-  fusion->addInput(weight);
-  fusion->addInput(bias);
-  fusion->addInput(running_mean);
-  fusion->addInput(running_var);
-
-  Double* momentum = new Double(kMomentum);
-  Double* eps = new Double(kEps);
-
-  auto result = batch_norm(
-      input, weight, bias, running_mean, running_var, kTraining, momentum, eps);
-
-  fusion->addOutput(result.output);
-  fusion->addOutput(result.mean);
-  fusion->addOutput(result.invstd);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto at_input = at::randn(input_shape, options);
-  auto at_weight = at::ones({input_shape[1]}, options);
-  auto at_bias = at::zeros({input_shape[1]}, options);
-  auto at_run_mean = at::zeros({input_shape[1]}, options);
-  auto at_run_var = at::ones({input_shape[1]}, options);
-
-  std::vector<IValue> aten_inputs = {
-      at_input, at_weight, at_bias, at_run_mean, at_run_var};
-
-  FusionExecutorCache executor_cache(std::move(fusion));
-
-  auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);
-
-  auto aten_outputs = at::native_batch_norm(
-      at_input,
-      c10::optional<at::Tensor>(at_weight),
-      c10::optional<at::Tensor>(at_bias),
-      c10::optional<at::Tensor>(at_run_mean),
-      c10::optional<at::Tensor>(at_run_var),
-      kTraining,
-      kMomentum,
-      kEps);
-
-  testValidate(
-      executor_cache.fusion(),
-      cg_outputs,
-      aten_inputs,
-      {at_run_mean,
-       at_run_var,
-       std::get<0>(aten_outputs),
-       std::get<1>(aten_outputs),
-       std::get<2>(aten_outputs)},
-      __LINE__,
-      __FILE__,
-      "");
-}
-
-// Disabling for now because memory reuse pass needs to be fixed.
-#if 0
-TEST(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  const int pixels_per_thread = 64;
-  const int TIDX = 128;
-  const int static_size = pixels_per_thread * TIDX;
-
-  TensorView* sx = makeConcreteTensor({-1, static_size});
-  TensorView* dx = makeSymbolicTensor(2);
-  fusion.addInput(sx);
-  fusion.addInput(dx);
-
-  TensorView* max_sx =
-      reductionOp(BinaryOpType::Max, {-1}, new Double(FLT_MIN), sx); // (M)
-  TensorView* max_dx =
-      reductionOp(BinaryOpType::Max, {-1}, new Double(FLT_MIN), dx); // (M)
-
-  // Reduction => merge local and shared memory TensorViews
-  TensorView* max_val = binaryOp(BinaryOpType::Max, max_sx, max_dx);
-  TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B)
-
-  TensorView* sx_max_sub = sub(sx, bcast_max); // (M, N)
-  TensorView* dx_max_sub = sub(dx, bcast_max); // (M, N)
-
-  TensorView* sx_exp = unaryOp(UnaryOpType::Exp, sx_max_sub); // (M, N)
-  TensorView* dx_exp = unaryOp(UnaryOpType::Exp, dx_max_sub); // (M, N)
-
-  TensorView* sx_sum_exp = sum(sx_exp, {-1}); // (M, R)
-  TensorView* dx_sum_exp = sum(dx_exp, {-1}); // (M, R)
-
-  // Reduction => merge local and shared memory TensorViews
-  TensorView* sum_exp = binaryOp(BinaryOpType::Add, sx_sum_exp, dx_sum_exp);
-  TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B)
-
-  TensorView* sx_softmax = div(sx_exp, bcast_sum); // (M, N)
-  TensorView* dx_softmax = div(dx_exp, bcast_sum); // (M, N)
-  fusion.addOutput(sx_softmax);
-  fusion.addOutput(dx_softmax);
-
-  auto sx_cache = sx->cache_after();
-  auto dx_cache = dx->cache_after();
-  dx_cache->setMemoryType(MemoryType::Shared);
-  dx_exp->setMemoryType(MemoryType::Shared);
-
-  // Reduction and Broadcast Tensors common to both memory TVs
-  std::vector<TensorView*> common_tensors(
-      {max_val, sum_exp, bcast_max, bcast_sum});
-
-  // Static Local Memory TVs
-  std::vector<TensorView*> static_tensors(
-      {sx, sx_cache, max_sx, sx_max_sub, sx_exp, sx_sum_exp, sx_softmax});
-
-  // Dynamic Local Memory TVs
-  std::vector<TensorView*> dynamic_tensors(
-      {dx, dx_cache, max_dx, dx_max_sub, dx_exp, dx_sum_exp, dx_softmax});
-
-  std::vector<TensorView*> all_tensors;
-  all_tensors.insert(
-      all_tensors.end(), common_tensors.begin(), common_tensors.end());
-  all_tensors.insert(
-      all_tensors.end(), static_tensors.begin(), static_tensors.end());
-  all_tensors.insert(
-      all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end());
-
-  // M => M
-  // M, N => M, N/128, 128
-  for (auto tensor : all_tensors) {
-    if (tensor->nDims() > 1) {
-      tensor->split(-1, TIDX);
-    }
-  }
-
-  auto sx_sum_exp_rf = sx_sum_exp->rFactor({1});
-  auto dx_sum_exp_rf = dx_sum_exp->rFactor({1});
-  all_tensors.push_back(sx_sum_exp_rf);
-  all_tensors.push_back(dx_sum_exp_rf);
-
-  // computeAt
-  sx->computeAt(sx_max_sub, 1);
-  dx->computeAt(dx_max_sub, 1);
-
-  sx_exp->computeAt(sx_softmax, 1);
-  dx_exp->computeAt(dx_softmax, 1);
-
-  sx_max_sub->computeAt(sx_exp, 2);
-  dx_max_sub->computeAt(dx_exp, 2);
-
-  sx_softmax->axis(0)->parallelize(ParallelType::BIDx);
-  dx_softmax->axis(0)->parallelize(ParallelType::BIDx);
-  for (auto tensor : all_tensors) {
-    if (tensor->nDims() > 1) {
-      tensor->axis(-1)->parallelize(ParallelType::TIDx);
-    }
-  }
-
-  const size_t dimx = 1024;
-  const size_t dimy = 16384;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({dimx, dimy}, options);
-  at::Tensor aten_static_in = aten_input.narrow(1, 0, static_size);
-  at::Tensor aten_dynamic_in =
-      aten_input.narrow(1, static_size, dimy - static_size);
-
-  at::Tensor out = at::zeros({dimx, dimy}, options);
-  at::Tensor cg_static_out = out.narrow(1, 0, static_size);
-  at::Tensor cg_dynamic_out = out.narrow(1, static_size, dimy - static_size);
-
-  std::vector<at::Tensor> aten_outputs;
-
-  auto aten_output = at::_softmax(aten_input.to(at::kDouble), -1, false);
-  at::Tensor aten_static_out = aten_output.narrow(1, 0, static_size);
-  at::Tensor aten_dynamic_out =
-      aten_output.narrow(1, static_size, dimy - static_size);
-
-  torch::jit::fuser::cuda::FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion(
-      {aten_static_in, aten_dynamic_in}, {cg_static_out, cg_dynamic_out});
-
-  testValidate(
-      &fusion,
-      {cg_static_out, cg_dynamic_out},
-      {aten_static_in, aten_dynamic_in},
-      {cg_static_out, cg_dynamic_out},
-      __LINE__,
-      __FILE__);
-}
-#endif
-
-// DISABLED. TODO: https://github.com/csarofeen/pytorch/issues/743
-TEST(NVFuserTest, FusionPersistentNormLocalShared_CUDA) {
-  return;
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  const int pixels_per_thread = 64;
-  const int TIDX = 128;
-  const int static_size = pixels_per_thread * TIDX;
-
-  TensorView* sx = makeConcreteTensor({-1, static_size});
-  TensorView* dx = makeSymbolicTensor(2);
-  fusion.addInput(sx);
-  fusion.addInput(dx);
-
-  Double* gamma = new Double();
-  Double* beta = new Double();
-  Double* eps = new Double();
-  Int* N = new Int();
-  fusion.addInput(gamma);
-  fusion.addInput(beta);
-  fusion.addInput(eps);
-  fusion.addInput(N);
-
-  // Reduction
-  auto sx_sum = sum(sx, {-1}); // (M, R)
-  auto dx_sum = sum(dx, {-1}); // (M, R)
-  // Reduction => merge local and shared memory TensorViews
-  auto x_sum = binaryOp(BinaryOpType::Add, sx_sum, dx_sum);
-
-  // Broadcast
-  auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B)
-  // Pwise
-  auto x_mean = div(x_sum_bcast, N); // (M, B)
-
-  auto sx_mean_sub = sub(sx, x_mean); // (M, N)
-  auto dx_mean_sub = sub(dx, x_mean); // (M, N)
-
-  auto sx_mean_sub_pow = mul(sx_mean_sub, sx_mean_sub); // (M, N)
-  auto dx_mean_sub_pow = mul(dx_mean_sub, dx_mean_sub); // (M, N)
-
-  // Reduction
-  auto sx_var_sum = sum(sx_mean_sub_pow, {-1}); // (M, R)
-  auto dx_var_sum = sum(dx_mean_sub_pow, {-1}); // (M, R)
-  // Reduction => merge local and shared memory TensorViews
-  auto var_sum = binaryOp(BinaryOpType::Add, sx_var_sum, dx_var_sum);
-
-  // Broadcast
-  auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B)
-  // Pwise
-  auto var = div(var_sum_bcast, N); // (M, B)
-  auto var_eps = add(var, eps); // (M, B)
-  auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B)
-
-  auto sx_norm = mul(sx_mean_sub, rvar);
-  auto dx_norm = mul(dx_mean_sub, rvar);
-
-  auto sx_norm_gamma = mul(sx_norm, gamma);
-  auto dx_norm_gamma = mul(dx_norm, gamma);
-
-  auto sx_norm_gamma_beta = add(sx_norm_gamma, beta);
-  auto dx_norm_gamma_beta = add(dx_norm_gamma, beta);
-
-  fusion.addOutput(sx_norm_gamma_beta);
-  fusion.addOutput(dx_norm_gamma_beta);
-
-  // Read Input into Shared Memory
-  // Read Input minus Input_Mean into Shared Memory
-  auto sx_cache = sx->cache_after();
-  auto dx_cache = dx->cache_after();
-  dx_cache->setMemoryType(MemoryType::Shared);
-  dx_mean_sub->setMemoryType(MemoryType::Shared);
-
-  std::vector<TensorView*> common_tensors(
-      {x_sum, x_sum_bcast, x_mean, var_sum, var_sum_bcast, var, var_eps, rvar});
-
-  std::vector<TensorView*> static_tensors(
-      {sx,
-       sx_cache,
-       sx_sum,
-       sx_mean_sub,
-       sx_mean_sub_pow,
-       sx_var_sum,
-       sx_norm,
-       sx_norm_gamma,
-       sx_norm_gamma_beta});
-
-  std::vector<TensorView*> dynamic_tensors(
-      {dx,
-       dx_cache,
-       dx_sum,
-       dx_mean_sub,
-       dx_mean_sub_pow,
-       dx_var_sum,
-       dx_norm,
-       dx_norm_gamma,
-       dx_norm_gamma_beta});
-
-  std::vector<TensorView*> all_tensors;
-  all_tensors.insert(
-      all_tensors.end(), common_tensors.begin(), common_tensors.end());
-  all_tensors.insert(
-      all_tensors.end(), static_tensors.begin(), static_tensors.end());
-  all_tensors.insert(
-      all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end());
-
-  // M => M
-  // M, N => M, N/128, 128
-  for (auto tensor : all_tensors) {
-    if (tensor->nDims() > 1) {
-      tensor->split(-1, TIDX);
-    }
-  }
-
-  // Local Sum => Block Broadcast
-  TensorView* sx_sum_rf = sx_sum->rFactor({1});
-  TensorView* sx_var_sum_rf = sx_var_sum->rFactor({1});
-  TensorView* dx_sum_rf = dx_sum->rFactor({1});
-  TensorView* dx_var_sum_rf = dx_var_sum->rFactor({1});
-  all_tensors.push_back(sx_sum_rf);
-  all_tensors.push_back(sx_var_sum_rf);
-  all_tensors.push_back(dx_sum_rf);
-  all_tensors.push_back(dx_var_sum_rf);
-
-  // ComputeAt
-  sx->computeAt(sx_mean_sub_pow, 1);
-  dx->computeAt(dx_mean_sub_pow, 1);
-
-  var_sum->computeAt(rvar, 1);
-
-  sx_mean_sub_pow->computeAt(sx_var_sum_rf, 2);
-  dx_mean_sub_pow->computeAt(dx_var_sum_rf, 2);
-
-  sx_norm->computeAt(sx_norm_gamma_beta, 2);
-  dx_norm->computeAt(dx_norm_gamma_beta, 2);
-
-  sx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx);
-  dx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx);
-  for (auto tensor : all_tensors) {
-    if (tensor->nDims() > 1) {
-      tensor->axis(-1)->parallelize(ParallelType::TIDx);
-    }
-  }
-
-  const int dimx = 1024;
-  const int dimy = 16384;
-  const float kGamma = 1.0f;
-  const float kBeta = 0.0f;
-  const float kEps = 1e-5;
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor aten_input = at::randn({dimx, dimy}, options);
-  at::Tensor aten_static_in = aten_input.narrow(1, 0, static_size);
-  at::Tensor aten_dynamic_in =
-      aten_input.narrow(1, static_size, dimy - static_size);
-
-  at::Tensor out = at::zeros({dimx, dimy}, options);
-  at::Tensor cg_static_out = out.narrow(1, 0, static_size);
-  at::Tensor cg_dynamic_out = out.narrow(1, static_size, dimy - static_size);
-
-  std::vector<IValue> aten_inputs = {
-      aten_static_in, aten_dynamic_in, kGamma, kBeta, kEps, dimy};
-
-  torch::jit::fuser::cuda::FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion(aten_inputs, {cg_static_out, cg_dynamic_out});
-
-  auto at_mu = at::mean(aten_input.to(at::kDouble), -1).unsqueeze(1);
-  auto at_var = at::var(aten_input.to(at::kDouble), -1, false).unsqueeze(1);
-  auto at_rvar = at::rsqrt(at::add(at_var, kEps));
-  auto at_norm = at::mul(at::sub(aten_input, at_mu), at_rvar);
-  auto aten_output = at::add(at::mul(at_norm, kGamma), kBeta);
-  at::Tensor aten_static_out = aten_output.narrow(1, 0, static_size);
-  at::Tensor aten_dynamic_out =
-      aten_output.narrow(1, static_size, dimy - static_size);
-
-  testValidate(
-      &fusion,
-      {cg_static_out, cg_dynamic_out},
-      aten_inputs,
-      {aten_static_out, aten_dynamic_out},
-      __LINE__,
-      __FILE__);
-}
-
-TEST(NVFuserTest, FusionSmemDynamicPersistentNorm_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  auto x = makeSymbolicTensor(2);
-  Double* gamma = new Double();
-  Double* beta = new Double();
-  Double* eps = new Double();
-  Int* N = new Int();
-  fusion.addInput(x);
-  fusion.addInput(gamma);
-  fusion.addInput(beta);
-  fusion.addInput(eps);
-  fusion.addInput(N);
-
-  // Reduction
-  auto x_sum = sum(x, {-1}); // (M, R)
-  // Broadcast
-  auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B)
-  // Pwise
-  auto x_mean = div(x_sum_bcast, N); // (M, B)
-  auto x_mean_sub = sub(x, x_mean); // (M, N)
-  auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub); // (M, N)
-  // Reduction
-  auto var_sum = sum(x_mean_sub_pow, {-1}); // (M, R)
-  // Broadcast
-  auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B)
-  // Pwise
-  auto var = div(var_sum_bcast, N); // (M, B)
-  auto var_eps = add(var, eps); // (M, B)
-  auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B)
-  auto norm = mul(x_mean_sub, rvar);
-  auto norm_gamma = mul(norm, gamma);
-  auto norm_gamma_beta = add(norm_gamma, beta);
-  fusion.addOutput(norm_gamma_beta);
-
-  // Read Input into Shared Memory
-  // Read Input minus Input_Mean into Shared Memory
-  auto cache_x = x->cache_after();
-  cache_x->setMemoryType(MemoryType::Shared);
-  x_mean_sub->setMemoryType(MemoryType::Shared);
-
-  std::vector<TensorView*> all_tensors(
-      {x_sum,
-       x_mean,
-       cache_x,
-       x_sum_bcast,
-       x_mean_sub,
-       x_mean_sub_pow,
-       var_sum,
-       var_sum_bcast,
-       var,
-       var_eps,
-       rvar,
-       norm,
-       norm_gamma,
-       norm_gamma_beta});
-
-  auto tidx = new Int();
-  fusion.addInput(tidx);
-
-  for (auto tensor : all_tensors) {
-    tensor->split(-1, tidx);
-  }
-
-  // Local Sum => Block Broadcast
-  TensorView* x_sum_rf = x_sum->rFactor({1});
-  TensorView* var_sum_rf = var_sum->rFactor({1});
-  all_tensors.push_back(x_sum_rf);
-  all_tensors.push_back(var_sum_rf);
-
-  // ComputeAt
-  x->computeAt(x_mean_sub_pow, 1);
-  var_sum->computeAt(rvar, 1);
-  x_mean_sub_pow->computeAt(var_sum_rf, 2);
-  norm->computeAt(norm_gamma_beta, 2);
-
-  for (auto tv : all_tensors) {
-    tv->axis(0)->parallelize(ParallelType::BIDx);
-    tv->axis(-1)->parallelize(ParallelType::TIDx);
-  }
-
-  const int dimx = 128;
-  const int dimy = 2048;
-  const float kGamma = 1.0f;
-  const float kBeta = 0.0f;
-  const float kEps = 1e-5;
-  const int TIDX = 128;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({dimx, dimy}, options);
-  auto at_mu = at::mean(aten_input.to(at::kDouble), -1).unsqueeze(1);
-  auto at_var = at::var(aten_input.to(at::kDouble), -1).unsqueeze(1);
-  auto at_rvar = at::rsqrt(at::add(at_var, kEps));
-  auto at_norm = at::mul(at::sub(aten_input, at_mu), at_rvar);
-  auto aten_output = at::add(at::mul(at_norm, kGamma), kBeta);
-
-  std::vector<IValue> aten_inputs = {
-      aten_input, kGamma, kBeta, kEps, dimy, TIDX};
-
-  torch::jit::fuser::cuda::FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-  fusion.addInput(tv0);
-  fusion.addOutput(tv1);
-  // tv1[I0, R1] = tv0[I0, I1]
-
-  // Interface should just be a direct split with a Parallel type. We can
-  // include the parallelize call if we do this.
-  tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
-  // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
-
-  TensorView* tv2 = tv1->rFactor({2});
-  tv2->setMemoryType(MemoryType::Shared);
-  // tv2[I0, R1oo, Ir1i{BIDx}] = tv0[I0, I1]
-  // tv1[I0,        R1i{BIDx}] = tv2[I0, R1oo, Ir1i{BIDx}]
-
-  tv0->computeAt(tv1, 1);
-
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-
-  constexpr int numel_x = 65000, numel_y = 1024;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({numel_x, numel_y}, options);
-  auto aten_output = aten_input.to(at::kDouble).sum({1});
-
-  // How many threads to use for the block reduction
-  constexpr int runtime_threadIdx_dim = 128;
-
-  LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
-}
-
-TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Algorithm
-  Int* sym_bsx = new Int();
-  TensorView* tv0 = makeSymbolicTensor(3); // M, K, N
-  fusion.addInput(tv0);
-  fusion.addInput(sym_bsx);
-
-  TensorView* tv1 = sum(tv0, {1}); // M, R, N
-  fusion.addOutput(tv1);
-
-  TensorView* tv2 = tv0->cache_after();
-  tv2->setMemoryType(MemoryType::Shared);
-
-  // Schedule
-  constexpr int BSX = 32;
-  tv1->split(2, BSX);
-  tv1->split(1, sym_bsx);
-  tv1->split(0, BSX);
-  // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
-  tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}});
-  TensorView* tv3 = tv1->rFactor({-2});
-
-  tv0->computeAt(tv1, -2);
-  tv0->computeAt(tv3, -2);
-
-  // Thread and Block binding
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv1->axis(1)->parallelize(ParallelType::BIDy);
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  // Manual Binding
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
-  constexpr int M = 154, K = 45, N = 1524;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({M, K, N}, options);
-  at::Tensor aten_output = aten_input.to(at::kDouble).sum({1});
-
-  // How many threads to use for the block reduction
-  constexpr int runtime_threadIdx_dim = 128;
-
-  auto lparams = LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input, runtime_threadIdx_dim}, lparams);
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input, runtime_threadIdx_dim},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-
-  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
-}
-
-TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  Int* sym_bsx = new Int();
-  TensorView* tv0 = makeSymbolicTensor(2); // (M, K)
-  TensorView* tv1 = makeSymbolicTensor(2); // (K, N)
-  TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
-  TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
-  TensorView* tv4 = mul(tv2, tv3); // M, K, N
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addInput(sym_bsx);
-  fusion.addOutput(tv4);
-  // Algorithm
-
-  tv2->setMemoryType(MemoryType::Shared);
-  tv3->setMemoryType(MemoryType::Shared);
-
-  constexpr int BSX = 32;
-  tv4->split(2, BSX);
-  tv4->split(1, sym_bsx);
-  tv4->split(0, BSX);
-  // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
-  tv4->reorder({{0, 0}, {1, 3}, {2, 1}, {3, 4}, {4, 2}, {5, 5}});
-  // M/BSX, K/BSX, N/BSX, MSX, KSX, NSX
-
-  tv0->computeAt(tv4, 3);
-  tv1->computeAt(tv4, 3);
-  // Schedule
-
-  tv4->axis(0)->parallelize(ParallelType::BIDx);
-  tv4->axis(2)->parallelize(ParallelType::BIDy);
-  // Manual Binding
-  tv2->axis(-2)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  // Thread and Block binding
-
-  constexpr int M = 128, K = 457, N = 1024;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({M, K}, options);
-  at::Tensor t1 = at::randn({K, N}, options);
-  at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0));
-  std::vector<IValue> aten_inputs = {t0, t1, BSX};
-
-  LaunchParams lparams(-1, -1, -1, BSX, -1, -1);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      aten_inputs,
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-
-  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
-}
-
-TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Symbolic integers we will use for runtime tiling
-  Int* symbolic_m_tile_dim = new Int(); // bound to threadIdx.z
-  Int* symbolic_split_k_tile_dim = new Int(); // bound to blockIdx.x
-  Int* symbolic_block_k_tile_dim = new Int(); // bound to threadIdx.x
-  // Compile-time integer for tiling
-  int n_smem_tile = 8; // bound to threadIdx.y
-
-  // Symbolic 2D tensors TV0[M, K], TV1[K, N]
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
-
-  // Broadcast tv0 to [M, K, *]
-  TensorView* tv2 = broadcast(tv0, {false, false, true});
-  // Broadcast tv1 to [*, K, N]
-  TensorView* tv3 = broadcast(tv1, {true, false, false});
-
-  // Pointwise multiplication resulting in tv3[M, K, N]
-  TensorView* tv4 = mul(tv2, tv3);
-
-  // Turn the K-dimension of tv4 into a reduction dimension
-  TensorView* tv5 = sum(tv4, {1});
-
-  // Register inputs and outputs
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addOutput(tv5);
-
-  // Register runtime tile dims as inputs
-  fusion.addInput(symbolic_m_tile_dim);
-  fusion.addInput(symbolic_split_k_tile_dim);
-  fusion.addInput(symbolic_block_k_tile_dim);
-
-  // Make a 3D tile, mix of symbolic and constant, do in reverse order because
-  // dims are inserted
-  tv5->split(2, n_smem_tile);
-  tv5->split(1, symbolic_block_k_tile_dim);
-  tv5->split(1, symbolic_split_k_tile_dim);
-  tv5->split(0, symbolic_m_tile_dim);
-
-  // Reorder so all outer tiles are in the leftmost 3 positions
-  tv5->reorder({{1, 5}, {5, 1}});
-
-  // Factor out the outer reduction IterDomain, then run the inter-cta
-  // reduction, and intra-cta reduction
-  auto tv6 = tv5->rFactor({2});
-
-  // Scope computations
-  tv6->computeAt(tv5, 2);
-
-  // RFactor moves reduction axes around, reorder to match ordering of tv5
-  tv6->reorder({
-      {2, -2},
-      {3, -1},
-      {4, 2},
-      {5, 3},
-      {6, 4},
-  });
-
-  // Setup compute at schedule
-  tv0->computeAt(tv6, 3);
-  tv1->computeAt(tv6, 3);
-  tv4->computeAt(tv6, -1);
-  //
-  // T2[Mo,  bNo, Koo, Koi,  Kii,  Mi, bNi] CA(4, 3)
-  // T3[bMo,  No, Koo, Koi,  Kii, bMi,  Ni] CA(4, 3)
-  // T4[ Mo,  No, Koo, Koi,  Kii,  Mi,  Ni]
-  // T6[ Mo,  No, rKoo, Koi, Kii,  Mi,  Ni]
-  // T5[ Mo,  No,      rKoi, rKii, Mi,  Ni]
-
-  // Cache smem tiles
-  tv2->setMemoryType(MemoryType::Shared);
-  tv3->setMemoryType(MemoryType::Shared);
-  tv4->setMemoryType(MemoryType::Local);
-  tv6->setMemoryType(MemoryType::Local);
-
-  tv5->axis(0)->parallelize(ParallelType::BIDz);
-  tv5->axis(1)->parallelize(ParallelType::BIDy);
-
-  std::vector<TensorView*> tv_list = {tv2, tv3, tv4, tv5, tv6};
-  for (auto tv : tv_list) {
-    tv->axis(-2)->parallelize(ParallelType::TIDz);
-    tv->axis(-1)->parallelize(ParallelType::TIDy);
-  }
-  tv2->axis(3)->parallelize(ParallelType::TIDx);
-  tv3->axis(3)->parallelize(ParallelType::TIDx);
-  tv4->axis(3)->parallelize(ParallelType::TIDx);
-  tv6->axis(3)->parallelize(ParallelType::TIDx);
-  tv5->axis(2)->parallelize(ParallelType::TIDx);
-
-  tv2->axis(4)->parallelize(ParallelType::BIDx);
-  tv3->axis(4)->parallelize(ParallelType::BIDx);
-  tv4->axis(4)->parallelize(ParallelType::BIDx);
-  tv6->axis(4)->parallelize(ParallelType::BIDx);
-  tv5->axis(3)->parallelize(ParallelType::BIDx);
-
-  constexpr int M = 31, K = 65, N = 33;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({M, K}, options);
-  at::Tensor t1 = at::randn({K, N}, options);
-
-  FusionExecutor fe;
-  // Generate CUDA and compile with nvRTC
-  fe.compileFusion(&fusion);
-
-  // Runtime tiling
-  int m_tile = 4; // bound to threadIdx.z
-  int split_k = 7; // bound to blockIdx.x
-  int intra_cta = 8; // bound to threadIdx.x
-
-  std::vector<IValue> aten_inputs = {t0, t1, m_tile, split_k, intra_cta};
-  at::Tensor aten_output =
-      mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1);
-
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-
-  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
-}
-
-TEST(NVFuserTest, FusionGlobalIntermediate_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-  fusion.addInput(tv0);
-  fusion.addOutput(tv1);
-  // tv1[I0, R1] = tv0[I0, I1]
-
-  // Interface should just be a direct split with a Parallel type. We can
-  // include the parallelize call if we do this.
-  tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
-  // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
-
-  TensorView* tv2 = tv1->rFactor({2});
-  tv2->setMemoryType(MemoryType::Global);
-  // tv2[I0, R1oo, Ir1i{BIDx}] = tv0[I0, I1]
-  // tv1[I0,        R1i{BIDx}] = tv2[I0, R1oo, Ir1i{BIDx}]
-
-  tv0->computeAt(tv1, 1);
-
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-
-  constexpr int numel_x = 65000, numel_y = 1024;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y}, options);
-
-  // How many threads to use for the block reduction
-  constexpr int runtime_threadIdx_dim = 128;
-
-  auto lparams = LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input}, lparams);
-
-  auto aten_output = input.to(at::kDouble).sum({1});
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {input},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-TEST(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
-  TensorView* tv2 = makeSymbolicTensor(2);
-  TensorView* tv3 = makeSymbolicTensor(2);
-  TensorView* tv4 = sub(tv2, tv3);
-  TensorView* tv5 = add(tv1, tv4);
-  TensorView* tv6 = sub(tv5, tv0);
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addInput(tv2);
-  fusion.addInput(tv3);
-  fusion.addOutput(tv6);
-  // t6 = ((t1 + (t2 - t3)) - t0)
-
-  tv4->setMemoryType(MemoryType::Global);
-  tv5->setMemoryType(MemoryType::Global);
-  tv6->setMemoryType(MemoryType::Global);
-
-  constexpr int M = 32, N = 810;
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({M, N}, options);
-  at::Tensor t1 = at::randn({M, N}, options);
-  at::Tensor t2 = at::randn({M, N}, options);
-  at::Tensor t3 = at::randn({M, N}, options);
-
-  at::Tensor aten_output = (t1 + (t2 - t3)) - t0;
-
-  std::vector<IValue> aten_inputs = {t0, t1, t2, t3};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({t0, t1, t2, t3});
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionConstCheck_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto one = new Int(1);
-  TORCH_CHECK(one->isConstScalar());
-
-  auto one_x2 = mul(one, one);
-  TORCH_CHECK(one_x2->isConstScalar());
-
-  auto one_x3 = mul(one_x2, one);
-  TORCH_CHECK(one_x3->isConstScalar());
-
-  auto one_x4 = mul(one_x3, one);
-  TORCH_CHECK(one_x4->isConstScalar());
-}
-
-TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) {
-  const std::vector<int64_t> tensor_dims_in = {128, 128};
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size());
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = add(tv0, new Double(0));
-  TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv1);
-  fusion.addOutput(tv2);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn(tensor_dims_in, options);
-  at::Tensor cg_output = at::empty({tensor_dims_in[0]}, options);
-
-  // Schedule
-  tv2->split(1, 32);
-  tv2->split(1, 4); // unroll
-
-  auto tv2_rf = tv2->rFactor({-3, -2});
-
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
-  tv2_rf->axis(0)->parallelize(ParallelType::BIDx);
-  tv2_rf->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2_rf->axis(-2)->parallelize(ParallelType::Unroll);
-
-  tv1->computeAt(tv2_rf, -1);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input});
-
-  auto aten_output = (input + 0).to(at::kDouble).sum(1);
-
-  testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Test isZeroInt
-TEST(NVFuserTest, FusionIsZeroInt_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  Int* x = new Int(0);
-  Int* y = new Int(1);
-  Val* z = mul(x, y);
-  TORCH_CHECK(x->isZeroInt());
-  TORCH_CHECK(!y->isZeroInt());
-  TORCH_CHECK(!z->isZeroInt());
-}
-
-// Test isOneInt
-TEST(NVFuserTest, FusionIsOneInt_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  Int* x = new Int(1);
-  Int* y = new Int(1);
-  Val* z = mul(x, y);
-  TORCH_CHECK(x->isOneInt());
-  TORCH_CHECK(y->isOneInt());
-  TORCH_CHECK(!z->isOneInt());
-}
-
-// This is to verify no cycle of computeAt is created. A more complex
-// variation of this pattern appears in one of the Python tests
-// (test_random_topo).
-TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-
-  // Common intermediate tensor
-  auto tv1 = add(tv0, new Double(1));
-  // tv1 -> tv2
-  auto tv2 = add(tv1, new Double(2));
-  // tv1 -> tv3 -> tv4
-  auto tv3 = add(tv1, new Double(3));
-  auto tv4 = add(tv3, new Double(4));
-
-  // NOTE: This should no longer occur as of PR #201.
-  // The order of adding outputs matters. If tv3 is added before tv4,
-  // it should be fine. However, if tv4 is added before tv3, there
-  // will be a cycle of tv3->tv4 and tv4->tv3. tv3->tv4 is created
-  // first, and then tv4->tv3 is created at the final phase of
-  // computeAt (ComputeAt::setupOutputs).
-  fusion.addOutput(tv2);
-  fusion.addOutput(tv4);
-  fusion.addOutput(tv3);
-
-  tv0->computeAt(tv2, -1);
-
-  TORCH_CHECK(tv3->hasComputeAt());
-  TORCH_CHECK(!tv4->hasComputeAt());
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn(100, options);
-
-  auto t1 = aten_input + 1;
-  auto t2 = t1 + 2;
-  auto t3 = t1 + 3;
-  auto t4 = t3 + 4;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  std::vector<at::Tensor> aten_outputs = {t2, t4, t3};
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTraversalOrder1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = add(tv0, new Double(1));
-  TensorView* tv2 = add(tv0, new Double(2));
-  TensorView* tv3 = add(tv1, new Double(3));
-  TensorView* tv4 = add(tv1, new Double(4));
-
-  fusion.addOutput(tv2);
-  fusion.addOutput(tv3);
-  fusion.addOutput(tv4);
-
-  tv1->computeAt(tv3, -1);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({10, 10}, options);
-
-  auto t1 = aten_input + 1;
-  auto t2 = aten_input + 2;
-  auto t3 = t1 + 3;
-  auto t4 = t1 + 4;
-
-  std::vector<at::Tensor> aten_outputs = {t2, t3, t4};
-
-  std::vector<at::Tensor> cg_outputs = {
-      at::empty_like(aten_input, options),
-      at::empty_like(aten_input, options),
-      at::empty_like(aten_input, options)};
-
-  fe.runFusion({aten_input}, cg_outputs);
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTraversalOrder2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = add(tv0, new Double(1));
-  TensorView* tv2 = add(tv1, new Double(2));
-
-  TensorView* tv3 = add(tv0, new Double(3));
-  TensorView* tv4 = add(tv3, new Double(4));
-
-  TensorView* tv5 = add(tv1, tv3);
-
-  fusion.addOutput(tv2);
-  fusion.addOutput(tv4);
-  fusion.addOutput(tv5);
-
-  tv1->computeAt(tv5, -1);
-  tv3->computeAt(tv5, -1);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({10, 10}, options);
-
-  auto t1 = aten_input + 1;
-  auto t2 = t1 + 2;
-  auto t3 = aten_input + 3;
-  auto t4 = t3 + 4;
-  auto t5 = t1 + t3;
-
-  std::vector<at::Tensor> aten_outputs = {t2, t4, t5};
-
-  std::vector<at::Tensor> cg_outputs = {
-      at::empty_like(aten_input, options),
-      at::empty_like(aten_input, options),
-      at::empty_like(aten_input, options)};
-
-  fe.runFusion({aten_input}, cg_outputs);
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTraversalOrder3_CUDA) {
-  for (int i = 0; i < 2; ++i) {
-    Fusion fusion;
-    FusionGuard fg(&fusion);
-
-    TensorView* tv0 = makeSymbolicTensor(1);
-    fusion.addInput(tv0);
-
-    TensorView* tv1 = add(tv0, new Double(1));
-    TensorView* tv2 = add(tv1, new Double(2));
-
-    TensorView* tv3 = add(tv0, new Double(3));
-    TensorView* tv4 = add(tv3, new Double(4));
-
-    TensorView* tv5 = add(tv1, tv3);
-
-    fusion.addOutput(tv2);
-    fusion.addOutput(tv4);
-    fusion.addOutput(tv5);
-
-    const int tile = 32;
-
-    tv1->split(-1, tile);
-    tv2->split(-1, tile);
-    tv3->split(-1, tile);
-    tv4->split(-1, tile);
-    tv5->split(-1, tile);
-
-    auto compute_at_outer = tv1;
-    auto compute_at_inner = tv3;
-    if (i == 1) {
-      std::swap(compute_at_inner, compute_at_outer);
-    }
-
-    compute_at_outer->computeAt(tv5, -2);
-    compute_at_inner->computeAt(tv5, -1);
-
-    FusionExecutor fe;
-    fe.compileFusion(&fusion);
-
-    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-    at::Tensor aten_input = at::randn({100}, options);
-    auto t1 = aten_input + 1;
-    auto t2 = t1 + 2;
-    auto t3 = aten_input + 3;
-    auto t4 = t3 + 4;
-    auto t5 = t1 + t3;
-
-    std::vector<at::Tensor> aten_outputs = {t2, t4, t5};
-
-    std::vector<at::Tensor> cg_outputs = {
-        at::empty_like(aten_input, options),
-        at::empty_like(aten_input, options),
-        at::empty_like(aten_input, options)};
-
-    fe.runFusion({aten_input}, cg_outputs);
-
-    testValidate(
-        &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-  }
-}
-
-TEST(NVFuserTest, FusionTraversalOrder4_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // First tree
-  TensorView* tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  TensorView* tv1 = add(tv0, new Double(1));
-  TensorView* tv2 = add(tv1, new Double(2));
-  TensorView* tv3 = add(tv1, new Double(3));
-  fusion.addOutput(tv2);
-  fusion.addOutput(tv3);
-
-  // Second tree
-  TensorView* tv4 = makeSymbolicTensor(1);
-  fusion.addInput(tv4);
-  TensorView* tv5 = add(tv4, new Double(5));
-  TensorView* tv6 = add(tv5, new Double(6));
-  TensorView* tv7 = add(tv5, new Double(7));
-  fusion.addOutput(tv6);
-  fusion.addOutput(tv7);
-
-  tv1->computeAt(tv2, -1);
-  tv5->computeAt(tv6, -1);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({100}, options);
-  at::Tensor t4 = at::rand_like(t0, options);
-
-  auto t1 = t0 + 1;
-  auto t2 = t1 + 2;
-  auto t3 = t1 + 3;
-  auto t5 = t4 + 5;
-  auto t6 = t5 + 6;
-  auto t7 = t5 + 7;
-
-  std::vector<at::Tensor> aten_outputs = {t2, t3, t6, t7};
-  std::vector<IValue> aten_inputs = {t0, t4};
-  std::vector<at::Tensor> cg_outputs = {
-      at::empty_like(t0, options),
-      at::empty_like(t0, options),
-      at::empty_like(t0, options),
-      at::empty_like(t0, options)};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion(aten_inputs, cg_outputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTraversalOrder5_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  TensorView* tv1 = add(tv0, new Double(1));
-  TensorView* tv2 = add(tv1, new Double(2));
-  TensorView* tv3 = add(tv0, new Double(3));
-  TensorView* tv4 = add(tv3, new Double(4));
-  TensorView* tv5 = add(tv2, tv4);
-
-  fusion.addOutput(tv1);
-  fusion.addOutput(tv3);
-  fusion.addOutput(tv5);
-
-  tv2->computeAt(tv5, -1);
-  tv4->computeAt(tv5, -1);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({100}, options);
-  std::vector<at::Tensor> cg_outputs = {
-      at::empty_like(aten_input, options),
-      at::empty_like(aten_input, options),
-      at::empty_like(aten_input, options)};
-
-  fe.runFusion({aten_input}, cg_outputs);
-
-  auto t1 = aten_input + 1;
-  auto t2 = t1 + 2;
-  auto t3 = aten_input + 3;
-  auto t4 = t3 + 4;
-  auto t5 = t2 + t4;
-
-  std::vector<at::Tensor> aten_outputs = {t1, t3, t5};
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTraversalOrder6_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  TensorView* tv1 = add(tv0, new Double(1));
-  TensorView* tv2 = add(tv0, new Double(2));
-  TensorView* tv3 = add(tv1, tv2);
-  TensorView* tv4 = add(tv3, new Double(4));
-
-  fusion.addOutput(tv4);
-
-  tv1->split(0, 32);
-  tv2->split(0, 32);
-  tv3->split(0, 32);
-  tv4->split(0, 32);
-
-  tv3->computeAt(tv4, -2);
-  tv1->computeAt(tv3, -1);
-  tv2->computeAt(tv3, -2);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({100}, options);
-
-  auto t1 = aten_input + 1;
-  auto t2 = aten_input + 2;
-  auto t3 = t1 + t2;
-  auto aten_output = t3 + 4;
-
-  at::Tensor cg_output = at::empty_like(aten_input, options);
-
-  fe.runFusion({aten_input}, {cg_output});
-
-  testValidate(
-      &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTraversalOrder7_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  TensorView* tv1 = add(tv0, new Double(1));
-  TensorView* tv2 = add(tv1, new Double(2));
-  TensorView* tv3 = add(tv0, new Double(3));
-  TensorView* tv4 = add(tv3, new Double(4));
-  TensorView* tv5 = add(tv2, tv4);
-
-  fusion.addOutput(tv5);
-
-  TensorView* tvs[] = {tv1, tv2, tv3, tv4, tv5};
-  for (auto tv : tvs) {
-    tv->split(0, 2);
-    tv->split(0, 4);
-    tv->split(0, 8);
-  }
-
-  // computeAt into inner loop nests
-  tv1->computeAt(tv2, -1);
-  tv3->computeAt(tv4, -2);
-
-  tv2->computeAt(tv5, -4);
-  tv4->computeAt(tv5, -3);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({100}, options);
-
-  auto t1 = aten_input + 1;
-  auto t2 = t1 + 2;
-  auto t3 = aten_input + 3;
-  auto t4 = t3 + 4;
-  auto aten_output = t2 + t4;
-
-  at::Tensor cg_output = at::empty_like(aten_input, options);
-  fe.runFusion({aten_input}, {cg_output});
-
-  testValidate(
-      &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Test predication of grid reduction
-TEST(NVFuserTest, FusionThreadPredicate_CUDA) {
-  const int gdimx = 4;
-  const int bdimx = 128;
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-  TensorView* tv2 = unaryOp(UnaryOpType::Neg, tv1);
-  TensorView* tv3 = add(tv0, new Double(2));
-
-  fusion.addOutput(tv3);
-  fusion.addOutput(tv2);
-
-  tv1->split(1, bdimx);
-  tv1->split(1, gdimx);
-  tv3->split(1, bdimx);
-  tv3->split(1, gdimx);
-
-  TensorView* tv1_rf = tv1->rFactor({1});
-
-  tv1->computeAt(tv2, -1);
-
-  tv1->axis(0)->parallelize(ParallelType::BIDy);
-  tv1_rf->axis(0)->parallelize(ParallelType::BIDy);
-  tv2->axis(0)->parallelize(ParallelType::BIDy);
-  tv1->axis(-2)->parallelize(ParallelType::BIDx);
-  tv1_rf->axis(-2)->parallelize(ParallelType::BIDx);
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv1_rf->axis(-1)->parallelize(ParallelType::TIDx);
-
-  tv3->axis(3)->parallelize(ParallelType::TIDx);
-  tv3->axis(2)->parallelize(ParallelType::BIDx);
-  tv3->axis(0)->parallelize(ParallelType::BIDy);
-
-  int numel_x = 100;
-  int numel_y = 1000;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({numel_x, numel_y}, options);
-
-  auto t2 = -aten_input.to(at::kDouble).sum({1});
-  auto t3 = aten_input + 2.0;
-
-  std::vector<at::Tensor> aten_outputs = {t3, t2};
-
-  std::vector<at::Tensor> cg_outputs = {
-      at::empty_like(aten_input, options), at::empty({numel_x}, options)};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({aten_input}, cg_outputs);
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionLSTMCell_CUDA) {
-  const int hidden_features = 512;
-  const int batch_size = 64;
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  TensorView* tvs[16];
-  for (size_t i = 0; i < 16; i++) {
-    tvs[i] = makeSymbolicTensor(2);
-    fusion.addInput(tvs[i]);
-  }
-
-  auto ingate = unaryOp(
-      UnaryOpType::Sigmoid, add(add(add(tvs[0], tvs[1]), tvs[2]), tvs[3]));
-
-  auto forgetgate = unaryOp(
-      UnaryOpType::Sigmoid, add(add(add(tvs[4], tvs[5]), tvs[6]), tvs[7]));
-
-  auto cellgate = unaryOp(
-      UnaryOpType::Tanh, add(add(add(tvs[8], tvs[9]), tvs[10]), tvs[11]));
-
-  auto outgate = unaryOp(
-      UnaryOpType::Sigmoid, add(add(add(tvs[12], tvs[13]), tvs[14]), tvs[15]));
-
-  auto cx = makeContigTensor(2);
-  fusion.addInput(cx);
-
-  auto cy = add(mul(forgetgate, cx), mul(ingate, cellgate));
-
-  auto hy = mul(outgate, unaryOp(UnaryOpType::Tanh, cy));
-
-  fusion.addOutput(cy);
-  fusion.addOutput(hy);
-
-  std::vector<c10::IValue> aten_inputs;
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor large_tensor0 =
-      at::randn({batch_size, hidden_features * 4}, options);
-  at::Tensor large_tensor1 =
-      at::randn({batch_size, hidden_features * 4}, options);
-  at::Tensor large_tensor2 =
-      at::randn({batch_size, hidden_features * 4}, options);
-  at::Tensor large_tensor3 =
-      at::randn({batch_size, hidden_features * 4}, options);
-
-  auto chunked0 = large_tensor0.chunk(4, 1);
-  auto chunked1 = large_tensor1.chunk(4, 1);
-  auto chunked2 = large_tensor2.chunk(4, 1);
-  auto chunked3 = large_tensor3.chunk(4, 1);
-
-  aten_inputs.insert(aten_inputs.end(), chunked0.begin(), chunked0.end());
-  aten_inputs.insert(aten_inputs.end(), chunked1.begin(), chunked1.end());
-  aten_inputs.insert(aten_inputs.end(), chunked2.begin(), chunked2.end());
-  aten_inputs.insert(aten_inputs.end(), chunked3.begin(), chunked3.end());
-
-  auto at_ingate =
-      chunked0[0].add(chunked0[1]).add(chunked0[2]).add(chunked0[3]).sigmoid();
-  auto at_forgetgate =
-      chunked1[0].add(chunked1[1]).add(chunked1[2]).add(chunked1[3]).sigmoid();
-  auto at_cellgate =
-      chunked2[0].add(chunked2[1]).add(chunked2[2]).add(chunked2[3]).tanh();
-  auto at_outgate =
-      chunked3[0].add(chunked3[1]).add(chunked3[2]).add(chunked3[3]).sigmoid();
-
-  auto at_cx = at::randn({batch_size, hidden_features}, options);
-  aten_inputs.push_back(at_cx);
-  auto at_cy = at_forgetgate.mul(at_cx).add(at_ingate.mul(at_cellgate));
-  auto at_hy = at_outgate.mul(at_cy.tanh());
-
-  auto lparams = schedulePointwise(&fusion, aten_inputs);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {at_cy, at_hy}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionComputeAtMultiBCast_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = mul(tv0, new Double(0.5));
-  TensorView* tv2 = broadcast(tv1, {true, false});
-  TensorView* tv3 = broadcast(tv1, {false, true});
-  TensorView* tv4 = add(tv2, tv3);
-  fusion.addOutput(tv4);
-
-  // Not possible to do computeAt at position -1 as recomputation
-  // would be required. An exception should be thrown.
-  ASSERT_ANY_THROW(tv1->computeAt(tv3, -1));
-}
-
-TEST(NVFuserTest, FusionReductionHalf_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(3, DataType::Half);
-  fusion.addInput(tv0);
-
-  auto tv1 = castOp(DataType::Float, tv0);
-  auto tv2 = add(tv1, new Double(1.0));
-  auto tv3 = sum(tv2, {2});
-  auto tv4 = castOp(DataType::Half, tv3);
-
-  fusion.addOutput(tv4);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({8, 8, 16}, options);
-
-  auto reduction_tv = tv3;
-
-  auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-  scheduleReduction(&fusion, reduction_params.value());
-
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-
-  auto lparams = reduction_params.value().lparams;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  // no broadcasting needed, omitting the last optional argument;
-  auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
-  auto aten_output = aten_input.add(1.0).to(at::kDouble).sum({2});
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-TEST(NVFuserTest, FusionReduceSingle_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeConcreteTensor({100, 1});
-  fusion.addInput(tv0);
-  auto tv1 = sum(tv0, {1});
-  fusion.addOutput(tv1);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({100, 1}, options);
-
-  // Grab only tensor views, though there shouldn't be any other type
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  // no broadcasting needed, omitting the last optional argument;
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  auto aten_output = aten_input.to(at::kDouble).sum({1});
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) {
-  constexpr int bid_x = 80;
-  constexpr int tid_x = 4096;
-  constexpr int red_dim = 1;
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1});
-  fusion.addInput(tv0);
-
-  TensorView* tv1 =
-      reductionOp(BinaryOpType::Add, {red_dim, 2}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({bid_x, tid_x, 1}, options);
-
-  // Apply reduction heuristic
-  auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-  scheduleReduction(&fusion, reduction_params.value());
-  auto lparams = reduction_params.value().lparams;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  // no broadcasting needed, omitting the last optional argument;
-  auto cg_outputs = fe.runFusion({aten_input}, lparams);
-  auto aten_output = aten_input.to(at::kDouble).sum({red_dim, 2});
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-TEST(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) {
-  constexpr int bid_x = 80;
-  constexpr int tid_x = 4096;
-  constexpr int red_dim = 1;
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1});
-  fusion.addInput(tv0);
-
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {2}, new Double(0), tv0);
-
-  TensorView* tv2 =
-      reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv1);
-  fusion.addOutput(tv2);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({bid_x, tid_x, 1}, options);
-
-  // Apply reduction heuristic
-  auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-
-  scheduleReduction(&fusion, reduction_params.value());
-  auto lparams = reduction_params.value().lparams;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  // no broadcasting needed, omitting the last optional argument;
-  auto cg_outputs = fe.runFusion({aten_input}, lparams);
-  auto aten_output = aten_input.to(at::kDouble).sum({1, 2});
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-TEST(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) {
-  constexpr int bid_x = 80;
-  constexpr int tid_x = 4096;
-  constexpr int red_dim = 1;
-
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1});
-  fusion.addInput(tv0);
-
-  TensorView* tv1 =
-      reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv0);
-
-  TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv1);
-  fusion.addOutput(tv2);
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({bid_x, tid_x, 1}, options);
-
-  // Apply reduction heuristic
-  auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-  scheduleReduction(&fusion, reduction_params.value());
-  auto lparams = reduction_params.value().lparams;
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  // no broadcasting needed, omitting the last optional argument;
-  auto cg_outputs = fe.runFusion({aten_input}, lparams);
-  auto aten_output = aten_input.to(at::kDouble).sum({2, 1});
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {aten_input},
-      {aten_output},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-TEST(NVFuserTest, FusionTrivialReduction_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Set up your input tensor views
-  TensorView* tv0 = makeConcreteTensor({10, 20, 1});
-  fusion.addInput(tv0);
-  TensorView* tv1 = reductionOp(BinaryOpType::Add, {2}, new Double(0), tv0);
-  fusion.addOutput(tv1);
-
-  TORCH_CHECK(!fusion.hasReduction(), "Trivial reduction picked up by fusion");
-
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({10, 20, 1}, options);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input});
-  auto aten_output = aten_input.to(at::kDouble).sum({2});
-
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTrivialReduction2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  int w = 1, x = 1, y = 7, z = 8;
-
-  auto tv0 = makeSymbolicTensor(2);
-  auto tv1 = makeConcreteTensor({w, x, y, z});
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  auto tv2 = sum(tv1, {0});
-  auto tv3 = sum(tv2, {0});
-  auto tv4 = add(tv3, tv0);
-
-  fusion.addOutput(tv4);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({y, z}, options);
-  at::Tensor t1 = at::randn({w, x, y, z}, options);
-  auto aten_output = t1.to(at::kDouble).sum({0}).sum({0}).add(t0);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  auto lparams = schedulePointwise(&fusion, aten_inputs);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTrivialReduction3_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  int v = 1, w = 1, x = 1, y = 7, z = 8;
-
-  auto tv0 = makeSymbolicTensor(2);
-  auto tv1 = makeConcreteTensor({v, w, x, y, z});
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  auto tv2 = sum(tv1, {0, 1, 2});
-  auto tv3 = add(tv2, tv0);
-
-  fusion.addOutput(tv3);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({y, z}, options);
-  at::Tensor t1 = at::randn({v, w, x, y, z}, options);
-  auto aten_output = t1.sum({0, 1, 2}).add(t0);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  auto lparams = schedulePointwise(&fusion, aten_inputs);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+/*
+ * Templatized Helper Function To generate single Op comparison between the
+ * JIT codegen for Cuda and the ATen Library.
+ */
 
-// Make sure trivial reductions are correctly detected even with
-// scheduling applied.
-TEST(NVFuserTest, FusionDetectTrivialReduction1_CUDA) {
+using OutputPair = std::pair<ValType, DataType>;
+template <
+    typename AtenFunc,
+    typename JitFunc,
+    typename InputTuple,
+    size_t... NumInputs>
+void test_op(
+    int blocks,
+    int threads,
+    std::string op_str,
+    AtenFunc af,
+    JitFunc jf,
+    OutputPair op,
+    InputTuple it,
+    std::index_sequence<NumInputs...>) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-
-  auto tv1 = broadcast(tv0, {false, true});
-  auto tv2 = sum(tv1, {1});
-  fusion.addOutput(tv2);
-
-  tv2->split(1, 4);
-  tv2->split(1, 8);
-  auto tv3 = tv2->rFactor({-1});
-  auto tv4 = tv2->rFactor({-1});
-
-  auto tv5 = broadcast(tv0, {true, false});
-  auto tv6 = add(tv5, new Double(1));
-  auto tv7 = sub(tv6, new Double(1));
-  auto tv8 = sum(tv7, {0});
-  fusion.addOutput(tv8);
-
-  auto tv9 = broadcast(tv0, {false, true, true});
-  auto tv10 = sum(tv9, {1});
-  auto tv11 = sum(tv10, {1});
-  fusion.addOutput(tv11);
-
-  tv8->split(0, 3);
-  tv10->split(1, 4);
-  tv11->split(1, 5);
-
-  tv0->computeAt(tv2, -1);
-  tv0->computeAt(tv8, -1);
-  tv0->computeAt(tv11, 1);
-
-  // Test indexing to gmem-backed tensors
-  tv3->setMemoryType(MemoryType::Global);
-  tv8->setMemoryType(MemoryType::Global);
+  // Generate Input JIT function Inputs and add them as Inputs to the Fusion
+  // Graph
+  std::array<Val*, sizeof...(NumInputs)> jit_inputs = {
+      gen_jit_operand(std::get<NumInputs>(it))...};
+  std::for_each(jit_inputs.begin(), jit_inputs.end(), [&fusion](Val* v) {
+    fusion.addInput(v);
+  });
+  TensorView* out =
+      static_cast<TensorView*>(jf(std::get<NumInputs>(jit_inputs)...));
+  fusion.addOutput(out);
 
-  GpuLower gpulw(&fusion);
+  std::for_each(jit_inputs.begin(), jit_inputs.end(), [out](Val* v) {
+    if (v->getValType() == ValType::TensorView)
+      static_cast<TensorView*>(v)->computeAt(out, -1);
+  });
+  out->axis(0)->parallelize(ParallelType::BIDx);
+  out->axis(-1)->parallelize(ParallelType::TIDx);
 
-  // No kir::ReductionOp should be generated as all the reduction
-  // exprs should be replaced with a unary set op.
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    TORCH_CHECK(!kir_node->isA<kir::ReductionOp>());
-  }
+  std::array<IValue, sizeof...(NumInputs)> aten_inputs = {gen_aten_operand(
+      std::get<NumInputs>(it), blocks, threads, /*rand*/ true)...};
+  const at::ArrayRef<IValue> aten_inputs_ivalues(aten_inputs);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({100}, options);
-  std::vector<IValue> aten_inputs = {t0};
+  at::Tensor output =
+      gen_aten_operand(op, blocks, threads, /*rand*/ false).toTensor();
+  std::vector<at::Tensor> output_vect = {output};
+  cudaDeviceSynchronize();
+  if (fusion.isStochastic())
+    at::manual_seed(0);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {t0, t0, t0}, __LINE__, __FILE__);
-}
-
-// Test detection of partially trivial reduction
-TEST(NVFuserTest, FusionDetectTrivialReduction2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = sum(tv0, {1});
-  auto tv2 = add(tv1, new Double(1));
-  fusion.addOutput(tv2);
-
-  tv1->split(1, 1);
-  // tv1->axis(1): non-trivial
-  // tv1->axis(2): trivial
+  fe.runFusion(aten_inputs_ivalues, output_vect);
+  cudaDeviceSynchronize();
 
-  auto tv3 = tv1->rFactor({-1});
+  if (fusion.isStochastic())
+    at::manual_seed(0);
+  at::Tensor ref_output = af(aten_inputs);
+  cudaDeviceSynchronize(); // This sync shouldn't be necessary;
 
-  GpuLower gpulw(&fusion);
+  std::function<std::string()> aten_inputs_to_str =
+      [&aten_inputs]() -> std::string {
+    int input_cnt = 1;
+    std::stringstream ss;
+    std::for_each(
+        aten_inputs.begin(), aten_inputs.end(), [&input_cnt, &ss](IValue& iv) {
+          ss << "\nINPUT" << input_cnt++ << ": " << iv.toTensor();
+        });
+    return ss.str();
+  };
 
-  // tv3's reduction axis is a trivial reduction. The only
-  // kir::ReductionOp should be for tv1.
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (kir_node->isA<kir::ReductionOp>()) {
-      auto reduction_out =
-          kir_node->as<kir::ReductionOp>()->outputs()[0]->as<kir::TensorView>();
-      TORCH_CHECK(reduction_out->fuserTv() == tv1);
-    }
+  at::Tensor diff;
+  if (output.scalar_type() == at::kBool) {
+    diff = at::eq(output, ref_output);
+  } else {
+    diff = at::sub(output, ref_output);
   }
-}
-
-TEST(NVFuserTest, FusionInputsIdLookup_CUDA) {
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({16, 8, 8}, options);
-  at::Tensor t1 = at::randn({8, 8}, options);
-  at::Tensor t2 = at::randn({6, 4}, options);
-
-  // create a cache with max size 2;
-  torch::jit::fuser::cuda::InputsIdLookup inputs_id_lookup(2);
-
-  // testing basic function, same encoding for identical inputs
-  auto id_0 = inputs_id_lookup.lookupId({t0, t1, 5.0});
-  auto id_0_lookup = inputs_id_lookup.lookupId({t0, t1, 2.5});
-  TORCH_CHECK(id_0.id == id_0_lookup.id);
-  TORCH_CHECK(inputs_id_lookup.size() == 1);
-  TORCH_CHECK(id_0.eviction == false);
-
-  // new input (even tho same shape, but we have different signature because of
-  // missing scalar input
-  auto id_1 = inputs_id_lookup.lookupId({t0, t1});
-  auto id_1_lookup = inputs_id_lookup.lookupId({t0, t1});
-  TORCH_CHECK(id_1.id == id_1_lookup.id);
-  TORCH_CHECK(inputs_id_lookup.size() == 2);
-  TORCH_CHECK(id_1.eviction == false);
-
-  // eviction should happen at this point
-  auto id_2 = inputs_id_lookup.lookupId({t2, t1});
-  TORCH_CHECK(id_2.id != id_0.id);
-  TORCH_CHECK(id_2.id != id_1.id);
-  TORCH_CHECK(inputs_id_lookup.size() == 2);
-  TORCH_CHECK(id_2.eviction == true);
-  TORCH_CHECK(id_2.evict_id == id_0.id);
-
-  // look at input 1 again
-  auto id_1_relook = inputs_id_lookup.lookupId({t0, t1});
-  TORCH_CHECK(id_1_relook.id == id_1.id);
-  TORCH_CHECK(id_1_relook.eviction == false);
-}
-
-TEST(NVFuserTest, FusionGroupGuardSimpleTensor_CUDA) {
-  std::vector<int64_t> sizes_vec({16, 8, 8});
-  std::vector<int64_t> strides_vec({64, 8, 1});
-  auto tensor_type = TensorType::create(
-      at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  // pass with identical shape
-  auto t0 = at::randn({16, 8, 8}, options);
-  TORCH_CHECK(complyWith(t0, tensor_type));
-
-  // pass with dynamic shape
-  auto t1 = at::randn({16, 16, 8}, options);
-  TORCH_CHECK(complyWith(t1, tensor_type));
-
-  // broadcasting semantic change failure
-  auto t2 = at::randn({16, 1, 8}, options);
-  TORCH_CHECK(!complyWith(t2, tensor_type));
-
-  // contiguity failure via slicing
-  auto t3 = t0.slice(1, 0, 8, 2);
-  TORCH_CHECK(!complyWith(t3, tensor_type));
-
-  // contiguity failure via slicing
-  auto t4 = t0.slice(2, 0, 8, 2);
-  TORCH_CHECK(!complyWith(t4, tensor_type));
-
-  // rank failure
-  auto t5 = at::randn({16, 8, 8, 8}, options);
-  TORCH_CHECK(!complyWith(t5, tensor_type));
-
-  // contiguity on stride 1 dimension with implicit broadcasting
-  auto t = at::randn({4}, options);
-  auto t6 = t.unsqueeze(1).expand({4, 8});
-  TORCH_CHECK(complyWith(t6, TensorType::create(t6)));
-}
-
-TEST(NVFuserTest, FusionGroupGuardBroadcastTensor_CUDA) {
-  std::vector<int64_t> sizes_vec({16, 1, 8});
-  std::vector<int64_t> strides_vec({8, 8, 1});
-  auto tensor_type = TensorType::create(
-      at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  // broadcasting semantic change
-  auto t0 = at::randn({16, 8, 8}, options);
-  TORCH_CHECK(!complyWith(t0, tensor_type));
-
-  // dtype failure
-  auto t1 = at::randn({16, 1, 8}, options.dtype(at::kHalf));
-  TORCH_CHECK(!complyWith(t1, tensor_type));
-
-  // dtype failure
-  auto t2 = at::randn({16, 1, 8}, options);
-  TORCH_CHECK(complyWith(t2, tensor_type));
 
-  // device inconsistency shouldn't fail
-  auto t3 = at::randn({16, 1, 8}, options.device(at::kCPU, 0));
-  TORCH_CHECK(complyWith(t3, tensor_type));
+  TORCH_CHECK(
+      (output.scalar_type() == at::kBool
+           ? output.equal(ref_output)
+           :
+           // The absolute Tolerance was raised to 1e-07 from 1e-08 to allow
+           // allow for the remainder function to pass.
+           output.allclose(ref_output, /*rtol*/ 1e-05, /*atol*/ 1e-07)),
+      "\nOp Type: -- ",
+      op_str,
+      " -- had a mismatch.",
+      aten_inputs_to_str(),
+      "\nABS MAX DIFF: ",
+      output.sub(ref_output).abs().max(),
+      "\n");
 }
 
-TEST(NVFuserTest, FusionGroupGuardPermutedTensor_CUDA) {
-  std::vector<int64_t> sizes_vec({16, 8, 8});
-  std::vector<int64_t> strides_vec({64, 1, 8});
-  auto tensor_type = TensorType::create(
-      at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  // failing permutation
-  auto t0 = at::randn({16, 8, 8}, options);
-  TORCH_CHECK(!complyWith(t0, tensor_type));
-
-  // passing with dynamic shape
-  auto t1 = t0.permute({0, 2, 1});
-  TORCH_CHECK(complyWith(t1, tensor_type));
+/*
+ *  Templatized Helper Function that uses variadic templates to
+ *  process a variable length Input Tuple of different Operand Type.
+ */
+template <typename AtenFunc, typename JitFunc, typename InputTuple>
+void test_op(
+    int blocks,
+    int threads,
+    std::string op_str,
+    AtenFunc af,
+    JitFunc jf,
+    OutputPair op,
+    InputTuple it) {
+  static constexpr auto size = std::tuple_size<InputTuple>::value;
+  test_op(
+      blocks,
+      threads,
+      op_str,
+      af,
+      jf,
+      op,
+      it,
+      std::make_index_sequence<size>{});
 }
 
-TEST(NVFuserTest, FusionGroupGuardRelaxedCheck_CUDA) {
-  std::vector<int64_t> sizes_vec({16, 8, 8});
-  std::vector<int64_t> strides_vec({128, 16, 1});
-  auto tensor_type = TensorType::create(
-      at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  // contiguity check passes although it differs
-  auto t0 = at::randn({16, 16, 8}, options);
-  TORCH_CHECK(complyWith(t0, tensor_type));
-
-  // passing with dynamic shape
-  auto t1 = t0.slice(1, 0, 16, 2);
-  TORCH_CHECK(complyWith(t1, tensor_type));
-}
+TEST(NVFuserTest, FusionUnaryOps_CUDA) {
+  using OpTuple =
+      std::tuple<at::Tensor (*)(const at::Tensor&), UnaryOpType, std::string>;
 
-TEST(NVFuserTest, FusionDisjointSet_CUDA) {
-  DisjointSet<int> set;
+  // [Note: explicit tuple type for uniform initialization list]
+  // Tuple type must be explicitly specified for each uniform initialization
+  // list within the vector to make this code compatible with some old env
+  // which we still need to support. eg. gcc 5.4 + cuda 9.2.
+  std::vector<OpTuple> ops{
+      OpTuple{at::abs, UnaryOpType::Abs, "abs"},
+      OpTuple{at::acos, UnaryOpType::Acos, "acos"},
+      OpTuple{at::asin, UnaryOpType::Asin, "asin"},
+      OpTuple{at::atan, UnaryOpType::Atan, "atan"},
+      // There does not appear to be an appropriate ATen function for atanh
+      // OpTuple{at::atanh,      UnaryOpType::Atanh,      "atanh"      },
+      OpTuple{at::ceil, UnaryOpType::Ceil, "ceil"},
+      OpTuple{at::cos, UnaryOpType::Cos, "cos"},
+      OpTuple{at::cosh, UnaryOpType::Cosh, "cosh"},
+      OpTuple{at::erf, UnaryOpType::Erf, "erf"},
+      OpTuple{at::erfc, UnaryOpType::Erfc, "erfc"},
+      OpTuple{at::exp, UnaryOpType::Exp, "exp"},
+      OpTuple{at::expm1, UnaryOpType::Expm1, "expm1"},
+      OpTuple{at::floor, UnaryOpType::Floor, "floor"},
+      OpTuple{at::frac, UnaryOpType::Frac, "frac"},
+      OpTuple{at::gelu, UnaryOpType::Gelu, "gelu"},
+      OpTuple{at::lgamma, UnaryOpType::Lgamma, "lgamma"},
+      OpTuple{at::log, UnaryOpType::Log, "log"},
+      OpTuple{at::log10, UnaryOpType::Log10, "log10"},
+      OpTuple{at::log1p, UnaryOpType::Log1p, "log1p"},
+      OpTuple{at::log2, UnaryOpType::Log2, "log2"},
+      OpTuple{at::neg, UnaryOpType::Neg, "neg"},
+      OpTuple{at::reciprocal, UnaryOpType::Reciprocal, "reciprocal"},
+      OpTuple{at::relu, UnaryOpType::Relu, "relu"},
+      OpTuple{at::round, UnaryOpType::Round, "round"},
+      OpTuple{at::rsqrt, UnaryOpType::Rsqrt, "rsqrt"},
+      OpTuple{at::sigmoid, UnaryOpType::Sigmoid, "sigmoid"},
+      OpTuple{at::sin, UnaryOpType::Sin, "sin"},
+      OpTuple{at::sinh, UnaryOpType::Sinh, "sinh"},
+      OpTuple{at::sqrt, UnaryOpType::Sqrt, "sqrt"},
+      OpTuple{at::tan, UnaryOpType::Tan, "tan"},
+      OpTuple{at::tanh, UnaryOpType::Tanh, "tanh"},
+      OpTuple{at::trunc, UnaryOpType::Trunc, "trunc"}};
 
-  const std::set<int> group_x({0, 1, 2});
-  const std::set<int> group_y({3, 4, 5});
-  const std::set<int> group_z({6, 7, 8});
-  const std::vector<std::set<int>> groups({group_x, group_y, group_z});
-  std::set<int> group_all;
-  std::for_each(groups.begin(), groups.end(), [&](const auto& g) {
-    group_all.insert(g.begin(), g.end());
+  std::for_each(ops.begin(), ops.end(), [](OpTuple& op) {
+    test_op(
+        /*blocks*/ 640,
+        /*threads*/ 64,
+        /*name*/ std::get<2>(op),
+        /*Aten Func   */
+        [&op](std::array<IValue, 1>& vals) {
+          return std::get<0>(op)(vals[0].toTensor());
+        },
+        /*JIT  Func   */
+        [&op](Val* in1) -> Val* { return unaryOp(std::get<1>(op), in1); },
+        /*Output      */ std::make_pair(ValType::TensorView, DataType::Float),
+        /*Inputs Tuple*/
+        std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float)));
   });
 
-  // Initially, nothing should be considered equivalent
-  for (auto i : group_all) {
-    for (auto j : group_all) {
-      TORCH_CHECK(!set.areEquivalent(i, j));
-    }
-  }
-
-  // Sets values in group_x are equivalent
-  for (auto i : group_x) {
-    for (auto j : group_x) {
-      set.join(i, j);
-      TORCH_CHECK(set.contains(i));
-      TORCH_CHECK(set.contains(j));
-    }
-  }
+  test_op(
+      /*blocks*/ 128,
+      /*threads*/ 64,
+      /*name*/ "rand_like",
+      /*Aten Func   */
+      [](std::array<IValue, 1>& vals) {
+        return at::rand_like(vals[0].toTensor());
+      },
+      /*JIT  Func   */
+      [](Val* in1) -> Val* { return unaryOp(UnaryOpType::RandLike, in1); },
+      /*Output      */ std::make_pair(ValType::TensorView, DataType::Float),
+      /*Inputs Tuple*/
+      std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float)));
+}
 
-  // All values in group_x shoudl be equivalent with each other
-  for (auto i : group_x) {
-    for (auto j : group_x) {
-      TORCH_CHECK(set.areEquivalent(i, j));
-    }
-  }
-  // But nothing else should be equivalent
-  for (auto i : group_all) {
-    for (auto j : group_y) {
-      TORCH_CHECK(!set.areEquivalent(i, j));
-    }
-    for (auto j : group_z) {
-      TORCH_CHECK(!set.areEquivalent(i, j));
-    }
-  }
+TEST(NVFuserTest, FusionBinaryOps_CUDA) {
+  using AtenFuncSig = at::Tensor (*)(const at::Tensor&, const at::Tensor&);
+  using OpTuple = std::tuple<AtenFuncSig, BinaryOpType, std::string>;
 
-  // Sets values in group_y are equivalent
-  for (auto i : group_y) {
-    for (auto j : group_y) {
-      set.join(i, j);
-      TORCH_CHECK(set.contains(i));
-      TORCH_CHECK(set.contains(j));
-    }
-  }
+  // see [Note: explicit tuple type for uniform initialization list]
+  std::vector<OpTuple> logic_ops{
+      OpTuple{at::eq, BinaryOpType::Eq, "eq"},
+      OpTuple{at::ge, BinaryOpType::GE, "ge"},
+      OpTuple{at::gt, BinaryOpType::GT, "gt"},
+      OpTuple{at::le, BinaryOpType::LE, "le"},
+      OpTuple{at::lt, BinaryOpType::LT, "lt"},
+      OpTuple{at::ne, BinaryOpType::NE, "ne"}};
 
-  // group_x should be still equivalent
-  for (auto i : group_x) {
-    for (auto j : group_x) {
-      TORCH_CHECK(set.areEquivalent(i, j));
-    }
-  }
-  // group_y should be now equivalent
-  for (auto i : group_y) {
-    for (auto j : group_y) {
-      TORCH_CHECK(set.areEquivalent(i, j));
-    }
-  }
-  // But group_z should not be equivalent with anything yet
-  for (auto i : group_all) {
-    for (auto j : group_z) {
-      TORCH_CHECK(!set.areEquivalent(i, j));
-    }
-  }
+  std::for_each(logic_ops.begin(), logic_ops.end(), [](OpTuple& op) {
+    test_op(
+        /*blocks*/ 640,
+        /*threads*/ 64,
+        /*name*/ std::get<2>(op),
+        /*Aten Func   */
+        [&op](std::array<IValue, 2>& vals) {
+          return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor());
+        },
+        /*JIT  Func   */
+        [&op](Val* in1, Val* in2) -> Val* {
+          return binaryOp(std::get<1>(op), in1, in2);
+        },
+        /*Output      */ std::make_pair(ValType::TensorView, DataType::Bool),
+        /*Inputs Tuple*/
+        std::make_tuple(
+            std::make_pair(ValType::TensorView, DataType::Float),
+            std::make_pair(ValType::TensorView, DataType::Float)));
+  });
 
-  // Sets values in group_z are equivalent
-  for (auto i : group_z) {
-    for (auto j : group_z) {
-      set.join(i, j);
-      TORCH_CHECK(set.contains(i));
-      TORCH_CHECK(set.contains(j));
-    }
-  }
+  // see [Note: explicit tuple type for uniform initialization list]
+  std::vector<OpTuple> math_ops{
+      OpTuple{at::atan2, BinaryOpType::Atan2, "atan2"},
+      OpTuple{at::div, BinaryOpType::Div, "div"},
+      OpTuple{at::fmod, BinaryOpType::Fmod, "fmod"},
+      OpTuple{at::max, BinaryOpType::Max, "max"},
+      OpTuple{at::min, BinaryOpType::Min, "min"},
+      OpTuple{at::mul, BinaryOpType::Mul, "mul"},
+      OpTuple{at::pow, BinaryOpType::Pow, "pow"},
+      // NOTE: Remainder does not match the Aten impl exactly
+      // despite using an identical function.
+      OpTuple{at::remainder, BinaryOpType::Remainder, "remainder"},
+  };
 
-  // Now each of the three groups should be equivalent within each
-  // group
-  for (size_t gi = 0; gi < groups.size(); ++gi) {
-    for (size_t gj = 0; gj < groups.size(); ++gj) {
-      for (auto i : groups[gi]) {
-        for (auto j : groups[gj]) {
-          TORCH_CHECK(
-              (gi == gj && set.areEquivalent(i, j)) ||
-              (gi != gj && !set.areEquivalent(i, j)));
-        }
-      }
-    }
-  }
+  std::for_each(math_ops.begin(), math_ops.end(), [](OpTuple& op) {
+    test_op(
+        /*blocks*/ 640,
+        /*threads*/ 64,
+        /*name*/ std::get<2>(op),
+        /*Aten Func   */
+        [&op](std::array<IValue, 2>& vals) {
+          return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor());
+        },
+        /*JIT  Func   */
+        [&op](Val* in1, Val* in2) -> Val* {
+          return binaryOp(std::get<1>(op), in1, in2);
+        },
+        /*Output      */ std::make_pair(ValType::TensorView, DataType::Float),
+        /*Inputs Tuple*/
+        std::make_tuple(
+            std::make_pair(ValType::TensorView, DataType::Float),
+            std::make_pair(ValType::TensorView, DataType::Float)));
+  });
 
-  auto all_elements = set.getAllElements();
-  std::sort(all_elements.begin(), all_elements.end());
-  std::vector<int> group_all_vec(group_all.begin(), group_all.end());
-  std::sort(group_all_vec.begin(), group_all_vec.end());
-  TORCH_CHECK(all_elements == group_all_vec);
+  test_op(
+      /*blocks*/ 640,
+      /*threads*/ 64,
+      /*name*/ "add_alpha",
+      /*Aten Func   */
+      [](std::array<IValue, 3>& vals) {
+        return at::add(
+            vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar());
+      },
+      /*JIT  Func   */ static_cast<Val* (*)(Val*, Val*, Val*)>(&add_alpha),
+      /*Output      */ std::make_pair(ValType::TensorView, DataType::Float),
+      /*Inputs Tuple*/
+      std::make_tuple(
+          std::make_pair(ValType::TensorView, DataType::Float),
+          std::make_pair(ValType::TensorView, DataType::Float),
+          std::make_pair(ValType::Scalar, DataType::Float)));
+  test_op(
+      /*blocks*/ 640,
+      /*threads*/ 64,
+      /*name*/ "sub_alpha",
+      /*Aten Func   */
+      [](std::array<IValue, 3>& vals) {
+        return at::sub(
+            vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar());
+      },
+      /*JIT  Func   */ static_cast<Val* (*)(Val*, Val*, Val*)>(&sub_alpha),
+      /*Output      */ std::make_pair(ValType::TensorView, DataType::Float),
+      /*Inputs Tuple*/
+      std::make_tuple(
+          std::make_pair(ValType::TensorView, DataType::Float),
+          std::make_pair(ValType::TensorView, DataType::Float),
+          std::make_pair(ValType::Scalar, DataType::Float)));
+}
 
-  set.clear();
-  all_elements = set.getAllElements();
-  TORCH_CHECK(all_elements.size() == 0);
+TEST(NVFuserTest, FusionTernaryOps_CUDA) {
+  test_op(
+      /*blocks*/ 640,
+      /*threads*/ 64,
+      /*name*/ "clamp",
+      /*Aten Func   */
+      [](std::array<IValue, 1>& vals) {
+        return at::clamp(vals[0].toTensor(), 0.f, 1.f);
+      },
+      /*JIT  Func   */
+      [](Val* in1) -> Val* {
+        return clamp(in1, new Float(0.f), new Float(1.f));
+      },
+      /*Output      */ std::make_pair(ValType::TensorView, DataType::Float),
+      /*Inputs Tuple*/
+      std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float)));
+  test_op(
+      /*blocks*/ 640,
+      /*threads*/ 64,
+      /*name*/ "threshold",
+      /*Aten Func   */
+      [](std::array<IValue, 1>& vals) {
+        return at::threshold(vals[0].toTensor(), 0.f, 1.f);
+      },
+      /*JIT  Func   */
+      [](Val* in1) -> Val* {
+        return threshold(in1, new Float(0.f), new Float(1.f));
+      },
+      /*Output      */ std::make_pair(ValType::TensorView, DataType::Float),
+      /*Inputs Tuple*/
+      std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float)));
+  test_op(
+      /*blocks*/ 640,
+      /*threads*/ 64,
+      /*name*/ "where",
+      /*Aten Func   */
+      [](std::array<IValue, 3>& vals) {
+        return at::where(
+            vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor());
+      },
+      /*JIT  Func   */ static_cast<Val* (*)(Val*, Val*, Val*)>(&where),
+      /*Output      */ std::make_pair(ValType::TensorView, DataType::Float),
+      /*Inputs Tuple*/
+      std::make_tuple(
+          std::make_pair(ValType::TensorView, DataType::Bool),
+          std::make_pair(ValType::TensorView, DataType::Float),
+          std::make_pair(ValType::TensorView, DataType::Float)));
+}
 
-  // All cleared. Nothing should be considered equivalent.
-  for (auto i : group_all) {
-    for (auto j : group_all) {
-      TORCH_CHECK(!set.areEquivalent(i, j));
-    }
-  }
+TEST(NVFuserTest, FusionCompoundOps_CUDA) {
+  test_op(
+      /*blocks*/ 640,
+      /*threads*/ 64,
+      /*name*/ "lerp",
+      /*Aten Func   */
+      [](std::array<IValue, 3>& vals) {
+        return at::lerp(
+            vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor());
+      },
+      /*JIT  Func   */ static_cast<Val* (*)(Val*, Val*, Val*)>(&lerp),
+      /*Output      */ std::make_pair(ValType::TensorView, DataType::Float),
+      /*Inputs Tuple*/
+      std::make_tuple(
+          std::make_pair(ValType::TensorView, DataType::Float),
+          std::make_pair(ValType::TensorView, DataType::Float),
+          std::make_pair(ValType::TensorView, DataType::Float)));
+  test_op(
+      /*blocks*/ 640,
+      /*threads*/ 64,
+      /*name*/ "addcmul",
+      /*Aten Func   */
+      [](std::array<IValue, 4>& vals) {
+        return at::addcmul(
+            vals[0].toTensor(),
+            vals[1].toTensor(),
+            vals[2].toTensor(),
+            vals[3].toScalar());
+      },
+      /*JIT  Func   */ static_cast<Val* (*)(Val*, Val*, Val*, Val*)>(&addcmul),
+      /*Output      */ std::make_pair(ValType::TensorView, DataType::Float),
+      /*Inputs Tuple*/
+      std::make_tuple(
+          std::make_pair(ValType::TensorView, DataType::Float),
+          std::make_pair(ValType::TensorView, DataType::Float),
+          std::make_pair(ValType::TensorView, DataType::Float),
+          std::make_pair(ValType::Scalar, DataType::Float)));
 }
 
-TEST(NVFuserTest, FusionNonUniqueBroadcastSize_CUDA) {
+TEST(NVFuserTest, FusionCastOps_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
-  auto tv1 = makeSymbolicTensor(2);
-  auto tv2 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addInput(tv2);
-
-  auto tv3 = broadcast(tv0, {false, true});
-  auto tv4 = add(tv3, tv1);
-  auto tv5 = add(tv3, tv2);
-
-  fusion.addOutput(tv4);
-  fusion.addOutput(tv5);
+  TensorView* tv0 = makeDummyTensor(2, DataType::Half);
 
-  // In order to do this, tv1->axis(1) and tv2->axis(1) must have the
-  // same size, but we can't prove it, so this should throw an error.
-  ASSERT_ANY_THROW(tv3->computeAt(tv4, -1));
-}
+  TensorView* intrm1 = castOp(DataType::Float, tv0);
+  TensorView* out = castOp(DataType::Half, intrm1);
 
-TEST(NVFuserTest, FusionBiasGeluFwd_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  fusion.addInput(tv0);
+  fusion.addOutput(out);
+  tv0->computeAt(out, -1);
 
-  const float k_079 = 0.79788456;
-  const float k_004 = 0.044715;
-
-  // bias vector
-  auto t0 = makeSymbolicTensor(1, DataType::Half);
-  fusion.addInput(t0);
-  auto t1 = castOp(DataType::Float, t0);
-  // input tensor
-  auto t2 = makeSymbolicTensor(3, DataType::Half);
-  fusion.addInput(t2);
-  auto t3 = castOp(DataType::Float, t2);
-  auto t4 = broadcast(t1, {true, true, false});
-  auto t5 = add(t4, t3);
-  auto t6 = mul(t5, new Double(0.5));
-  auto t7 = mul(t5, new Double(k_079));
-  auto t8 = mul(t5, new Double(k_004));
-  auto t9 = mul(t8, t5);
-  auto t10 = add(t9, new Int(1));
-  auto t11 = mul(t7, t10);
-  auto t12 = unaryOp(UnaryOpType::Tanh, t11);
-  auto t13 = add(t12, new Double(1));
-  auto t14 = mul(t6, t13);
-  auto t15 = castOp(DataType::Half, t14);
-  fusion.addOutput(t15);
+  out->axis(0)->parallelize(ParallelType::BIDx);
+  out->axis(-1)->parallelize(ParallelType::TIDx);
 
   auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  std::vector<int64_t> input_shape{6, 512, 4096};
-  std::vector<int64_t> bias_shape{4096};
 
-  auto at_input = at::randn(input_shape, options);
-  auto at_bias = at::randn(bias_shape, options);
-
-  auto at_x =
-      at_bias.to(c10::ScalarType::Float) + at_input.to(c10::ScalarType::Float);
-  auto aten_output_float =
-      at_x * 0.5 * (1.0 + (k_079 * at_x * (1 + k_004 * at_x * at_x)).tanh());
-  auto aten_output = aten_output_float.to(c10::ScalarType::Half);
+  at::Tensor input1 = at::rand({1, 4}, options);
+  at::Tensor ref_output = at::empty_like(input1);
 
-  std::vector<IValue> aten_inputs = {at_bias, at_input};
-  auto lparams = schedulePointwise(&fusion, aten_inputs);
+  std::array<IValue, 1> inputs = {input1};
+  const at::ArrayRef<IValue> input_ivalues(inputs);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion(input_ivalues);
 
-  auto cg_outputs = fe.runFusion(aten_inputs, lparams);
+  ref_output = at::_cast_Half(at::_cast_Float(input1));
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  TORCH_CHECK(
+      outputs[0].equal(ref_output),
+      "\nOp Type: -- ",
+      "cast FP16->FP32->FP16",
+      " -- had a mismatch.\n",
+      "\nABS MAX DIFF: ",
+      outputs[0].sub(ref_output).abs().max(),
+      "\n");
 }
 
-TEST(NVFuserTest, FusionBiasGeluBwd_CUDA) {
-  // skipping on pre-volta device
-  if (at::cuda::getDeviceProperties(c10::cuda::current_device())->major < 7) {
-    return;
-  }
+// We want split/merge/reorder all tested both on and off rfactor domains, also
+// want compute at into the rfactor domain, and into its consumer
+TEST(NVFuserTest, FusionRFactorReplay_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  const float k_079 = 0.79788456;
-  const float k_004 = 0.044715;
-  const float k_010 = 0.1070322243;
-
-  // gradient tensor
-  auto t0 = makeSymbolicTensor(3, DataType::Half);
-  fusion.addInput(t0);
-  auto t1 = castOp(DataType::Float, t0);
-  // bias tensor
-  auto t2 = makeSymbolicTensor(1, DataType::Half);
-  fusion.addInput(t2);
-  auto t3 = castOp(DataType::Float, t2);
-  // input tensor
-  auto t4 = makeSymbolicTensor(3, DataType::Half);
-  fusion.addInput(t4);
-  auto t5 = castOp(DataType::Float, t4);
-  auto t6 = broadcast(t3, {true, true, false});
-  auto t7 = add(t6, t5);
-  auto t8 = mul(t7, new Double(k_079));
-  auto t9 = mul(t7, new Double(k_004));
-  auto t10 = mul(t9, t7);
-  auto t11 = add(t10, new Int(1));
-  auto t12 = mul(t8, t11);
-  auto t13 = unaryOp(UnaryOpType::Tanh, t12);
-  auto t14 = mul(t7, new Double(0.5));
-  auto t15 = mul(t13, t13);
-  auto t16 = unaryOp(UnaryOpType::Neg, t15);
-  auto t17 = add(t16, new Int(1));
-  auto t18 = mul(t7, new Double(k_010));
-  auto t19 = mul(t18, t7);
-  auto t20 = add(t19, new Double(k_079));
-  auto t21 = mul(t17, t20);
-  auto t22 = mul(t14, t21);
-  auto t23 = add(t13, new Int(1));
-  auto t24 = mul(t23, new Double(0.5));
-  auto t25 = add(t22, t24);
-  auto t26 = mul(t25, t1);
-  // Save float output for validation
-  fusion.addOutput(t26);
-  auto t27 = castOp(DataType::Half, t26);
-  fusion.addOutput(t27);
-
-  auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
-  at::manual_seed(1);
-  std::vector<int64_t> input_shape{6, 512, 4096};
-  std::vector<int64_t> bias_shape{4096};
-  auto at_input = at::randn(input_shape, options);
-  auto at_bias = at::randn(bias_shape, options);
-  auto at_grad = at::randn(input_shape, options);
-
-  auto at_x =
-      at_bias.to(c10::ScalarType::Float) + at_input.to(c10::ScalarType::Float);
-  auto at_tanh_out = (k_079 * at_x * (1 + k_004 * at_x * at_x)).tanh();
-  auto at_ff = 0.5 * at_x *
-          ((1 - at_tanh_out * at_tanh_out) * (k_079 + k_010 * at_x * at_x)) +
-      0.5 * (1 + at_tanh_out);
-  auto at_out = at_ff * at_grad;
-  auto at_out_half = at_out.to(c10::ScalarType::Half);
-
-  std::vector<IValue> aten_inputs = {at_grad, at_bias, at_input};
-  std::vector<at::Tensor> aten_outputs = {at_out, at_out_half};
-
-  auto lparams = schedulePointwise(&fusion, aten_inputs);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
-}
-
-// Reproducer of issue #459
-TEST(NVFuserTest, FusionIssue459_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
 
-  auto tv0 = makeSymbolicTensor(1);
+  // Register your inputs
   fusion.addInput(tv0);
-  auto tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
 
-  auto tv2 = add(tv0, new Double(1));
-  auto tv3 = broadcast(tv2, {true, false});
-  auto tv4 = add(tv1, tv3);
+  // Do math with it, it returns a `Val*` but can be static_casted back to
+  // TensorView
+  TensorView* tv1 = sum(tv0, {1});
+  // tv1[I0, R1]
+  tv1->split(0, 32);
+  // tv1[I0o, I0i{32}, R1]
+  tv1->split(0, 16);
+  // tv1[I0oo, I0oi{16}, I0i{32}, R1]
+  tv1->split(-1, 8);
+  // tv1[I0oo, I0oi{16}, I0i{32}, R1o, R1i{8}]
+  tv1->split(-2, 4);
+  // tv1[I0oo, I0oi{16}, I0i{32}, R1oo, R1oi{4}, R1i{8}]
+  tv1->reorder({{0, -2}, {2, -1}, {-3, 0}, {-1, 1}});
+  // tv1[R1oo, R1i{8}, I0oi{16}, R1oi{4}, I0oo, I0i{32}]
 
-  // Create two outputs from the final arithmetic result
-  auto tv5 = add(tv4, new Double(1));
-  fusion.addOutput(tv5);
-  auto tv6 = add(tv4, new Double(1));
-  fusion.addOutput(tv6);
+  tv1->merge(0);
+  tv1->merge(-2);
 
-  // Scheduling
-  for (auto output : ir_utils::filterByType<TensorView>(fusion.outputs())) {
-    output->merge(-2, -1);
-  }
-  for (auto output : ir_utils::filterByType<TensorView>(fusion.outputs())) {
-    output->split(0, 128);
-  }
+  // tv1[R1oo*R1i{8}, I0oi{16}, R1oi{4}, I0oo*I0i{32}]
+  TensorDomain* new_domain = TransformRFactor::runReplay(tv1->domain(), {0});
+  // new_domain[r(R1oo*R1i{8})rf, I0oi{16}, ir1oi{4}rf, I0oo*I0i{32}]
 
-  tv0->computeAt(tv5, -1);
+  TensorDomain* new_domain2 = TransformRFactor::runReplay2(tv1->domain(), {0});
+  // new_domain2[                 I0oi{16},           , I0oo*I0i{32}, R1oi{4}]
 
-  tv6->axis(0)->parallelize(ParallelType::BIDx);
-  tv6->axis(1)->parallelize(ParallelType::TIDx);
+  // Move rfactor axis to end, keep iter rfactor axis
+  new_domain->reorder({{0, -1}, {2, 2}});
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  const int numel_x = 10;
-  const int numel_y = 20;
-  auto t0 = at::randn({numel_x}, options);
-  auto t1 = at::randn({numel_y, numel_x}, options);
-  auto aten_output = (t0 + 1).unsqueeze(0) + t1 + 1;
+  // Replay casp, replay new_domain2 as new_domain
+  // reordered_new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf]
+  auto replay_casp = TransformReplay::replayCasP(new_domain2, new_domain, 2);
+  TensorDomain* casp = replay_casp.first;
+  // new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf]
+  //       casp[I0oi{16}, I0oo*I0i{32},  R1oi{4}]
 
-  std::vector<IValue> aten_inputs = {t0, t1};
+  casp->split(1, new Int(2));
+  // casp      [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4} ]
+  // new_domain[I0oi{16},  I0oo*I0i{32}  ,                 ir1oi{4}rf,
+  // R(R1oo*R1i{8})rf]
 
-  torch::jit::fuser::cuda::FusionExecutor fe;
-  fe.compileFusion(&fusion);
+  auto replay_pasc = TransformReplay::replayPasC(new_domain, casp, 2);
+  TensorDomain* pasc = replay_pasc.first;
+  // pasc      [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4}rf,
+  // R(R1oo*R1i{8})rf]
+
+  TORCH_CHECK(
+      new_domain->nDims() - 1 == new_domain2->nDims(),
+      casp->nDims() == new_domain2->nDims() + 1,
+      pasc->nDims() == new_domain->nDims() + 1,
+      "Error in rfactor, number of dimensions is not correct.");
 
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  TORCH_CHECK(
+      !casp->sameAs(new_domain2) && !pasc->sameAs(new_domain) &&
+          !new_domain->sameAs(new_domain2) &&
+          !tv1->domain()->sameAs(new_domain) &&
+          !tv1->domain()->sameAs(new_domain2),
+      "Error in rfactor, number of dimensions is not correct.");
 
-  testValidate(
-      &fusion,
-      cg_outputs,
-      aten_inputs,
-      {aten_output, aten_output},
-      __LINE__,
-      __FILE__);
+  auto dom = new_domain->getRootDomain();
+  TORCH_CHECK(
+      !dom[0]->isReduction() &&
+          std::any_of(
+              dom.begin(),
+              dom.end(),
+              [](IterDomain* id) { return id->isReduction(); }) &&
+          std::any_of(
+              dom.begin(),
+              dom.end(),
+              [](IterDomain* id) { return id->isRFactorProduct(); }),
+      "Error in rFactor, there seems to be something wrong in root domain.");
+
+  auto dom2 = new_domain2->getRootDomain();
+  TORCH_CHECK(
+      !dom2[0]->isReduction() &&
+          std::any_of(
+              dom2.begin(),
+              dom2.end(),
+              [](IterDomain* id) { return id->isReduction(); }),
+      "Error in rFactor, there seems to be something wrong in root domain.");
 }
 
-TEST(NVFuserTest, FusionSmemIndexingSimple_CUDA) {
+// Start off simple, block on the outer dim
+// block stride + thread all reduce + unrolling on inner dim
+TEST(NVFuserTest, FusionReduction_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(1));
-  auto tv3 = add(tv2, new Double(1));
-  fusion.addOutput(tv3);
 
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
-  tv3->axis(1)->parallelize(ParallelType::TIDx);
-
-  tv0->computeAt(tv3, -1);
-
-  tv1->setMemoryType(MemoryType::Shared);
-  tv2->setMemoryType(MemoryType::Global);
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  fusion.addOutput(tv1);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  tv1->split(1, 128);
+  // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+  tv1->split(1, 4);
+  // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1]
 
-  auto aten_input = at::randn({12, 34}, options);
-  at::Tensor aten_output = aten_input + 1.0 + 1.0 + 1.0;
+  TensorView* tv2 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1]
+  // tv1[I0,        R1oi{4},  R1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}]
 
-  auto cg_outputs = fe.runFusion({aten_input});
+  TensorView* tv3 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1]
+  // tv3[I0,        R1oi{4}, Ir1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}]
+  // tv1[I0,                  R1i{128}] = tv3[I0,        R1oi{4}, Ir1i{128}]
 
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
+  // Incrementally, can print in between for debugging
+  tv0->computeAt(tv2, 1);
+  tv2->computeAt(tv3, 1);
+  tv3->computeAt(tv1, 1);
 
-TEST(NVFuserTest, FusionSmemIndexing_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  // Re do it all at once, because why not.
+  tv0->computeAt(tv1, 1);
 
-  // Symbolic integers we will use for runtime tiling
-  Int* symbolic_m_tile_dim = new Int();
-  Int* symbolic_split_k_tile_dim = new Int();
-  Int* symbolic_block_k_tile_dim = new Int();
-  // Compile-time integer for tiling
-  int n_smem_tile = 32;
+  tv2->axis(2)->parallelize(ParallelType::Unroll);
+  tv1->axis(0)->parallelize(ParallelType::BIDx);
 
-  // Symbolic 2D tensors TV0[M, K], TV1[K, N]
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
 
-  // Broadcast tv0 to [M, K, *]
-  TensorView* tv2 = broadcast(tv0, {false, false, true});
-  // Broadcast tv1 to [*, K, N]
-  TensorView* tv3 = broadcast(tv1, {true, false, false});
+  int numel_x = 65000;
+  int numel_y = 1025;
 
-  // Pointwise multiplication resulting in tv3[M, K, N]
-  TensorView* tv4 = mul(tv2, tv3);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
 
-  // Sum the K-dim
-  TensorView* tv5 = sum(tv4, {1});
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  fe.runFusion({input}, {cg_output});
 
-  // Register inputs and outputs
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addOutput(tv5);
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(cg_output));
+}
 
-  // Register runtime tile dims as inputs
-  fusion.addInput(symbolic_m_tile_dim);
-  fusion.addInput(symbolic_split_k_tile_dim);
-  fusion.addInput(symbolic_block_k_tile_dim);
+TEST(NVFuserTest, FusionReduction2_CUDA) {
+  {
+    Fusion fusion;
+    FusionGuard fg(&fusion);
 
-  // Make a 3D tile, mix of symbolic and constant, do in reverse order because
-  // dims are inserted
-  // [M, rK, N]
-  tv5->split(2, n_smem_tile);
-  // [M, rK, No, Ni{32}]
-  tv5->split(1, symbolic_block_k_tile_dim);
-  // [M, rKo, rKi{i2}, No, Ni{32}]
-  tv5->split(1, symbolic_split_k_tile_dim);
-  // [M, rKoo, rKoi{i1}, rKi{i2}, No, Ni{32}]
-  tv5->split(0, symbolic_m_tile_dim);
-  // [Mo, Mi{i0}, rKoo, rKoi{i1}, rKi{i2}, No, Ni{32}]
+    // Set up your input tensor views
+    TensorView* tv0 = makeDummyTensor(2);
+    fusion.addInput(tv0);
 
-  // Reorder so all outer tiles are in the leftmost 3 positions
-  // [Mo, Mi{i0}, rKoo, rKoi{i1}, rKi{i2},     No, Ni{32}]
-  // [Mo,     No, rKoo, rKoi{i1}, rKi{i2}, Mi{i0}, Ni{32}]
-  tv5->reorder({{1, 5}, {5, 1}});
+    // tv1[I0, R1] = tv0[I0, I1]
+    TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
 
-  // Factor out the outer reduction IterDomain, then run the inter-cta
-  // reduction, and intra-cta reduction
-  // [Mo, No, rKoo,  Koi{i1},  Ki{i2}, Mi{i0}, Ni{32}]
-  // [Mo, No,       rKoi{i1}, rKi{i2}, Mi{i0}, Ni{32}]
-  auto tv6 = tv5->rFactor({2});
+    fusion.addOutput(tv1);
 
-  // Scope computations
-  tv6->computeAt(tv5, 2);
+    // switches to try some different scenarios. maybe we should iterate on all
+    // permutations.
+    bool bind_bidx = true;
+    bool bind_tidx = true;
+    bool bind_tidy = true;
+    bool bind_unroll = true;
 
-  // [Mo, No, rKoo, Koi{i1},  Ki{i2}, Mi{i0}, Ni{32}]
-  // [Mo, No, Ki{i2}, Mi{i0}, Ni{32}, rKoo, Koi{i1}]
-  tv6->reorder({
-      {2, -2},
-      {3, -1},
-      {4, 2},
-      {5, 3},
-      {6, 4},
-  });
+    int numel_x = 1025; // Cannot exceed block dim max size / tidy
+    int numel_y = 129;
+    int tidx = 16;
+    int tidy = 8;
+    int unroll_factor = 4;
 
-  // Setup compute at schedule
-  tv0->computeAt(tv6, 3);
-  tv1->computeAt(tv6, 3);
-  tv4->computeAt(tv6, -1);
+    tv1->split(1, tidx);
+    // tv1[I0, R1o, R1i{tidx}] = tv0[I0, I1]
 
-  // Cache smem tiles
-  tv2->setMemoryType(MemoryType::Shared);
-  tv3->setMemoryType(MemoryType::Shared);
-  tv4->setMemoryType(MemoryType::Shared);
-  tv6->setMemoryType(MemoryType::Shared);
+    tv1->split(1, unroll_factor);
+    // tv1[I0, R1oo, R1oi{unroll}, R1i{tidx}] = tv0[I0, I1]
 
-  tv5->axis(0)->parallelize(ParallelType::BIDz);
-  tv5->axis(1)->parallelize(ParallelType::BIDy);
+    tv1->split(0, tidy);
 
-  std::vector<TensorView*> tv_list = {tv2, tv3, tv4, tv5, tv6};
-  for (auto tv : tv_list) {
-    tv->axis(-2)->parallelize(ParallelType::TIDz);
-    tv->axis(-1)->parallelize(ParallelType::TIDy);
-  }
+    TensorView* tv2 = tv1->rFactor({-3});
+    // tv2[I0,             >R1oo<, Ir1oi{unroll}, Ir1i{tidx}]
+    // tv1[I0o, I0i{tidy},          R1oi{unroll},  R1i{tidx}]
 
-  constexpr int M = 31, K = 65, N = 32;
+    TensorView* tv3 = tv1->rFactor({-2});
+    // tv2[I0,             >R1oo<, Ir1oi{unroll}, Ir1i{tidx}]
+    // tv3[I0,                      R1oi{unroll}, Ir1i{tidx}]
+    // tv1[I0o, I0i{tidy},                         R1i{tidx}]
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({M, K}, options);
-  at::Tensor t1 = at::randn({K, N}, options);
+    tv0->computeAt(tv1, -2);
 
-  at::Tensor aten_output =
-      mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1);
+    if (bind_unroll)
+      tv2->axis(-2)->parallelize(ParallelType::Unroll);
+    if (bind_bidx)
+      tv1->axis(0)->parallelize(ParallelType::BIDx);
+    if (bind_tidy)
+      tv1->axis(1)->parallelize(ParallelType::TIDy);
 
-  // A, B, m_tile_dim, split_k, intra_cta_tile
-  std::vector<IValue> aten_inputs = {t0, t1, 3, 4, 5};
+    if (bind_tidx) {
+      tv2->axis(-1)->parallelize(ParallelType::TIDx);
+      tv3->axis(-1)->parallelize(ParallelType::TIDx);
+      tv1->axis(-1)->parallelize(ParallelType::TIDx);
+    }
 
-  torch::jit::fuser::cuda::FusionExecutor fe;
-  fe.compileFusion(&fusion);
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+    at::Tensor input = at::rand({numel_x, numel_y}, options);
 
-  auto cg_outputs = fe.runFusion(aten_inputs);
+    FusionExecutor fe;
+    fe.compileFusion(&fusion);
+    auto outputs = fe.runFusion({input});
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+    auto aten_output = input.sum({1});
+    TORCH_CHECK(aten_output.allclose(outputs[0]));
+  }
 
-// Reproducer of issue 408
-TEST(NVFuserTest, FusionCacheBeforeReduction_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  {
+    // What if Z participates in the reduction with X?
+    Fusion fusion;
+    FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = sum(tv1, {1});
-  fusion.addOutput(tv2);
+    // Set up your input tensor views
+    TensorView* tv0 = makeDummyTensor(2);
+    fusion.addInput(tv0);
 
-  tv2->split(0, 4);
+    // tv1[I0, R1] = tv0[I0, I1]
+    TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
 
-  auto tv3 = tv2->cache_before();
+    fusion.addOutput(tv1);
 
-  tv0->computeAt(tv3, -1);
-  tv3->computeAt(tv2, -1);
+    int numel_x = 1025; // Cannot exceed block dim max size / tidy
+    int numel_y = 129;
+    int tidx = 16;
+    int tidz = 8;
 
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+    tv1->split(1, tidz);
+    // tv1[I0, R1o, R1i{tidz}] = tv0[I0, I1]
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
+    tv1->split(1, tidx);
+    // tv1[I0, R1oo, R1oi{tidx}, R1i{tidz}] = tv0[I0, I1]
 
-  const int numel_x = 100;
-  const int numel_y = 200;
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+    TensorView* tv2 = tv1->rFactor({-3});
+    // tv2[I0,  >R1oo<, Ir1oi{tidx}, Ir1i{tidz}]
+    // tv1[I0o,          R1oi{tidx},  R1i{tidz}]
 
-  at::Tensor aten_input = at::randn({numel_x, numel_y}, options);
-  at::Tensor cg_output = at::empty({numel_x}, options);
+    tv0->computeAt(tv1, -3);
 
-  auto aten_output = (aten_input + 1).to(at::kDouble).sum({1});
+    tv1->axis(0)->parallelize(ParallelType::BIDx);
+    tv1->axis(-2)->parallelize(ParallelType::TIDx);
+    tv1->axis(-1)->parallelize(ParallelType::TIDz);
+
+    tv2->axis(-2)->parallelize(ParallelType::TIDx);
+    tv2->axis(-1)->parallelize(ParallelType::TIDz);
+
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+    at::Tensor input = at::rand({numel_x, numel_y}, options);
+    at::Tensor cg_output = at::empty({numel_x}, options);
 
-  fe.runFusion({aten_input}, {cg_output});
+    FusionExecutor fe;
+    fe.compileFusion(&fusion);
+    fe.runFusion({input}, {cg_output});
 
-  testValidate(
-      &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
+    auto aten_output = input.sum({1});
+    TORCH_CHECK(aten_output.allclose(cg_output));
+  }
 }
 
-TEST(NVFuserTest, FusionCacheBeforeReduction2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+TEST(NVFuserTest, FusionReduction3_CUDA) {
+  {
+    Fusion fusion;
+    FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(3);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = sum(tv1, {1});
-  auto tv3 = add(tv2, new Double(1));
-  fusion.addOutput(tv2);
-  fusion.addOutput(tv3);
+    // Set up your input tensor views
+    TensorView* tv0 = makeDummyTensor(2);
+    TensorView* tv1 = makeDummyTensor(2);
 
-  auto tv4 = tv2->cache_before();
+    TensorView* tv2 = add(tv0, tv1);
+    // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1]
 
-  tv4->computeAt(tv3, 1);
-  tv0->computeAt(tv4, -1);
+    fusion.addInput(tv0);
+    fusion.addInput(tv1);
 
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv4->axis(-1)->parallelize(ParallelType::TIDx);
+    TensorView* tv3 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv2);
+    // tv3[I0, R1] = tv2[I0, I1]
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
+    TensorView* tv4 = makeDummyTensor(1);
+    fusion.addInput(tv4);
 
-  const int numel_x = 10;
-  const int numel_y = 20;
-  const int numel_z = 30;
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+    // tv5[I0] = tv3[I0, R1] * tv4[I0]
+    TensorView* tv5 = mul(tv3, tv4);
+    fusion.addOutput(tv5);
 
-  at::Tensor aten_input = at::randn({numel_x, numel_y, numel_z}, options);
-  auto t2 = (aten_input + 1).to(at::kDouble).sum({1});
-  auto t3 = t2 + 1;
-  std::vector<at::Tensor> aten_outputs = {t2, t3};
+    int tidx = 16;
 
-  auto cg_outputs = fe.runFusion({aten_input});
+    // RFactor the reduction
+    tv3->split(1, tidx);
+    // tv3[I0, R1o, R1i{tidx}] = tv2[I0, I1]
 
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
+    TensorView* tv6 = tv3->rFactor({-2});
+    // tv6[I0, R1o, iR1i{tidx}] = tv2[I0, I1]
+    // tv3[I0,       R1i{tidx}] = tv3[I0, I1]
+    tv2->computeAt(tv6, 2);
 
-TEST(NVFuserTest, FusionIssue367_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+    // Compute at inline with tv5 (only 1D)
+    tv6->computeAt(tv3, 1);
+    tv3->computeAt(tv5, 1);
 
-  // Symbolic integers we will use for runtime tiling
-  Int* symbolic_m_tile_dim = new Int();
-  Int* symbolic_split_k_tile_dim = new Int();
-  Int* symbolic_block_k_tile_dim = new Int();
-  // Compile-time integer for tiling
-  int n_smem_tile = 32;
+    tv5->axis(0)->parallelize(ParallelType::BIDx);
 
-  // Symbolic 2D tensors TV0[M, K], TV1[K, N]
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
+    // Intermediate tensors only need this, but doesn't hurt to do on inputs
+    // tv0, 1, 4
+    tv2->axis(-1)->parallelize(ParallelType::TIDx);
+    tv3->axis(-1)->parallelize(ParallelType::TIDx);
+    tv6->axis(-1)->parallelize(ParallelType::TIDx);
 
-  // Broadcast tv0 to [M, K, *]
-  TensorView* tv2 = broadcast(tv0, {false, false, true});
-  // Broadcast tv1 to [*, K, N]
-  TensorView* tv3 = broadcast(tv1, {true, false, false});
+    int numel_x = 1025;
+    int numel_y = 129;
 
-  // Pointwise multiplication resulting in tv3[M, K, N]
-  TensorView* tv4 = mul(tv2, tv3);
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+    at::Tensor t0 = at::rand({numel_x, numel_y}, options);
+    at::Tensor t1 = at::rand({numel_x, numel_y}, options);
+    auto t2 = t0.add(t1);
+    auto t3 = t2.sum({1});
+    at::Tensor t4 = at::rand({numel_x}, options);
+    auto t5 = t3.mul(t4);
 
-  // Sum the K-dim
-  TensorView* tv5 = sum(tv4, {1});
+    FusionExecutor fe;
+    fe.compileFusion(&fusion);
+    auto outputs = fe.runFusion({t0, t1, t4});
 
-  // Register inputs and outputs
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addOutput(tv5);
+    TORCH_CHECK(
+        t5.allclose(outputs[0]), "Error of: ", t5.sub(outputs[0]).abs().max());
+  }
+}
 
-  // Register runtime tile dims as inputs
-  fusion.addInput(symbolic_m_tile_dim);
-  fusion.addInput(symbolic_split_k_tile_dim);
-  fusion.addInput(symbolic_block_k_tile_dim);
+TEST(NVFuserTest, FusionReduction4_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  // Make a 3D tile, mix of symbolic and constant, do in reverse order because
-  // dims are inserted
-  tv5->split(2, n_smem_tile);
-  tv5->split(1, symbolic_block_k_tile_dim);
-  tv5->split(1, symbolic_split_k_tile_dim);
-  tv5->split(0, symbolic_m_tile_dim);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(3);
 
-  // tv5[M/m_tile, m_tile, r{K/split_k/block_k}, r{split_k}, r{block_k}, N/32,
-  // 32]
-  tv5->reorder({{1, 5}, {5, 1}});
-  // tv5[M/m_tile, N/32, r{K/split_k/block_k}, r{split_k}, r{block_k},  m_tile,
-  // 32]
+  fusion.addInput(tv0);
 
-  auto tv6 = tv5->rFactor({2});
-  auto tv7 = tv5->rFactor({2});
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
 
-  // Scope computations
-  tv6->computeAt(tv5, 2);
+  fusion.addOutput(tv1);
 
-  tv6->reorder({
-      {2, -2},
-      {3, -1},
-      {4, 2},
-      {5, 3},
-      {6, 4},
-  });
+  int bidy = 2;
+  int tidy = 4;
+  int tidx = 5;
 
-  tv7->reorder({
-      {2, -2},
-      {3, -1},
-      {-2, 2},
-      {-1, 3},
-  });
+  int dim1 = 11;
 
-  tv0->computeAt(tv6, 3);
-  tv1->computeAt(tv6, 3);
-  tv4->computeAt(tv6, -1);
+  tv1->split(-2, tidy);
 
-  // Cache smem tiles
-  tv2->setMemoryType(MemoryType::Shared);
-  tv3->setMemoryType(MemoryType::Shared);
-  tv4->setMemoryType(MemoryType::Local);
-  tv6->setMemoryType(MemoryType::Local);
-  tv7->setMemoryType(MemoryType::Local);
+  TensorView* tv2 = tv1->rFactor({-3});
 
-  tv5->axis(0)->parallelize(ParallelType::BIDz);
-  tv5->axis(1)->parallelize(ParallelType::BIDy);
+  tv0->computeAt(tv1, 1);
+  tv1->axis(0)->parallelize(ParallelType::BIDy);
 
-  std::vector<TensorView*> tv_list = {tv2, tv3, tv4, tv5, tv6, tv7};
-  for (auto tv : tv_list) {
-    tv->axis(-2)->parallelize(ParallelType::TIDz);
-    tv->axis(-1)->parallelize(ParallelType::TIDy);
+  for (auto* val : fusion.vals()) {
+    if (!fusion.hasInput(val) &&
+        val->getValType().value() == ValType::TensorView) {
+      val->as<TensorView>()->axis(-1)->parallelize(ParallelType::TIDx);
+    }
   }
-  tv2->axis(3)->parallelize(ParallelType::TIDx);
-  tv3->axis(3)->parallelize(ParallelType::TIDx);
-  tv4->axis(3)->parallelize(ParallelType::TIDx);
-  tv6->axis(3)->parallelize(ParallelType::TIDx);
-  tv7->axis(2)->parallelize(ParallelType::TIDx);
-
-  tv2->axis(4)->parallelize(ParallelType::BIDx);
-  tv3->axis(4)->parallelize(ParallelType::BIDx);
-  tv4->axis(4)->parallelize(ParallelType::BIDx);
-  tv6->axis(4)->parallelize(ParallelType::BIDx);
-  tv7->axis(3)->parallelize(ParallelType::BIDx);
-  tv5->axis(2)->parallelize(ParallelType::BIDx);
 
-  constexpr int M = 3, K = 6, N = 16;
+  tv2->axis(-2)->parallelize(ParallelType::TIDy);
+  tv1->axis(-2)->parallelize(ParallelType::TIDy);
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::randn({bidy, dim1, tidx}, options);
 
-  at::Tensor t0 = at::randn({M, K}, options);
-  at::Tensor t1 = at::randn({K, N}, options);
-
-  // A, B, m, split_k, block_k
-  std::vector<IValue> aten_inputs = {t0, t1, 2, 2, 3};
-  at::Tensor aten_output =
-      mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1);
+  at::Tensor cg_output = at::empty({bidy, tidx}, options);
 
-  torch::jit::fuser::cuda::FusionExecutor fe;
+  FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  fe.runFusion({input}, {cg_output});
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(
+      aten_output.allclose(cg_output, 1e-5, 1e-7),
+      "Error of: ",
+      aten_output.sub(cg_output).abs().max());
 }
 
-TEST(NVFuserTest, FusionIssue468_CUDA) {
+TEST(NVFuserTest, FusionReduction5_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
+  const int bdimx = 64;
+  const int bdimy = 8;
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(3);
   fusion.addInput(tv0);
-  auto tv1 = sum(tv0, {1});
-  auto tv2 = sum(tv1, {0});
-  fusion.addOutput(tv2);
 
-  tv1->axis(0)->parallelize(ParallelType::TIDy);
-  tv1->axis(1)->parallelize(ParallelType::TIDx);
+  // tv1[I0, R1, R2] = tv0[I0, I1, I2]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Float(0), tv0);
+  fusion.addOutput(tv1);
+
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+  tv1->split(2, bdimx);
+  // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
+  tv1->split(1, bdimy);
+  // tv1[I0, R1o, R1i{8}, R2o, R2i{128}] = tv0[I0, I1, I2]
 
-  tv2->axis(0)->parallelize(ParallelType::TIDy);
+  TensorView* tv2 = tv1->rFactor({3});
+  // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
+  // tv1[I0, R1o, R1i{8},      R2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
+
+  TensorView* tv3 = tv1->rFactor({1});
+  // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
+  // tv3[I0, R1o, I1i{8},      I2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
+  // tv1[I0,      R1i{8},      R2i{128}] = tv3[I0, R1o, I1i{8},      I2i{128}]
+
+  tv3->computeAt(tv1, 1);
+  tv2->computeAt(tv3, 2);
+
+  tv1->axis(0)->parallelize(ParallelType::BIDx);
+  tv2->axis(0)->parallelize(ParallelType::BIDx);
+  tv3->axis(0)->parallelize(ParallelType::BIDx);
+
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+
+  tv1->axis(-2)->parallelize(ParallelType::TIDy);
+  tv3->axis(-2)->parallelize(ParallelType::TIDy);
+  tv2->axis(-3)->parallelize(ParallelType::TIDy);
+
+  int numel_x = 650;
+  int numel_y = 1000;
+  int numel_z = 4;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_input = at::randn({10, 100}, options);
-  at::Tensor aten_output = aten_input.to(at::kDouble).sum({1}).sum({0});
+  at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options);
 
-  torch::jit::fuser::cuda::FusionExecutor fe;
+  FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input});
+  auto outputs = fe.runFusion({input});
 
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
+  auto aten_output = input.sum({1, 2});
+  TORCH_CHECK(aten_output.allclose(outputs[0]));
 }
 
-TEST(NVFuserTest, FusionIssue363_CUDA) {
+TEST(NVFuserTest, FusionReductionTFT_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  // Symbolic 2D tensors TV0[M, K], TV1[K, N]
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(2);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
 
-  // Broadcast tv0 to [M, K, *]
-  TensorView* tv2 = broadcast(tv0, {false, false, true});
-  // Broadcast tv1 to [*, K, N]
-  TensorView* tv3 = broadcast(tv1, {true, false, false});
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
 
-  // Pointwise multiplication resulting in tv3[M, K, N]
-  TensorView* tv4 = mul(tv2, tv3);
+  fusion.addOutput(tv1);
 
-  // Sum the K-dim
-  TensorView* tv5 = sum(tv4, {1});
+  int numel_x = 1025;
+  int numel_y = 129;
+  int tidx = 16;
+  int tidy = 8;
+  int tidz = 8;
 
-  // Register inputs and outputs
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-  fusion.addOutput(tv5);
+  tv1->split(1, tidx);
+  // tv1[I0, R1o, R1i{tidx}]
 
-  tv2->setMemoryType(MemoryType::Global);
-  tv3->setMemoryType(MemoryType::Global);
-  tv4->setMemoryType(MemoryType::Global);
+  tv1->split(1, tidz);
+  // tv1[I0, R1oo, R1Oi{tidz}, R1R1i{tidx}]
+
+  tv1->split(0, tidy);
+  // tv1[I0o, I0i, R1oo, R1Oi{tidz}, R1R1i{tidx}]
+
+  TensorView* tv2 = tv1->rFactor({2});
+  // tv2[I0o, I0i, R1oo, I1Oi{tidz}, I11i{tidx}]
+  // tv1[I0o, I0i,       R1Oi{tidz}, R1R1i{tidx}]
 
-  tv0->computeAt(tv5, -1);
-  tv1->computeAt(tv5, -1);
+  tv2->computeAt(tv1, 2);
 
-  tv5->axis(0)->parallelize(ParallelType::BIDz);
-  tv5->axis(1)->parallelize(ParallelType::BIDy);
+  tv1->axis(1)->parallelize(ParallelType::TIDy);
 
-  tv5->axis(2)->parallelize(ParallelType::BIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
 
-  constexpr int M = 3, K = 6, N = 16;
+  tv1->axis(-2)->parallelize(ParallelType::TIDz);
+  tv2->axis(-2)->parallelize(ParallelType::TIDz);
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
 
-  at::Tensor t0 = at::randn({M, K}, options);
-  at::Tensor t1 = at::randn({K, N}, options);
-  at::Tensor aten_output =
-      mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  torch::jit::fuser::cuda::FusionExecutor fe;
+  FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  fe.runFusion({input}, {cg_output});
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(cg_output));
 }
 
-TEST(NVFuserTest, FusionIssue484_CUDA) {
+TEST(NVFuserTest, FusionBranches_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = makeDummyTensor(2);
+  TensorView* tv2 = makeDummyTensor(2);
   fusion.addInput(tv0);
-  auto tv1 = sum(tv0, {1});
-  auto tv2 = add(tv1, new Double(0));
-  fusion.addOutput(tv2);
+  fusion.addInput(tv1);
+  fusion.addInput(tv2);
+
+  auto tv3 = add(tv0, new Float(1.0));
+  auto tv4 = add(tv3, tv1);
+  auto tv5 = add(tv3, tv2);
+  auto tv6 = add(tv4, tv5);
 
-  tv1->setMemoryType(MemoryType::Global);
-  tv1->axis(1)->parallelize(ParallelType::TIDx);
+  fusion.addOutput(tv6);
 
-  constexpr int M = 100;
+  constexpr int x = 63, y = 33;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  at::Tensor aten_input = at::randn({M, M}, options);
-  at::Tensor aten_output = aten_input.to(at::kDouble).sum({1});
+  at::Tensor t0 = at::randn({x, y}, options);
+  at::Tensor t1 = at::randn({x, y}, options);
+  at::Tensor t2 = at::randn({x, y}, options);
+
+  FusionExecutor fe;
+  tv6->merge(0);
+  tv6->split(0, 128);
+  tv6->split(0, 4);
+
+  tv6->axis(0)->parallelize(ParallelType::BIDx);
+
+  tv0->computeAt(tv6, 1);
+  tv1->computeAt(tv6, 1);
+  tv2->computeAt(tv6, 1);
+
+  tv3->axis(-2)->parallelize(ParallelType::Unroll);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  tv4->axis(-2)->parallelize(ParallelType::Unroll);
+  tv4->axis(-1)->parallelize(ParallelType::TIDx);
+  tv5->axis(-2)->parallelize(ParallelType::Unroll);
+  tv5->axis(-1)->parallelize(ParallelType::TIDx);
+  tv6->axis(-1)->parallelize(ParallelType::TIDx);
 
-  torch::jit::fuser::cuda::FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input});
+  auto outputs = fe.runFusion({t0, t1, t2});
+
+  auto t3 = t0.add(1.0);
+  auto t4 = t3.add(t1);
+  auto t5 = t3.add(t2);
+  auto t6 = t4.add(t5);
 
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
+  TORCH_CHECK(t6.allclose(outputs[0]));
 }
 
-TEST(NVFuserTest, Issue329_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+TEST(NVFuserTest, FusionSimpleBCast_CUDA) {
+  {
+    Fusion fusion;
+    FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = sum(tv1, {1});
-  fusion.addOutput(tv2);
-  auto tv3 = sum(tv1, {1});
-  fusion.addOutput(tv3);
+    // Set up your input tensor views
+    TensorView* tv0 = makeDummyTensor(2);
+    fusion.addInput(tv0);
+    TensorView* tv1 = add(tv0, new Float(1.5));
 
-  tv1->computeAt(tv2, -1);
+    TensorView* tv2 = makeDummyTensor(2);
+    fusion.addInput(tv2);
+    TensorView* tv3 = makeDummyTensor(2);
+    fusion.addInput(tv3);
+    TensorView* tv4 = sub(tv2, tv3);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+    TensorView* tv5 = broadcast(tv1, {false, false, true});
+    TensorView* tv6 = broadcast(tv4, {true, false, false});
 
-  std::vector<int64_t> t0_shape{17, 19};
-  auto aten_input = at::randn(t0_shape, options);
-  auto t2 = (aten_input + 1).to(at::kDouble).sum({1});
-  auto t3 = (aten_input + 1).to(at::kDouble).sum({1});
-  std::vector<at::Tensor> aten_outputs = {t2, t3};
+    TensorView* tv7 = add(tv5, tv6);
+    fusion.addOutput(tv7);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
+    tv7->split(-1, 4);
+    tv7->split(0, 8);
 
-  auto cg_outputs = fe.runFusion({aten_input});
+    tv0->computeAt(tv7, -1);
+    tv2->computeAt(tv7, -1);
 
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
+    tv7->axis(0)->parallelize(ParallelType::BIDx);
+    tv7->axis(-1)->parallelize(ParallelType::TIDx);
 
-TEST(NVFuserTest, FusionIssue382_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+    constexpr int x = 63, y = 33, z = 15;
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = broadcast(tv1, {false, false, true});
-  auto tv3 = makeSymbolicTensor(3);
-  fusion.addInput(tv3);
-  auto tv4 = add(tv2, tv3);
-  fusion.addOutput(tv4);
+    at::Tensor t0 = at::randn({x, y}, options);
+    at::Tensor t1 = t0.add(1.5);
 
-  tv2->merge(1);
-  tv4->merge(1);
+    at::Tensor t2 = at::randn({y, z}, options);
+    at::Tensor t3 = at::randn({y, z}, options);
 
-  tv1->computeAt(tv4, 1);
+    at::Tensor t4 = t2.sub(t3);
+    at::Tensor t5 = t1.unsqueeze(-1).expand({x, y, z});
 
-  tv4->axis(0)->parallelize(ParallelType::BIDx);
+    at::Tensor t6 = t4.expand({x, y, z});
+    at::Tensor t7 = t5.add(t6);
 
-  tv1->setMemoryType(MemoryType::Global);
-  tv2->setMemoryType(MemoryType::Global);
+    FusionExecutor fe;
+    fe.compileFusion(&fusion);
+    auto outputs = fe.runFusion({t0, t2, t3});
 
-  torch::jit::fuser::cuda::FusionExecutor fe;
-  fe.compileFusion(&fusion);
+    TORCH_CHECK(t7.allclose(outputs[0]));
+  }
 
-  const int numel_x = 12;
-  const int numel_y = 34;
-  const int numel_z = 56;
+  {
+    Fusion fusion;
+    FusionGuard fg(&fusion);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  auto t0 = at::randn({numel_x, numel_y}, options);
-  auto t3 = at::randn({numel_x, numel_y, numel_z}, options);
+    // Set up your input tensor views
+    TensorView* tv0 = makeDummyTensor(2);
+    fusion.addInput(tv0);
+    TensorView* tv1 = makeDummyTensor(2);
+    fusion.addInput(tv1);
 
-  std::vector<IValue> aten_inputs = {t0, t3};
-  auto aten_output = (t0 + 1).unsqueeze(-1) + t3;
+    TensorView* tv2 = add(tv0, tv1);
 
-  auto cg_outputs = fe.runFusion(aten_inputs);
+    TensorView* tv3 = broadcast(tv2, {false, false, true});
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+    TensorView* tv4 = makeDummyTensor(2);
+    fusion.addInput(tv4);
 
-TEST(NVFuserTest, Issue507_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+    TensorView* tv5 = sub(tv4, new Float(0.1));
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(1));
-  fusion.addOutput(tv2);
+    TensorView* tv6 = broadcast(tv5, {true, false, false});
 
-  tv1->setMemoryType(MemoryType::Shared);
+    TensorView* tv7 = add(tv3, tv6);
 
-  tv1->axis(1)->parallelize(ParallelType::TIDx);
-  tv2->axis(1)->parallelize(ParallelType::TIDx);
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
+    fusion.addOutput(tv7);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+    tv7->merge(0, 1);
 
-  std::vector<int64_t> t0_shape{17, 19};
-  auto aten_input = at::randn(t0_shape, options);
-  auto t1 = (aten_input + 1);
-  auto aten_output = (t1 + 1);
+    tv0->computeAt(tv7, -1);
+    tv4->computeAt(tv7, -1);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
+    tv7->axis(0)->parallelize(ParallelType::BIDx);
+    tv7->axis(-1)->parallelize(ParallelType::TIDx);
 
-  auto cg_outputs = fe.runFusion({aten_input});
+    constexpr int x = 63, y = 33, z = 15;
 
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-TEST(NVFuserTest, FusionIssue532_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+    at::Tensor t0 = at::randn({x, y}, options);
+    at::Tensor t1 = at::randn({x, y}, options);
+    at::Tensor t2 = t0.add(t1);
+    at::Tensor t3 = t2.unsqueeze(-1).expand({x, y, z});
 
-  // Algorithm
-  TensorView* tv0 = makeSymbolicTensor(1);
-  TensorView* tv1 = add(tv0, new Double(1));
-  TensorView* tv2 = add(tv1, new Double(1));
-  fusion.addInput(tv0);
-  fusion.addOutput(tv2);
+    at::Tensor t4 = at::randn({y, z}, options);
+    at::Tensor t5 = t4.sub(0.1);
+    at::Tensor t6 = t5.expand({x, y, z});
+    at::Tensor t7 = t3.add(t6);
 
-  const int M_BLOCK = 64;
-  const int M_THREAD = 4;
+    at::Tensor cg_output = at::empty({x, y, z}, options);
 
-  tv2->split(0, M_BLOCK);
-  // tv2: [M/M_BLOCK, M_BLOCK]
-  tv1->computeAt(tv2, 1);
-  // tv1: [M/M_BLOCK, M_BLOCK]
+    FusionExecutor fe;
+    fe.compileFusion(&fusion);
+    fe.runFusion({t0, t1, t4}, {cg_output});
 
-  tv1->split(-1, M_BLOCK / M_THREAD);
-  // tv1: [M/M_BLOCK, M_THREAD, M_BLOCK / M_THREAD]
+    TORCH_CHECK(t7.allclose(cg_output));
+  }
 
-  tv2->split(-1, M_THREAD);
-  // tv2: [M/M_BLOCK, M_BLOCK / M_THREAD, M_THREAD]
+  {
+    Fusion fusion;
+    FusionGuard fg(&fusion);
 
-  constexpr int M = 1000;
+    // Set up your input tensor views
+    std::vector<IterDomain*> dom;
+    dom.push_back(new IterDomain(new Int(0), new Int()));
+    dom.push_back(new IterDomain(
+        new Int(0),
+        new Int(1),
+        ParallelType::Serial,
+        IterType::BroadcastWithStride));
+
+    // tv0[I1, B{1}]
+    TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float);
+    fusion.addInput(tv0);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor t0 = at::randn({M}, options);
-  std::vector<IValue> aten_inputs = {t0};
+    // tv1[I0, I1, I2]
+    TensorView* tv2 = makeDummyTensor(3);
+    fusion.addInput(tv2);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(aten_inputs);
+    TensorView* tv3 = add(tv0, tv2);
 
-  at::Tensor aten_output = t0 + 1 + 1;
+    fusion.addOutput(tv3);
 
-  testValidate(
-      &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+    tv3->merge(0);
+    tv3->merge(0);
 
-TEST(NVFuserTest, FusionLoopUnswitch_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+    tv0->computeAt(tv3, -1);
+    tv2->computeAt(tv3, -1);
 
-  // Algorithm
-  TensorView* tv0 = makeSymbolicTensor(1);
-  TensorView* tv1 = add(tv0, new Double(1));
-  TensorView* tv2 = add(tv1, new Double(1));
-  fusion.addInput(tv0);
-  fusion.addOutput(tv2);
+    tv3->axis(0)->parallelize(ParallelType::BIDx);
 
-  tv2->split(0, 32);
-  tv1->computeAt(tv2, -1);
+    constexpr int x = 2, y = 3, z = 4;
 
-  tv2->axis(1)->parallelize(ParallelType::Unswitch);
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  constexpr int M = 1000;
+    at::Tensor t0 = at::randn({y, 1}, options);
+    at::Tensor t2 = at::randn({x, y, z}, options);
+    auto t3 = t0.add(t2);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor t0 = at::randn({M}, options);
-  std::vector<IValue> aten_inputs = {t0};
+    at::Tensor cg_output = at::empty({x, y, z}, options);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(aten_inputs);
+    FusionExecutor fe;
+    fe.compileFusion(&fusion);
+    fe.runFusion({t0, t2}, {cg_output});
 
-  at::Tensor aten_output = t0 + 1 + 1;
+    TORCH_CHECK(t3.allclose(cg_output));
+  }
 
-  testValidate(
-      &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+  {
+    Fusion fusion;
+    FusionGuard fg(&fusion);
 
-TEST(NVFuserTest, FusionIssue549_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+    // Set up your input tensor views
+    std::vector<IterDomain*> dom;
+    dom.push_back(new IterDomain(
+        new Int(0),
+        new Int(1),
+        ParallelType::Serial,
+        IterType::BroadcastWithStride));
+    dom.push_back(new IterDomain(new Int(0), new Int()));
+    TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float);
+
+    TensorView* tv1 = makeDummyTensor(3);
+    fusion.addInput(tv0);
+    fusion.addInput(tv1);
 
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2); // M, K
-  TensorView* tv1 = makeSymbolicTensor(2); // K, N
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
+    TensorView* tv3 = add(tv0, tv1);
 
-  auto tv2 = add(tv0, new Double(1));
+    tv3->merge(0);
+    tv3->merge(0);
+    tv3->split(0, 128);
+    tv3->split(0, 4);
 
-  TensorView* tv3 = broadcast(tv2, {false, false, true});
-  // tv3[I0, I1, B] = tv0[I0, I1]
+    fusion.addOutput(tv3);
 
-  TensorView* tv4 = broadcast(tv1, {true, false, false});
-  // tv4[B, I1, I2] = tv1[I1, I2]
+    tv0->computeAt(tv3, -1);
+    tv1->computeAt(tv3, -1);
 
-  // tv5[I0, I1, I2] = tv3[I0, I1, B] * tv4[B, I1, I2]
-  TensorView* tv5 = mul(tv3, tv4);
-  // tv6[I0, R1, I2] = tv5[I0, I1, I2]
-  TensorView* tv6 = sum(tv5, {1});
-  fusion.addOutput(tv6);
+    tv3->axis(0)->parallelize(ParallelType::BIDx);
+    tv3->axis(-1)->parallelize(ParallelType::TIDx);
+    tv3->axis(-2)->parallelize(ParallelType::Unroll);
 
-  tv6->split(1, 32);
-  // tv6[I0, R1o, R1i{32}, I2]
+    constexpr int x = 63, y = 33, z = 15;
 
-  auto tv7 = tv6->rFactor({1});
-  // tv7[I0, R1o, I1i{32}, I2] = tv5[I0, I1, I2]
-  // tv6[I0,    , R1i{32}, I2] = tv7[I0, R1o, I1i{32}, I2]
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  tv6->split(0, 4);
-  tv6->split(-1, 4);
-  // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
-  // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
+    at::Tensor t0 = at::randn({1, z}, options);
+    at::Tensor t1 = at::randn({x, y, z}, options);
 
-  tv0->computeAt(tv6, -1);
-  tv1->computeAt(tv6, -1);
+    at::Tensor cg_output = at::empty({x, y, z}, options);
 
-  // tv7[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
-  // tv6[I0o, I0i{4},    , R1i{32}, I2o, I2i{4}]
-  //--> (line symbolizes compute at location)
-  // tv5[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o]
-  // tv7[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
-  // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
+    FusionExecutor fe;
+    fe.compileFusion(&fusion);
+    fe.runFusion({t0, t1}, {cg_output});
 
-  tv0->computeAt(tv7, -1);
-  tv1->computeAt(tv7, -1);
-  // tv5[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
-  // tv7[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
-  // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
+    auto t3 = t0.add(t1);
 
-  tv6->axis(0)->parallelize(ParallelType::BIDz);
-  tv6->axis(1)->parallelize(ParallelType::TIDz);
+    TORCH_CHECK(t3.allclose(cg_output));
+  }
 
-  tv6->axis(-2)->parallelize(ParallelType::BIDy);
-  tv6->axis(-1)->parallelize(ParallelType::TIDy);
+  {
+    Fusion fusion;
+    FusionGuard fg(&fusion);
 
-  tv6->axis(2)->parallelize(ParallelType::TIDx);
-  tv7->axis(2)->parallelize(ParallelType::TIDx);
+    constexpr int m = 2, k = 3, n = 4;
 
-  constexpr int M = 65, K = 33, N = 17;
+    auto zero = new Int(0);
+    auto M = new IterDomain(zero, new Int(m));
+    auto K = new IterDomain(zero, new Int(k));
+    auto N = new IterDomain(zero, new Int(n));
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+    // Set up your input tensor views
+    TensorView* tv0 =
+        new TensorView(new TensorDomain({M, K}, {true, true}), DataType::Float);
+    TensorView* tv1 =
+        new TensorView(new TensorDomain({K, N}, {true, true}), DataType::Float);
 
-  at::Tensor t0 = at::randn({M, K}, options);
-  at::Tensor t1 = at::randn({K, N}, options);
+    fusion.addInput(tv0);
+    fusion.addInput(tv1);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  // Lets specify a few bounds in launch params to make sure it works
-  fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4));
+    TensorView* tv2 = broadcast(tv0, {false, false, true});
+    TensorView* tv3 = broadcast(tv1, {true, false, false});
 
-  // Make sure bad launch params throws
-  // TODO: Re-enable once we have parallelization validation in.
-  // ASSERT_ANY_THROW(fe.runFusion({t0, t1}, LaunchParams(1, 2, 3, 4, 5, 6)));
+    TensorView* tv4 = add(tv2, tv3);
 
-  // Don't specify any launch params
-  auto cg_outputs = fe.runFusion({t0, t1});
+    fusion.addOutput(tv4);
 
-  auto aten_output = (t0 + 1).to(at::kDouble).matmul(t1.to(at::kDouble));
+    tv4->merge(0);
+    tv4->merge(0);
 
-  testValidate(
-      &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__);
-}
+    tv0->computeAt(tv4, -1);
+    tv1->computeAt(tv4, -1);
 
-TEST(NVFuserTest, simplecompileRtc_CUDA) {
-  FusionExecutor fe;
-  std::string kernel = R"(
-__global__ void kernel1(Tensor<float, 1> T0, Tensor<float, 1> T1) {
-  if(threadIdx.x==0){
-    for(size_t ki28 = 0; ki28 < T0.size[0]; ++ki28) {
-      T1[ki28*T1.stride[0]] = T0[ki28*T0.stride[0]]*2;
-    }
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+
+    at::Tensor t0 = at::randn({m, k}, options);
+    at::Tensor t1 = at::randn({k, n}, options);
+
+    at::Tensor cg_output = at::empty({m, k, n}, options);
+
+    FusionExecutor fe;
+    fe.compileFusion(&fusion);
+    fe.runFusion({t0, t1}, {cg_output});
+
+    auto t2 = t0.unsqueeze(-1).expand({m, k, n});
+    auto t3 = t1.expand({m, k, n});
+    auto t4 = t2.add(t3);
+
+    TORCH_CHECK(t4.allclose(cg_output));
   }
 }
-    )";
-  fe.compileRtc(kernel, "CudaCodeGen::kernel1");
-  LaunchParams lp(
-      256, // gdimx
-      1, // gdimy
-      1, // gdimz
-      1, // bdimx
-      1, // bdimy
-      1 // bdimz
-  );
-  lp.setSmem(0);
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const std::vector<int64_t> tensor_dims = {8};
-  auto in0 = at::randn(tensor_dims, options);
-  auto out0 = at::empty_like(in0);
-  fe.runRtc(lp, {in0, out0});
 
-  auto out_ref = in0 * 2;
-  TORCH_CHECK(out_ref.allclose(out0));
-}
+TEST(NVFuserTest, FusionComplexBCast_CUDA) {
+  {
+    Fusion fusion;
+    FusionGuard fg(&fusion);
 
-TEST(NVFuserTest, serialWelford_CUDA) {
-  FusionExecutor fe;
-  int x = 128, y = 64, z = 64;
-
-  std::string kernel = R"(
-__global__ void kernel1(
-    Tensor<float,3> inp,
-    Tensor<float,1> out_var,
-    Tensor<float,1> out_avg
-){
-    for(int i0=0;i0<inp.size[0];i0++){
-        float tmp_M2=0;
-        float tmp_avg=0;
-        long tmp_N=0;
-        for(int i1=0;i1<inp.size[1];i1++){
-            for(int i2=0;i2<inp.size[2];i2++){
-                welfordCombine(
-                    tmp_avg,
-                    tmp_M2,
-                    tmp_N,
-                    inp[i0*inp.stride[0]+
-                        i1*inp.stride[1]+
-                        i2*inp.stride[2]],
-                    0.f,
-                    (long)1
-                );
-            }
-        }
-        out_var[i0*out_var.stride[0]]=
-            tmp_M2/(tmp_N);
-        out_avg[i0*out_avg.stride[0]]=
-            tmp_avg;
-    }
-}
-    )";
-  fe.compileRtc(kernel, "CudaCodeGen::kernel1");
-  LaunchParams lp(
-      1, // gdimx
-      1, // gdimy
-      1, // gdimz
-      1, // bdimx
-      1, // bdimy
-      1 // bdimz
-  );
-  lp.setSmem(0);
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const std::vector<int64_t> tensor_dims = {x, y, z};
-  auto in0 = at::randn(tensor_dims, options);
-  auto out_var = at::empty({x}, options);
-  auto out_avg = at::empty({x}, options);
-  fe.runRtc(lp, {in0, out_var, out_avg});
+    int x = 2, y = 3, z = 4;
+
+    auto tv0 = makeConcreteTensor({y});
+    auto tv1 = div(tv0, new Float(2.0));
+    auto tv2 = broadcast(tv1, {false, true});
+    auto tv3 = makeConcreteTensor({y, z});
+    auto tv4 = mul(tv2, tv3);
+    auto tv5 = broadcast(tv4, {true, false, false});
+    auto tv6 = makeConcreteTensor({x, y, z});
+    auto tv7 = add(tv5, tv6);
+
+    // tv0[    i1    ] = input
+    // tv1[    i1    ] = tv0/2.0
+    // tv2[    i1, b2] = bcast(tv1)
+    // tv3[    i1, i2] = input
+    // tv4[    i1, i2] = tv2 * tv3
+    // tv5[b0, i1, i2] = bcast(tv4)
+    // tv6[i0, i1, i2] = input
+    // tv7[i0, i1, i2] = tv5 + tv6
+
+    // tv4 = bcast(tv1) * tv3
+    // tv7 = bcast(tv4) + tv6
 
-  TORCH_CHECK(in0.var({1, 2}, false).allclose(out_var));
-  TORCH_CHECK(in0.mean({1, 2}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6));
-}
+    fusion.addInput(tv0);
+    fusion.addInput(tv3);
+    fusion.addInput(tv6);
 
-TEST(NVFuserTest, blockWelford_CUDA) {
-  FusionExecutor fe;
-  int x = 7, y = 8, z = 9;
-
-  std::string kernel = R"(
-__global__ void kernel1(
-    Tensor<float,2> inp,
-    Tensor<float,1> out_avg,
-    Tensor<float,1> out_var,
-    Tensor<float,1> init_avg,
-    Tensor<float,1> init_var,
-    Tensor<long,0> init_N
-){
-    //actual generated kernel will use dynamic shared mem,
-    // here is just for prototype
-    __shared__ float mem_avg[512];
-    __shared__ float mem_M2[512];
-    __shared__ long mem_N[512];
-    float in=inp[threadIdx.x*inp.stride[0]+
-                        threadIdx.y*inp.stride[1]];
-    float tmp_avg=0;
-    float tmp_M2=0;
-    long tmp_N=0;
-    blockWelford<false,true,false>(
-        tmp_avg,
-        tmp_M2,
-        tmp_N,
-        in,
-        0.f,
-        (long)1,
-        threadIdx,
-        blockDim,
-        (float*)mem_avg,
-        (float*)mem_M2,
-        (long*)mem_N,
-        (bool)(threadIdx.x<inp.size[0]),
-        0.f);
-    __syncthreads();
-    if(threadIdx.x<out_var.size[0] && threadIdx.y==0){
-        welfordCombine(
-                    tmp_avg,
-                    tmp_M2,
-                    tmp_N,
-                    init_avg[threadIdx.x*init_avg.stride[0]],
-                    init_var[threadIdx.x*init_var.stride[0]]*init_N[0],
-                    init_N[0]
-                );
-        out_avg[threadIdx.x*out_avg.stride[0]]=tmp_avg;
-        out_var[threadIdx.x*out_var.stride[0]]=tmp_M2/(tmp_N);
-    }
-}
-    )";
-  fe.compileRtc(kernel, "CudaCodeGen::kernel1");
-  LaunchParams lp(
-      1, // gdimx
-      1, // gdimy
-      1, // gdimz
-      x, // bdimx
-      y, // bdimy
-      1 // bdimz
-  );
-  lp.setSmem(0);
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const std::vector<int64_t> tensor_dims = {x, y};
-  const std::vector<int64_t> init_dims = {x, z};
-
-  // generate initial values
-  auto init_in = at::randn(init_dims, options);
-  auto init_var = init_in.var({1}, false);
-  auto init_avg = init_in.mean({1});
-  auto init_N =
-      at::tensor(z, at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0));
-
-  auto in0 = at::randn(tensor_dims, options);
-
-  // run kernel
-  auto out_var = at::zeros({x}, options);
-  auto out_avg = at::zeros({x}, options);
-  fe.runRtc(lp, {in0, out_avg, out_var, init_avg, init_var, init_N});
-
-  // compare with reference output
-  auto cat_tensor = at::cat({init_in, in0}, 1);
-  TORCH_CHECK(cat_tensor.var({1}, false).allclose(out_var));
-  TORCH_CHECK(
-      cat_tensor.mean({1}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6));
-}
+    fusion.addOutput(tv7);
 
-TEST(NVFuserTest, blockWelfordNoInit_CUDA) {
-  FusionExecutor fe;
-  int x = 7, y = 8, z = 9;
-
-  // need support IValue for integer input as initial count
-  std::string kernel = R"(
-__global__ void kernel1(
-    Tensor<float,3> inp,
-    Tensor<float,1> out_avg,
-    Tensor<float,1> out_var
-){
-    //actual generated kernel will use dynamic shared mem,
-    // here is just for prototype
-    __shared__ float mem_avg[512];
-    __shared__ float mem_M2[512];
-    __shared__ long mem_N[512];
-    float in=inp[threadIdx.x*inp.stride[0]+
-                        threadIdx.y*inp.stride[1]+
-                        threadIdx.z*inp.stride[2]];
-    float tmp_avg=0;
-    float tmp_M2=0;
-    long tmp_N=0;
-    block_sync::init();
-    blockWelford<false,true,true>(
-        tmp_avg,
-        tmp_M2,
-        tmp_N,
-        in,
-        0.f,
-        (long) 1,
-        threadIdx,
-        blockDim,
-        (float*)mem_avg,
-        (float*)mem_M2,
-        (long*)mem_N,
-        (bool)(threadIdx.x<inp.size[0]),
-        0.f);
-    __syncthreads();
-    if(threadIdx.x<out_var.size[0] && threadIdx.y==0 && threadIdx.z==0){
-        out_avg[threadIdx.x*out_var.stride[0]]=tmp_avg;
-        out_var[threadIdx.x*out_var.stride[0]]=tmp_M2/(tmp_N);
-    }
-}
-    )";
-  fe.compileRtc(kernel, "CudaCodeGen::kernel1");
-  LaunchParams lp(
-      1, // gdimx
-      1, // gdimy
-      1, // gdimz
-      x, // bdimx
-      y, // bdimy
-      z // bdimz
-  );
-  lp.setSmem(0);
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const std::vector<int64_t> tensor_dims = {x, y, z};
-  auto in0 = at::randn(tensor_dims, options);
-  auto out_var = at::empty({x}, options);
-  auto out_avg = at::empty({x}, options);
-  fe.runRtc(lp, {in0, out_avg, out_var});
+    tv7->merge(0);
+    tv7->merge(0);
+    tv0->computeAt(tv7, -1);
 
-  TORCH_CHECK(in0.var({1, 2}, false).allclose(out_var));
-  TORCH_CHECK(in0.mean({1, 2}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6));
-}
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-TEST(NVFuserTest, gridWelfordNoInit_CUDA) {
-  FusionExecutor fe;
-  int x = 128, y = 64, z = 128;
-
-  std::string kernel = R"(
-__global__ void kernel1(
-    Tensor<float,3> inp,
-    Tensor<float,1> out_avg,
-    Tensor<float,1> out_var,
-    Tensor<float,1> work_buf_avg,
-    Tensor<float,1> work_buf_M2,
-    Tensor<long,1> work_buf_N,
-    Tensor<int64_t,1> sync_flag
-){
-    __shared__ float shared_buf_avg[512];
-    __shared__ float shared_buf_M2[512];
-    __shared__ long shared_buf_N[512];
-    float tmp_avg=0;
-    float tmp_M2=0;
-    long tmp_N=0;
-    float in = inp[ blockIdx.x  * inp.stride[0]+
-                    blockIdx.y  * inp.stride[1]+
-                    threadIdx.x * inp.stride[2]];
-    bool T_pred;
-    block_sync::init();
-    T_pred=welford::gridWelford<
-        true,true,false,
-        true,false,false
-    >(
-        tmp_avg,
-        tmp_M2,
-        tmp_N,
-        in,
-        0.f,
-        (long) 1,
-        &work_buf_avg[0],
-        &work_buf_M2[0],
-        &work_buf_N[0],
-        sync_flag,
-        (float*)shared_buf_avg,
-        (float*)shared_buf_M2,
-        (long*)shared_buf_N,
-        threadIdx.x<out_var.size[0],
-        threadIdx.x<out_var.size[0],
-        0.f);
-    if(T_pred){
-        out_avg[threadIdx.x*out_avg.stride[0]]=tmp_avg;
-        out_var[threadIdx.x*out_var.stride[0]]=tmp_M2/tmp_N;
-    }
-}
-    )";
-  fe.compileRtc(kernel, "CudaCodeGen::kernel1");
-  LaunchParams lp(
-      x, // gdimx
-      y, // gdimy
-      1, // gdimz
-      z, // bdimx
-      1, // bdimy
-      1 // bdimz
-  );
-  lp.setSmem(0);
-  const auto options =
-      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const auto options_int =
-      at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
-
-  const std::vector<int64_t> tensor_dims = {x, y, z};
-  auto in0 = at::randn(tensor_dims, options);
-
-  auto out_avg = at::empty({z}, options);
-  auto out_var = at::empty({z}, options);
-  auto work_buf_avg = at::empty({x * y * z}, options);
-  auto work_buf_var = at::empty({x * y * z}, options);
-  auto work_buf_N = at::empty({x * y * z}, options_int);
-  auto sync_flag = at::zeros({1}, options_int);
-  fe.runRtc(
-      lp,
-      {in0,
-       out_avg,
-       out_var,
-       work_buf_avg,
-       work_buf_var,
-       work_buf_N,
-       sync_flag});
-  std::vector<int64_t> dims{0, 1};
-
-  TORCH_CHECK(in0.mean(dims).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6));
-  TORCH_CHECK(in0.var(dims, false).allclose(out_var));
-}
-
-TEST(NVFuserTest, FusionWelfordOp_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+    at::Tensor t0 = at::randn({y}, options);
+    at::Tensor t3 = at::randn({y, z}, options);
+    at::Tensor t6 = at::randn({x, y, z}, options);
 
-  int M = 64, N = 128;
+    auto t4 = t0.div(2.0).unsqueeze(-1).expand({y, z}) * t3;
+    auto t7 = t4.unsqueeze(0).expand({x, y, z}) + t6;
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = mul(tv0, new Double(1));
-  auto tvs = Welford(tv1, {1});
-  auto tv_avg = tvs.avg;
-  auto tv_M2 = tvs.var_sum;
-  auto tv_N = tvs.n;
-  fusion.addOutput(tv_avg);
-  fusion.addOutput(tv_M2);
-  fusion.addOutput(tv_N);
-
-  tv_avg->split(1, 32);
-  tv_avg->split(0, 32);
-  tv_avg->split(0, 4);
-  tv_avg->reorder({{-1, -3}, {-3, -1}});
-  tv1->computeAt(tv_avg, -1);
+    FusionExecutor fe;
+    fe.compileFusion(&fusion);
+    auto outputs = fe.runFusion({t0, t3, t6});
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor t0 = at::randn({M, N}, options);
+    TORCH_CHECK(t7.allclose(outputs[0]));
+  }
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion({t0});
+  {
+    Fusion fusion;
+    FusionGuard fg(&fusion);
 
-  // by default Welford outputs sum of square diff so need to divide to get var
-  outputs[1] /= N;
+    int x = 2, y = 3, z = 4;
 
-  testValidate(
-      &fusion,
-      outputs,
-      {t0},
-      {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N},
-      __LINE__,
-      __FILE__);
-}
+    auto tv0 = makeConcreteTensor({y, z});
+    auto tv1 = div(tv0, new Float(2.0));
+    auto tv2 = sum(tv1, {1});
+    auto tv3 = broadcast(tv2, {true, false});
+    auto tv4 = makeConcreteTensor({x, y});
+    auto tv5 = add(tv3, tv4);
 
-TEST(NVFuserTest, FusionBlockWelfordOp_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+    // tv0[    i1, i2] = input
+    // tv1[    i1, i2] = tv0/2.0
+    // tv2[    i1    ] = sum(tv1, 1)
+    // tv3[b0, i1    ] = bcast(tv2)
+    // tv4[i0, i1    ] = input
+    // tv5[i0, i1    ] = tv3 + tv4
 
-  int M = 64, N = 128;
+    // tv2 = sum(tv0/2.0, 1)
+    // tv5 = bcast(tv2) + tv4
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = mul(tv0, new Double(1));
-  auto tvs = Welford(tv1, {1});
-  auto tv_avg = tvs.avg;
-  auto tv_M2 = tvs.var_sum;
-  auto tv_N = tvs.n;
-  fusion.addOutput(tv_avg);
-  fusion.addOutput(tv_M2);
-  fusion.addOutput(tv_N);
+    fusion.addInput(tv0);
+    fusion.addInput(tv4);
 
-  tv_avg->axis(-1)->parallelize(ParallelType::TIDx);
+    fusion.addOutput(tv5);
 
-  tv1->computeAt(tv_avg, -1);
+    tv5->merge(0);
+    tv0->computeAt(tv5, -1);
+    tv1->computeAt(tv2, -1);
 
-  //
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor t0 = at::randn({M, N}, options);
-  at::Tensor t_var = at::empty({M}, options);
-  at::Tensor t_avg = at::empty({M}, options);
-  at::Tensor t_N = at::empty({M}, options_int);
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion({t0});
+    at::Tensor t0 = at::randn({y, z}, options);
+    auto t1 = t0.div(2.0);
+    auto t2 = t1.sum(1);
+    auto t3 = t2.unsqueeze(0).expand({x, y});
+    at::Tensor t4 = at::randn({x, y}, options);
+    auto t5 = t3.add(t4);
 
-  // by default Welford outputs sum of square diff so need to divide to get var
-  outputs[1] /= N;
+    FusionExecutor fe;
+    fe.compileFusion(&fusion);
+    auto outputs = fe.runFusion({t0, t4});
 
-  testValidate(
-      &fusion,
-      outputs,
-      {t0},
-      {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N},
-      __LINE__,
-      __FILE__);
+    TORCH_CHECK(t5.allclose(outputs[0]));
+  }
 }
 
-TEST(NVFuserTest, FusionGridWelfordOp_CUDA) {
+TEST(NVFuserTest, FusionAdvancedIndexing1_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  int M = 64, N = 128;
+  int w = 3, x = 4, y = 7, z = 8;
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  auto tv0 = makeSymbolicTensor(2);
+  auto tv0 = makeDummyTensor(3);
+  auto tv1 = makeDummyTensor(4);
   fusion.addInput(tv0);
-  auto tv1 = mul(tv0, new Double(1));
-  auto tvs = Welford(tv1, {1});
-  auto tv_avg = tvs.avg;
-  auto tv_M2 = tvs.var_sum;
-  auto tv_N = tvs.n;
-  fusion.addOutput(tv_avg);
-  fusion.addOutput(tv_M2);
-  fusion.addOutput(tv_N);
+  fusion.addInput(tv1);
 
-  tv_avg->axis(0)->parallelize(ParallelType::TIDx);
-  tv_avg->axis(-1)->parallelize(ParallelType::BIDx);
+  auto tv2 = add(tv0, new Float(1.0));
+  auto tv3 = broadcast(tv2, {true, false, false, false});
+  auto tv4 = add(tv3, tv1);
 
-  tv1->computeAt(tv_avg, -1);
+  fusion.addOutput(tv4);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor t0 = at::randn({M, N}, options);
-  at::Tensor t_avg = at::empty({M}, options);
-  at::Tensor t_var = at::empty({M}, options);
-  at::Tensor t_N = at::empty({M}, options_int);
+  tv4->merge(0);
+  tv4->merge(0);
+  tv4->merge(0);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion({t0});
+  tv4->split(0, 128);
+  tv4->split(0, 4);
 
-  // by default Welford outputs sum of square diff so need to divide to get var
-  outputs[1] /= N;
+  tv2->computeAt(tv4, 1);
 
-  testValidate(
-      &fusion,
-      outputs,
-      {t0},
-      {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N},
-      __LINE__,
-      __FILE__);
-}
+  tv4->axis(0)->parallelize(ParallelType::BIDx);
+  tv4->axis(1)->parallelize(ParallelType::Unroll);
+  tv4->axis(2)->parallelize(ParallelType::TIDx);
 
-TEST(NVFuserTest, FusionRfactorWelfordOp_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  tv3->axis(1)->parallelize(ParallelType::Unroll);
+  tv3->axis(2)->parallelize(ParallelType::TIDx);
 
-  int M = 64, N = 128;
+  tv2->axis(1)->parallelize(ParallelType::Unroll);
+  tv2->axis(2)->parallelize(ParallelType::TIDx);
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = mul(tv0, new Double(1));
-  auto tvs = Welford(tv1, {1});
-  auto tv_avg = tvs.avg;
-  auto tv_M2 = tvs.var_sum;
-  auto tv_N = tvs.n;
-  fusion.addOutput(tv_avg);
-  fusion.addOutput(tv_M2);
-  fusion.addOutput(tv_N);
-
-  tv_avg->split(1, 4);
-  auto rtvs = tvs.rFactor({2});
-  tv1->computeAt(tv_avg, -1);
+  FusionExecutor fe;
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor t0 = at::randn({M, N}, options);
-  at::Tensor t_avg = at::empty({M}, options);
-  at::Tensor t_var = at::empty({M}, options);
-  at::Tensor t_N = at::empty({M}, options_int);
+  at::Tensor t0 = at::randn({x, y, z}, options);
+  at::Tensor t1 = at::randn({w, x, y, z}, options);
 
-  FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion({t0});
+  auto outputs = fe.runFusion({t0, t1});
 
-  // by default Welford outputs sum of square diff so need to divide to get var
-  outputs[1] /= N;
+  auto t3 = t0.add(1.0);
+  auto t4 = t3.add(t1);
 
-  testValidate(
-      &fusion,
-      outputs,
-      {t0},
-      {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N},
-      __LINE__,
-      __FILE__);
+  TORCH_CHECK(t4.allclose(outputs[0]));
 }
 
-TEST(NVFuserTest, FusionWelfordSchedule_CUDA) {
+TEST(NVFuserTest, FusionAdvancedIndexing2_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  int M = 64, N = 128;
+  int w = 3, x = 4, y = 7, z = 8;
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  auto tv0 = makeSymbolicTensor(2);
+  auto tv0 = makeDummyTensor(3);
+  auto tv1 = makeDummyTensor(4);
   fusion.addInput(tv0);
-  auto tv1 = mul(tv0, new Double(1));
-  auto tvs = Welford(tv1, {1});
-  auto tv_avg = tvs.avg;
-  auto tv_M2 = tvs.var_sum;
-  auto tv_N = tvs.n;
-  fusion.addOutput(tv_avg);
-  fusion.addOutput(tv_M2);
-  fusion.addOutput(tv_N);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor t0 = at::randn({M, N}, options);
-  // TODO: Why do we use launch params from here, but not scheduling???
-  auto reduction_params = getReductionHeuristics(&fusion, {t0});
-  scheduleReduction(&fusion, reduction_params.value());
+  fusion.addInput(tv1);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion({t0}, reduction_params.value().lparams);
+  auto tv2 = add(tv0, new Float(1.0));
+  auto tv3 = broadcast(tv2, {true, false, false, false});
+  auto tv4 = add(tv3, tv1);
 
-  // by default Welford outputs sum of square diff so need to divide to get var
-  outputs[1] /= N;
+  fusion.addOutput(tv4);
 
-  auto at_avg = t0.mean({1});
-  auto at_var = t0.var({1}, false);
-  auto at_n = at::ones({M}, options_int) * N;
+  tv4->merge(-2);
+  tv4->merge(-2);
+  tv4->merge(-2);
 
-  testValidate(
-      &fusion,
-      outputs,
-      {t0},
-      {at_avg, at_var, at_n},
-      __LINE__,
-      __FILE__,
-      "validate welford",
-      reduction_params.value().lparams);
-}
+  tv4->split(0, 128);
+  tv4->split(0, 4);
 
-namespace {
-void testWelford(DataType dtype, int red_axis, int odim, int rdim) {
-  const int axis = red_axis;
-  at::ScalarType aten_dtype = data_type_to_aten(dtype);
+  tv2->computeAt(tv4, 1);
 
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-  TensorView* tv0 = makeSymbolicTensor(2, dtype);
-  bool is_fp16 = dtype == DataType::Half;
-  TensorView* tv0_cast = tv0;
-  if (is_fp16) {
-    tv0_cast = castOp(DataType::Float, tv0);
-  }
-  fusion.addInput(tv0);
-  auto tv1 = mul(tv0_cast, new Double(1));
-  auto tvs = Welford(tv1, {axis});
-  auto tv_avg = tvs.avg;
-  auto tv_M2 = tvs.var_sum;
-  auto tv_N = tvs.n;
-
-  TensorView* avg_cast = tv_avg;
-  TensorView* M2_cast = tv_M2;
-
-  if (is_fp16) {
-    avg_cast = castOp(DataType::Half, tv_avg);
-    M2_cast = castOp(DataType::Half, tv_M2);
-  }
+  tv4->axis(0)->parallelize(ParallelType::BIDx);
+  tv4->axis(1)->parallelize(ParallelType::Unroll);
+  tv4->axis(2)->parallelize(ParallelType::TIDx);
 
-  fusion.addOutput(avg_cast);
-  fusion.addOutput(M2_cast);
-  fusion.addOutput(tv_N);
+  tv3->axis(1)->parallelize(ParallelType::Unroll);
+  tv3->axis(2)->parallelize(ParallelType::TIDx);
 
-  auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
-  auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  std::vector<TensorView*> outputs_of_red;
-  at::Tensor aten_input =
-      (axis ? at::randn({odim, rdim}, options)
-            : at::randn({rdim, odim}, options));
-
-  if (is_fp16) {
-    outputs_of_red.push_back(avg_cast);
-    outputs_of_red.push_back(M2_cast);
-  }
+  tv2->axis(1)->parallelize(ParallelType::Unroll);
+  tv2->axis(2)->parallelize(ParallelType::TIDx);
 
-  auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
-  scheduleReduction(&fusion, reduction_params.value());
+  FusionExecutor fe;
 
-  auto lparams = reduction_params.value().lparams;
+  at::Tensor t0 = at::randn({x, y, z}, options);
+  at::Tensor t1 = at::randn({w, x, y, z}, options);
 
-  FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion({aten_input}, reduction_params.value().lparams);
-
-  // by default Welford outputs sum of square diff so need to divide to
-  // get var
-
-  outputs[1] /= rdim;
-
-  auto at_avg = aten_input.mean({axis});
-  auto at_var = aten_input.var({axis}, false);
-  auto at_n =
-      (axis ? at::ones({odim, rdim}, options)
-            : at::ones({rdim, odim}, options));
-  at_n = at_n.sum({axis});
-
-  testValidate(
-      &fusion,
-      outputs,
-      {aten_input},
-      {at_avg, at_var, at_n},
-      __LINE__,
-      __FILE__,
-      "validate welford",
-      reduction_params.value().lparams);
-}
-} // namespace
+  auto outputs = fe.runFusion({t0, t1});
 
-TEST(NVFuserTest, FusionWelfordShmoo_CUDA) {
-  std::vector<DataType> dtypes = {
-      DataType::Double, DataType::Float, DataType::Half};
-  std::vector<int> red_axis = {1, 0};
-  std::vector<int> output_dims = {160, 320};
-  std::vector<int> red_dims;
-
-  // Tried to cut down the number iterations with just
-  // doing every other power of 2.
-  for (int i = 1; i <= 1024 * 1024; i <<= 2) {
-    red_dims.push_back(i);
-  }
+  auto t3 = t0.add(1.0);
+  auto t4 = t3.add(t1);
 
-  for (auto dtype : dtypes) {
-    for (auto& axis : red_axis) {
-      for (auto& odim : output_dims) {
-        for (auto& rdim : red_dims) {
-          // TODO: original welford algorithm actually keeps a running sum of
-          // squares, i.e. M_{2n} in the
-          //       cf:
-          //       https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
-          //       algorithm notation, and it can reach inf for large numbers
-          //       with half precision. skipping too large volumes for half for
-          //       nwo might need further numerical experiments to re-design
-          //       this.
-          if (rdim > 32768 && dtype == DataType::Half) {
-            continue;
-          }
-          testWelford(dtype, axis, odim, rdim);
-        }
-      }
-    }
-  }
+  TORCH_CHECK(t4.allclose(outputs[0]));
 }
 
-TEST(NVFuserTest, FusionTranspose1_CUDA) {
+TEST(NVFuserTest, FusionAdvancedIndexing3_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  constexpr int M = 10;
-  constexpr int N = 20;
+  int w = 3, x = 4, y = 7, z = 8;
 
-  auto tv0 = makeSymbolicTensor(2);
-  auto tv1 = transpose(tv0, {{0, 1}});
+  auto tv0 = makeDummyTensor(3);
+  auto tv1 = makeDummyTensor(4);
   fusion.addInput(tv0);
-  fusion.addOutput(tv1);
+  fusion.addInput(tv1);
 
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv1->axis(1)->parallelize(ParallelType::TIDx);
+  auto tv2 = add(tv0, new Float(1.0));
+  auto tv3 = add(tv2, tv1);
+  fusion.addOutput(tv3);
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor t0 = at::randn({M, N}, options);
-  std::vector<IValue> aten_inputs = {t0};
+  at::Tensor t0 = at::randn({x, y, z}, options);
+  at::Tensor t1 = at::randn({w, x, y, z}, options);
+
+  scheduleFusion(&fusion, {t0, t1});
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(aten_inputs);
+  auto outputs = fe.runFusion({t0, t1});
 
-  at::Tensor aten_output = t0.t();
+  auto t2 = t0.add(1.0);
+  auto t3 = t2.add(t1);
 
-  testValidate(
-      &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  TORCH_CHECK(t3.allclose(outputs[0]));
 }
 
-TEST(NVFuserTest, FusionTranspose2_CUDA) {
+TEST(NVFuserTest, FusionAdvancedIndexing4_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  constexpr int M = 10;
-  constexpr int N = 20;
-
-  auto tv0 = makeSymbolicTensor(2);
-  auto tv1 = transpose(tv0, {{0, 1}});
+  // Set up your input tensor views
+  TensorView* tv0 = makeConcreteTensor({10, 20});
   fusion.addInput(tv0);
-  fusion.addOutput(tv1);
-
-  tv1->merge(0);
-  tv1->split(0, 32);
+  TensorView* tv1 = makeConcreteTensor({10, 10, 20});
+  fusion.addInput(tv1);
 
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv1->axis(1)->parallelize(ParallelType::TIDx);
+  TensorView* tv2 = add(tv0, new Float(1));
+  TensorView* tv3 = broadcast(tv2, {true, false, false});
+  TensorView* tv4 = add(tv3, tv1);
+  fusion.addOutput(tv4);
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor t0 = at::randn({M, N}, options);
-  std::vector<IValue> aten_inputs = {t0};
+  at::Tensor t0 = at::randn({10, 20}, options);
+  at::Tensor t1 = at::randn({10, 10, 20}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(aten_inputs);
+  auto outputs = fe.runFusion({t0, t1});
 
-  at::Tensor aten_output = t0.t();
+  auto t2 = t0.add(1.0);
+  auto t3 = t2.add(t1);
 
-  testValidate(
-      &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  TORCH_CHECK(t3.allclose(outputs[0]));
 }
 
-TEST(NVFuserTest, FusionSimpleGemmTransposed_CUDA) {
+// Test a simple Gemm but also play around with fusion executor features
+TEST(NVFuserTest, FusionSimpleGemm_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
   // Set up your input tensor views
-
-  TensorView* tv0 = makeSymbolicTensor(2); // K, M
-  TensorView* tv1 = makeSymbolicTensor(2); // N, K
+  TensorView* tv0 = makeDummyTensor(2); // M, K
+  TensorView* tv1 = makeDummyTensor(2); // K, N
   fusion.addInput(tv0);
   fusion.addInput(tv1);
 
-  TensorView* tv0_t = transpose(tv0, {{0, 1}});
-  TensorView* tv1_t = transpose(tv1, {{0, 1}});
-
-  TensorView* tv2 = broadcast(tv0_t, {false, false, true});
+  TensorView* tv2 = broadcast(tv0, {false, false, true});
   // tv2[I0, I1, B] = tv0[I0, I1]
 
-  TensorView* tv3 = broadcast(tv1_t, {true, false, false});
+  TensorView* tv3 = broadcast(tv1, {true, false, false});
   // tv3[B, I1, I2] = tv1[I1, I2]
 
   // tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2]
@@ -12315,8 +4035,8 @@ TEST(NVFuserTest, FusionSimpleGemmTransposed_CUDA) {
   // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
   // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
 
-  tv0_t->computeAt(tv5, -1);
-  tv1_t->computeAt(tv5, -1);
+  tv0->computeAt(tv5, -1);
+  tv1->computeAt(tv5, -1);
 
   // tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
   // tv5[I0o, I0i{4},    , R1i{32}, I2o, I2i{4}]
@@ -12325,8 +4045,8 @@ TEST(NVFuserTest, FusionSimpleGemmTransposed_CUDA) {
   // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
   // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
 
-  tv0_t->computeAt(tv6, -1);
-  tv1_t->computeAt(tv6, -1);
+  tv0->computeAt(tv6, -1);
+  tv1->computeAt(tv6, -1);
   // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
   // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
   // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
@@ -12344,3754 +4064,3565 @@ TEST(NVFuserTest, FusionSimpleGemmTransposed_CUDA) {
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  at::Tensor t0 = at::randn({K, M}, options);
-  at::Tensor t1 = at::randn({N, K}, options);
+  at::Tensor t0 = at::randn({M, K}, options);
+  at::Tensor t1 = at::randn({K, N}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
   // Lets specify a few bounds in launch params to make sure it works
   fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4));
 
-  // Don't specify any launch params
-  auto cg_outputs = fe.runFusion({t0, t1});
+  // Make sure bad launch params throws
+  ASSERT_ANY_THROW(fe.runFusion({t0, t1}, LaunchParams(1, 2, 3, 4, 5, 6)));
 
-  auto aten_output = t0.t().to(at::kDouble).matmul(t1.t().to(at::kDouble));
+  // Don't specify any launch params
+  auto outputs = fe.runFusion({t0, t1});
 
-  testValidate(
-      &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__);
+  auto t2 = t0.matmul(t1);
+  TORCH_CHECK(
+      t2.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      t2.sub(outputs[0]).abs().max());
 }
 
-TEST(NVFuserTest, FusionSoftmax3DTransposed_CUDA) {
+// Softmax with a 1D tensor. Parallelized only with a single thread block.
+TEST(NVFuserTest, FusionSoftmax1D_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  const int tidx = 32;
-  const int dimx = 32;
-  const int dimy = 16;
-  const int dimz = 130;
+  const int tidx = 128;
+  const int dimx = 1000;
 
   // Set up your input tensor views
-  TensorView* input_tv0 = makeSymbolicTensor(3);
+  TensorView* input_tv0 = makeDummyTensor(1);
   fusion.addInput(input_tv0);
 
-  TensorView* input_t = transpose(input_tv0, {{1, 2}});
-
-  TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_t);
+  TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0);
   TensorView* sum_exp_tv2 = sum(exp_tv1, {-1});
-  TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {false, false, true});
+  TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {true});
 
   // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
   // computed at sum_exp_rf_tv8.
-  TensorView* input_t_copy = transpose(input_tv0, {{1, 2}});
-  TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_t_copy);
+  TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0);
 
   TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3);
 
   fusion.addOutput(output_tv4);
 
-  bcast_sum_tv3->split(-1, tidx);
+  bcast_sum_tv3->split(0, tidx);
 
   sum_exp_tv2->split(-1, tidx);
   TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2});
 
   output_tv4->split(-1, tidx);
 
-  input_t->computeAt(sum_exp_rf_tv5, -1);
-  input_t_copy->computeAt(output_tv4, -1);
+  exp_tv1->computeAt(sum_exp_rf_tv5, -1);
+  exp_tv1_copy->computeAt(output_tv4, -1);
 
   TensorView* tensors_to_parallelize[] = {
       sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5};
 
   for (auto tv : tensors_to_parallelize) {
-    tv->axis(0)->parallelize(ParallelType::BIDx);
-    tv->axis(1)->parallelize(ParallelType::BIDy);
     tv->axis(-1)->parallelize(ParallelType::TIDx);
   }
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({dimx, dimz, dimy}, options);
-
-  at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options);
+  at::Tensor t0 = at::randn({dimx}, options);
+  at::Tensor cg_output = at::empty({dimx}, options);
+  at::Tensor t3_output = at::empty_like(cg_output, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  fe.runFusion({input}, {cg_output});
-
-  auto aten_input_t = at::transpose(input, 1, 2);
-  auto aten_output = at::_softmax(aten_input_t.to(at::kDouble), -1, false);
+  fe.runFusion({t0}, {cg_output});
 
-  testValidate(
-      &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
+  auto t2 = at::_softmax(t0, -1, false);
+  TORCH_CHECK(
+      t2.allclose(cg_output, 1e-5, 1e-5),
+      "Error of: ",
+      t2.sub(cg_output).abs().max());
 }
 
-TEST(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) {
-  // Case 1
-  // tv1 = tv0 * 0.5
-  // tv2 = tv1 * -1
-  // tv3 = tv1 + 3
-  // tv4 = tv1 * 2
-  // tv5 = tv3 + tv2
-  // tv6 = tv5 + tv4
-  // tv7 = tv1 + tv4
+// Softmax with a 1D tensor with input normalization.
+TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
+  const int tidx = 128;
+  const int dimx = 1000;
 
-  tv0 = transpose(tv0, {{0, 1}});
+  // Set up your input tensor views
+  TensorView* input_tv0 = makeDummyTensor(1);
+  fusion.addInput(input_tv0);
 
-  TensorView* tv1 = mul(tv0, new Double(0.5));
-  TensorView* tv2 = mul(tv1, new Double(-1.0));
-  TensorView* tv3 = add(tv1, new Double(3.0));
-  TensorView* tv4 = mul(tv1, new Double(2.0));
-  TensorView* tv5 = add(tv3, tv2);
+  // Normalize with the max value before computing exp.
+  TensorView* max_val_tv1 =
+      reductionOp(BinaryOpType::Max, {-1}, new Float(0), input_tv0);
+  TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {true});
+  TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2);
+  TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3);
+  TensorView* sum_exp_tv5 = sum(exp_tv4, {-1});
+  TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {true});
 
-  TensorView* tv6 = add(tv5, tv4);
-  TensorView* tv7 = add(tv1, tv4);
+  // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
+  // computed at sum_exp_rf_tv8.
+  TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2);
+  TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy);
 
-  fusion.addOutput(tv6);
-  fusion.addOutput(tv7);
+  TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6);
 
-  // Lets setup to actually run
-  tv7->merge(0);
-  tv7->split(0, 128);
-  tv7->split(0, 4);
+  fusion.addOutput(output_tv7);
+  bcast_max_tv2->split(0, tidx);
+  bcast_sum_tv6->split(0, tidx);
 
-  tv7->axis(0)->parallelize(ParallelType::BIDx);
+  max_val_tv1->split(-1, tidx);
+  TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2});
 
-  tv0->computeAt(tv7, 1);
+  sum_exp_tv5->split(-1, tidx);
+  TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2});
 
-  // The this-position of the last tensor should be zero.
-  TORCH_CHECK(
-      tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 &&
-      tv7->getMaxProducerPosition() == 1);
-  TORCH_CHECK(
-      tv6->nDims() == 3 && tv6->getComputeAtPosition() == 0 &&
-      tv6->getMaxProducerPosition() == 1);
-  // The position of every other tensor should be 1.
-  for (auto tv : {tv1, tv2, tv3, tv4, tv5}) {
-    TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1);
-  }
+  output_tv7->split(-1, tidx);
 
-  for (Val* val : fusion.vals()) {
-    if (!fusion.hasInput(val) &&
-        val->getValType().value() == ValType::TensorView) {
-      TensorView* tv = static_cast<TensorView*>(val);
-      tv->axis(1)->parallelize(ParallelType::Unroll);
-      tv->axis(-1)->parallelize(ParallelType::TIDx);
-    }
+  sub_tv3->computeAt(sum_exp_rf_tv9, -1);
+  sub_tv3_copy->computeAt(output_tv7, -1);
+
+  TensorView* tensors_to_parallelize[] = {
+      max_val_tv1,
+      bcast_max_tv2,
+      sum_exp_tv5,
+      bcast_sum_tv6,
+      output_tv7,
+      max_val_rf_tv8,
+      sum_exp_rf_tv9};
+
+  for (auto tv : tensors_to_parallelize) {
+    tv->axis(-1)->parallelize(ParallelType::TIDx);
   }
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor aten_input = at::randn({129, 127}, options);
+  at::Tensor t0 = at::randn({dimx}, options);
+  at::Tensor t3_output = at::empty({dimx}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  at::Tensor aten_input_t = aten_input.t();
-
-  auto t1 = aten_input_t.mul({0.5});
-  auto t2 = t1.mul({-1.0});
-  auto t3 = t1.add({3.0});
-  auto t4 = t1.mul({2.0});
-  auto t5 = t3.add(t2);
-  auto t6 = t5.add(t4);
-  auto t7 = t1.add(t4);
-
-  std::vector<at::Tensor> aten_outputs = {t6, t7};
+  auto outputs = fe.runFusion({t0});
 
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
+  auto t2 = at::_softmax(t0, -1, false);
+  TORCH_CHECK(
+      t2.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      t2.sub(outputs[0]).abs().max());
 }
 
-TEST(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) {
-  // Case 2
-  // tv1 = tv0 * -1
-  // tv2 = tv0 + 3
-  // tv3 = tv0 * 2
-  // tv4 = tv2 + tv1
-  // tv5 = tv4 + tv3
-  // tv6 = tv5 + tv3
+// Softmax with a 3D tensor, where the inner-most 3rd dimension is
+// normalized. Pallelized with multiple thread blocks.
+TEST(NVFuserTest, FusionSoftmax3D_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
+  const int tidx = 32;
+  const int dimx = 32;
+  const int dimy = 16;
+  const int dimz = 130;
 
-  tv0 = transpose(tv0, {{0, 1}});
+  // Set up your input tensor views
+  TensorView* input_tv0 = makeDummyTensor(3);
+  fusion.addInput(input_tv0);
 
-  TensorView* tv1 = mul(tv0, new Double(-1.0));
-  TensorView* tv2 = add(tv0, new Double(3.0));
-  TensorView* tv3 = mul(tv0, new Double(2.0));
-  TensorView* tv4 = add(tv2, tv1);
+  TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0);
+  TensorView* sum_exp_tv2 = sum(exp_tv1, {-1});
+  TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {false, false, true});
 
-  TensorView* tv5 = add(tv4, tv3);
-  TensorView* tv6 = add(tv5, tv3);
+  // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
+  // computed at sum_exp_rf_tv8.
+  TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0);
 
-  fusion.addOutput(tv5);
-  fusion.addOutput(tv6);
+  TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3);
 
-  // Lets setup to actually run
-  tv6->merge(0);
-  tv6->split(0, 128);
-  tv6->split(0, 4);
+  fusion.addOutput(output_tv4);
 
-  tv6->axis(0)->parallelize(ParallelType::BIDx);
+  bcast_sum_tv3->split(-1, tidx);
 
-  tv0->computeAt(tv6, 1);
+  sum_exp_tv2->split(-1, tidx);
+  TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2});
 
-  for (Val* val : fusion.vals()) {
-    if (!fusion.hasInput(val) &&
-        val->getValType().value() == ValType::TensorView) {
-      TensorView* tv = static_cast<TensorView*>(val);
+  output_tv4->split(-1, tidx);
 
-      tv->axis(1)->parallelize(ParallelType::Unroll);
-      tv->axis(-1)->parallelize(ParallelType::TIDx);
-    }
+  exp_tv1->computeAt(sum_exp_rf_tv5, -1);
+  exp_tv1_copy->computeAt(output_tv4, -1);
+
+  TensorView* tensors_to_parallelize[] = {
+      sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5};
+
+  for (auto tv : tensors_to_parallelize) {
+    tv->axis(0)->parallelize(ParallelType::BIDx);
+    tv->axis(1)->parallelize(ParallelType::BIDy);
+    tv->axis(-1)->parallelize(ParallelType::TIDx);
   }
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({129, 127}, options);
-
+  at::Tensor t0 = at::randn({dimx, dimy, dimz}, options);
+  at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options);
+  at::Tensor t3_output = at::empty_like(cg_output, options);
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input});
-
-  auto input_t = input.t();
-  auto t1 = input_t.mul({-1.0});
-  auto t2 = input_t.add({3.0});
-  auto t3 = input_t.mul({2.0});
-  auto t4 = t2.add(t1);
-  auto t5 = t4.add(t3);
-  auto t6 = t5.add(t3);
-
-  std::vector<at::Tensor> aten_outputs = {t5, t6};
+  fe.runFusion({t0}, {cg_output});
 
-  testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__);
+  auto t2 = at::_softmax(t0, -1, false);
+  TORCH_CHECK(
+      t2.allclose(cg_output, 1e-5, 1e-5),
+      "Error of: ",
+      t2.sub(cg_output).abs().max());
 }
 
-TEST(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) {
-  // Case 3
-  // T2 = T1 * 0.979361
-  // T3 = T2 * T0
+// Softmax with a 3D tensor with input normalization.
+TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(4);
-  fusion.addInput(tv0);
+  const int tidx = 32;
+  const int dimx = 32;
+  const int dimy = 16;
+  const int dimz = 130;
 
-  tv0 = transpose(tv0, {{0, 1}, {1, 2}, {2, 3}, {3, 0}});
+  // Set up your input tensor views
+  TensorView* input_tv0 = makeDummyTensor(3);
+  fusion.addInput(input_tv0);
 
-  TensorView* tv1 = makeSymbolicTensor(4);
-  fusion.addInput(tv1);
+  // Normalize with the max value before computing exp.
+  TensorView* max_val_tv1 =
+      reductionOp(BinaryOpType::Max, {-1}, new Float(0), input_tv0);
+  TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true});
+  TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2);
+  TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3);
+  TensorView* sum_exp_tv5 = sum(exp_tv4, {-1});
+  TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {false, false, true});
 
-  tv1 = transpose(tv1, {{0, 1}, {1, 2}, {2, 3}, {3, 0}});
+  // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
+  // computed at sum_exp_rf_tv8.
+  TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2);
+  TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy);
 
-  TensorView* tv2 = mul(tv1, new Double(.979361));
-  TensorView* tv3 = mul(tv2, tv0);
+  TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6);
 
-  fusion.addOutput(tv3);
+  fusion.addOutput(output_tv7);
 
-  // Lets setup to actually run
-  while (tv3->nDims() > 1)
-    tv3->merge(0);
-  tv3->split(0, 128);
-  tv3->split(0, 4);
+  bcast_max_tv2->split(-1, tidx);
+  bcast_sum_tv6->split(-1, tidx);
 
-  tv0->computeAt(tv3, 1);
-  tv1->computeAt(tv3, 1);
+  max_val_tv1->split(-1, tidx);
+  TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2});
 
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
+  sum_exp_tv5->split(-1, tidx);
+  TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2});
 
-  for (Val* val : fusion.vals()) {
-    if (!fusion.hasInput(val) &&
-        val->getValType().value() == ValType::TensorView) {
-      TensorView* tv = static_cast<TensorView*>(val);
+  output_tv7->split(-1, tidx);
 
-      tv->axis(1)->parallelize(ParallelType::Unroll);
-      tv->axis(-1)->parallelize(ParallelType::TIDx);
-    }
+  sub_tv3->computeAt(sum_exp_rf_tv9, -1);
+  sub_tv3_copy->computeAt(output_tv7, -1);
+
+  TensorView* tensors_to_parallelize[] = {
+      max_val_tv1,
+      bcast_max_tv2,
+      sum_exp_tv5,
+      bcast_sum_tv6,
+      output_tv7,
+      max_val_rf_tv8,
+      sum_exp_rf_tv9};
+
+  for (auto tv : tensors_to_parallelize) {
+    tv->axis(0)->parallelize(ParallelType::BIDx);
+    tv->axis(1)->parallelize(ParallelType::BIDy);
+    tv->axis(-1)->parallelize(ParallelType::TIDx);
   }
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({129, 127, 63, 65}, options);
-  at::Tensor t1 = at::rand_like(t0, options);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
+  at::Tensor t0 = at::randn({dimx, dimy, dimz}, options);
+  at::Tensor t3_output = at::empty({dimx, dimy, dimz}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  auto t0_t = t0.permute({3, 0, 1, 2});
-  auto t1_t = t1.permute({3, 0, 1, 2});
-  auto t2 = t1_t.mul({0.979361});
-  auto aten_output = t2.mul(t0_t);
+  auto outputs = fe.runFusion({t0});
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  auto t2 = at::_softmax(t0, -1, false);
+  TORCH_CHECK(
+      t2.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      t2.sub(outputs[0]).abs().max());
 }
 
-TEST(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) {
-  // Case 4
-  // T4 = T2 - T3
-  // T5 = T1 + T4
-  // T6 = T5 - T0
+TEST(NVFuserTest, FusionSoftmaxComputeAt_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(4);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
 
-  tv0 = transpose(tv0, {{0, 1}, {1, 2}, {2, 3}, {3, 0}});
-
-  TensorView* tv1 = makeSymbolicTensor(4);
-  fusion.addInput(tv1);
-
-  tv1 = transpose(tv1, {{0, 1}, {1, 2}, {2, 3}, {3, 0}});
-
-  TensorView* tv2 = makeSymbolicTensor(4);
-  fusion.addInput(tv2);
-
-  tv2 = transpose(tv2, {{0, 1}, {1, 2}, {2, 3}, {3, 0}});
+  auto tv1 = sum(tv0, {1});
+  auto tv2 = broadcast(tv1, {false, true});
 
-  TensorView* tv3 = makeSymbolicTensor(4);
-  fusion.addInput(tv3);
+  auto tv3 = add(tv0, new Float(1.0));
 
-  tv3 = transpose(tv3, {{0, 1}, {1, 2}, {2, 3}, {3, 0}});
+  auto tv4 = mul(tv2, tv3);
 
-  TensorView* tv4 = sub(tv2, tv3);
-  TensorView* tv5 = add(tv1, tv4);
-  TensorView* tv6 = sub(tv5, tv0);
+  auto tv5 = sum(tv4, {1});
+  auto tv6 = broadcast(tv5, {false, true});
 
-  fusion.addOutput(tv6);
+  auto tv7 = sub(tv6, tv4);
+  fusion.addOutput(tv7);
 
-  // Lets setup to actually run
-  while (tv6->nDims() > 1)
-    tv6->merge(0);
-  tv6->split(0, 128);
-  tv6->split(0, 4);
+  tv1->computeAt(tv7, 1);
+  ASSERT_ANY_THROW(tv1->computeAt(tv7, -1));
+}
 
-  tv0->computeAt(tv6, 1);
-  tv1->computeAt(tv6, 1);
-  tv2->computeAt(tv6, 1);
-  tv3->computeAt(tv6, 1);
+// Similar to FusionReduction but uses grid reduction
+TEST(NVFuserTest, FusionGridReduction1_CUDA) {
+  const int gdimx = 32;
+  const int bdimx = 128;
 
-  tv6->axis(0)->parallelize(ParallelType::BIDx);
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  for (Val* val : fusion.vals()) {
-    if (!fusion.hasInput(val) &&
-        val->getValType().value() == ValType::TensorView) {
-      TensorView* tv = static_cast<TensorView*>(val);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
 
-      tv->axis(1)->parallelize(ParallelType::Unroll);
-      tv->axis(-1)->parallelize(ParallelType::TIDx);
-    }
-  }
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  fusion.addOutput(tv1);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({129, 127, 63, 65}, options);
-  at::Tensor t1 = at::rand_like(t0, options);
-  at::Tensor t2 = at::rand_like(t0, options);
-  at::Tensor t3 = at::rand_like(t0, options);
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
 
-  std::vector<IValue> aten_inputs = {t0, t1, t2, t3};
+  tv1->split(1, bdimx);
+  // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+  tv1->split(1, gdimx);
+  // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  TensorView* tv2 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
+  // tv1[I0,        R1oi{32},  R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
 
-  auto t0_t = t0.permute({3, 0, 1, 2});
-  auto t1_t = t1.permute({3, 0, 1, 2});
-  auto t2_t = t2.permute({3, 0, 1, 2});
-  auto t3_t = t3.permute({3, 0, 1, 2});
-  auto t4 = t2_t.sub(t3_t);
-  auto t5 = t1_t.add(t4);
-  auto aten_output = t5.sub(t0_t);
+  // Incrementally, can print in between for debugging
+  tv0->computeAt(tv2, 1);
+  tv2->computeAt(tv1, 1);
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+  // Re do it all at once, because why not.
+  tv0->computeAt(tv1, 1);
 
-TEST(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) {
-  // Case 5
-  // tv2 = tv0 + 2.0
-  // tv3 = tv1 * tv2
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  tv1->axis(0)->parallelize(ParallelType::BIDy);
+  tv1->axis(1)->parallelize(ParallelType::BIDx);
+  tv2->axis(2)->parallelize(ParallelType::BIDx);
 
-  // Set up your input tensor views
-  TensorView* tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  tv0 = transpose(tv0, {{0, 1}});
-  TensorView* tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
-  tv1 = transpose(tv1, {{0, 1}});
-  TensorView* tv2 = add(tv0, new Double(2.0));
-  TensorView* tv3 = mul(tv1, tv2);
-  fusion.addOutput(tv3);
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
 
-  tv3->merge(0);
-  tv3->split(-1, 8);
-  tv3->split(-1, 4);
+  int numel_x = 10000;
+  int numel_y = 65000;
 
-  tv0->computeAt(tv3, 1);
-  tv1->computeAt(tv3, 1);
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
+  // fusion.printKernel();
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({63, 65}, options);
-  at::Tensor t1 = at::rand_like(t0, options);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  auto t2 = t0.t().add(2.0);
-  auto aten_output = t1.t().mul(t2);
+  fe.runFusion({input}, {cg_output});
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(cg_output));
 }
 
-TEST(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) {
+// Same test as the above but uses BIDy and TIDx for reduction
+TEST(NVFuserTest, FusionGridReduction2_CUDA) {
+  const int gdimy = 32;
+  const int bdimx = 128;
+
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeSymbolicTensor(2);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
-  tv0 = transpose(tv0, {{0, 1}});
-  TensorView* tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
-  tv1 = transpose(tv1, {{0, 1}});
-  TensorView* tv2 = add(tv0, new Double(2.0));
-  TensorView* tv3 = mul(tv1, tv2);
-  fusion.addOutput(tv3);
-
-  tv2->merge(0);
-  tv2->split(-1, 8);
-  tv2->split(-1, 4);
-  tv3->merge(0);
-  tv3->split(-1, 8);
-
-  tv0->computeAt(tv3, 1);
-  tv1->computeAt(tv3, 1);
-
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({63, 65}, options);
-  at::Tensor t1 = at::rand_like(t0, options);
 
-  std::vector<IValue> aten_inputs = {t0, t1};
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  fusion.addOutput(tv1);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
 
-  auto t2 = t0.t().add(2.0);
-  auto aten_output = t1.t().mul(t2);
+  tv1->split(1, bdimx);
+  // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+  tv1->split(1, gdimy);
+  // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+  TensorView* tv2 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
+  // tv1[I0,        R1oi{32},  R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
 
-TEST(NVFuserTest, FusionSegmentReducePointwise_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+  // Incrementally, can print in between for debugging
+  tv0->computeAt(tv2, 1);
+  tv2->computeAt(tv1, 1);
 
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(1);
-  TensorView* tv2 = makeSymbolicTensor(2);
+  // Re do it all at once, because why not.
+  tv0->computeAt(tv1, 1);
 
-  fusion->addInput(tv0);
-  fusion->addInput(tv1);
-  fusion->addInput(tv2);
+  tv1->axis(0)->parallelize(ParallelType::BIDx);
+  tv1->axis(1)->parallelize(ParallelType::BIDy);
+  tv2->axis(2)->parallelize(ParallelType::BIDy);
 
-  TensorView* tv3 = add(tv0, new Double(1)); // Group 0
-  TensorView* tv4 =
-      max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues)
-  TensorView* tv5 = add(tv4, tv1); //  Group 0 (Non Broadcast after reduce,
-                                   //  keeps normalization scheduler away)
-  TensorView* tv6 = add(tv5, tv2); //  Group 1 (Broadcast after reduce)
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
 
-  fusion->addOutput(tv6);
+  int numel_x = 10000;
+  int numel_y = 65000;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({128, 65}, options);
-  at::Tensor t1 = at::randn({65}, options);
-  at::Tensor t2 = at::randn({128, 65}, options);
-
-  auto t3 = t0.add(1.0);
-  auto t4 = std::get<0>(at::max(t3, 0));
-  auto t5 = t4.add(t1);
-  auto t6 = t5.add(t2);
-
-  FusionExecutorCache executor_cache(std::move(fusion));
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
 
-  auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2});
-
-  TORCH_CHECK(
-      executor_cache.getMostRecentKernelRuntime()->isSegmented(),
-      "segmentation didn't happen");
-  TORCH_CHECK(
-      executor_cache.getMostRecentKernelRuntime()
-              ->fusionSegments()
-              ->groups()
-              .size() == 2,
-      "segmentation didn't happen as expected");
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({input});
 
-  testValidate(
-      executor_cache.fusion(), outputs, {t0, t1, t2}, {t6}, __LINE__, __FILE__);
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(outputs[0]));
 }
 
-TEST(NVFuserTest, FusionMultipleVectorize_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+// Same test but uses BIDy and BIDz for reduction. No TID used.
+TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) {
+  const int gdimz = 32;
+  const int gdimy = 128;
 
-  TensorView* tv0 = makeContigTensor(1);
-  TensorView* tv1 = makeContigTensor(1);
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  fusion->addInput(tv0);
-  fusion->addInput(tv1);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
 
-  TensorView* tv3 = add(tv0, tv1);
-  fusion->addOutput(tv3);
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  fusion.addOutput(tv1);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({40960}, options);
-  at::Tensor t1 = at::randn({40960}, options);
-  auto t2 = t0 + t1;
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
 
-  FusionExecutorCache executor_cache(std::move(fusion));
-  executor_cache.profile(true);
+  tv1->split(1, gdimy);
+  // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+  tv1->split(1, gdimz);
+  // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
 
-  auto outputs = executor_cache.runFusionWithInputs({t0, t1});
-  auto runtime1 = executor_cache.getMostRecentKernelRuntime();
-  auto log1 = executor_cache.getMostRecentExecutorInfo().pointwise_params;
-  TORCH_CHECK(log1.has_value());
-  TORCH_CHECK(log1->vectorize);
+  TensorView* tv2 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
+  // tv1[I0,        R1oi{32},  R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
 
-  testValidate(
-      executor_cache.fusion(), outputs, {t0, t1}, {t2}, __LINE__, __FILE__);
+  // Incrementally, can print in between for debugging
+  tv0->computeAt(tv2, 1);
+  tv2->computeAt(tv1, 1);
 
-  t0 = at::randn({40964}, options);
-  t1 = at::randn({40964}, options);
-  t2 = t0 + t1;
+  // Re do it all at once, because why not.
+  tv0->computeAt(tv1, 1);
 
-  outputs = executor_cache.runFusionWithInputs({t0, t1});
-  auto runtime2 = executor_cache.getMostRecentKernelRuntime();
-  auto log2 = executor_cache.getMostRecentExecutorInfo().pointwise_params;
-  TORCH_CHECK(log2.has_value());
-  TORCH_CHECK(log2->vectorize);
+  tv1->axis(0)->parallelize(ParallelType::BIDx);
+  tv1->axis(1)->parallelize(ParallelType::BIDz);
+  tv2->axis(2)->parallelize(ParallelType::BIDz);
 
-  testValidate(
-      executor_cache.fusion(), outputs, {t0, t1}, {t2}, __LINE__, __FILE__);
+  tv1->axis(-1)->parallelize(ParallelType::BIDy);
+  tv2->axis(-1)->parallelize(ParallelType::BIDy);
 
-  t0 = at::randn({40962}, options);
-  t1 = at::randn({40962}, options);
-  t2 = t0 + t1;
+  int numel_x = 100;
+  int numel_y = 6500;
 
-  outputs = executor_cache.runFusionWithInputs({t0, t1});
-  auto runtime3 = executor_cache.getMostRecentKernelRuntime();
-  auto log3 = executor_cache.getMostRecentExecutorInfo().pointwise_params;
-  TORCH_CHECK(log3.has_value());
-  TORCH_CHECK(log3->vectorize);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
 
-  testValidate(
-      executor_cache.fusion(), outputs, {t0, t1}, {t2}, __LINE__, __FILE__);
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  fe.runFusion({input}, {cg_output});
 
-  TORCH_CHECK(runtime1 == runtime2);
-  TORCH_CHECK(runtime1 != runtime3);
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(cg_output));
 }
 
-TEST(NVFuserTest, FusionVectorizeSimple_CUDA) {
+// Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0
+TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) {
+  const int rdim = 0;
+  const int gdimy = 128;
+  const int gdimz = 32;
+
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  TensorView* tv0 = makeContigTensor(3);
-
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
 
-  auto tv1 = unaryOp(UnaryOpType::Sin, tv0);
-
+  // tv1[R0, I1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {rdim}, new Float(0), tv0);
   fusion.addOutput(tv1);
 
-  auto tv0_cache = tv0->cache_after();
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
 
-  auto tv1_cache = tv1->cache_before();
+  tv1->split(rdim, gdimy);
+  // tv1[R0o, R0i{128}, I1] = tv0[I0, I1]
+  tv1->split(rdim, gdimz);
+  // tv1[R0oo, R0oi{32}, R0i{128}, I1] = tv0[I0, I1]
 
-  tv1->merge(0);
-  tv1->merge(0);
-  tv1->split(0, 4);
-  tv1->split(0, 128);
+  TensorView* tv2 = tv1->rFactor({rdim});
+  // tv2[R0oo, I0oi{32}, I0i{128}, I1] = tv0[I0, I1]
+  // tv1[      R0oi{32}, R0i{128}, I1] = tv2[R0oo, I0oi{32}, I0i{128}, I1]
 
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
-  tv1->axis(1)->parallelize(ParallelType::TIDx);
+  // Note that computeAt isn't going to make anything better as there
+  // is no dynamically sized dimension.
 
-  tv0->computeAt(tv1, 2);
+  // Map parallelism as [Serial, BIDz, BIDy, BIDx]
+  tv1->axis(-1)->parallelize(ParallelType::BIDx);
+  tv2->axis(-1)->parallelize(ParallelType::BIDx);
+  tv1->axis(-2)->parallelize(ParallelType::BIDy);
+  tv2->axis(-2)->parallelize(ParallelType::BIDy);
+  tv1->axis(-3)->parallelize(ParallelType::BIDz);
+  tv2->axis(-3)->parallelize(ParallelType::BIDz);
 
-  tv0_cache->axis(2)->parallelize(ParallelType::Vectorize);
-  tv1->axis(2)->parallelize(ParallelType::Vectorize);
+  int numel_x = 6500;
+  int numel_y = 100;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
-  at::Tensor aten_input = at::empty({2, 6, 32}, options);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({aten_input});
-
-  at::Tensor aten_output = aten_input.sin();
+  auto outputs = fe.runFusion({input});
 
-  testValidate(
-      &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
+  auto aten_output = input.sum({0});
+  TORCH_CHECK(aten_output.allclose(outputs[0]));
 }
 
-TEST(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+// This is similar to the FusionReduction, but swaps BIDx and TIDx
+TEST(NVFuserTest, FusionGridReduction4_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
+
+  const int bdimx = 128;
+  const int gdimx = 1024;
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
+
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  fusion.addOutput(tv1);
+
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
 
-  std::vector<int64_t> input_shape{32, 64, 8};
-  const int kReductionAxis = 1;
+  tv1->split(1, gdimx);
+  // tv1[I0, R1o, R1i{1024}] = tv0[I0, I1]
+  tv1->split(1, 4);
+  // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1]
 
-  auto tv0 = TensorViewBuilder()
-                 .ndims(input_shape.size())
-                 .dtype(DataType::Double)
-                 .build();
+  TensorView* tv2 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
+  // tv1[I0,        R1oi{4},  R1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
 
-  fusion->addInput(tv0);
+  TensorView* tv3 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
+  // tv3[I0,        R1oi{4}, Ir1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
+  // tv1[I0,                  R1i{1024}] = tv3[I0,        R1oi{4}, Ir1i{1024}]
 
-  auto tv1 = add(tv0, new Double(1.0));
-  auto tv2 = sum(tv1, {2}); // Group 0
+  // Incrementally, can print in between for debugging
+  tv0->computeAt(tv2, 1);
+  tv2->computeAt(tv3, 1);
+  tv3->computeAt(tv1, 1);
 
-  auto output = softmax(tv2, kReductionAxis); // Group 1
-  fusion->addOutput(output);
+  // Re do it all at once, because why not.
+  tv0->computeAt(tv1, 1);
 
-  auto options = at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0);
-  at::Tensor at_x = at::randn(input_shape, options);
+  tv2->axis(2)->parallelize(ParallelType::Unroll);
+  tv1->axis(0)->parallelize(ParallelType::TIDx);
 
-  FusionExecutorCache executor_cache(std::move(fusion));
+  tv1->axis(-1)->parallelize(ParallelType::BIDx);
+  tv2->axis(-1)->parallelize(ParallelType::BIDx);
+  tv3->axis(-1)->parallelize(ParallelType::BIDx);
 
-  auto outputs = executor_cache.runFusionWithInputs({at_x});
+  int numel_x = bdimx;
+  int numel_y = 65000;
 
-  auto t1 = at_x.add(1.0);
-  auto t2 = t1.sum({2});
-  auto t3 = at::_softmax(t2.to(at::kDouble), -1, false);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
 
-  auto optimized_fusion = executor_cache.getMostRecentKernelRuntime();
-  TORCH_CHECK(optimized_fusion->isSegmented(), "segmentation didn't happen");
-  TORCH_CHECK(
-      optimized_fusion->fusionSegments()->groups().size() == 2,
-      "segmentation didn't happen as expected");
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  fe.runFusion({input}, {cg_output});
 
-  testValidate(
-      executor_cache.fusion(), outputs, {at_x}, {t3}, __LINE__, __FILE__);
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(cg_output));
 }
 
-TEST(NVFuserTest, FusionSwizzle1_CUDA) {
+// Grid reduction with 2D thread blocks but only TIDx and BIDx are
+// mapped to a reduction dim
+TEST(NVFuserTest, FusionGridReduction5_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
+  const int bdimx = 64;
+  const int bdimy = 16;
+  const int gdimx = 4;
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = mul(tv1, new Double(2));
-  fusion.addOutput(tv2);
 
-  tv2->split(0, 7);
-  tv2->split(0, 9);
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  fusion.addOutput(tv1);
 
-  tv0->computeAt(tv2, 1);
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
 
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
+  tv1->split(1, bdimx);
+  // tv1[I0, R1o, R1i{64}] = tv0[I0, I1]
+  tv1->split(1, gdimx);
+  // tv1[I0, R1oo, R1oi{4}, R1i{64}] = tv0[I0, I1]
+
+  TensorView* tv2 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}] = tv0[I0, I1]
+  // tv1[I0,        R1oi{4},  R1i{64}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}]
+
+  tv0->computeAt(tv1, 1);
 
-  tv1->setMemoryType(MemoryType::Shared);
-  tv1->swizzle(SwizzleType::Transpose, {1, 2});
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
 
-  tv1->axis(1)->parallelize(ParallelType::TIDx);
-  tv1->axis(2)->parallelize(ParallelType::TIDy);
+  tv1->axis(-2)->parallelize(ParallelType::BIDx);
+  tv2->axis(-2)->parallelize(ParallelType::BIDx);
 
-  tv2->axis(1)->parallelize(ParallelType::TIDx);
-  tv2->axis(2)->parallelize(ParallelType::TIDy);
+  tv1->axis(0)->parallelize(ParallelType::TIDy);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({100}, options);
+  int numel_x = bdimy;
+  int numel_y = 6500;
 
-  std::vector<IValue> aten_inputs = {t0};
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  auto outputs = fe.runFusion({input});
 
-  auto aten_output = (t0 + 1) * 2;
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(outputs[0]));
 }
 
-TEST(NVFuserTest, FusionSwizzle2_CUDA) {
+// Similar to FusionGridReduction1 but with 3D tensors
+TEST(NVFuserTest, FusionGridReduction6_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(3);
   fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = mul(tv1, new Double(2));
-  fusion.addOutput(tv2);
 
-  tv1->split(-1, 4);
-  tv1->split(-2, 4);
+  // tv1[I0, R1, R2] = tv0[I0, I1, I2]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Float(0), tv0);
+  fusion.addOutput(tv1);
 
-  tv2->split(-1, 4);
-  tv2->split(-2, 4);
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
 
-  tv0->computeAt(tv2, 1);
+  // Splitting for TID
+  tv1->split(2, 128);
+  // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
 
-  tv2->reorder({{-1, -2}});
+  // Splitting for BID
+  tv1->split(1, 128);
 
-  tv1->setMemoryType(MemoryType::Shared);
-  tv1->swizzle(SwizzleType::Transpose, {-2, -1});
+  // tv1[I0, R1o, R1i{128}, R2o, R2i{128}] = tv0[I0, I1, I2]
+
+  TensorView* tv2 = tv1->rFactor({3});
+  // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
+  // tv1[I0, R1o, R1i{128},      R2i{128}]
+
+  TensorView* tv3 = tv1->rFactor({1});
+  // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
+  // tv3[I0, R1o, I1i{128},      I2i{128}]
+  // tv1[I0,      R1i{128},      R2i{128}]
+
+  tv3->computeAt(tv1, 1);
+  tv2->computeAt(tv3, 3);
+
+  tv1->axis(0)->parallelize(ParallelType::BIDy);
 
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-2)->parallelize(ParallelType::TIDy);
   tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv1->axis(-2)->parallelize(ParallelType::TIDy);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({123}, options);
+  tv1->axis(-2)->parallelize(ParallelType::BIDx);
+  tv2->axis(-3)->parallelize(ParallelType::BIDx);
+  tv3->axis(-2)->parallelize(ParallelType::BIDx);
+
+  int numel_x = 6500;
+  int numel_y = 200;
+  int numel_z = numel_y;
 
-  std::vector<IValue> aten_inputs = {t0};
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  auto aten_output = (t0 + 1) * 2;
+  fe.runFusion({input}, {cg_output});
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  auto aten_output = input.sum({1, 2});
+  TORCH_CHECK(aten_output.allclose(cg_output));
 }
 
-TEST(NVFuserTest, FusionTransposeWithSwizzle_CUDA) {
+TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) {
+  int bid_x = 3;
+  int tid_x = 2;
+  int red_dim = 0;
+
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
-  auto tv1 = transpose(tv0, {{0, 1}});
-  fusion.addOutput(tv1);
-
-  // tv0: [I0, I1]
-  // tv1: [I1, I0]
-
-  const int BS = 32;
-
-  // CTA tiling by BS*BS
-  tv1->split(1, BS);
-  tv1->split(0, BS);
-  tv1->reorder({{1, 2}});
-  // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)]
-
-  // Create a smem buffer to cache each tile
-  auto tv0_cache = tv0->cache_after();
-  tv0_cache->setMemoryType(MemoryType::Shared);
-
-  tv0->computeAt(tv1, 2);
-  // tv0: [I0, I1]
-  // tv0_cache: [I1/BS, I0/BS, BS(I1), BS(I0)]
-  // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)]
-
-  // Assign each thread block to a tile
-  tv1->axis(0)->parallelize(ParallelType::BIDy);
-  tv1->axis(1)->parallelize(ParallelType::BIDx);
 
-  // Thread mapping for each tile. For both of the input and output
-  // tiles, map TIDx to the fastest-changing dimension to facilitate
-  // coalesced gmem accesses.
-  tv1->axis(2)->parallelize(ParallelType::TIDy);
-  tv1->axis(3)->parallelize(ParallelType::TIDx);
-  // Note that the fastest-changing axis is next to the inner-most
-  // axis since computeAt reorders the axes as the output tensor.
-  tv0_cache->axis(2)->parallelize(ParallelType::TIDx);
-  tv0_cache->axis(3)->parallelize(ParallelType::TIDy);
+  TensorView* tv1 =
+      reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0);
+  fusion.addOutput(tv1);
 
-  // Swizzles the smem cache to avoid bank conflicts
-  tv0_cache->swizzle(SwizzleType::Transpose, {3, 2});
+  tv1->split(-1, tid_x);
+  tv1->axis(-2)->parallelize(ParallelType::BIDx);
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const int bx = 100;
-  const int by = 200;
-  at::Tensor t0 = at::randn({bx, by}, options);
-  std::vector<IValue> aten_inputs = {t0};
+  at::Tensor input = at::rand({16, bid_x * tid_x}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({input});
 
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  auto aten_output = t0.t();
+  auto aten_output = input.sum({red_dim});
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0]),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
 }
 
-TEST(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) {
+TEST(NVFuserTest, FusionSplitBCast_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = transpose(tv0, {{0, 1}});
-  fusion.addOutput(tv1);
-
-  // tv0: [I0, I1]
-  // tv1: [I1, I0]
-
-  const int BS = 32;
-  const int BDIM = 256;
+  // Set up your input tensor views
+  TensorView* input_tv0 = makeDummyTensor(3);
+  TensorView* input_tv1 = makeDummyTensor(3);
+  fusion.addInput(input_tv0);
+  fusion.addInput(input_tv1);
 
-  // CTA tiling by BS*BS
-  tv1->split(1, BS);
-  tv1->split(0, BS);
-  tv1->reorder({{1, 2}});
-  // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)]
+  TensorView* sum_tv2 =
+      reductionOp(BinaryOpType::Add, {2}, new Float(0), input_tv0);
+  TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true});
+  TensorView* output_tv4 = div(input_tv1, bcast_tv3);
 
-  // Create a smem buffer to cache each tile
-  auto tv0_cache = tv0->cache_after();
-  tv0_cache->setMemoryType(MemoryType::Shared);
+  sum_tv2->split(-1, 32);
+  TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2});
 
-  tv0->computeAt(tv1, 2);
-  // tv0: [I0, I1]
-  // tv0_cache: [I1/BS, I0/BS, BS*BS/BDIM, BDIM]
-  // tv1: [I1/BS, I0/BS, BS*BS/BDIM, BDIM]
+  bcast_tv3->split(-1, 32);
+  output_tv4->split(-1, 32);
 
-  // Tranform the tile axes for 1D thread mapping
-  tv1->merge(-2, -1);
-  tv1->split(-1, BDIM);
-  // tv1: [I1/BS, I0/BS, BS*BS/BDIM, BDIM]
+  sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx);
+  sum_tv2->axis(0)->parallelize(ParallelType::BIDx);
+  bcast_tv3->axis(0)->parallelize(ParallelType::BIDx);
+  output_tv4->axis(0)->parallelize(ParallelType::BIDx);
 
-  // Transform the cache similarly but apply swizzle to the 2D tile axes.
-  tv0_cache->reorder({{-2, -1}});
-  tv0_cache->swizzle(SwizzleType::Transpose, {2, 3});
-  tv0_cache->merge(-2, -1);
-  tv0_cache->split(-1, BDIM);
-  // tv0: [I1/BS, I0/BS, BS*BS/BDIM, BDIM]
+  sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy);
+  sum_tv2->axis(1)->parallelize(ParallelType::BIDy);
+  bcast_tv3->axis(1)->parallelize(ParallelType::BIDy);
+  output_tv4->axis(1)->parallelize(ParallelType::BIDy);
 
-  // Assign each thread block to a tile
-  tv1->axis(0)->parallelize(ParallelType::BIDy);
-  tv1->axis(1)->parallelize(ParallelType::BIDx);
+  sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx);
+  sum_tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  bcast_tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  output_tv4->axis(-1)->parallelize(ParallelType::TIDx);
 
-  // Thread mapping for each tile.
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv0_cache->axis(-1)->parallelize(ParallelType::TIDx);
+  fusion.addOutput(output_tv4);
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const int bx = 100;
-  const int by = 200;
-  at::Tensor t0 = at::randn({bx, by}, options);
-  std::vector<IValue> aten_inputs = {t0};
+  at::Tensor t0 = at::randn({32, 32, 128}, options);
+  at::Tensor t1 = at::randn({32, 32, 128}, options);
+  at::Tensor cg_output = at::empty({32, 32, 128}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  auto aten_output = t0.t();
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  fe.runFusion({t0, t1}, {cg_output});
 }
 
-// Grid reduction can be executed only once in a kernel. Should result
-// in an error at the time of compilation.
-TEST(NVFuserTest, FusionGridReductionInLoop_CUDA) {
+TEST(NVFuserTest, FusionBCastInnerDim_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
-  auto tv1 = sum(tv0, {1});
-  fusion.addOutput(tv1);
 
-  tv1->axis(1)->parallelize(ParallelType::BIDx);
+  // reduce then broadcast
+  auto tv1 = sum(tv0, {0});
+  auto tv2 = broadcast(tv1, {false, true});
 
-  FusionExecutor fe;
-  ASSERT_ANY_THROW(fe.compileFusion(&fusion));
+  TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast());
 }
 
-TEST(NVFuserTest, FusionIssue633_CUDA) {
+TEST(NVFuserTest, FusionBCastReduce_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  const int dx = 10;
-  const int dy = 11;
-  const int dz = 12;
-
-  auto tv0 = makeConcreteTensor({dx, dy, dz});
-  fusion.addInput(tv0);
-  auto tv1 = makeConcreteTensor({dx, dy, 1});
-  fusion.addInput(tv1);
-  auto tv2 = add(tv0, tv1);
-  fusion.addOutput(tv2);
-
-  tv2->merge(1);
-  tv2->merge(0);
-  tv2->split(-1, 128);
-
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(1)->parallelize(ParallelType::TIDx);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({dx, dy, dz}, options);
-  at::Tensor t1 = at::randn({dx, dy, 1}, options);
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  auto aten_output = t0 + t1;
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  auto tv1 = broadcast(tv0, {true, false, false});
+  auto tv2 = sum(tv1, {1});
+  TORCH_CHECK(
+      tv2->axis(0)->isBroadcast() && tv2->axis(1)->isReduction() &&
+      !tv2->axis(2)->isBroadcast() && !tv2->axis(2)->isReduction());
 }
 
-TEST(NVFuserTest, FusionKirScoping_CUDA) {
+// Multiple consumer reduction with computeAt
+// https://github.com/csarofeen/pytorch/issues/110
+TEST(NVFuserTest, FusionReductionMultiConsumer_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(2));
-  fusion.addOutput(tv2);
+  auto tv1 = unaryOp(UnaryOpType::Exp, tv0);
+  auto tv2 = reductionOp(BinaryOpType::Max, {-1}, new Float(0), tv1);
+  auto tv3 = reductionOp(BinaryOpType::Min, {-1}, new Float(0), tv1);
+  auto tv4 = add(tv2, tv3);
+  fusion.addOutput(tv4);
+  tv1->computeAt(tv2, -1);
 
-  tv2->merge(0);
-  tv2->split(0, 4);
-  tv0->computeAt(tv2, -1);
+  TORCH_CHECK(
+      (tv1->getComputeAtView() == tv2 || tv1->getComputeAtView() == tv3) &&
+      tv1->getThisComputeAtAxis() == 2 && tv1->getRelativeComputeAtAxis() == 2);
+}
 
-  GpuLower gpulw(&fusion);
+TEST(NVFuserTest, FusionComputeAtExprOrder1_CUDA) {
+  for (int i = 0; i < 2; ++i) {
+    Fusion fusion;
+    FusionGuard fg(&fusion);
 
-  auto kir_tv1 = gpulw.lowerValue(tv1);
-  auto tv1_scope = kir_tv1->definition()->scope();
-  TORCH_CHECK(tv1_scope != nullptr);
-  TORCH_CHECK(tv1_scope->owner()->as<kir::IfThenElse>());
+    // Set up your input tensor views
+    TensorView* tv0 = makeDummyTensor(1);
+    fusion.addInput(tv0);
 
-  auto kir_tv2 = gpulw.lowerValue(tv2);
-  auto tv2_scope = kir_tv2->definition()->scope();
-  TORCH_CHECK(tv2_scope != nullptr);
-  TORCH_CHECK(tv2_scope->owner()->as<kir::IfThenElse>());
+    auto tv1 = add(tv0, new Float(1));
+    auto tv2 = add(tv0, new Float(1));
+    TensorView* tv3 = add(tv1, tv2);
+    if (i == 0) {
+      tv1->computeAt(tv3, -1);
+      fusion.addOutput(tv2);
+    } else {
+      tv2->computeAt(tv3, -1);
+      fusion.addOutput(tv1);
+    }
+    fusion.addOutput(tv3);
 
-  TORCH_CHECK(tv1_scope != tv2_scope);
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+    at::Tensor input = at::rand({100}, options);
 
-  // tv1 and tv2 should have the same inner-most ForLoop
-  auto parent_scope = tv1_scope->owner()->scope();
-  TORCH_CHECK(parent_scope == tv2_scope->owner()->scope());
-  TORCH_CHECK(parent_scope->owner()->as<kir::ForLoop>());
-  // There should be one more loop
-  parent_scope = parent_scope->owner()->scope();
-  TORCH_CHECK(parent_scope->owner()->as<kir::ForLoop>());
+    FusionExecutor fe;
+    fe.compileFusion(&fusion);
+    auto outputs = fe.runFusion({input});
 
-  // scope() should return nullptr for top-level exprs
-  auto top_level_scope = parent_scope->owner()->scope();
-  TORCH_CHECK(top_level_scope == nullptr);
+    auto aten_output = (input + 1) * 2;
+    TORCH_CHECK(
+        aten_output.allclose(outputs[1]),
+        "Error of: ",
+        aten_output.sub(outputs[1]).abs().max());
+  }
 }
 
-TEST(NVFuserTest, FusionBroadcastAcrossComputeAt_CUDA) {
+TEST(NVFuserTest, FusionComputeAtExprOrder2_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  std::vector<int64_t> shape{17, 19};
-
-  auto tv0 = makeSymbolicTensor(1);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
-  auto tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
-  auto tv2 = broadcast(tv0, {false, true});
-  auto tv3 = add(tv1, tv2);
+
+  auto tv1 = add(tv0, new Float(1));
+  auto tv2 = add(tv0, new Float(1));
+  TensorView* tv3 = add(tv1, tv2);
   fusion.addOutput(tv3);
 
-  tv3->split(1, 128);
-  tv0->computeAt(tv3, 2);
+  tv3->split(-1, 32);
 
-  for (auto tv : {tv2, tv3}) {
-    tv->axis(-1)->parallelize(ParallelType::TIDx);
-  }
+  tv1->computeAt(tv3, -1);
+  tv2->computeAt(tv3, -2);
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({shape[0]}, options);
-  at::Tensor t1 = at::randn(shape, options);
-  std::vector<IValue> aten_inputs = {t0, t1};
+  at::Tensor input = at::rand({100, 100}, options);
+  at::Tensor output = at::empty_like(input, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  fe.runFusion({input}, {output});
 
-  auto t3 = t0.unsqueeze(-1).expand(shape) + t1;
-
-  testValidate(&fusion, cg_outputs, aten_inputs, {t3}, __LINE__, __FILE__);
+  auto aten_output = (input + 1) * 2;
+  TORCH_CHECK(
+      aten_output.allclose(output),
+      "Error of: ",
+      aten_output.sub(output).abs().max());
 }
 
-TEST(NVFuserTest, FusionVectorizeMisalignedPointwise_CUDA) {
+TEST(NVFuserTest, FusionZeroDimComputeAt_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeContigTensor(2);
-  auto tv1 = makeContigTensor(2);
+  TensorView* tv0 = makeDummyTensor(1);
   fusion.addInput(tv0);
-  fusion.addInput(tv1);
 
-  auto tv2 = add(tv0, tv1);
+  auto tv1 = sum(tv0, {0});
+  auto tv2 = add(tv1, new Float(1));
   fusion.addOutput(tv2);
-
-  const int kTDX = 64;
-  const int kVecSize = 4;
-  const int kNumElems = kTDX * kVecSize;
-
-  tv2->split(1, kNumElems);
-
-  auto c0 = tv0->cache_after();
-  auto c1 = tv1->cache_after();
-  auto c2 = tv2->cache_before();
-
-  tv2->split(-1, kVecSize);
-
-  c0->computeAt(tv2, -2);
-  c1->computeAt(tv2, -2);
-
-  c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
-  c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
-
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(-2)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+  TORCH_CHECK(tv2->nDims() == 0);
+  tv1->computeAt(tv2, 0);
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const int bx = 128;
-  const int by = 457;
-  at::Tensor t0 = at::randn({bx, by}, options);
-  at::Tensor t1 = at::randn({bx, by}, options);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
+  at::Tensor input = at::rand({100}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  auto outputs = fe.runFusion({input});
 
-  auto aten_output = t0 + t1;
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  auto aten_output = input.sum() + 1;
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0]),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
 }
 
-TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeContig_CUDA) {
+TEST(NVFuserTest, FusionZeroDimBroadcast_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeContigTensor(4);
-  auto tv1 = makeContigTensor(4);
+  TensorView* tv0 = makeDummyTensor(0);
   fusion.addInput(tv0);
-  fusion.addInput(tv1);
 
-  auto tv2 = add(tv0, tv1);
-  fusion.addOutput(tv2);
-
-  tv2->reorder({{0, 1}, {1, 0}});
-  tv2->merge(-2);
-
-  const int kTDX = 64;
-  const int kVecSize = 2;
-  const int kNumElems = kTDX * kVecSize;
-
-  tv2->split(-1, kNumElems);
-
-  auto c0 = tv0->cache_after();
-  auto c1 = tv1->cache_after();
-  auto c2 = tv2->cache_before();
-
-  tv2->split(0, 128);
-  tv2->split(-1, kVecSize);
+  auto tv1 = broadcast(tv0, {true, true});
+  TORCH_CHECK(tv1->nDims() == 2);
 
-  c0->computeAt(tv2, -2);
-  c1->computeAt(tv2, -2);
+  TensorView* tv2 = makeDummyTensor(2);
+  fusion.addInput(tv2);
 
-  c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
-  c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+  auto tv3 = add(tv1, tv2);
+  auto tv4 = sum(tv3, {0, 1});
+  fusion.addOutput(tv4);
 
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(1)->parallelize(ParallelType::BIDy);
-  tv2->axis(-2)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+  tv3->computeAt(tv4, -1);
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const int n = 32;
-  const int c = 127;
-  const int h = 51;
-  const int w = 23;
-  at::Tensor t0 = at::randn({n, c, h, w}, options);
-  at::Tensor t1 = at::randn({n, c, h, w}, options);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
+  at::Tensor input1 = at::rand({}, options);
+  at::Tensor input2 = at::rand({10, 10}, options);
+  at::Tensor output = at::empty({}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  fe.runFusion({input1, input2}, {output});
 
-  auto aten_output = t0 + t1;
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  auto aten_output =
+      (input1.unsqueeze(-1).unsqueeze(-1).expand({10, 10}) + input2).sum();
+  TORCH_CHECK(
+      aten_output.allclose(output),
+      "Error of: ",
+      aten_output.sub(output).abs().max());
 }
 
-TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) {
+TEST(NVFuserTest, FusionZeroDimReduction_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  constexpr int kNumDims = 4;
-  constexpr int kTDX = 64;
-  constexpr int kVecSize = 2;
-  constexpr int kNumElems = kTDX * kVecSize;
+  const int bdimx = 32;
+  const int gdimx = 32;
 
-  auto tv0 = makeSymbolicTensor(kNumDims);
-  auto tv1 = makeSymbolicTensor(kNumDims);
+  TensorView* tv0 = makeDummyTensor(1);
   fusion.addInput(tv0);
-  fusion.addInput(tv1);
 
-  auto tv2 = add(tv0, tv1);
-  fusion.addOutput(tv2);
-
-  // Create caches for vectorization
-  auto c0 = tv0->cache_after();
-  auto c1 = tv1->cache_after();
-  auto c2 = tv2->cache_before();
-
-  // Merge all dimensions together except inner-most dim
-  for (int idx = 0; idx < kNumDims - 2; ++idx) {
-    tv2->merge(0);
-  }
-  // Split inner-most dim
-  tv2->split(-1, kNumElems);
-  tv2->split(-1, kVecSize);
-  TransformPropagator::from(tv2);
-
-  c0->computeAt(tv2, -2);
-  c1->computeAt(tv2, -2);
+  auto tv1 = sum(tv0, {0});
+  fusion.addOutput(tv1);
 
-  // Parallelization Strategy
-  c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
-  c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+  tv1->split(0, bdimx);
+  tv1->split(0, gdimx);
+  auto tv2 = tv1->rFactor({0});
 
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(2)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv1->axis(-2)->parallelize(ParallelType::BIDx);
+  tv2->axis(-2)->parallelize(ParallelType::BIDx);
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const int n = 5;
-  const int c = 3;
-  const int h = 51;
-  const int w = 257;
-  at::Tensor t0 = at::randn({n, c, h, w}, options);
-  at::Tensor t1 = at::randn({n, c, h, w}, options);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
+  at::Tensor input = at::rand({1000}, options);
+  at::Tensor output = at::empty({}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  auto aten_output = t0 + t1;
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicFail_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  constexpr int kNumDims = 4;
-  constexpr int kTDX = 64;
-  constexpr int kVecSize = 2;
-  constexpr int kNumElems = kTDX * kVecSize;
-  std::vector<int64_t> bcast_shape{1, 1, 1, -1};
-
-  auto tv0 = makeContigTensor(kNumDims);
-  auto tv1 = TensorViewBuilder().shape(bcast_shape).build();
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  auto tv2 = add(tv0, tv1);
-  fusion.addOutput(tv2);
-
-  // Create caches for vectorization
-  auto c0 = tv0->cache_after();
-  auto c1 = tv1->cache_after();
-  auto c2 = tv2->cache_before();
-
-  // Merge all dimensions together
-  // Backward merge order is necessary for vectorize validation
-  for (int idx = kNumDims - 1; idx > 0; --idx) {
-    tv2->merge(idx - 1);
-  }
-  tv2->split(-1, kNumElems);
-  tv2->split(-1, kVecSize);
-  TransformPropagator::from(tv2);
-
-  c0->computeAt(tv2, -2);
-  c1->computeAt(tv2, -2);
+  fe.runFusion({input}, {output});
 
-  // Parallelization Strategy
-  c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
-  c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
-
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const int n = 32;
-  const int c = 128;
-  const int h = 51;
-  const int w = 23;
-  at::Tensor t0 = at::randn({n, c, h, w}, options);
-  at::Tensor t1 = at::randn({1, 1, 1, w}, options);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
-
-  FusionExecutor fe;
-  // TODO: throw assertion - cannot merge non-contiguous vectorization axes
-  // Make sure compilation fails
-  ASSERT_ANY_THROW(fe.compileFusion(&fusion));
+  auto aten_output = input.sum();
+  TORCH_CHECK(
+      aten_output.allclose(output),
+      "Error of: ",
+      aten_output.sub(output).abs().max());
 }
 
-TEST(NVFuserTest, FusionVectorizeMisalignedRFactor_CUDA) {
+TEST(NVFuserTest, FusionBCastAfterReduce_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
+  const int tidx = 128;
 
-  auto tv0 = makeContigTensor(2);
-  auto tv1 = makeContigTensor(2);
-
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  auto tv2 = add(tv0, tv1);
 
-  auto tv3 = sum(tv2, {-1});
-
-  fusion.addOutput(tv3);
+  auto tv1 = sum(tv0, {1});
+  auto tv2 = broadcast(tv1, {false, true});
 
-  auto c0 = tv0->cache_after();
-  auto c1 = tv1->cache_after();
+  tv1->split(1, tidx);
+  auto tv3 = tv1->rFactor({-2});
 
-  tv3->split(-1, 128 * 4);
-  tv3->split(-1, 4);
-  // Reduce outer dim first
-  auto tv4 = tv3->rFactor({-3, -1});
-  // Tv3 will reduce threads
+  TensorView* tv4 = makeDummyTensor(2);
+  fusion.addInput(tv4);
 
-  tv0->computeAt(tv3, 1);
-  tv1->computeAt(tv3, 1);
+  auto tv5 = add(tv2, tv4);
+  fusion.addOutput(tv5);
+  tv5->split(1, tidx);
 
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
+  tv3->computeAt(tv5, 1);
 
-  tv0->computeAt(tv4, -2);
-  tv1->computeAt(tv4, -2);
+  tv2->split(1, tidx);
 
-  c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
-  c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  tv5->axis(-1)->parallelize(ParallelType::TIDx);
 
-  tv4->axis(-2)->parallelize(ParallelType::TIDx);
-  tv3->axis(1)->parallelize(ParallelType::TIDx);
+  tv5->axis(0)->parallelize(ParallelType::BIDx);
 
-  tv2->computeAt(tv4, -1);
+  int x = 63, y = 200;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const int bx = 128;
-  const int by = 2050;
-  at::Tensor t0 = at::randn({bx, by}, options);
-  at::Tensor t1 = at::randn({bx, by}, options);
 
-  std::vector<IValue> aten_inputs = {t0, t1};
+  at::Tensor t0 = at::randn({x, y}, options);
+  at::Tensor t4 = at::randn({x, y}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  auto outputs = fe.runFusion({t0, t4});
 
-  auto aten_output = t0.add(t1).sum(1);
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  auto t3 = t0.sum({1}).unsqueeze(-1).expand({x, y});
+  auto t5 = t3.add(t4);
+
+  // Error is larger than the default threshold
+  TORCH_CHECK(t5.allclose(outputs[0], 1e-5, 1e-5));
 }
 
-TEST(NVFuserTest, FusionVectorizeMisalignedWrongDimFail_CUDA) {
+TEST(NVFuserTest, FusionReductionScheduler_CUDA) {
+  constexpr int bid_x = 80;
+  constexpr int tid_x = 4096;
+  constexpr int red_dim = 1;
+
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeContigTensor(2);
-  auto tv1 = makeContigTensor(2);
-
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  auto tv2 = add(tv0, tv1);
-  fusion.addOutput(tv2);
-
-  tv2->split(1, 16);
-  tv2->split(1, 64);
-
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(2)->parallelize(ParallelType::TIDx);
 
-  auto c0 = tv0->cache_after();
-  auto c1 = tv1->cache_after();
-  auto c2 = tv2->cache_before();
+  TensorView* tv1 =
+      reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0);
+  fusion.addOutput(tv1);
 
-  c0->computeAt(tv2, -2);
-  c1->computeAt(tv2, -2);
+  const auto options =
+      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::randn({bid_x, tid_x}, options);
 
-  std::vector<TensorView*> vectorized_tvs = {c0, c1, tv2};
-  for (auto tv : vectorized_tvs) {
-    tv->split(-1, 4);
-    // Vectorize the wrong dimension
-    tv->axis(-2)->parallelize(ParallelType::MisalignedVectorize);
-  }
+  // Apply reduction heuristic
+  auto reduction_params = getReductionHeuristics(&fusion, {input}, tv1);
+  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
+  scheduleReduction(&fusion, reduction_params.value(), tv1, {});
 
   FusionExecutor fe;
-  // Make sure compilation fails
-  ASSERT_ANY_THROW(fe.compileFusion(&fusion));
+  fe.compileFusion(&fusion);
+  // no broadcasting needed, omitting the last optional argument;
+  auto outputs = fe.runFusion({input}, reduction_params.value().lparams);
+  auto aten_output = input.sum({red_dim});
+
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-04, 1e-04),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
 }
 
-TEST(NVFuserTest, FusionVectorizeMisalignedStride_CUDA) {
+// Simple reduction parallelized on a symbolic size.
+TEST(NVFuserTest, FusionSymbolicReduction_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-  auto tv1 = makeSymbolicTensor(2);
-
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  auto tv2 = add(tv0, tv1);
-  fusion.addOutput(tv2);
 
-  const int kTDX = 64;
-  const int kVecSize = 4;
-  const int kNumElems = kTDX * kVecSize;
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  fusion.addOutput(tv1);
 
-  tv2->split(1, kNumElems);
+  // Interface should just be a direct split with a Parallel type. We can
+  // include the parallelize call if we do this.
+  tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
+  // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
 
-  auto c0 = tv0->cache_after();
-  auto c1 = tv1->cache_after();
+  TensorView* tv2 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}] = tv0[I0, I1]
+  // tv1[I0,        R1oi{4},  R1i{BIDx}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}]
 
-  tv2->split(-1, kVecSize);
+  // Incrementally, can print in between for debugging
+  tv0->computeAt(tv2, 1);
+  tv2->computeAt(tv1, 1);
 
-  c0->computeAt(tv2, -2);
-  c1->computeAt(tv2, -2);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
 
-  c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
-  c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+  tv1->axis(0)->parallelize(ParallelType::BIDx);
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
 
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(-2)->parallelize(ParallelType::TIDx);
+  int numel_x = 65000;
+  int numel_y = 1025;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const int bx = 128;
-  const int by = 2049;
-  at::Tensor t0 = at::randn({bx, by}, options).index({"...", Slice(3)});
-  at::Tensor t1 = at::randn({bx, by}, options).index({"...", Slice(3)});
-  std::vector<IValue> aten_inputs = {t0, t1};
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+
+  // How many threads to use for the block reduction
+  int runtime_threadIdx_dim = 128;
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+  auto outputs = fe.runFusion(
+      {input}, LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1));
 
-  auto aten_output = t0 + t1;
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(outputs[0]));
 }
 
-TEST(NVFuserTest, FusionVectorizeMisalignedStrideFail_CUDA) {
+TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) {
+  const std::vector<int> red_dims = {0, 2};
+  // Copy is because CodeGen requires int and Pytorch requires int64_t
+  // for a vector of reduction dimensions
+  const std::vector<int64_t> red_dims64 = {0, 2};
+  const std::vector<int64_t> tensor_dims_in = {5, 10, 15, 20};
+  const std::vector<int64_t> tensor_dims_out = {10, 20};
+
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-  auto tv1 = makeSymbolicTensor(2);
-
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(tensor_dims_in.size());
   fusion.addInput(tv0);
-  fusion.addInput(tv1);
 
-  auto tv2 = add(tv0, tv1);
-  fusion.addOutput(tv2);
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, red_dims, new Float(0), tv0);
+  fusion.addOutput(tv1);
+
+  const auto options =
+      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::randn(tensor_dims_in, options);
+  at::Tensor cg_output = at::empty(tensor_dims_out, options);
+
+  // Apply reduction heuristic
+  auto reduction_params = getReductionHeuristics(&fusion, {input}, tv1);
+  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
+  scheduleReduction(&fusion, reduction_params.value(), tv1, {});
+
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({input}, reduction_params.value().lparams);
 
-  const int kTDX = 64;
-  const int kVecSize = 4;
-  const int kNumElems = kTDX * kVecSize;
+  auto aten_output = input.sum(red_dims64);
 
-  tv2->split(1, kNumElems);
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-04, 1e-04),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
+}
 
-  auto c0 = tv0->cache_after();
-  auto c1 = tv1->cache_after();
-  auto c2 = tv2->cache_before();
+TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) {
+  const std::vector<int> red_dims = {1, 3};
+  // Copy is because CodeGen requires int and Pytorch requires int64_t
+  // for a vector of reduction dimensions
+  const std::vector<int64_t> red_dims64 = {1, 3};
+  const std::vector<int64_t> tensor_dims_in = {5, 10, 15, 20};
+  const std::vector<int64_t> tensor_dims_out = {5, 15};
 
-  tv2->split(-1, kVecSize);
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  c0->computeAt(tv2, -2);
-  c1->computeAt(tv2, -2);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(tensor_dims_in.size());
+  fusion.addInput(tv0);
 
-  c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
-  c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, red_dims, new Float(0), tv0);
+  fusion.addOutput(tv1);
 
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(-2)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+  const auto options =
+      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::randn(tensor_dims_in, options);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const int bx = 128;
-  const int by = 2049;
-  at::Tensor t0 = at::randn({bx, by}, options).index({"...", Slice(3)});
-  at::Tensor t1 = at::randn({bx, by}, options).index({"...", Slice(3)});
-  std::vector<IValue> aten_inputs = {t0, t1};
+  auto reduction_params = getReductionHeuristics(&fusion, {input}, tv1);
+  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
+  scheduleReduction(&fusion, reduction_params.value(), tv1, {});
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({input}, reduction_params.value().lparams);
+
+  auto aten_output = input.sum(red_dims64);
 
-  // Failure because the input + output tensors do not have the same stride
-  ASSERT_ANY_THROW(fe.runFusion(aten_inputs));
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-05, 1e-05),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
 }
 
-TEST(NVFuserTest, FusionVectorization1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+TEST(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) {
+  std::vector<bool> fp16_usage = {true, false};
+  std::vector<int> red_axis = {1, 0};
+  std::vector<int> output_dims = {320, 640};
+  std::vector<int> red_dims;
 
-  auto tv0 = makeSymbolicTensor(2);
+  // Making sure we get deterministic results
+  // (see https://github.com/csarofeen/pytorch/issues/399)
+  at::manual_seed(0);
 
-  auto tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
+  // Tried to cut down the number iterations with just
+  // doing every other power of 2.
+  for (int i = 1; i <= 1024 * 1024; i <<= 2) {
+    red_dims.push_back(i);
+  }
 
-  auto tv2 = add(tv0, tv1);
-  fusion.addOutput(tv2);
+  for (auto fp16 : fp16_usage) {
+    for (auto& axis : red_axis) {
+      for (auto& odim : output_dims) {
+        for (auto& rdim : red_dims) {
+          Fusion fusion;
+          FusionGuard fg(&fusion);
 
-  tv2->split(1, 16);
-  tv2->split(1, 64);
+          TensorView* tv0 =
+              makeDummyTensor(2, (fp16 ? DataType::Half : DataType::Float));
+          fusion.addInput(tv0);
 
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(2)->parallelize(ParallelType::TIDx);
+          Val* tv0_cast = nullptr;
+          if (fp16) {
+            tv0_cast = castOp(DataType::Float, tv0);
+          }
 
-  auto c0 = tv0->cache_after();
-  auto c1 = tv1->cache_after();
-  auto c2 = tv2->cache_before();
+          TensorView* tv1 = reductionOp(
+              BinaryOpType::Add,
+              {axis},
+              new Float(0),
+              (fp16 ? tv0_cast->as<TensorView>() : tv0));
 
-  c0->computeAt(tv2, -2);
-  c1->computeAt(tv2, -2);
+          TensorView* tv1_cast = nullptr;
+          if (fp16) {
+            tv1_cast = castOp(DataType::Half, tv1);
+          }
 
-  std::vector<TensorView*> vectorized_tvs = {c0, c1, tv2};
-  for (auto tv : vectorized_tvs) {
-    tv->split(-1, 4);
-    tv->axis(-1)->parallelize(ParallelType::Vectorize);
-  }
+          fusion.addOutput((fp16 ? tv1_cast : tv1));
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const int bx = 128;
-  const int by = 2048;
-  at::Tensor t0 = at::randn({bx, by}, options);
-  at::Tensor t1 = at::randn({bx, by}, options);
+          auto options = at::TensorOptions()
+                             .dtype((fp16 ? at::kHalf : at::kFloat))
+                             .device(at::kCUDA, 0);
+          at::Tensor input =
+              (axis ? at::randn({odim, rdim}, options)
+                    : at::randn({rdim, odim}, options));
 
-  std::vector<IValue> aten_inputs = {t0, t1};
+          std::vector<TensorView*> outputs_of_red;
+          if (fp16) {
+            outputs_of_red.push_back(tv1_cast);
+          }
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
+          auto reduction_params = getReductionHeuristics(&fusion, {input}, tv1);
+          TORCH_CHECK(reduction_params.has_value(), "Reduction is not found!");
+          scheduleReduction(
+              &fusion, reduction_params.value(), tv1, outputs_of_red);
+
+          FusionExecutor fe;
+          fe.compileFusion(&fusion);
 
-  auto aten_output = t0 + t1;
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+          auto outputs =
+              fe.runFusion({input}, reduction_params.value().lparams);
+          auto aten_output = input.sum({axis});
+
+          TORCH_CHECK(
+              aten_output.allclose(outputs[0], 1e-03, 1e-03),
+              "Error of: ",
+              aten_output.sub(outputs[0]).abs().max());
+        }
+      }
+    }
+  }
 }
 
-TEST(NVFuserTest, FusionVectorization2_CUDA) {
+TEST(NVFuserTest, FusionCacheBefore_CUDA) {
+  // TVM Cache Write
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-
-  auto tv1 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = add(tv0, new Float(1.0));
+  TensorView* tv2 = mul(tv1, new Float(3.0));
   fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  auto tv2 = add(tv0, tv1);
   fusion.addOutput(tv2);
+  // Before: TV2 = TV1 * 3
+  // After:  TV3 = TV1 * 3;
+  //         TV2 = TV3;
+
+  constexpr int BSX = 32;
+  tv2->split(-1, BSX);
+  tv0->computeAt(tv2, -1);
 
-  tv2->split(1, 16);
-  tv2->split(1, 64);
+  // cache_before automatically applies ComputeAt to the cache TensorView
+  tv2->cache_before();
 
+  // Thread and Block binding
   tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(2)->parallelize(ParallelType::TIDx);
-
-  auto c0 = tv0->cache_after();
-  auto c1 = tv1->cache_after();
-  auto c2 = tv2->cache_before();
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
 
-  c0->computeAt(tv2, -2);
-  c1->computeAt(tv2, -2);
+  constexpr int M = 32, N = 750;
 
-  std::vector<TensorView*> vectorized_tvs = {c0, c1, tv2};
-  for (auto tv : vectorized_tvs) {
-    tv->split(-1, 4);
-    // Vectorize the wrong dimension
-    tv->axis(-2)->parallelize(ParallelType::Vectorize);
-  }
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({M, N}, options);
 
   FusionExecutor fe;
-  // Make sure compilation fails
-  ASSERT_ANY_THROW(fe.compileFusion(&fusion));
+  fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({input});
+
+  at::Tensor aten_output = (input + 1.0) * 3.0;
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().sum());
 }
 
-TEST(NVFuserTest, FusionVectorization3_CUDA) {
+TEST(NVFuserTest, FusionCacheAfter_CUDA) {
+  // TVM Cache Read
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-
-  auto tv1 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = add(tv0, new Float(1.0));
+  TensorView* tv2 = mul(tv1, new Float(3.0));
   fusion.addInput(tv0);
-  fusion.addInput(tv1);
-
-  auto tv2 = add(tv0, tv1);
   fusion.addOutput(tv2);
+  // Before: TV1 = TV0 + 1
+  // After:  TV3 = TV0;
+  //         TV1 = TV3 + 1
 
-  tv2->split(1, 16);
-  tv2->split(1, 64);
-
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(2)->parallelize(ParallelType::TIDx);
+  constexpr int BSX = 32;
+  tv2->split(-1, BSX);
+  tv0->computeAt(tv2, -1);
 
-  auto c0 = tv0->cache_after();
-  auto c1 = tv1->cache_after();
-  auto c2 = tv2->cache_before();
+  // cache_after automatically applies ComputeAt to the cache TensorView
+  tv0->cache_after();
 
-  c0->computeAt(tv2, -2);
-  c1->computeAt(tv2, -2);
+  // Thread and Block binding
+  tv2->axis(0)->parallelize(ParallelType::BIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
 
-  std::vector<TensorView*> vectorized_tvs = {c0, c1, tv2};
-  for (auto tv : vectorized_tvs) {
-    tv->split(-1, 4);
-    tv->axis(-1)->parallelize(ParallelType::Vectorize);
-  }
+  constexpr int M = 32, N = 457;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const int bx = 128;
-  const int by = 2049;
-  at::Tensor t0 = at::randn({bx, by}, options);
-  at::Tensor t1 = at::randn({bx, by}, options);
+  at::Tensor input = at::rand({M, N}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({input});
 
-  std::vector<IValue> aten_inputs = {t0, t1};
-  ASSERT_ANY_THROW(fe.runFusion(aten_inputs));
-
-  aten_inputs[0] = t0.index({"...", Slice(1)});
-  aten_inputs[1] = t1.index({"...", Slice(1)});
-  ASSERT_ANY_THROW(fe.runFusion(aten_inputs));
-
-  t0 = at::randn({bx, 2048}, options).index({"...", Slice(4)});
-  t1 = at::randn({bx, 2048}, options).index({"...", Slice(4)});
-  aten_inputs = {t0, t1};
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  auto aten_output = t0 + t1;
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  at::Tensor aten_output = (input + 1.0) * 3.0;
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().sum());
 }
 
-TEST(NVFuserTest, FusionVectorizationRFactor_CUDA) {
+TEST(NVFuserTest, FusionCacheIndirect_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-
-  auto tv1 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = makeDummyTensor(2);
+  TensorView* tv2 = makeDummyTensor(2);
+  TensorView* tv3 = makeDummyTensor(2);
+  TensorView* tv4 = sub(tv2, tv3);
+  TensorView* tv5 = add(tv1, tv4);
+  TensorView* tv6 = sub(tv5, tv0);
   fusion.addInput(tv0);
   fusion.addInput(tv1);
+  fusion.addInput(tv2);
+  fusion.addInput(tv3);
+  fusion.addOutput(tv6);
+  // t6 = ((t1 + (t2 - t3)) - t0)
 
-  auto tv2 = add(tv0, tv1);
-
-  auto tv3 = sum(tv2, {-1});
-
-  fusion.addOutput(tv3);
-
-  tv3->split(-1, 128 * 4);
-  tv3->split(-1, 4);
-  // Reduce outer dim first
-  auto tv4 = tv3->rFactor({-3, -1});
-  // Tv3 will reduce threads
-
-  auto tv6 = tv0->cache_after();
-  auto tv7 = tv1->cache_after();
-
-  tv0->computeAt(tv3, 1);
-  tv1->computeAt(tv3, 1);
-
-  tv3->axis(0)->parallelize(ParallelType::BIDx);
+  // cache_after on inputs placed before schedule
+  constexpr int BSX = 32;
+  tv6->split(-1, BSX);
+  tv2->computeAt(tv6, -1);
 
-  tv0->computeAt(tv4, -2);
-  tv1->computeAt(tv4, -2);
+  tv5->cache_after();
+  tv5->cache_before();
 
-  tv6->axis(-1)->parallelize(ParallelType::Vectorize);
-  tv7->axis(-1)->parallelize(ParallelType::Vectorize);
+  // Thread and Block binding
+  tv6->axis(0)->parallelize(ParallelType::BIDx);
+  tv6->axis(-1)->parallelize(ParallelType::TIDx);
 
-  tv4->axis(-2)->parallelize(ParallelType::TIDx);
-  tv3->axis(1)->parallelize(ParallelType::TIDx);
+  constexpr int M = 32, N = 810;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  const int bx = 128;
-  const int by = 2048;
-  at::Tensor t0 = at::randn({bx, by}, options);
-  at::Tensor t1 = at::randn({bx, by}, options);
-
-  std::vector<IValue> aten_inputs = {t0, t1};
+  at::Tensor in0 = at::rand({M, N}, options);
+  at::Tensor in1 = at::rand({M, N}, options);
+  at::Tensor in2 = at::rand({M, N}, options);
+  at::Tensor in3 = at::rand({M, N}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-
-  auto aten_output = t0.add(t1).sum(1);
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+  auto outputs = fe.runFusion({in0, in1, in2, in3});
 
-  auto t3 = t0.add(t1).sum(1);
-
-  testValidate(&fusion, cg_outputs, aten_inputs, {t3}, __LINE__, __FILE__);
+  at::Tensor aten_output = (in1 + (in2 - in3)) - in0;
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().sum());
 }
 
-// Unswitched loops with extent one may omit else clause.
-TEST(NVFuserTest, FusionSizeOneLoop1_CUDA) {
+TEST(NVFuserTest, FusionCacheBcast_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  // Progressively broadcast tensors
-  TensorView* tv0 = makeSymbolicTensor(1);
+  // Algorithm
+  TensorView* tv0 = makeDummyTensor(1); // (M, 1)
+  TensorView* tv1 = broadcast(tv0, {false, true});
+  TensorView* tv2 = makeDummyTensor(1); // (1, N)
+  TensorView* tv3 = broadcast(tv2, {true, false});
+  TensorView* tv4 = mul(tv1, tv3);
   fusion.addInput(tv0);
-  TensorView* tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
-  TensorView* tv2 = makeSymbolicTensor(3);
   fusion.addInput(tv2);
+  fusion.addOutput(tv4);
 
-  TensorView* tv3 = broadcast(tv0, {false, true});
-  TensorView* tv4 = add(tv3, tv1);
-  TensorView* tv5 = add(tv4, tv2);
+  constexpr int BSX = 128;
+  tv4->split(0, BSX);
+  tv4->split(-1, BSX);
+  tv4->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
+  // M/BSX, N/BSY, BSX, BSY
+  tv0->computeAt(tv4, 2);
+  tv2->computeAt(tv4, 2);
+  // 0, 1 | 2, 3, 4
 
-  fusion.addOutput(tv5);
+  // Case 1
+  tv0->cache_after();
 
-  // Split inner dimension
-  tv5->split(1, 8);
-  // Merge middle dims with outer dimensions
-  tv5->merge(2);
-  tv5->merge(0);
+  // Case 2
+  tv1->cache_before();
 
-  // tv5[I0*I1o, I1i*I2]
-  // Get a dim of size 1 to unswitch
-  tv5->split(0, 1, false);
+  // Case 3
+  tv1->cache_after();
 
-  // Compute everything inline
-  tv0->computeAt(tv5, -1);
+  // Case 4
+  TensorView* tv8 = tv4->cache_before();
 
-  tv5->axis(0)->parallelize(ParallelType::Unswitch);
-  tv5->axis(1)->parallelize(ParallelType::BIDx);
-  tv5->axis(2)->parallelize(ParallelType::TIDx);
+  tv4->axis(0)->parallelize(ParallelType::BIDx);
+  tv4->axis(1)->parallelize(ParallelType::BIDy);
+  tv4->axis(-1)->parallelize(ParallelType::TIDx);
+  // Manual Replay on TV3
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  tv8->axis(-1)->parallelize(ParallelType::TIDx);
 
-  // Make sure the unswitched loop does not have an else clause.
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto fl = dynamic_cast<kir::ForLoop*>(kir_node.get())) {
-      if (fl->iter_domain()->parallelType() != ParallelType::Unswitch) {
-        continue;
-      }
-      if (auto pred = dynamic_cast<kir::IfThenElse*>(fl->parentScope())) {
-        TORCH_CHECK(!pred->hasElse());
-      }
-    }
-  }
+  constexpr int M = 92, N = 500;
 
-  const int x = 11;
-  const int y = 12;
-  const int z = 13;
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({x}, options);
-  at::Tensor t1 = at::randn({x, y}, options);
-  at::Tensor t2 = at::randn({z, x, y}, options);
-  std::vector<IValue> aten_inputs = {t0, t1, t2};
+  at::Tensor t0 = at::randn({M}, options);
+  at::Tensor t1 = at::randn({N}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-  auto t6 = (t0.unsqueeze(-1) + t1).unsqueeze(0) + t2;
+  auto outputs = fe.runFusion({t0, t1});
 
-  testValidate(&fusion, cg_outputs, aten_inputs, {t6}, __LINE__, __FILE__);
+  at::Tensor aten_output = t0.unsqueeze(1).matmul(t1.unsqueeze(0));
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
 }
 
-// The unswitched loop has extent one but inner loops don't. The else
-// part should not be omitted.
-TEST(NVFuserTest, FusionSizeOneLoop2_CUDA) {
+TEST(NVFuserTest, FusionCacheComplex_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  const int x = 15;
-  auto tv0 = makeConcreteTensor({x});
+  TensorView* tv0 = makeDummyTensor(2); // (N, N)
+  TensorView* tv1 = makeDummyTensor(1); // (N)
+  TensorView* tv2 = sum(tv0, {1}); // (N)
+  TensorView* tv3 = broadcast(tv2, {false, true}); // (N, 1)
+  TensorView* tv4 = broadcast(tv1, {true, false}); // (1, N)
+  TensorView* tv5 = mul(tv3, tv4); // (N, N)
   fusion.addInput(tv0);
+  fusion.addInput(tv1);
+  fusion.addOutput(tv5);
 
-  auto tv1 = add(tv0, new Double(1));
-  fusion.addOutput(tv1);
+  // Exception: Cache-Before on reduction Op
+  // TensorView* tv9 = tv2->cache_before();
+
+  constexpr int BSX = 128;
+  tv5->split(0, BSX);
+  tv5->split(-1, BSX);
+  // M/BSX, BSX, N/BSX, BSX
+  tv5->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
+  // M/BSX, N/BSY, BSX, BSY
+  tv0->computeAt(tv5, 2);
+  tv1->computeAt(tv5, 2);
+  // 0, 1 | 2, 3, 4
+
+  tv2->cache_after();
+  TensorView* tv7 = tv5->cache_before();
 
-  tv1->split(-1, 4);
-  tv1->split(-2, 1);
+  tv5->axis(0)->parallelize(ParallelType::BIDx);
+  tv5->axis(1)->parallelize(ParallelType::BIDy);
+  tv5->axis(-1)->parallelize(ParallelType::TIDx);
 
-  tv1->axis(-2)->parallelize(ParallelType::Unswitch);
+  tv4->axis(-1)->parallelize(ParallelType::TIDx);
+  tv7->axis(-1)->parallelize(ParallelType::TIDx);
 
-  // Make sure the size-one unswitched loop does not omit the else clause.
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto fl = dynamic_cast<kir::ForLoop*>(kir_node.get())) {
-      if (fl->iter_domain()->parallelType() != ParallelType::Unswitch) {
-        continue;
-      }
-      if (auto pred = dynamic_cast<kir::IfThenElse*>(fl->parentScope())) {
-        TORCH_CHECK(pred->hasElse());
-      }
-    }
-  }
+  constexpr int N = 800;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({x}, options);
-  std::vector<IValue> aten_inputs = {t0};
+  at::Tensor input1 = at::rand({N, N}, options);
+  at::Tensor input2 = at::rand({N}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion(aten_inputs);
-  auto t1 = t0 + 1;
+  auto outputs = fe.runFusion({input1, input2});
 
-  testValidate(&fusion, cg_outputs, aten_inputs, {t1}, __LINE__, __FILE__);
+  at::Tensor aten_output =
+      matmul(sum(input1, 1).unsqueeze(1), input2.unsqueeze(0));
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().sum());
 }
 
-TEST(NVFuserTest, FusionValidateParallelize1_CUDA) {
+TEST(NVFuserTest, FusionCacheMultiConsumer_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
+  TensorView* tv0 = makeDummyTensor(1);
+  TensorView* tv1 = add(tv0, new Float(1));
+  TensorView* tv2 = add(tv1, new Float(2));
+  TensorView* tv3 = add(tv0, new Float(1));
+  TensorView* tv4 = add(tv3, new Float(2));
 
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(1));
+  fusion.addInput(tv0);
   fusion.addOutput(tv2);
+  fusion.addOutput(tv4);
 
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDy);
-
-  // Invalid as tv1 and tv2 do have the same ParallelType
-  FusionExecutor fe;
-  ASSERT_ANY_THROW(fe.compileFusion(&fusion));
-}
+  tv1->computeAt(tv2, -1);
+  tv3->computeAt(tv4, -1);
 
-TEST(NVFuserTest, FusionValidateParallelize2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  auto tv5 = tv1->cache_before();
+  auto tv6 = tv3->cache_before();
+  tv5->setMemoryType(MemoryType::Shared);
+  tv6->setMemoryType(MemoryType::Shared);
 
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
+  // Fails because tensor must be recomputed twice
+  // auto tv7 = tv0->cache_after();
 
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(1));
-  fusion.addOutput(tv2);
+  constexpr int N = 800;
 
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDy);
-  tv1->setMemoryType(MemoryType::Shared);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({N}, options);
 
-  // tv1 and tv2 do have the same ParallelType, but tv1 is on shared
-  // memory, so it is valid
   FusionExecutor fe;
   fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({input});
+
+  auto aten_output = (input + 1) + 2;
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().sum());
+  TORCH_CHECK(
+      aten_output.allclose(outputs[1], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[1]).abs().sum());
 }
 
-TEST(NVFuserTest, FusionValidateParallelize3_CUDA) {
+TEST(NVFuserTest, FusionSmem_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
+  // Algorithm
+  TensorView* tv0 = makeDummyTensor(2); // (M, N)
+  TensorView* tv1 = makeDummyTensor(2); // (M, N)
+  TensorView* tv2 = mul(tv0, tv1);
   fusion.addInput(tv0);
-
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(1));
+  fusion.addInput(tv1);
   fusion.addOutput(tv2);
 
-  tv1->split(-1, 4);
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->split(-1, 4);
+  // Schedule
+  TensorView* tv3 = tv0->cache_after();
+  TensorView* tv4 = tv1->cache_after();
+  tv3->setMemoryType(MemoryType::Shared);
+  tv4->setMemoryType(MemoryType::Shared);
+
+  constexpr int BSY = 32;
+  constexpr int BSX = 128;
+  tv2->split(0, BSY);
+  tv2->split(2, BSX);
+  // M/BSX, BSX, N/BSX, BSX
+  tv2->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
+  // M/BSX, N/BSX, BSX, BSX
+
+  tv0->computeAt(tv2, 2);
+  tv1->computeAt(tv2, 2);
+
+  // Thread and Block binding
+  tv2->axis(0)->parallelize(ParallelType::BIDx);
+  tv2->axis(1)->parallelize(ParallelType::BIDy);
   tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  // Manual Binding
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  tv4->axis(-1)->parallelize(ParallelType::TIDx);
+
+  constexpr int M = 128, N = 10240;
 
-  tv1->setMemoryType(MemoryType::Global);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor t0 = at::randn({M, N}, options);
+  at::Tensor t1 = at::randn({M, N}, options);
 
-  // tv1 and tv2 have the same shape and ParallelType
   FusionExecutor fe;
   fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({t0, t1});
+
+  at::Tensor aten_output = mul(t0, t1);
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
+  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
 }
 
-TEST(NVFuserTest, FusionValidateParallelize4_CUDA) {
+TEST(NVFuserTest, FusionSmemReduce_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
+  // Algorithm
+  TensorView* tv0 = makeDummyTensor(3); // M, K, N
+  TensorView* tv1 = sum(tv0, {1}); // M, R, N
   fusion.addInput(tv0);
+  fusion.addOutput(tv1);
 
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(1));
-  fusion.addOutput(tv2);
+  TensorView* tv2 = tv0->cache_after();
+  tv2->setMemoryType(MemoryType::Shared);
+
+  // Schedule
+  constexpr int BSX = 32;
+  tv1->split(2, BSX);
+  tv1->split(1, 128);
+  tv1->split(0, BSX);
+  // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
+  tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}});
+  TensorView* tv3 = tv1->rFactor({-2});
+
+  tv0->computeAt(tv1, -2);
+  tv0->computeAt(tv3, -2);
 
-  tv1->split(-1, 4);
+  // Thread and Block binding
+  tv1->axis(0)->parallelize(ParallelType::BIDx);
+  tv1->axis(1)->parallelize(ParallelType::BIDy);
   tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->split(-1, 8);
+  // Manual Binding
   tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+
+  constexpr int M = 154, K = 45, N = 1524;
 
-  tv1->setMemoryType(MemoryType::Global);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor t0 = at::randn({M, K, N}, options);
 
-  // tv1 and tv2 do not have the same shape
   FusionExecutor fe;
-  ASSERT_ANY_THROW(fe.compileFusion(&fusion));
+  fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({t0});
+
+  at::Tensor aten_output = sum(t0, {1});
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
+  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 1);
+  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(24) == 1);
 }
 
-TEST(NVFuserTest, FusionValidateParallelize5_CUDA) {
+TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
+  // Algorithm
+  TensorView* tv0 = makeDummyTensor(2); // (M, K)
+  TensorView* tv1 = makeDummyTensor(2); // (K, N)
+  TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
+  TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
+  TensorView* tv4 = mul(tv2, tv3); // M, K, N
+  TensorView* tv5 = sum(tv4, {1}); // M, R, N
   fusion.addInput(tv0);
+  fusion.addInput(tv1);
+  fusion.addOutput(tv5);
 
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(1));
-  fusion.addOutput(tv2);
+  // Schedule
+  constexpr int BSX = 16;
+  tv5->split(2, BSX);
+  tv5->split(1, BSX);
+  tv5->split(0, BSX);
+  // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
+  tv5->reorder({{0, 0}, {1, 3}, {2, 2}, {3, 5}, {4, 1}, {5, 4}});
+  // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX
+  TensorView* tv6 = tv5->rFactor({-1});
 
-  tv1->split(-1, 4);
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv1->setMemoryType(MemoryType::Shared);
+  tv2->setMemoryType(MemoryType::Shared);
+  tv3->setMemoryType(MemoryType::Shared);
+  tv4->setMemoryType(MemoryType::Shared);
+  tv6->setMemoryType(MemoryType::Shared);
 
-  tv2->split(-1, 8);
+  tv0->computeAt(tv5, 3);
+  tv1->computeAt(tv5, 3);
+
+  // Thread and Block binding
+  tv5->axis(0)->parallelize(ParallelType::BIDx);
+  tv5->axis(1)->parallelize(ParallelType::BIDy);
+  tv5->axis(-2)->parallelize(ParallelType::TIDy);
+  tv5->axis(-1)->parallelize(ParallelType::TIDx);
+  // Manual Binding
   tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  tv4->axis(-1)->parallelize(ParallelType::TIDx);
+  tv6->axis(-3)->parallelize(ParallelType::TIDy);
+  tv6->axis(-2)->parallelize(ParallelType::TIDx);
+
+  constexpr int M = 154, K = 45, N = 1524;
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor t0 = at::randn({M, K}, options);
+  at::Tensor t1 = at::randn({K, N}, options);
 
-  // tv1 and tv2 do not have the same shape, but tv1 is on shared
-  // memory, so it is valid
   FusionExecutor fe;
   fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({t0, t1});
+
+  at::Tensor aten_output = matmul(t0, t1);
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
+  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
 }
 
-// See issue #995
-TEST(NVFuserTest, FusionValidateParallelize6_CUDA) {
+TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(3);
-  auto tv1 = makeSymbolicTensor(4);
+  // Algorithm
+  TensorView* tv0 = makeDummyTensor(2); // (M, K)
+  TensorView* tv1 = makeDummyTensor(2); // (K, N)
+  TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
+  TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
+  TensorView* tv4 = mul(tv2, tv3); // M, K, N
+  TensorView* tv5 = sum(tv4, {1}); // M, R, N
   fusion.addInput(tv0);
   fusion.addInput(tv1);
+  fusion.addOutput(tv5);
 
-  auto tv2 = add(tv0, new Double(1));
-  auto tv3 = broadcast(tv2, {true, false, false, false});
-  auto tv4 = add(tv3, tv1);
-  fusion.addOutput(tv4);
-
-  tv4->merge(0);
-  tv4->merge(0);
-  tv4->merge(0);
-  tv4->split(0, 128);
-  tv4->split(0, 1);
-  tv4->split(0, 1);
-
-  TransformPropagator::from(tv4);
+  // Schedule
+  // Remove reduction axis from tv5
+  // tv6 = (M, R, N)
+  // tv5 = (M, N)
+  TensorView* tv6 = tv5->cache_before();
 
-  tv0->computeAt(tv2, 2);
-  tv3->computeAt(tv4, 2);
+  constexpr int BSX = 16;
+  tv5->split(1, BSX);
+  tv5->split(0, BSX);
+  // M/BSX, BSX, N/BSX, BSX
+  tv5->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
+  // tv5 = M/BSX, N/BSX, MSX, NSX
 
-  tv4->axis(0)->parallelize(ParallelType::BIDx);
-  tv4->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(0)->parallelize(ParallelType::BIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv6->computeAt(tv5, 2);
+  tv6->computeAt(tv5, 2);
 
-  // Validation should throw an exception saying the first axes of tv2
-  // and tv3 have incompatible parallelization. See also issue #995.
-  ASSERT_ANY_THROW(fusion.printKernel());
-}
+  tv6->split(-1, BSX);
+  // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
+  tv6->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 4}, {4, 2}, {5, 5}});
+  // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX
+  TensorView* tv7 = tv6->rFactor({-1});
+  // tv7 = M/BSX, N/BSX, K/BSXrf, MSX, NSX, KSXr
+  // tv6 = M/BSX, N/BSX, K/BSXr, MSX, NSX
 
-TEST(NVFuserTest, FusionDAGMerging_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  tv0->computeAt(tv6, 3);
+  tv1->computeAt(tv6, 3);
 
-  auto tv0 = makeSymbolicTensor(5);
-  auto tv1 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  fusion.addInput(tv1);
+  tv0->computeAt(tv7, 3);
+  tv1->computeAt(tv7, 3);
 
-  // Branch 0
-  auto tv2 = sum(tv0, {0}); // 0
-  auto tv3 = sum(tv2, {0}); // 1
-  auto tv4 = sum(tv3, {0}); // 2
-  auto tv5 = sum(tv4, {0}); // 3
+  tv2->setMemoryType(MemoryType::Shared);
+  tv3->setMemoryType(MemoryType::Shared);
+  tv4->setMemoryType(MemoryType::Shared);
+  tv6->setMemoryType(MemoryType::Shared);
+  tv7->setMemoryType(MemoryType::Shared);
+  // Memory Type
 
-  // Branch 1
-  auto tv6 = add(tv1, new Double(1)); // 4
+  // Thread and Block binding
+  tv5->axis(0)->parallelize(ParallelType::BIDx);
+  tv5->axis(1)->parallelize(ParallelType::BIDy);
+  tv5->axis(-2)->parallelize(ParallelType::TIDy);
+  tv5->axis(-1)->parallelize(ParallelType::TIDx);
+  // Manual Binding
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  tv4->axis(-1)->parallelize(ParallelType::TIDx);
 
-  // Merge
-  auto tv7 = add(tv6, tv5); // 5
+  tv7->axis(-3)->parallelize(ParallelType::TIDy);
+  tv7->axis(-2)->parallelize(ParallelType::TIDx);
 
-  // Maximum expected output groups (can improve overtime):
-  //  {0}, {1}, {2}, {3,4,5}
-  //  without final merge would have been {0}, {1}, {2}, {3,4}, {5}
+  tv6->axis(-2)->parallelize(ParallelType::TIDy);
+  tv6->axis(-1)->parallelize(ParallelType::TIDx);
 
-  fusion.addOutput(tv7);
+  constexpr int M = 154, K = 45, N = 1524;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({2, 2, 2, 2, 2}, options);
-  at::Tensor t1 = at::randn({2}, options);
+  at::Tensor t0 = at::randn({M, K}, options);
+  at::Tensor t1 = at::randn({K, N}, options);
+
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({t0, t1});
 
-  auto fusion_segments = fusion.segment({t0, t1});
-  TORCH_CHECK(fusion_segments->groups().size() <= 4);
+  at::Tensor aten_output = matmul(t0, t1);
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
+  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
 }
 
-TEST(NVFuserTest, FusionDAGScalarMerging_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
+
+  TensorView* x = makeDummyTensor(2);
+  fusion.addInput(x);
+  TensorView* max_val =
+      reductionOp(BinaryOpType::Max, {-1}, new Float(FLT_MIN), x); // (M)
+  TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B)
+  TensorView* x_max_sub = sub(x, bcast_max); // (M, N)
+  TensorView* exp = unaryOp(UnaryOpType::Exp, x_max_sub); // (M, N)
+  TensorView* sum_exp = sum(exp, {-1}); // (M, R)
+  TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B)
+  TensorView* softmax = div(exp, bcast_sum); // (M, N)
+  fusion.addOutput(softmax);
 
-  auto tv0 = makeSymbolicTensor(3);
-  auto i0 = new Double();
+  // Read Input into Shared Memory
+  // Load Input + Pwise into shared memory
+  auto cache_x = x->cache_after();
+  cache_x->setMemoryType(MemoryType::Shared);
+  exp->setMemoryType(MemoryType::Shared);
 
-  fusion->addInput(tv0);
-  fusion->addInput(i0);
+  std::vector<TensorView*> all_tensors(
+      {x,
+       cache_x,
+       max_val,
+       bcast_max,
+       x_max_sub,
+       exp,
+       sum_exp,
+       bcast_sum,
+       softmax});
 
-  auto i1 = add(i0, new Double(1.0));
-  auto i2 = mul(i1, i1);
-  auto i3 = add(i2, i1);
+  auto tidx = new Int();
+  fusion.addInput(tidx);
 
-  // Branch 0
-  auto tv1 = sum(tv0, {0}); // 0
-  auto tv2 = add(tv1, i2);
-  // Branch 1
-  auto tv3 = sum(tv2, {0}); // 1
-  auto tv4 = add(tv3, i3);
+  for (auto tensor : all_tensors) {
+    tensor->split(-1, tidx);
+  }
 
-  auto tv5 = add(tv4, i0);
+  auto sum_exp_rf = sum_exp->rFactor({1});
+  all_tensors.push_back(sum_exp_rf);
 
-  fusion->addOutput(tv5);
+  // computeAt
+  x->computeAt(x_max_sub, 1);
+  exp->computeAt(softmax, 1);
+  x_max_sub->computeAt(exp, 2);
 
-  FusionExecutorCache executor_cache(std::move(fusion));
+  softmax->axis(0)->parallelize(ParallelType::BIDx);
+  for (auto tensor : all_tensors) {
+    tensor->axis(-1)->parallelize(ParallelType::TIDx);
+  }
 
+  const size_t dimx = 1024;
+  const size_t dimy = 4096;
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({16, 16, 16}, options);
-  double s0 = 0.5;
-
-  auto s1 = s0 + 1.0;
-  auto s2 = s1 * s1;
-  auto s3 = s2 + s1;
-  auto t1 = t0.sum({0});
-  auto t2 = t1 + s2;
-  auto t3 = sum(t2, {0});
-  auto t4 = t3 + s3;
-  auto t5 = t4 + s0;
+  at::Tensor t0 = at::randn({dimx, dimy}, options);
 
-  auto outputs = executor_cache.runFusionWithInputs({t0, s0});
+  torch::jit::fuser::cuda::FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({t0, 128});
 
+  auto t1 = at::_softmax(t0, -1, false);
   TORCH_CHECK(
-      executor_cache.getMostRecentKernelRuntime()->isSegmented(),
-      "segmentation didn't happen");
-  TORCH_CHECK(
-      executor_cache.getMostRecentKernelRuntime()
-              ->fusionSegments()
-              ->groups()
-              .size() == 2,
-      "segmentation didn't happen as expected");
-
-  testValidate(
-      executor_cache.fusion(), outputs, {t0, s0}, {t5}, __LINE__, __FILE__);
+      t1.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      t1.sub(outputs[0]).abs().max());
 }
 
-TEST(NVFuserTest, FusionBlockReduceInSerialLoop_CUDA) {
+TEST(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  constexpr int M = 10;
-  constexpr int N = 20;
-  constexpr int K = 20;
+  const int pixels_per_thread = 64;
+  const int TIDX = 128;
+  const int static_size = pixels_per_thread * TIDX;
 
-  auto tv0 = makeSymbolicTensor(3);
-  auto tv1 = sum(tv0, {{1, 2}});
-  fusion.addInput(tv0);
-  fusion.addOutput(tv1);
+  TensorView* sx = makeConcreteTensor({-1, static_size});
+  TensorView* dx = makeDummyTensor(2);
+  fusion.addInput(sx);
+  fusion.addInput(dx);
 
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv1->axis(0)->parallelize(ParallelType::BIDx);
+  TensorView* max_sx =
+      reductionOp(BinaryOpType::Max, {-1}, new Float(FLT_MIN), sx); // (M)
+  TensorView* max_dx =
+      reductionOp(BinaryOpType::Max, {-1}, new Float(FLT_MIN), dx); // (M)
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor t0 = at::randn({M, N, K}, options);
-  std::vector<IValue> aten_inputs = {t0};
+  // Reduction => merge local and shared memory TensorViews
+  TensorView* max_val = binaryOp(BinaryOpType::Max, max_sx, max_dx);
+  TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B)
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(aten_inputs);
-  at::Tensor aten_output = t0.sum({1, 2});
-  testValidate(
-      &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+  TensorView* sx_max_sub = sub(sx, bcast_max); // (M, N)
+  TensorView* dx_max_sub = sub(dx, bcast_max); // (M, N)
 
-TEST(NVFuserTest, FusionBlockWelfordInSerialLoop_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  TensorView* sx_exp = unaryOp(UnaryOpType::Exp, sx_max_sub); // (M, N)
+  TensorView* dx_exp = unaryOp(UnaryOpType::Exp, dx_max_sub); // (M, N)
 
-  constexpr int M = 10;
-  constexpr int N = 20;
-  constexpr int K = 20;
+  TensorView* sx_sum_exp = sum(sx_exp, {-1}); // (M, R)
+  TensorView* dx_sum_exp = sum(dx_exp, {-1}); // (M, R)
 
-  auto tv0 = makeSymbolicTensor(3);
-  auto tvs = Welford(tv0, {{1, 2}});
-  fusion.addInput(tv0);
-  auto tv_avg = tvs.avg;
-  auto tv_M2 = tvs.var_sum;
-  auto tv_N = tvs.n;
-  fusion.addOutput(tv_avg);
-  fusion.addOutput(tv_M2);
+  // Reduction => merge local and shared memory TensorViews
+  TensorView* sum_exp = binaryOp(BinaryOpType::Add, sx_sum_exp, dx_sum_exp);
+  TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B)
 
-  tv_avg->axis(-1)->parallelize(ParallelType::TIDx);
-  tv_avg->axis(0)->parallelize(ParallelType::BIDx);
+  TensorView* sx_softmax = div(sx_exp, bcast_sum); // (M, N)
+  TensorView* dx_softmax = div(dx_exp, bcast_sum); // (M, N)
+  fusion.addOutput(sx_softmax);
+  fusion.addOutput(dx_softmax);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor t0 = at::randn({M, N, K}, options);
-  std::vector<IValue> aten_inputs = {t0};
+  auto sx_cache = sx->cache_after();
+  auto dx_cache = dx->cache_after();
+  dx_cache->setMemoryType(MemoryType::Shared);
+  dx_exp->setMemoryType(MemoryType::Shared);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(aten_inputs);
-  at::Tensor aten_avg = t0.mean({1, 2});
-  at::Tensor aten_M2 = t0.var({1, 2}, false) * N * K;
-  testValidate(
-      &fusion, outputs, aten_inputs, {aten_avg, aten_M2}, __LINE__, __FILE__);
-}
+  // Reduction and Broadcast Tensors common to both memory TVs
+  std::vector<TensorView*> common_tensors(
+      {max_val, sum_exp, bcast_max, bcast_sum});
 
-// See Issue #716
-TEST(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  // Static Local Memory TVs
+  std::vector<TensorView*> static_tensors(
+      {sx, sx_cache, max_sx, sx_max_sub, sx_exp, sx_sum_exp, sx_softmax});
 
-  constexpr int M = 10;
-  constexpr int N = 11;
+  // Dynamic Local Memory TVs
+  std::vector<TensorView*> dynamic_tensors(
+      {dx, dx_cache, max_dx, dx_max_sub, dx_exp, dx_sum_exp, dx_softmax});
 
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
+  std::vector<TensorView*> all_tensors;
+  all_tensors.insert(
+      all_tensors.end(), common_tensors.begin(), common_tensors.end());
+  all_tensors.insert(
+      all_tensors.end(), static_tensors.begin(), static_tensors.end());
+  all_tensors.insert(
+      all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end());
+
+  // M => M
+  // M, N => M, N/128, 128
+  for (auto tensor : all_tensors) {
+    if (tensor->nDims() > 1) {
+      tensor->split(-1, TIDX);
+    }
+  }
+
+  auto sx_sum_exp_rf = sx_sum_exp->rFactor({1});
+  auto dx_sum_exp_rf = dx_sum_exp->rFactor({1});
+  all_tensors.push_back(sx_sum_exp_rf);
+  all_tensors.push_back(dx_sum_exp_rf);
 
-  std::vector<int> reduction_axes = {1};
-  std::vector<bool> broadcast_mask = {false, true};
+  // computeAt
+  sx->computeAt(sx_max_sub, 1);
+  dx->computeAt(dx_max_sub, 1);
+
+  sx_exp->computeAt(sx_softmax, 1);
+  dx_exp->computeAt(dx_softmax, 1);
+
+  sx_max_sub->computeAt(sx_exp, 2);
+  dx_max_sub->computeAt(dx_exp, 2);
 
-  auto tv0_bcast = broadcast(tv0, broadcast_mask);
-  auto path1_bcast = add(tv0_bcast, new Double(1.0));
-  auto path1 = sum(path1_bcast, reduction_axes);
-  fusion.addOutput(path1);
+  sx_softmax->axis(0)->parallelize(ParallelType::BIDx);
+  dx_softmax->axis(0)->parallelize(ParallelType::BIDx);
+  for (auto tensor : all_tensors) {
+    if (tensor->nDims() > 1) {
+      tensor->axis(-1)->parallelize(ParallelType::TIDx);
+    }
+  }
 
-  auto p = path1->split(1, 1);
-  path1->rFactor({1});
-  path1->axis(0)->parallelize(ParallelType::BIDx);
-  tv0->computeAt(path1, 1);
+  const size_t dimx = 1024;
+  const size_t dimy = 16384;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor t0 = at::randn({M}, options);
-  at::Tensor t0_ref = t0.clone();
-  std::vector<IValue> aten_inputs = {t0};
+  at::Tensor in = at::randn({dimx, dimy}, options);
+  at::Tensor static_in = in.narrow(1, 0, static_size);
+  at::Tensor dynamic_in = in.narrow(1, static_size, dimy - static_size);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
+  at::Tensor out = at::zeros({dimx, dimy}, options);
+  at::Tensor static_out = out.narrow(1, 0, static_size);
+  at::Tensor dynamic_out = out.narrow(1, static_size, dimy - static_size);
 
-  // inplace op, we are adding t0 to itself
-  auto outputs = fe.runFusion(aten_inputs, {t0});
+  torch::jit::fuser::cuda::FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  auto outputs =
+      fe.runFusion({static_in, dynamic_in}, {static_out, dynamic_out});
 
-  TORCH_CHECK(outputs[0].allclose(t0_ref.add(1)));
+  auto t1 = at::_softmax(in, -1, false);
+  TORCH_CHECK(
+      t1.allclose(out, 1e-5, 1e-5), "Error of: ", t1.sub(out).abs().max());
 }
 
-TEST(NVFuserTest, FusionReductionPredicate_CUDA) {
+TEST(NVFuserTest, FusionPersistentBatchNormLocalShared_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = sum(tv0, {0});
-  fusion.addOutput(tv1);
+  const int pixels_per_thread = 64;
+  const int TIDX = 128;
+  const int static_size = pixels_per_thread * TIDX;
 
-  auto tv2 = tv0->cache_after();
+  TensorView* sx = makeConcreteTensor({-1, static_size});
+  TensorView* dx = makeDummyTensor(2);
+  fusion.addInput(sx);
+  fusion.addInput(dx);
 
-  const int bdimx = 128;
-  tv1->split(1, bdimx);
-  tv1->split(1, 4);
-  tv1->split(1, 1);
+  Float* gamma = new Float();
+  Float* beta = new Float();
+  Float* eps = new Float();
+  Int* N = new Int();
+  fusion.addInput(gamma);
+  fusion.addInput(beta);
+  fusion.addInput(eps);
+  fusion.addInput(N);
 
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv1->axis(2)->parallelize(ParallelType::Unroll);
-  tv1->split(0, 10);
-  tv0->computeAt(tv1, 4);
+  // Reduction
+  auto sx_sum = sum(sx, {-1}); // (M, R)
+  auto dx_sum = sum(dx, {-1}); // (M, R)
+  // Reduction => merge local and shared memory TensorViews
+  auto x_sum = binaryOp(BinaryOpType::Add, sx_sum, dx_sum);
 
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  // Broadcast
+  auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B)
+  // Pwise
+  auto x_mean = div(x_sum_bcast, N); // (M, B)
 
-  int numel_x = 650;
-  int numel_y = 102;
+  auto sx_mean_sub = sub(sx, x_mean); // (M, N)
+  auto dx_mean_sub = sub(dx, x_mean); // (M, N)
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({numel_x, numel_y}, options);
-  at::Tensor cg_output = at::empty({numel_y}, options);
+  auto sx_mean_sub_pow = mul(sx_mean_sub, sx_mean_sub); // (M, N)
+  auto dx_mean_sub_pow = mul(dx_mean_sub, dx_mean_sub); // (M, N)
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  fe.runFusion({input}, {cg_output});
+  // Reduction
+  auto sx_var_sum = sum(sx_mean_sub_pow, {-1}); // (M, R)
+  auto dx_var_sum = sum(dx_mean_sub_pow, {-1}); // (M, R)
+  // Reduction => merge local and shared memory TensorViews
+  auto var_sum = binaryOp(BinaryOpType::Add, sx_var_sum, dx_var_sum);
 
-  auto aten_output = input.to(at::kDouble).sum({0});
+  // Broadcast
+  auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B)
+  // Pwise
+  auto var = div(var_sum_bcast, N); // (M, B)
+  auto var_eps = add(var, eps); // (M, B)
+  auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B)
 
-  testValidate(
-      &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
+  auto sx_norm = mul(sx_mean_sub, rvar);
+  auto dx_norm = mul(dx_mean_sub, rvar);
 
-TEST(NVFuserTest, FusionIssue728_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  auto sx_norm_gamma = mul(sx_norm, gamma);
+  auto dx_norm_gamma = mul(dx_norm, gamma);
 
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addOutput(tv0);
-  auto tv1 = makeSymbolicTensor(1);
-  fusion.addOutput(tv1);
-  auto tv2 = makeSymbolicTensor(1);
-  fusion.addOutput(tv2);
+  auto sx_norm_gamma_beta = add(sx_norm_gamma, beta);
+  auto dx_norm_gamma_beta = add(dx_norm_gamma, beta);
+  fusion.addOutput(sx_norm_gamma_beta);
+  fusion.addOutput(dx_norm_gamma_beta);
 
-  auto tv3 = add(tv0, new Double(1));
-  auto tv4 = add(tv3, tv1);
-  auto tv5 = add(tv4, new Double(1));
-  auto tv6 = add(tv2, new Double(1));
-  fusion.addOutput(tv5);
-  fusion.addOutput(tv6);
+  // Read Input into Shared Memory
+  // Read Input minus Input_Mean into Shared Memory
+  auto sx_cache = sx->cache_after();
+  auto dx_cache = dx->cache_after();
+  dx_cache->setMemoryType(MemoryType::Shared);
+  dx_mean_sub->setMemoryType(MemoryType::Shared);
 
-  // tv0 -> tv3 -+
-  // tv1 --------+-> tv4 -> tv5
-  //
-  // tv2 -> tv6
+  std::vector<TensorView*> common_tensors(
+      {x_sum, x_sum_bcast, x_mean, var_sum, var_sum_bcast, var, var_eps, rvar});
+
+  std::vector<TensorView*> static_tensors(
+      {sx,
+       sx_cache,
+       sx_sum,
+       sx_mean_sub,
+       sx_mean_sub_pow,
+       sx_var_sum,
+       sx_norm,
+       sx_norm_gamma,
+       sx_norm_gamma_beta});
+
+  std::vector<TensorView*> dynamic_tensors(
+      {dx,
+       dx_cache,
+       dx_sum,
+       dx_mean_sub,
+       dx_mean_sub_pow,
+       dx_var_sum,
+       dx_norm,
+       dx_norm_gamma,
+       dx_norm_gamma_beta});
+
+  std::vector<TensorView*> all_tensors;
+  all_tensors.insert(
+      all_tensors.end(), common_tensors.begin(), common_tensors.end());
+  all_tensors.insert(
+      all_tensors.end(), static_tensors.begin(), static_tensors.end());
+  all_tensors.insert(
+      all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end());
 
-  auto all_vals_under_tv3 =
-      DependencyCheck::getAllValsBetween({tv3}, fusion.outputs());
-  std::unordered_set<Val*> included_tensors({tv3, tv4, tv5});
-  for (auto tv : included_tensors) {
-    TORCH_CHECK(
-        std::find(all_vals_under_tv3.begin(), all_vals_under_tv3.end(), tv) !=
-            all_vals_under_tv3.end(),
-        "TV",
-        tv->name(),
-        " not found");
-  }
-  for (auto tv : ir_utils::filterByType<TensorView>(fusion.vals())) {
-    if (included_tensors.find(tv) == included_tensors.end()) {
-      TORCH_CHECK(
-          std::find(all_vals_under_tv3.begin(), all_vals_under_tv3.end(), tv) ==
-              all_vals_under_tv3.end(),
-          "TV",
-          tv->name(),
-          " should not be found");
+  // M => M
+  // M, N => M, N/128, 128
+  for (auto tensor : all_tensors) {
+    if (tensor->nDims() > 1) {
+      tensor->split(-1, TIDX);
     }
   }
 
-  auto no_dependency = DependencyCheck::getAllValsBetween({}, fusion.outputs());
-  TORCH_CHECK(no_dependency.empty(), "No val should be returned");
-
-  auto no_dep_path = DependencyCheck::getAllValsBetween({tv0, tv1}, {tv6});
-  TORCH_CHECK(no_dep_path.empty(), "No val should be returned");
+  // Local Sum => Block Broadcast
+  TensorView* sx_sum_rf = sx_sum->rFactor({1});
+  TensorView* sx_var_sum_rf = sx_var_sum->rFactor({1});
+  TensorView* dx_sum_rf = dx_sum->rFactor({1});
+  TensorView* dx_var_sum_rf = dx_var_sum->rFactor({1});
+  all_tensors.push_back(sx_sum_rf);
+  all_tensors.push_back(sx_var_sum_rf);
+  all_tensors.push_back(dx_sum_rf);
+  all_tensors.push_back(dx_var_sum_rf);
 
-  auto no_dep_path2 = DependencyCheck::getAllValsBetween({tv2}, {tv5});
-  TORCH_CHECK(no_dep_path2.empty(), "No val should be returned");
+  // ComputeAt
+  sx->computeAt(sx_mean_sub_pow, 1);
+  dx->computeAt(dx_mean_sub_pow, 1);
 
-  auto just_tv3 = DependencyCheck::getAllValsBetween({tv3}, {tv3});
-  TORCH_CHECK(
-      just_tv3.size() == 1 && *(just_tv3.begin()) == tv3,
-      "Only tv3 should be included");
-}
+  var_sum->computeAt(rvar, 1);
 
-TEST(NVFuserTest, FusionIssue757_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  sx_mean_sub_pow->computeAt(sx_var_sum_rf, 2);
+  dx_mean_sub_pow->computeAt(dx_var_sum_rf, 2);
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = sum(tv0, {1});
-  auto tv2 = broadcast(tv1, {false, true});
-  auto tv3 = makeSymbolicTensor(2);
-  fusion.addInput(tv3);
-  auto tv4 = add(tv2, tv3);
-  fusion.addOutput(tv4);
+  sx_norm->computeAt(sx_norm_gamma_beta, 2);
+  dx_norm->computeAt(dx_norm_gamma_beta, 2);
 
-  tv1->computeAt(tv4, -1);
+  sx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx);
+  dx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx);
+  for (auto tensor : all_tensors) {
+    if (tensor->nDims() > 1) {
+      tensor->axis(-1)->parallelize(ParallelType::TIDx);
+    }
+  }
 
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv4->axis(-1)->parallelize(ParallelType::TIDx);
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  const int dimx = 1024;
+  const int dimy = 16384;
+  const float kGamma = 1.0f;
+  const float kBeta = 0.0f;
+  const float kEps = 1e-5;
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  int numel_x = 650;
-  int numel_y = 102;
+  at::Tensor in = at::randn({dimx, dimy}, options);
+  at::Tensor static_in = in.narrow(1, 0, static_size);
+  at::Tensor dynamic_in = in.narrow(1, static_size, dimy - static_size);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  at::Tensor t3 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0, t3};
+  at::Tensor out = at::zeros({dimx, dimy}, options);
+  at::Tensor static_out = out.narrow(1, 0, static_size);
+  at::Tensor dynamic_out = out.narrow(1, static_size, dimy - static_size);
 
-  FusionExecutor fe;
+  torch::jit::fuser::cuda::FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0.sum({1});
-  auto t2 = t1.unsqueeze(-1).expand({numel_x, numel_y});
-  auto t4 = t2 + t3;
+  auto outputs = fe.runFusion(
+      {static_in, dynamic_in, kGamma, kBeta, kEps, dimy},
+      {static_out, dynamic_out});
 
-  testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__);
+  auto at_mu = at::mean(in, -1).unsqueeze(1);
+  auto at_var = at::var(in, -1).unsqueeze(1);
+  auto at_rvar = at::rsqrt(at::add(at_var, kEps));
+  auto at_norm = at::mul(at::sub(in, at_mu), at_rvar);
+  auto at_norm_gamma_beta = at::add(at::mul(at_norm, kGamma), kBeta);
+  TORCH_CHECK(
+      at_norm_gamma_beta.allclose(out, 1e-3, 1e-3),
+      "Error of: ",
+      at_norm_gamma_beta.sub(out).abs().max());
 }
 
-// See issue #759
-TEST(NVFuserTest, FusionPredicatedBlockBroadcast_CUDA) {
+TEST(NVFuserTest, FusionSmemDynamicPersistentBatchNorm_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = sum(tv0, {1});
-  auto tv2 = broadcast(tv1, {false, true});
-  auto tv3 = makeSymbolicTensor(2);
-  fusion.addInput(tv3);
-  auto tv4 = add(tv2, tv3);
-  fusion.addOutput(tv4);
-
-  tv4->split(0, 4);
-  tv1->computeAt(tv4, -1);
+  // Set up your input tensor views
+  auto x = makeDummyTensor(2);
+  Float* gamma = new Float();
+  Float* beta = new Float();
+  Float* eps = new Float();
+  Int* N = new Int();
+  fusion.addInput(x);
+  fusion.addInput(gamma);
+  fusion.addInput(beta);
+  fusion.addInput(eps);
+  fusion.addInput(N);
 
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(1)->parallelize(ParallelType::TIDy);
-  tv4->axis(-1)->parallelize(ParallelType::TIDx);
-  tv4->axis(1)->parallelize(ParallelType::TIDy);
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  // Reduction
+  auto x_sum = sum(x, {-1}); // (M, R)
+  // Broadcast
+  auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B)
+  // Pwise
+  auto x_mean = div(x_sum_bcast, N); // (M, B)
+  auto x_mean_sub = sub(x, x_mean); // (M, N)
+  auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub); // (M, N)
+  // Reduction
+  auto var_sum = sum(x_mean_sub_pow, {-1}); // (M, R)
+  // Broadcast
+  auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B)
+  // Pwise
+  auto var = div(var_sum_bcast, N); // (M, B)
+  auto var_eps = add(var, eps); // (M, B)
+  auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B)
+  auto norm = mul(x_mean_sub, rvar);
+  auto norm_gamma = mul(norm, gamma);
+  auto norm_gamma_beta = add(norm_gamma, beta);
+  fusion.addOutput(norm_gamma_beta);
 
-  int numel_x = 100;
-  int numel_y = 101;
+  // Read Input into Shared Memory
+  // Read Input minus Input_Mean into Shared Memory
+  auto cache_x = x->cache_after();
+  cache_x->setMemoryType(MemoryType::Shared);
+  x_mean_sub->setMemoryType(MemoryType::Shared);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  at::Tensor t3 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0, t3};
+  std::vector<TensorView*> all_tensors(
+      {x_sum,
+       x_mean,
+       cache_x,
+       x_sum_bcast,
+       x_mean_sub,
+       x_mean_sub_pow,
+       var_sum,
+       var_sum_bcast,
+       var,
+       var_eps,
+       rvar,
+       norm,
+       norm_gamma,
+       norm_gamma_beta});
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(inputs);
+  auto tidx = new Int();
+  fusion.addInput(tidx);
 
-  auto t1 = t0.sum({1});
-  auto t2 = t1.unsqueeze(-1).expand({numel_x, numel_y});
-  auto t4 = t2 + t3;
+  for (auto tensor : all_tensors) {
+    tensor->split(-1, tidx);
+  }
+  norm_gamma->split(1, 1);
+  norm_gamma_beta->split(1, 1);
 
-  testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__);
-}
+  // Local Sum => Block Broadcast
+  TensorView* x_sum_rf = x_sum->rFactor({1});
+  TensorView* var_sum_rf = var_sum->rFactor({1});
+  all_tensors.push_back(x_sum_rf);
+  all_tensors.push_back(var_sum_rf);
 
-TEST(NVFuserTest, FusionSegmentVerticalMerge_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+  // ComputeAt
+  x->computeAt(x_mean_sub_pow, 1);
+  var_sum->computeAt(rvar, 1);
+  x_mean_sub_pow->computeAt(var_sum_rf, 2);
+  norm->computeAt(norm_gamma_beta, 2);
 
-  auto tv0 = makeSymbolicTensor(3);
+  for (auto tv : all_tensors) {
+    tv->axis(0)->parallelize(ParallelType::BIDx);
+    tv->axis(-1)->parallelize(ParallelType::TIDx);
+  }
 
-  fusion->addInput(tv0);
-  // {first kernel}
-  auto tv1 = sum(tv0, {0});
-  auto tv2 = add(tv1, tv0);
-  auto tv3 = sum(tv2, {0});
-  auto tv4 = add(tv3, tv0);
-  auto tv5 = sum(tv4, {0});
-  auto tv6 = sum(tv5, {0});
-  // {second kernel}
-  auto tv7 = add(tv6, tv5);
-  auto tv8 = add(tv7, tv5);
-  auto tv9 = sum(tv8, {0});
-
-  fusion->addOutput(tv9);
-
-  SegmentCandidateFinderOptions segment_options;
-  segment_options.run_herrmann_merge = false;
-  segment_options.run_final_merge = false;
+  const int dimx = 128;
+  const int dimy = 2048;
+  const float kGamma = 1.0f;
+  const float kBeta = 0.0f;
+  const float kEps = 1e-5;
+  const int TIDX = 128;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({2, 2, 2}, options);
+  at::Tensor t0 = at::randn({dimx, dimy}, options);
 
-  auto segmented_fusion =
-      SegmentCandidateFinder::segment(fusion.get(), {t0}, segment_options);
+  torch::jit::fuser::cuda::FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({t0, kGamma, kBeta, kEps, dimy, TIDX});
 
-  TORCH_CHECK(segmented_fusion->groups().size() == 2);
+  auto at_mu = at::mean(t0, -1).unsqueeze(1);
+  auto at_var = at::var(t0, -1).unsqueeze(1);
+  auto at_rvar = at::rsqrt(at::add(at_var, kEps));
+  auto at_norm = at::mul(at::sub(t0, at_mu), at_rvar);
+  auto at_norm_gamma_beta = at::add(at::mul(at_norm, kGamma), kBeta);
+  TORCH_CHECK(
+      at_norm_gamma_beta.allclose(outputs[0], 1e-3, 1e-3),
+      "Error of: ",
+      at_norm_gamma_beta.sub(outputs[0]).abs().max());
 }
 
-TEST(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
-
-  auto tv0 = makeSymbolicTensor(3);
-  auto i0 = new Double();
+TEST(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  fusion->addInput(tv0);
-  fusion->addInput(i0);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  fusion.addInput(tv0);
+  fusion.addOutput(tv1);
+  // tv1[I0, R1] = tv0[I0, I1]
 
-  // Branch 0 {first kernel}
-  auto tv1 = sum(tv0, {0});
-  auto tv2 = add(tv0, i0);
-  auto tv3 = unaryOp(UnaryOpType::Rsqrt, tv2);
-  auto tv4 = sum(tv3, {0});
+  // Interface should just be a direct split with a Parallel type. We can
+  // include the parallelize call if we do this.
+  tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
+  // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
 
-  // Branch 1 {first kernel}
-  auto tv5 = unaryOp(UnaryOpType::Rsqrt, tv3);
-  auto tv6 = sum(tv5, {0});
+  TensorView* tv2 = tv1->rFactor({2});
+  tv2->setMemoryType(MemoryType::Shared);
+  // tv2[I0, R1oo, Ir1i{BIDx}] = tv0[I0, I1]
+  // tv1[I0,        R1i{BIDx}] = tv2[I0, R1oo, Ir1i{BIDx}]
 
-  // Incompatible {second kernel}
-  auto tv7 = sum(tv6, {0});
+  tv0->computeAt(tv1, 1);
 
-  fusion->addOutput(tv1);
-  fusion->addOutput(tv4);
-  fusion->addOutput(tv7);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv1->axis(0)->parallelize(ParallelType::BIDx);
 
-  SegmentCandidateFinderOptions segment_options;
-  segment_options.run_herrmann_merge = false;
-  segment_options.run_final_merge = false;
+  constexpr int numel_x = 65000, numel_y = 1024;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({2, 2, 2}, options);
-
-  auto segmented_fusion =
-      SegmentCandidateFinder::segment(fusion.get(), {t0, 1.0}, segment_options);
-
-  TORCH_CHECK(segmented_fusion->groups().size() == 2);
-}
-
-TEST(NVFuserTest, FusionSegmentMixReduction_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
 
-  auto tv0 = makeSymbolicTensor(3);
-
-  fusion->addInput(tv0);
-
-  // def of tv1 in kernel 1 through horizontal
-  auto tv1 = sum(tv0, {0, 1});
-  // kernel 2
-  auto tv2 = sum(tv0, {2});
-  auto tv3 = broadcast(tv2, {false, false, true});
-  auto tv4 = add(tv0, tv3);
-  auto tv5 = sum(tv4, {2});
-  // end of kernel 2
-  // kernel 1
-  auto tv6 = unaryOp(UnaryOpType::Rsqrt, tv0);
-  auto tv7 = sum(tv6, {0, 1});
-  auto tv8 = sum(tv6, {0, 1});
-
-  fusion->addOutput(tv1);
-  fusion->addOutput(tv5);
-  fusion->addOutput(tv7);
-  fusion->addOutput(tv8);
-
-  SegmentCandidateFinderOptions segment_options;
-  segment_options.run_herrmann_merge = false;
-  segment_options.run_final_merge = false;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({2, 2, 2}, options);
+  // How many threads to use for the block reduction
+  constexpr int runtime_threadIdx_dim = 128;
 
-  auto segmented_fusion =
-      SegmentCandidateFinder::segment(fusion.get(), {t0}, segment_options);
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion(
+      {input}, LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1));
 
-  TORCH_CHECK(segmented_fusion->groups().size() <= 2);
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
+  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
 }
 
-TEST(NVFuserTest, FusionSBAR_CUDA) {
+TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  // N, H, W, C format
-  std::vector<int64_t> input_shape{656, 7, 7, 64};
-
-  auto x = makeContigTensor(4);
-  auto y = makeContigTensor(4);
-  auto weight = makeContigTensor(1);
-  auto bias = makeContigTensor(1);
-
-  fusion.addInput(x);
-  fusion.addInput(y);
-  fusion.addInput(weight);
-  fusion.addInput(bias);
-
-  const size_t kNumberOfDims = x->nDims();
-  std::vector<bool> broadcast_mask(kNumberOfDims, false);
-  for (size_t axis = 0; axis < kNumberOfDims - 1; ++axis) {
-    broadcast_mask[axis] = true;
-  }
+  // Algorithm
+  Int* sym_bsx = new Int();
+  TensorView* tv0 = makeDummyTensor(3); // M, K, N
+  fusion.addInput(tv0);
+  fusion.addInput(sym_bsx);
 
-  auto weight_bcast = broadcast(weight, broadcast_mask);
-  auto scale = mul(x, weight_bcast);
-  auto bias_bcast = broadcast(bias, broadcast_mask);
-  auto scale_bias = add(scale, bias_bcast);
-  auto scale_bias_add = add(scale_bias, y);
-  auto scale_bias_add_relu = unaryOp(UnaryOpType::Relu, scale_bias_add);
+  TensorView* tv1 = sum(tv0, {1}); // M, R, N
+  fusion.addOutput(tv1);
 
-  fusion.addOutput(scale_bias_add_relu);
+  TensorView* tv2 = tv0->cache_after();
+  tv2->setMemoryType(MemoryType::Shared);
 
-  // inputs
-  at::manual_seed(0);
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor at_x = at::randn(input_shape, options);
-  at::Tensor at_y = at::randn(input_shape, options);
-  at::Tensor at_weight = at::ones({input_shape[3]}, options);
-  at::Tensor at_bias = at::zeros({input_shape[3]}, options);
+  // Schedule
+  constexpr int BSX = 32;
+  tv1->split(2, BSX);
+  tv1->split(1, sym_bsx);
+  tv1->split(0, BSX);
+  // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
+  tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}});
+  TensorView* tv3 = tv1->rFactor({-2});
 
-  // inputs
-  std::vector<c10::IValue> inputs = {at_x, at_y, at_weight, at_bias};
+  tv0->computeAt(tv1, -2);
+  tv0->computeAt(tv3, -2);
 
-  // outputs
-  std::vector<at::Tensor> outputs;
+  // Thread and Block binding
+  tv1->axis(0)->parallelize(ParallelType::BIDx);
+  tv1->axis(1)->parallelize(ParallelType::BIDy);
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  // Manual Binding
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
 
-  auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
+  constexpr int M = 154, K = 45, N = 1524;
 
-  FusionExecutor executor;
-  executor.compileFusion(&fusion);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor t0 = at::randn({M, K, N}, options);
 
-  outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
+  // How many threads to use for the block reduction
+  constexpr int runtime_threadIdx_dim = 128;
 
-  auto at_scale = at::mul(at_x, at_weight);
-  auto at_scale_bias = at::add(at_scale, at_bias);
-  auto pwise_add = at::add(at_scale_bias, at_y);
-  auto output = at::relu(pwise_add);
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion(
+      {t0, runtime_threadIdx_dim},
+      LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1));
 
-  testValidate(&fusion, outputs, inputs, {output}, __LINE__, __FILE__);
+  at::Tensor aten_output = sum(t0, {1});
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
+  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 1);
+  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(24) == 1);
 }
 
-TEST(NVFuserTest, FusionSingleElement_CUDA) {
+TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(0);
-  fusion.addInput(tv0);
+  Int* sym_bsx = new Int();
+  TensorView* tv0 = makeDummyTensor(2); // (M, K)
+  TensorView* tv1 = makeDummyTensor(2); // (K, N)
+  TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
+  TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
+  TensorView* tv4 = mul(tv2, tv3); // M, K, N
+  fusion.addInput(tv0);
+  fusion.addInput(tv1);
+  fusion.addInput(sym_bsx);
+  fusion.addOutput(tv4);
+  // Algorithm
+
+  tv2->setMemoryType(MemoryType::Shared);
+  tv3->setMemoryType(MemoryType::Shared);
+
+  constexpr int BSX = 32;
+  tv4->split(2, BSX);
+  tv4->split(1, sym_bsx);
+  tv4->split(0, BSX);
+  // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
+  tv4->reorder({{0, 0}, {1, 3}, {2, 1}, {3, 4}, {4, 2}, {5, 5}});
+  // M/BSX, K/BSX, N/BSX, MSX, KSX, NSX
+
+  tv0->computeAt(tv4, 3);
+  tv1->computeAt(tv4, 3);
+  // Schedule
 
-  auto tv1 = add(tv0, new Double(2.5));
+  tv4->axis(0)->parallelize(ParallelType::BIDx);
+  tv4->axis(2)->parallelize(ParallelType::BIDy);
+  // Manual Binding
+  tv2->axis(-2)->parallelize(ParallelType::TIDx);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  // Thread and Block binding
 
-  auto tv2 = add(tv1, new Double(3.5));
-  fusion.addOutput(tv2);
+  constexpr int M = 128, K = 457, N = 1024;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input = at::randn({}, options);
-
-  at::Tensor cg_output = at::empty({}, options);
-
-  auto lparams = schedulePointwise(&fusion, {input});
+  at::Tensor t0 = at::randn({M, K}, options);
+  at::Tensor t1 = at::randn({K, N}, options);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  fe.runFusion({input}, {cg_output}, lparams);
+  auto outputs =
+      fe.runFusion({t0, t1, BSX}, LaunchParams(-1, -1, -1, BSX, -1, -1));
 
-  auto aten_output = input.add(2.5).add(3.5);
-
-  testValidate(
-      &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
+  at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0));
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
+  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 1);
+  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(22) == 1);
 }
 
-TEST(NVFuserTest, FusionBNBackwardRepro_CUDA) {
-  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
-  Fusion& fusion = *fusion_ptr.get();
+TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) {
+  Fusion fusion;
   FusionGuard fg(&fusion);
 
-  int batch = 4;
-  int c = 4;
-  int h = 4;
-  int w = 4;
-  int numDims = 4;
-
-  auto input = makeSymbolicTensor(numDims);
-  fusion.addInput(input);
-  auto weight = makeSymbolicTensor(1);
-  fusion.addInput(weight);
-  auto running_mean = makeSymbolicTensor(1);
-  fusion.addInput(running_mean);
-  auto running_var = makeSymbolicTensor(1);
-  fusion.addInput(running_var);
-  auto save_mean = makeSymbolicTensor(1);
-  fusion.addInput(save_mean);
-  auto save_invstd = makeSymbolicTensor(1);
-  fusion.addInput(save_invstd);
-
-  auto grad_out_prev = makeSymbolicTensor(numDims);
-  fusion.addInput(grad_out_prev);
-  auto gt_0 =
-      makeSymbolicTensor(numDims); // single tensor broadcasted is dangerous.
-  fusion.addInput(gt_0);
-
-  auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, new Int(1));
-  auto gt_float = castOp(DataType::Float, gt_bool);
-
-  auto grad_out = mul(grad_out_prev, gt_float);
-
-  Val* eps_ptr = new Double(1e-5);
-
-  auto grads = batch_norm_backward(
-      input,
-      grad_out,
-      weight,
-      running_mean,
-      running_var,
-      save_mean,
-      save_invstd,
-      true,
-      eps_ptr,
-      {true, true, true});
-
-  fusion.addOutput(grads.grad_input);
-  fusion.addOutput(grads.grad_weight);
-  fusion.addOutput(grads.grad_bias);
+  // Symbolic integers we will use for runtime tiling
+  Int* symbolic_m_tile_dim = new Int(); // bound to threadIdx.z
+  Int* symbolic_split_k_tile_dim = new Int(); // bound to blockIdx.x
+  Int* symbolic_block_k_tile_dim = new Int(); // bound to threadIdx.x
+  // Compile-time integer for tiling
+  int n_smem_tile = 8; // bound to threadIdx.y
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input0 = at::randn({batch, c, h, w}, options);
-  at::Tensor input1 = at::randn({c}, options);
-  at::Tensor input2 = at::randn_like(input1);
-  at::Tensor input3 = at::randn_like(input1);
-  at::Tensor input4 = at::randn_like(input1);
-  at::Tensor input5 = at::randn_like(input1);
-  at::Tensor input6 = at::randn_like(input0);
-  at::Tensor input7 = at::randn_like(input0);
-
-  FusionExecutorCache fec(std::move(fusion_ptr));
-  std::vector<IValue> inputs = {
-      input0, input1, input2, input3, input4, input5, input6, input7};
-  auto outputs = fec.runFusionWithInputs(inputs);
-}
-
-// TODO: We only changed inputs, merge this with the test above.
-TEST(NVFuserTest, FusionBNBackwardRepro2_CUDA) {
-  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
-  Fusion& fusion = *fusion_ptr.get();
-  FusionGuard fg(&fusion);
+  // Symbolic 2D tensors TV0[M, K], TV1[K, N]
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = makeDummyTensor(2);
 
-  int batch = 2;
-  int c = 81;
-  int h = 1;
-  int w = 1;
-  int numDims = 4;
-
-  // auto input = makeSymbolicTensor(numDims);
-  auto input = makeConcreteTensor({-1, -1, 1, 1});
-  fusion.addInput(input);
-  auto weight = makeSymbolicTensor(1);
-  fusion.addInput(weight);
-  auto running_mean = makeSymbolicTensor(1);
-  fusion.addInput(running_mean);
-  auto running_var = makeSymbolicTensor(1);
-  fusion.addInput(running_var);
-  auto save_mean = makeSymbolicTensor(1);
-  fusion.addInput(save_mean);
-  auto save_invstd = makeSymbolicTensor(1);
-  fusion.addInput(save_invstd);
-
-  // auto grad_out_prev = makeSymbolicTensor(numDims);
-  auto grad_out_prev = makeConcreteTensor({-1, -1, 1, 1});
-  fusion.addInput(grad_out_prev);
-  // auto gt_0 =
-  //     makeSymbolicTensor(numDims); // single tensor broadcasted is dangerous.
-  auto gt_0 = makeConcreteTensor({-1, -1, 1, 1});
-  fusion.addInput(gt_0);
-
-  auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, new Int(1));
-  auto gt_float = castOp(DataType::Float, gt_bool);
-
-  auto grad_out = mul(grad_out_prev, gt_float);
-
-  Val* eps_ptr = new Double(1e-5);
-
-  auto grads = batch_norm_backward(
-      input,
-      grad_out,
-      weight,
-      running_mean,
-      running_var,
-      save_mean,
-      save_invstd,
-      true,
-      eps_ptr,
-      {true, true, true});
-
-  fusion.addOutput(grads.grad_input);
-  fusion.addOutput(grads.grad_weight);
-  fusion.addOutput(grads.grad_bias);
+  // Broadcast tv0 to [M, K, *]
+  TensorView* tv2 = broadcast(tv0, {false, false, true});
+  // Broadcast tv1 to [*, K, N]
+  TensorView* tv3 = broadcast(tv1, {true, false, false});
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input0 = at::randn({batch, c, h, w}, options);
-  at::Tensor input1 = at::randn({c}, options);
-  at::Tensor input2 = at::randn_like(input1);
-  at::Tensor input3 = at::randn_like(input1);
-  at::Tensor input4 = at::randn_like(input1);
-  at::Tensor input5 = at::randn_like(input1);
-  at::Tensor input6 = at::randn_like(input0);
-  at::Tensor input7 = at::randn_like(input0);
+  // Pointwise multiplication resulting in tv3[M, K, N]
+  TensorView* tv4 = mul(tv2, tv3);
 
-  FusionExecutorCache fec(std::move(fusion_ptr));
-  std::vector<IValue> inputs = {
-      input0, input1, input2, input3, input4, input5, input6, input7};
-  auto outputs = fec.runFusionWithInputs(inputs);
-}
+  // Turn the K-dimension of tv4 into a reduction dimension
+  TensorView* tv5 = sum(tv4, {1});
 
-TEST(NVFuserTest, FusionBNRepro_CUDA) {
-  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
-  Fusion& fusion = *fusion_ptr.get();
-  FusionGuard fg(&fusion);
+  // Register inputs and outputs
+  fusion.addInput(tv0);
+  fusion.addInput(tv1);
+  fusion.addOutput(tv5);
 
-  const bool kTraining = true;
-  const float kMomentum = 0.1;
-  const float kEps = 1e-5;
+  // Register runtime tile dims as inputs
+  fusion.addInput(symbolic_m_tile_dim);
+  fusion.addInput(symbolic_split_k_tile_dim);
+  fusion.addInput(symbolic_block_k_tile_dim);
 
-  int batch = 14;
-  int c = 65;
-  int h = 7;
-  int w = 7;
-  int numDims = 4;
-
-  auto input = makeSymbolicTensor(numDims);
-  fusion.addInput(input);
-  auto weight = makeSymbolicTensor(1);
-  fusion.addInput(weight);
-  auto bias = makeSymbolicTensor(1);
-  fusion.addInput(bias);
-  auto running_mean = makeSymbolicTensor(1);
-  fusion.addInput(running_mean);
-  auto running_var = makeSymbolicTensor(1);
-  fusion.addInput(running_var);
-
-  auto momentum_ptr = new Double(kMomentum);
-  auto eps_ptr = new Double(kEps);
-
-  auto result = batch_norm(
-      input,
-      weight,
-      bias,
-      running_mean,
-      running_var,
-      kTraining,
-      momentum_ptr,
-      eps_ptr);
-
-  fusion.addOutput(result.output);
-  fusion.addOutput(result.mean);
-  fusion.addOutput(result.invstd);
+  // Make a 3D tile, mix of symbolic and constant, do in reverse order because
+  // dims are inserted
+  tv5->split(2, n_smem_tile);
+  tv5->split(1, symbolic_block_k_tile_dim);
+  tv5->split(1, symbolic_split_k_tile_dim);
+  tv5->split(0, symbolic_m_tile_dim);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input1 = at::randn({batch, c, h, w}, options);
-  at::Tensor input2 = at::randn({c}, options);
-  at::Tensor input3 = at::randn_like(input2);
-  at::Tensor input4 = at::randn_like(input2);
-  at::Tensor input5 = at::randn_like(input2);
-
-  auto input1_ref = input1.clone();
-  auto input2_ref = input2.clone();
-  auto input3_ref = input3.clone();
-  auto input4_ref = input4.clone();
-  auto input5_ref = input5.clone();
-
-  FusionExecutorCache fec(std::move(fusion_ptr));
-  std::vector<IValue> aten_inputs = {input1, input2, input3, input4, input5};
-  auto cg_outputs = fec.runFusionWithInputs(aten_inputs);
-
-  auto at_results = at::native_batch_norm(
-      input1_ref,
-      input2_ref,
-      input3_ref,
-      input4_ref,
-      input5_ref,
-      kTraining,
-      kMomentum,
-      kEps);
-
-  auto at_output = std::get<0>(at_results);
-  auto at_mean = std::get<1>(at_results);
-  auto at_invstd = std::get<2>(at_results);
-
-  std::vector<at::Tensor> aten_outputs = {
-      input4_ref, input5_ref, at_output, at_mean, at_invstd};
-
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionBNRepro2_CUDA) {
-  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
-  Fusion& fusion = *fusion_ptr.get();
-  FusionGuard fg(&fusion);
+  // Reorder so all outer tiles are in the leftmost 3 positions
+  tv5->reorder({{1, 5}, {5, 1}});
 
-  const bool kTraining = true;
-  const float kMomentum = 0.1;
-  const float kEps = 1e-5;
+  // Factor out the outer reduction IterDomain, then run the inter-cta
+  // reduction, and intra-cta reduction
+  auto tv6 = tv5->rFactor({2});
 
-  int batch = 2;
-  int c = 4;
-  int h = 17;
-  int w = 17;
-  int numDims = 4;
+  // Scope computations
+  tv6->computeAt(tv5, 2);
 
-  auto input = makeSymbolicTensor(numDims);
-  fusion.addInput(input);
+  // RFactor moves reduction axes around, reorder to match ordering of tv5
+  tv6->reorder({
+      {2, -2},
+      {3, -1},
+      {4, 2},
+      {5, 3},
+      {6, 4},
+  });
 
-  Val* momentum_ptr = new Double(kMomentum);
-  Val* eps_ptr = new Double(kEps);
+  // Setup compute at schedule
+  tv0->computeAt(tv6, 3);
+  tv1->computeAt(tv6, 3);
+  tv4->computeAt(tv6, -1);
+  //
+  // T2[Mo,  bNo, Koo, Koi,  Kii,  Mi, bNi] CA(4, 3)
+  // T3[bMo,  No, Koo, Koi,  Kii, bMi,  Ni] CA(4, 3)
+  // T4[ Mo,  No, Koo, Koi,  Kii,  Mi,  Ni]
+  // T6[ Mo,  No, rKoo, Koi, Kii,  Mi,  Ni]
+  // T5[ Mo,  No,      rKoi, rKii, Mi,  Ni]
 
-  auto result = batch_norm(
-      input,
-      nullptr,
-      nullptr,
-      nullptr,
-      nullptr,
-      kTraining,
-      momentum_ptr,
-      eps_ptr);
+  // Cache smem tiles
+  tv2->setMemoryType(MemoryType::Shared);
+  tv3->setMemoryType(MemoryType::Shared);
+  tv4->setMemoryType(MemoryType::Local);
+  tv6->setMemoryType(MemoryType::Local);
 
-  fusion.addOutput(result.output);
-  fusion.addOutput(result.mean);
-  fusion.addOutput(result.invstd);
+  tv5->axis(0)->parallelize(ParallelType::BIDz);
+  tv5->axis(1)->parallelize(ParallelType::BIDy);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input1 = at::randn({batch, c, h, w}, options);
+  std::vector<TensorView*> tv_list = {tv2, tv3, tv4, tv5, tv6};
+  for (auto tv : tv_list) {
+    tv->axis(-2)->parallelize(ParallelType::TIDz);
+    tv->axis(-1)->parallelize(ParallelType::TIDy);
+  }
+  tv2->axis(3)->parallelize(ParallelType::TIDx);
+  tv3->axis(3)->parallelize(ParallelType::TIDx);
+  tv4->axis(3)->parallelize(ParallelType::TIDx);
+  tv6->axis(3)->parallelize(ParallelType::TIDx);
+  tv5->axis(2)->parallelize(ParallelType::TIDx);
 
-  auto input1_ref = input1.clone();
-  at::Tensor r_m;
-  at::Tensor r_v;
-  at::Tensor weight;
-  at::Tensor bias;
+  tv2->axis(4)->parallelize(ParallelType::BIDx);
+  tv3->axis(4)->parallelize(ParallelType::BIDx);
+  tv4->axis(4)->parallelize(ParallelType::BIDx);
+  tv6->axis(4)->parallelize(ParallelType::BIDx);
+  tv5->axis(3)->parallelize(ParallelType::BIDx);
+
+  constexpr int M = 31, K = 65, N = 33;
 
-  FusionExecutorCache fec(std::move(fusion_ptr));
-  std::vector<IValue> aten_inputs = {input1};
-  auto cg_outputs = fec.runFusionWithInputs(aten_inputs);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor A = at::randn({M, K}, options);
+  at::Tensor B = at::randn({K, N}, options);
 
-  auto at_results = at::native_batch_norm(
-      input1_ref, r_m, r_v, weight, bias, kTraining, kMomentum, kEps);
+  FusionExecutor fe;
+  // Generate CUDA and compile with nvRTC
+  fe.compileFusion(&fusion);
 
-  auto at_output = std::get<0>(at_results);
-  auto at_mean = std::get<1>(at_results);
-  auto at_invstd = std::get<2>(at_results);
+  // Runtime tiling
+  int m_tile = 4; // bound to threadIdx.z
+  int split_k = 7; // bound to blockIdx.x
+  int intra_cta = 8; // bound to threadIdx.x
 
-  std::vector<at::Tensor> aten_outputs = {at_output, at_mean, at_invstd};
+  auto fuser_outputs = fe.runFusion({A, B, m_tile, split_k, intra_cta});
+  auto C_fuser = fuser_outputs[0];
 
-  testValidate(
-      &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
+  at::Tensor aten_C = mul(A.unsqueeze(2), B.unsqueeze(0)).sum(1);
+  TORCH_CHECK(
+      aten_C.allclose(C_fuser, 1e-5, 1e-5),
+      "Error of: ",
+      aten_C.sub(C_fuser).abs().max());
+  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 1);
+  TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(41) == 1);
 }
 
-TEST(NVFuserTest, FusionZeroSizeTensorPW_CUDA) {
+TEST(NVFuserTest, FusionGlobalIntermediate_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
   fusion.addInput(tv0);
+  fusion.addOutput(tv1);
+  // tv1[I0, R1] = tv0[I0, I1]
 
-  auto tv1 = makeConcreteTensor({0});
-  fusion.addInput(tv1);
+  // Interface should just be a direct split with a Parallel type. We can
+  // include the parallelize call if we do this.
+  tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
+  // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
 
-  auto tv2 = add(tv0, new Double(2.5));
-  fusion.addOutput(tv2);
+  TensorView* tv2 = tv1->rFactor({2});
+  tv2->setMemoryType(MemoryType::Global);
+  // tv2[I0, R1oo, Ir1i{BIDx}] = tv0[I0, I1]
+  // tv1[I0,        R1i{BIDx}] = tv2[I0, R1oo, Ir1i{BIDx}]
 
-  auto tv3 = makeConcreteTensor({0});
-  fusion.addOutput(tv3);
+  tv0->computeAt(tv1, 1);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv1->axis(0)->parallelize(ParallelType::BIDx);
+
+  constexpr int numel_x = 65000, numel_y = 1024;
 
-  at::Tensor input0 = at::randn({2}, options);
-  at::Tensor input1 = at::randn({0}, options);
-  at::Tensor cg_output2 = at::empty({2}, options);
-  at::Tensor cg_output3 = at::empty({0}, options);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
 
-  auto lparams = schedulePointwise(&fusion, {input0, input1});
+  // How many threads to use for the block reduction
+  constexpr int runtime_threadIdx_dim = 128;
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  fe.runFusion({input0, input1}, {cg_output2, cg_output3}, lparams);
+  auto outputs = fe.runFusion(
+      {input}, LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1));
 
-  auto aten_output2 = input0.add(2.5);
-  at::Tensor aten_output3 = at::empty({0}, options);
-
-  testValidate(
-      &fusion,
-      {cg_output2, cg_output3},
-      {input0, input1},
-      {aten_output2, aten_output3},
-      __LINE__,
-      __FILE__);
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
 }
 
-TEST(NVFuserTest, FusionZeroSizeTensorReduction_CUDA) {
+TEST(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = makeDummyTensor(2);
+  TensorView* tv2 = makeDummyTensor(2);
+  TensorView* tv3 = makeDummyTensor(2);
+  TensorView* tv4 = sub(tv2, tv3);
+  TensorView* tv5 = add(tv1, tv4);
+  TensorView* tv6 = sub(tv5, tv0);
   fusion.addInput(tv0);
-
-  auto tv1 = makeConcreteTensor({0});
   fusion.addInput(tv1);
+  fusion.addInput(tv2);
+  fusion.addInput(tv3);
+  fusion.addOutput(tv6);
+  // t6 = ((t1 + (t2 - t3)) - t0)
 
-  auto tv2 = sum(tv0, {1});
-  fusion.addOutput(tv2);
-
-  auto tv3 = makeConcreteTensor({0});
-  fusion.addOutput(tv3);
+  tv4->setMemoryType(MemoryType::Global);
+  tv5->setMemoryType(MemoryType::Global);
+  tv6->setMemoryType(MemoryType::Global);
 
+  constexpr int M = 32, N = 810;
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor in0 = at::rand({M, N}, options);
+  at::Tensor in1 = at::rand({M, N}, options);
+  at::Tensor in2 = at::rand({M, N}, options);
+  at::Tensor in3 = at::rand({M, N}, options);
 
-  at::Tensor input0 = at::randn({2, 4}, options);
-  at::Tensor input1 = at::randn({0}, options);
-  at::Tensor cg_output2 = at::empty({2}, options);
-  at::Tensor cg_output3 = at::empty({0}, options);
-
-  auto reduction_params = getReductionHeuristics(&fusion, {input0, input1});
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-  scheduleReduction(&fusion, reduction_params.value());
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-
-  auto lparams = reduction_params.value().lparams;
   FusionExecutor fe;
   fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input0, input1}, lparams);
-  auto aten_output2 = input0.sum({1});
-  at::Tensor aten_output3 = at::empty({0}, options);
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {input0, input1},
-      {aten_output2, aten_output3},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-TEST(NVFuserTest, FusionZeroSizeTensorNormalization_CUDA) {
+  auto outputs = fe.runFusion({in0, in1, in2, in3});
+
+  at::Tensor aten_output = (in1 + (in2 - in3)) - in0;
+  TORCH_CHECK(
+      aten_output.allclose(outputs[0], 1e-5, 1e-5),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().sum());
+}
+
+TEST(NVFuserTest, FusionConstCheck_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  auto tv1 = makeConcreteTensor({0});
-  fusion.addInput(tv1);
+  auto one = new Int(1);
+  TORCH_CHECK(one->isConstScalar());
 
-  auto tv2 = sum(tv0, {0});
-  auto tv3 = broadcast(tv2, {true, false});
-  auto tv4 = add(tv0, tv3);
-  fusion.addOutput(tv4);
+  auto one_x2 = mul(one, one);
+  TORCH_CHECK(one_x2->isConstScalar());
 
-  auto tv5 = makeConcreteTensor({0});
-  fusion.addOutput(tv5);
+  auto one_x3 = mul(one_x2, one);
+  TORCH_CHECK(one_x3->isConstScalar());
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  auto one_x4 = mul(one_x3, one);
+  TORCH_CHECK(one_x4->isConstScalar());
+}
 
-  at::Tensor input0 = at::randn({2, 4}, options);
-  at::Tensor input1 = at::randn({0}, options);
-  at::Tensor cg_output2 = at::empty({2, 4}, options);
-  at::Tensor cg_output3 = at::empty({0}, options);
+TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) {
+  const std::vector<int64_t> tensor_dims_in = {128, 128};
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  auto reduction_params = getNormalizationHeuristics(&fusion, {input0, input1});
-  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-  scheduleNormalization(&fusion, reduction_params.value());
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(tensor_dims_in.size());
+  fusion.addInput(tv0);
 
-  auto lparams = reduction_params.value().lparams;
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto cg_outputs = fe.runFusion({input0, input1}, lparams);
-  auto aten_output2 = input0.sum({0}).add(input0);
-  at::Tensor aten_output3 = at::empty({0}, options);
-
-  testValidate(
-      &fusion,
-      cg_outputs,
-      {input0, input1},
-      {aten_output2, aten_output3},
-      __LINE__,
-      __FILE__,
-      "",
-      lparams);
-}
-
-TEST(NVFuserTest, FusionSegmentIoAlias_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+  TensorView* tv1 = add(tv0, new Float(0));
+  TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv1);
+  fusion.addOutput(tv2);
 
-  TensorView* tv0 = makeSymbolicTensor(2);
-  TensorView* tv1 = makeSymbolicTensor(1);
-  TensorView* tv2 = makeSymbolicTensor(2);
+  const auto options =
+      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand(tensor_dims_in, options);
+  at::Tensor cg_output = at::empty({tensor_dims_in[0]}, options);
 
-  fusion->addInput(tv0);
-  fusion->addInput(tv1);
-  fusion->addInput(tv2);
+  // const at::ArrayRef<c10::IValue> inputs({input});
 
-  TensorView* tv3 = add(tv0, new Double(1)); // Group 0
-  TensorView* tv4 =
-      max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues)
-  TensorView* tv5 = add(tv4, tv1); //  Group 0 (Non Broadcast after reduce,
-                                   //  keeps normalization scheduler away)
-  TensorView* tv6 = add(tv5, tv2); //  Group 1 (Broadcast after reduce)
+  // Schedule
+  tv2->split(1, 32);
+  tv2->split(1, 4); // unroll
 
-  fusion->addOutput(tv6);
-  // Note: test alias;
-  fusion->aliasOutputToInput(tv6, tv0);
+  auto tv2_rf = tv2->rFactor({-3, -2});
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({128, 65}, options);
-  at::Tensor t1 = at::randn({65}, options);
-  at::Tensor t2 = at::randn({128, 65}, options);
+  tv2->axis(0)->parallelize(ParallelType::BIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
 
-  auto t3 = t0.add(1.0);
-  auto t4 = std::get<0>(at::max(t3, 0));
-  auto t5 = t4.add(t1);
-  auto t6 = t5.add(t2);
+  tv2_rf->axis(0)->parallelize(ParallelType::BIDx);
+  tv2_rf->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2_rf->axis(-2)->parallelize(ParallelType::Unroll);
 
-  FusionExecutorCache executor_cache(std::move(fusion));
+  tv1->computeAt(tv2_rf, -1);
 
-  auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2});
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({input});
 
-  // validating aliasing
-  TORCH_INTERNAL_ASSERT(outputs[0].data_ptr() == t0.data_ptr());
+  auto aten_output = (input + 0).sum(1);
 
   TORCH_CHECK(
-      executor_cache.getMostRecentKernelRuntime()->isSegmented(),
-      "segmentation didn't happen");
-  TORCH_CHECK(
-      executor_cache.getMostRecentKernelRuntime()
-              ->fusionSegments()
-              ->groups()
-              .size() == 2,
-      "segmentation didn't happen as expected");
-
-  testValidate(
-      executor_cache.fusion(), outputs, {t0, t1, t2}, {t6}, __LINE__, __FILE__);
+      aten_output.allclose(outputs[0]),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
 }
 
-TEST(NVFuserTest, FusionWelford1Output_CUDA) {
-  auto fusion_ptr = std::make_unique<Fusion>();
-  auto fusion = fusion_ptr.get();
-  FusionGuard fg(fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion->addInput(tv0);
-
-  auto tvs = Welford(tv0, {1});
-  fusion->addOutput(tvs.var_sum);
-  FusionExecutorCache executor_cache(std::move(fusion_ptr));
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({128, 65}, options);
-  auto outputs = executor_cache.runFusionWithInputs({t0});
+// Test isZeroInt
+TEST(NVFuserTest, FusionIsZeroInt_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  auto t1 = t0.var({1}, false) * 65;
-  testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__);
+  Int* x = new Int(0);
+  Int* y = new Int(1);
+  Val* z = mul(x, y);
+  TORCH_CHECK(x->isZeroInt());
+  TORCH_CHECK(!y->isZeroInt());
+  TORCH_CHECK(!z->isZeroInt());
 }
 
-TEST(NVFuserTest, FusionTranslate1Welford_CUDA) {
-  auto fusion_ptr = std::make_unique<Fusion>();
-  auto fusion = fusion_ptr.get();
-  FusionGuard fg(fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion->addInput(tv0);
-
-  auto tvs = Welford(tv0, {1});
-  fusion->addOutput(tvs.var_sum);
-  FusionExecutorCache executor_cache(std::move(fusion_ptr));
-
-  auto run_test = [&executor_cache,
-                   fusion](auto inner_size) -> FusionKernelRuntime* {
-    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-    at::Tensor t0 = at::randn({128, inner_size}, options);
-    auto outputs = executor_cache.runFusionWithInputs({t0});
-    // Square sums does not fit well in the testValidate assumptions,
-    //  so we just compare the divided output here.
-    outputs[0] /= inner_size;
-    auto t1 = t0.var({1}, false);
-    testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__);
-
-    return executor_cache.getMostRecentKernelRuntime();
-  };
-
-  // Run a translated welford
-  auto runtime1 = run_test(64);
-  // Check it was translated
-  TORCH_CHECK(runtime1->singleKernelFusion()->unordered_exprs().size() > 2);
-  TORCH_CHECK(
-      runtime1->schedulerHeuristics()->singleKernelHeuristics()->heuristc() ==
-      ScheduleHeuristic::Normalization);
+// Test isOneInt
+TEST(NVFuserTest, FusionIsOneInt_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  // Run an un-translated welford
-  auto runtime2 = run_test(65536);
-  // Check it was not translated
-  TORCH_CHECK(runtime2->singleKernelFusion()->unordered_exprs().size() == 1);
-  TORCH_CHECK(
-      runtime2->schedulerHeuristics()->singleKernelHeuristics()->heuristc() ==
-      ScheduleHeuristic::Reduction);
+  Int* x = new Int(1);
+  Int* y = new Int(1);
+  Val* z = mul(x, y);
+  TORCH_CHECK(x->isOneInt());
+  TORCH_CHECK(y->isOneInt());
+  TORCH_CHECK(!z->isOneInt());
 }
 
-TEST(NVFuserTest, FusionTranslate2Welford_CUDA) {
-  auto fusion_ptr = std::make_unique<Fusion>();
-  auto fusion = fusion_ptr.get();
-  FusionGuard fg(fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion->addInput(tv0);
-
-  auto tvs1 = Welford(tv0, {1});
-  auto tvs2 = Welford(tv0, {1});
-
-  fusion->addOutput(tvs1.var_sum);
-  fusion->addOutput(tvs2.var_sum);
-
-  FusionExecutorCache executor_cache(std::move(fusion_ptr));
-
-  auto run_test = [&executor_cache,
-                   fusion](auto inner_size) -> FusionKernelRuntime* {
-    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-    at::Tensor t0 = at::randn({128, inner_size}, options);
-    auto outputs = executor_cache.runFusionWithInputs({t0});
-
-    // Square sums does not fit well in the testValidate assumptions,
-    //  so we just compare the divided output here.
-    outputs[0] /= inner_size;
-    outputs[1] /= inner_size;
-    auto t1 = t0.var({1}, false);
-    testValidate(fusion, outputs, {t0}, {t1, t1}, __LINE__, __FILE__);
-
-    return executor_cache.getMostRecentKernelRuntime();
-  };
+// This is to verify no cycle of computeAt is created. A more complex
+// variation of this pattern appears in one of the Python tests
+// (test_random_topo).
+TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  // Run a translated welford
-  auto runtime1 = run_test(64);
-  // Check it was translated
-  TORCH_CHECK(runtime1->singleKernelFusion()->unordered_exprs().size() > 4);
-  TORCH_CHECK(
-      runtime1->schedulerHeuristics()->singleKernelHeuristics()->heuristc() ==
-      ScheduleHeuristic::Normalization);
+  TensorView* tv0 = makeDummyTensor(1);
+  fusion.addInput(tv0);
 
-  // Run an un-translated welford
-  auto runtime2 = run_test(65536);
-  // // Check it was not translated
-  TORCH_CHECK(runtime2->singleKernelFusion()->unordered_exprs().size() == 2);
-}
+  // Common intermediate tensor
+  auto tv1 = add(tv0, new Float(1));
+  // tv1 -> tv2
+  auto tv2 = add(tv1, new Float(2));
+  // tv1 -> tv3 -> tv4
+  auto tv3 = add(tv1, new Float(3));
+  auto tv4 = add(tv3, new Float(4));
 
-TEST(NVFuserTest, FusionLargeWelfordNormalization_CUDA) {
-  auto fusion_ptr = std::make_unique<Fusion>();
-  auto fusion = fusion_ptr.get();
-  FusionGuard fg(fusion);
+  // NOTE: This should no longer occur as of PR #201.
+  // The order of adding outputs matters. If tv3 is added before tv4,
+  // it should be fine. However, if tv4 is added before tv3, there
+  // will be a cycle of tv3->tv4 and tv4->tv3. tv3->tv4 is created
+  // first, and then tv4->tv3 is created at the final phase of
+  // computeAt (ComputeAt::setupOutputs).
+  fusion.addOutput(tv2);
+  fusion.addOutput(tv4);
+  fusion.addOutput(tv3);
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion->addInput(tv0);
+  tv0->computeAt(tv2, -1);
 
-  auto tvs1 = Welford(tv0, {1});
-  auto sum_of_tv0 = sum(tv0, {1});
-  auto sum_plus_avg = add(tvs1.avg, sum_of_tv0);
+  TORCH_CHECK(
+      !(tv3->getComputeAtView() == tv4 && tv4->getComputeAtView() == tv3),
+      "ComputeAt cycle detected between tv3 and tv4");
 
-  fusion->addOutput(sum_plus_avg);
+  const auto options =
+      at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand(100, options);
 
-  FusionExecutorCache executor_cache(std::move(fusion_ptr));
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion({input});
 
-  auto run_test = [&executor_cache,
-                   fusion](auto inner_size) -> FusionKernelRuntime* {
-    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-    at::Tensor t0 = at::randn({128, inner_size}, options);
-    auto outputs = executor_cache.runFusionWithInputs({t0});
+  auto& output_tv2 = outputs[0];
+  auto& output_tv4 = outputs[1];
+  auto& output_tv3 = outputs[2];
 
-    auto t1 = t0.mean({1}) + t0.sum({1});
-    testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__);
+  auto aten_t1 = input + 1;
+  auto aten_t2 = aten_t1 + 2;
+  auto aten_t3 = aten_t1 + 3;
+  auto aten_t4 = aten_t3 + 4;
 
-    return executor_cache.getMostRecentKernelRuntime();
-  };
+  TORCH_CHECK(
+      aten_t2.allclose(output_tv2),
+      "Error of: ",
+      aten_t2.sub(output_tv2).abs().max());
+  TORCH_CHECK(
+      aten_t3.allclose(output_tv3),
+      "Error of: ",
+      aten_t3.sub(output_tv3).abs().max());
+  TORCH_CHECK(
+      aten_t4.allclose(output_tv4),
+      "Error of: ",
+      aten_t4.sub(output_tv4).abs().max());
 
-  auto runtime = run_test(65536);
-  TORCH_CHECK(!runtime->isSegmented());
+  return;
 }
 
-TEST(NVFuserTest, FusionWelfordOtherPersistence_CUDA) {
-  auto fusion_ptr = std::make_unique<Fusion>();
-  auto fusion = fusion_ptr.get();
-  FusionGuard fg(fusion);
+TEST(NVFuserTest, FusionTraversalOrder1_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-  fusion->addInput(tv0);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
 
-  auto tvs1 = Welford(tv0, {1});
-  auto sum_of_tv0 = sum(tv0, {1});
-  auto sum_bcasted = broadcast(sum_of_tv0, {false, true});
-  auto avg_bcasted = broadcast(tvs1.avg, {false, true});
-  auto tv0_plus_sum = add(tv0, sum_bcasted);
-  auto tv0_plus_avg = add(tv0, avg_bcasted);
+  TensorView* tv1 = add(tv0, new Float(1));
+  TensorView* tv2 = add(tv0, new Float(2));
+  TensorView* tv3 = add(tv1, new Float(3));
+  TensorView* tv4 = add(tv1, new Float(4));
 
-  fusion->addOutput(tv0_plus_sum);
-  fusion->addOutput(tv0_plus_avg);
+  fusion.addOutput(tv2);
+  fusion.addOutput(tv3);
+  fusion.addOutput(tv4);
 
-  FusionExecutorCache executor_cache(std::move(fusion_ptr));
+  tv1->computeAt(tv3, -1);
 
-  auto run_test = [&executor_cache,
-                   fusion](auto inner_size) -> FusionKernelRuntime* {
-    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-    at::Tensor t0 = at::randn({128, inner_size}, options);
-    auto outputs = executor_cache.runFusionWithInputs({t0});
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
 
-    auto t1 = t0.mean({1}).unsqueeze(1) + t0;
-    auto t2 = t0.sum({1}).unsqueeze(1) + t0;
-    testValidate(fusion, outputs, {t0}, {t2, t1}, __LINE__, __FILE__);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({10, 10}, options);
+  at::Tensor cg_output_tv2 = at::empty_like(input, options);
+  at::Tensor cg_output_tv3 = at::empty_like(input, options);
+  at::Tensor cg_output_tv4 = at::empty_like(input, options);
+  fe.runFusion({input}, {cg_output_tv2, cg_output_tv3, cg_output_tv4});
 
-    return executor_cache.getMostRecentKernelRuntime();
-  };
+  auto t1 = input + 1;
+  auto t2 = input + 2;
+  auto t3 = t1 + 3;
+  auto t4 = t1 + 4;
 
-  for (auto inner_size : {4096, 8192, 32768}) {
-    auto runtime = run_test(4096);
-    TORCH_CHECK(!runtime->isSegmented());
-  }
+  TORCH_CHECK(
+      t2.allclose(cg_output_tv2),
+      "tv2 error of: ",
+      t2.sub(cg_output_tv2).abs().max());
+  TORCH_CHECK(
+      t3.allclose(cg_output_tv3),
+      "tv5 error of: ",
+      t3.sub(cg_output_tv3).abs().max());
+  TORCH_CHECK(
+      t4.allclose(cg_output_tv4),
+      "tv4 error of: ",
+      t4.sub(cg_output_tv4).abs().max());
 }
 
-TEST(NVFuserTest, TestSegmentIslands_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
-
-  auto tv0 = makeSymbolicTensor(2);
-  auto tv1 = makeSymbolicTensor(2);
-  fusion->addInput(tv0);
-  fusion->addInput(tv1);
-
-  auto tv2 = sum(tv0, {0});
-  auto tv3 = sum(tv1, {1});
-  fusion->addOutput(tv2);
-  fusion->addOutput(tv3);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({16, 16}, options);
-  at::Tensor t1 = at::randn({16, 16}, options);
+TEST(NVFuserTest, FusionTraversalOrder2_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  FusionExecutorCache fusion_executor_cache(std::move(fusion));
-  fusion_executor_cache.runFusionWithInputs({t0, t1});
-}
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
 
-TEST(NVFuserTest, TestBackOffInnerBroadcast_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+  TensorView* tv1 = add(tv0, new Float(1));
+  TensorView* tv2 = add(tv1, new Float(2));
 
-  auto tv0 = makeSymbolicTensor(1);
-  auto tv1 = makeSymbolicTensor(2);
-  auto tv2 = makeSymbolicTensor(4);
-  fusion->addInput(tv0);
-  fusion->addInput(tv1);
+  TensorView* tv3 = add(tv0, new Float(3));
+  TensorView* tv4 = add(tv3, new Float(4));
 
-  auto tv3 = broadcast(tv0, {false, true, true, true});
-  auto tv4 = broadcast(tv1, {false, false, true, true});
-  auto tv5 = unaryOp(UnaryOpType::Rsqrt, tv2);
+  TensorView* tv5 = add(tv1, tv3);
 
-  auto tv6 = add(tv3, tv5);
-  auto tv7 = add(tv4, tv5);
-  auto tv8 = add(tv3, tv4);
+  fusion.addOutput(tv2);
+  fusion.addOutput(tv4);
+  fusion.addOutput(tv5);
 
-  auto tv9 = add(tv6, tv7);
-  auto tv10 = add(tv9, tv8);
+  tv1->computeAt(tv5, -1);
+  tv3->computeAt(tv5, -1);
 
-  fusion->addOutput(tv10);
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
 
-  tv0->computeAt(tv10, -2);
-  tv1->computeAt(tv10, -2);
-  tv2->computeAt(tv10, -2);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({10, 10}, options);
+  at::Tensor cg_output_tv2 = at::empty_like(input, options);
+  at::Tensor cg_output_tv4 = at::empty_like(input, options);
+  at::Tensor cg_output_tv5 = at::empty_like(input, options);
+  fe.runFusion({input}, {cg_output_tv2, cg_output_tv4, cg_output_tv5});
 
-  TORCH_CHECK(tv3->getComputeAtPosition() == 1);
-  TORCH_CHECK(tv4->getComputeAtPosition() == 2);
-  TORCH_CHECK(tv5->getComputeAtPosition() == 3);
+  auto t1 = input + 1;
+  auto t2 = t1 + 2;
+  auto t3 = input + 3;
+  auto t4 = t3 + 4;
+  auto t5 = t1 + t3;
 
-  TORCH_CHECK(tv6->getMaxProducerPosition() == 3);
-  TORCH_CHECK(tv7->getMaxProducerPosition() == 3);
-  TORCH_CHECK(tv8->getMaxProducerPosition() == 2);
+  TORCH_CHECK(
+      t2.allclose(cg_output_tv2),
+      "tv2 error of: ",
+      t2.sub(cg_output_tv2).abs().max());
+  TORCH_CHECK(
+      t4.allclose(cg_output_tv4),
+      "tv4 error of: ",
+      t4.sub(cg_output_tv4).abs().max());
+  TORCH_CHECK(
+      t5.allclose(cg_output_tv5),
+      "tv5 error of: ",
+      t5.sub(cg_output_tv5).abs().max());
 }
 
-TEST(NVFuserTest, TestBackOffInnerBroadcast2_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+TEST(NVFuserTest, FusionTraversalOrder3_CUDA) {
+  for (int i = 0; i < 2; ++i) {
+    Fusion fusion;
+    FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-  auto tv1 = makeSymbolicTensor(3);
-  fusion->addInput(tv0);
-  fusion->addInput(tv1);
-  auto tv2 = broadcast(tv0, {false, false, true});
-  auto tv3 = add(tv2, tv1);
+    TensorView* tv0 = makeDummyTensor(1);
+    fusion.addInput(tv0);
 
-  fusion->addOutput(tv3);
-  tv3->split(-2, 4);
-  tv3->reorder({{-1, -2}});
-  tv0->computeAt(tv3, -2);
-  tv1->computeAt(tv3, -2);
-  TORCH_CHECK(tv2->getComputeAtPosition() == 2);
-  TORCH_CHECK(tv3->getMaxProducerPosition() == 2);
-}
+    TensorView* tv1 = add(tv0, new Float(1));
+    TensorView* tv2 = add(tv1, new Float(2));
 
-TEST(NVFuserTest, TestBackOffInnerBroadcast3_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+    TensorView* tv3 = add(tv0, new Float(3));
+    TensorView* tv4 = add(tv3, new Float(4));
 
-  auto tv0 = makeSymbolicTensor(2);
-  auto tv1 = makeSymbolicTensor(4);
-  fusion->addInput(tv0);
-  fusion->addInput(tv1);
-  auto tv2 = broadcast(tv0, {false, false, true});
-  auto tv3 = broadcast(tv2, {false, true, false, false});
-  auto tv4 = add(tv3, tv1);
+    TensorView* tv5 = add(tv1, tv3);
 
-  fusion->addOutput(tv4);
-  tv0->computeAt(tv4, -1);
-  tv1->computeAt(tv4, -1);
-  TORCH_CHECK(tv2->getComputeAtPosition() == 2);
-  TORCH_CHECK(tv3->getMaxProducerPosition() == 3);
-}
+    fusion.addOutput(tv2);
+    fusion.addOutput(tv4);
+    fusion.addOutput(tv5);
 
-TEST(NVFuserTest, FusionSegfaultReduction_CUDA) {
-  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
-  Fusion& fusion = *fusion_ptr.get();
-  FusionGuard fg(&fusion);
+    const int tile = 32;
+
+    tv1->split(-1, tile);
+    tv2->split(-1, tile);
+    tv3->split(-1, tile);
+    tv4->split(-1, tile);
+    tv5->split(-1, tile);
 
-  int batch = 2;
-  int c = 1;
-  int h = 1;
-  int w = 1;
-  int numDims = 4;
-
-  auto input = makeConcreteTensor({-1, 1, 1, 1});
-  fusion.addInput(input);
-  auto bcast_bias = makeConcreteTensor({-1, 1, 1, 1});
-  fusion.addInput(bcast_bias);
-
-  std::vector<int64_t> at_sum_axes;
-  std::vector<int> outer_reduction_axes;
-  std::vector<bool> outer_broadcast_mask(numDims, false);
-  Val* N = new Double(1);
-  for (size_t axis = 0; axis < numDims; ++axis) {
-    if (axis != 1) {
-      outer_reduction_axes.push_back(axis);
-      at_sum_axes.push_back(axis);
-      outer_broadcast_mask[axis] = true;
-      N = mul(N, input->domain()->domain()[axis]->extent());
+    auto compute_at_outer = tv1;
+    auto compute_at_inner = tv3;
+    if (i == 1) {
+      std::swap(compute_at_inner, compute_at_outer);
     }
-  }
 
-  auto output0 = mul(input, bcast_bias);
-  fusion.addOutput(output0);
-  auto output1 = sum(output0, outer_reduction_axes);
-  fusion.addOutput(output1);
+    compute_at_outer->computeAt(tv5, -2);
+    compute_at_inner->computeAt(tv5, -1);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input0 = at::randn({batch, c, h, w}, options);
-  at::Tensor input1 = at::randn({batch, c, h, w}, options);
+    FusionExecutor fe;
+    fe.compileFusion(&fusion);
 
-  auto at_output0 = input0.mul(input1);
-  auto at_output1 = at_output0.sum(at_sum_axes);
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+    at::Tensor input = at::rand({100}, options);
+    at::Tensor cg_output_tv2 = at::empty_like(input, options);
+    at::Tensor cg_output_tv4 = at::empty_like(input, options);
+    at::Tensor cg_output_tv5 = at::empty_like(input, options);
+    fe.runFusion({input}, {cg_output_tv2, cg_output_tv4, cg_output_tv5});
 
-  FusionExecutorCache fec(std::move(fusion_ptr));
-  std::vector<IValue> inputs = {input0, input1};
-  auto outputs = fec.runFusionWithInputs(inputs);
+    auto t1 = input + 1;
+    auto t2 = t1 + 2;
+    auto t3 = input + 3;
+    auto t4 = t3 + 4;
+    auto t5 = t1 + t3;
 
-  testValidate(
-      &fusion, outputs, inputs, {at_output0, at_output1}, __LINE__, __FILE__);
+    TORCH_CHECK(
+        t2.allclose(cg_output_tv2),
+        "tv2 error of: ",
+        t2.sub(cg_output_tv2).abs().max());
+    TORCH_CHECK(
+        t4.allclose(cg_output_tv4),
+        "tv4 error of: ",
+        t4.sub(cg_output_tv4).abs().max());
+    TORCH_CHECK(
+        t5.allclose(cg_output_tv5),
+        "tv5 error of: ",
+        t5.sub(cg_output_tv5).abs().max());
+  }
 }
 
-TEST(NVFuserTest, FusionPredicateElimination_CUDA) {
+TEST(NVFuserTest, FusionTraversalOrder4_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
+  // First tree
+  TensorView* tv0 = makeDummyTensor(1);
   fusion.addInput(tv0);
+  TensorView* tv1 = add(tv0, new Float(1));
+  TensorView* tv2 = add(tv1, new Float(2));
+  TensorView* tv3 = add(tv1, new Float(3));
+  fusion.addOutput(tv2);
+  fusion.addOutput(tv3);
 
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(2));
-  auto tv3 = add(tv2, new Double(3));
+  // Second tree
+  TensorView* tv4 = makeDummyTensor(1);
+  fusion.addInput(tv4);
+  TensorView* tv5 = add(tv4, new Float(5));
+  TensorView* tv6 = add(tv5, new Float(6));
+  TensorView* tv7 = add(tv5, new Float(7));
+  fusion.addOutput(tv6);
+  fusion.addOutput(tv7);
 
-  fusion.addOutput(tv3);
+  tv1->computeAt(tv2, -1);
+  tv5->computeAt(tv6, -1);
 
-  tv3->split(0, 32);
-  tv0->computeAt(tv3, 1);
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
 
-  tv2->axis(1)->parallelize(ParallelType::Unswitch);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor t0 = at::rand({100}, options);
+  at::Tensor t4 = at::rand_like(t0, options);
+  at::Tensor cg_output_tv2 = at::empty_like(t0, options);
+  at::Tensor cg_output_tv3 = at::empty_like(t0, options);
+  at::Tensor cg_output_tv6 = at::empty_like(t0, options);
+  at::Tensor cg_output_tv7 = at::empty_like(t0, options);
 
-  {
-    GpuLower gpulw(&fusion);
-    TORCH_CHECK(!isPredicated(tv2, gpulw));
-  }
+  fe.runFusion(
+      {t0, t4}, {cg_output_tv2, cg_output_tv3, cg_output_tv6, cg_output_tv7});
 
-  tv2->axis(1)->parallelize(ParallelType::Serial);
-  tv2->split(1, 5);
+  auto t1 = t0 + 1;
+  auto t2 = t1 + 2;
+  auto t3 = t1 + 3;
+  auto t5 = t4 + 5;
+  auto t6 = t5 + 6;
+  auto t7 = t5 + 7;
 
-  {
-    GpuLower gpulw(&fusion);
-    TORCH_CHECK(isPredicated(tv2, gpulw));
-  }
+  TORCH_CHECK(
+      t2.allclose(cg_output_tv2),
+      "tv2 error of: ",
+      t2.sub(cg_output_tv2).abs().max());
+  TORCH_CHECK(
+      t3.allclose(cg_output_tv3),
+      "tv3 error of: ",
+      t3.sub(cg_output_tv3).abs().max());
+  TORCH_CHECK(
+      t6.allclose(cg_output_tv6),
+      "tv6 error of: ",
+      t6.sub(cg_output_tv6).abs().max());
+  TORCH_CHECK(
+      t7.allclose(cg_output_tv7),
+      "tv7 error of: ",
+      t7.sub(cg_output_tv7).abs().max());
 }
 
-TEST(NVFuserTest, ForceFp16Simple_CUDA) {
-  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
-  auto fusion = fusion_ptr.get();
-  FusionGuard fg(fusion);
+TEST(NVFuserTest, FusionTraversalOrder5_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
-  auto tv1 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(1);
+  fusion.addInput(tv0);
+  TensorView* tv1 = add(tv0, new Float(1));
+  TensorView* tv2 = add(tv1, new Float(2));
+  TensorView* tv3 = add(tv0, new Float(3));
+  TensorView* tv4 = add(tv3, new Float(4));
+  TensorView* tv5 = add(tv2, tv4);
 
-  fusion->addInput(tv0);
-  fusion->addInput(tv1);
+  fusion.addOutput(tv1);
+  fusion.addOutput(tv3);
+  fusion.addOutput(tv5);
 
-  // Group 1
-  auto tv2 = sum(tv0, {1});
-  auto tv3 = broadcast(tv2, {false, true});
+  tv2->computeAt(tv5, -1);
+  tv4->computeAt(tv5, -1);
 
-  // Group 2
-  auto tv4 = add(tv3, tv1); // Edge: tv3: expect cast
-  auto tv5 = castOp(DataType::Half, tv4);
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
 
-  fusion->addOutput(tv5);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor t0 = at::rand({100}, options);
+  at::Tensor cg_output_tv1 = at::empty_like(t0, options);
+  at::Tensor cg_output_tv3 = at::empty_like(t0, options);
+  at::Tensor cg_output_tv5 = at::empty_like(t0, options);
 
-  FusionExecutorCache fec(std::move(fusion_ptr));
+  fe.runFusion({t0}, {cg_output_tv1, cg_output_tv3, cg_output_tv5});
 
-  std::vector<int64_t> shape{15, 16};
+  auto t1 = t0 + 1;
+  auto t2 = t1 + 2;
+  auto t3 = t0 + 3;
+  auto t4 = t3 + 4;
+  auto t5 = t2 + t4;
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto in0 = at::randn(shape, options);
-  auto in1 = at::randn(shape, options);
-  fec.runFusionWithInputs({in0, in1});
-
-  // Check the segmented edge is fp16
-  auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments();
-  for (auto edge : segmented_fusion->edges()) {
-    auto edge_tv = edge->val->as<TensorView>();
-    TORCH_CHECK(edge_tv->getDataType() == DataType::Half);
-  }
+  TORCH_CHECK(
+      t1.allclose(cg_output_tv1),
+      "tv1 error of: ",
+      t1.sub(cg_output_tv1).abs().max());
+  TORCH_CHECK(
+      t3.allclose(cg_output_tv3),
+      "tv3 error of: ",
+      t3.sub(cg_output_tv3).abs().max());
+  TORCH_CHECK(
+      t5.allclose(cg_output_tv5),
+      "tv5 error of: ",
+      t5.sub(cg_output_tv5).abs().max());
 }
 
-TEST(NVFuserTest, ForceFp16NotAllCast_CUDA) {
-  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
-  auto fusion = fusion_ptr.get();
-  FusionGuard fg(fusion);
+TEST(NVFuserTest, FusionTraversalOrder6_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(3);
-  auto tv1 = makeSymbolicTensor(3);
+  TensorView* tv0 = makeDummyTensor(1);
+  fusion.addInput(tv0);
+  TensorView* tv1 = add(tv0, new Float(1));
+  TensorView* tv2 = add(tv0, new Float(2));
+  TensorView* tv3 = add(tv1, tv2);
+  TensorView* tv4 = add(tv3, new Float(4));
 
-  fusion->addInput(tv0);
-  fusion->addInput(tv1);
+  fusion.addOutput(tv4);
 
-  // Group 1
-  auto tv3 = sum(tv0, {1});
-  auto tv4 = broadcast(tv3, {false, true, false});
-  auto tv5 = sum(tv0, {1});
+  tv1->split(0, 32);
+  tv2->split(0, 32);
+  tv3->split(0, 32);
+  tv4->split(0, 32);
 
-  // Group 2
-  auto tv6 = add(tv4, tv1); // edge tv4, expect cast
-  auto tv7 = castOp(DataType::Half, tv6);
+  tv3->computeAt(tv4, -2);
+  tv1->computeAt(tv3, -1);
+  tv2->computeAt(tv3, -2);
 
-  // Group 3
-  auto tv8 = sum(tv5, {1}); // edge tv5, don't expect cast
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
 
-  fusion->addOutput(tv7);
-  fusion->addOutput(tv8);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor t0 = at::rand({100}, options);
+  at::Tensor cg_output_tv4 = at::empty_like(t0, options);
 
-  FusionExecutorCache fec(std::move(fusion_ptr));
+  fe.runFusion({t0}, {cg_output_tv4});
 
-  std::vector<int64_t> shape{16, 16, 16};
+  auto t1 = t0 + 1;
+  auto t2 = t0 + 2;
+  auto t3 = t1 + t2;
+  auto t4 = t3 + 4;
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto in0 = at::randn(shape, options);
-  auto in1 = at::randn(shape, options);
-  fec.runFusionWithInputs({in0, in1});
-
-  auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments();
-  auto complete_fusion = segmented_fusion->completeFusion();
-
-  // Check that the edge that wasn't fp16 is the producer of the
-  //  reduction op, i.e. tv8 = sum(tv5,{1});.
-  for (auto edge : segmented_fusion->edges()) {
-    auto edge_tv = edge->val->as<TensorView>();
-    if (edge_tv->getDataType() == DataType::Float) {
-      auto consumer = *(complete_fusion->unordered_uses(edge_tv).begin());
-      TORCH_CHECK(consumer->isA<ReductionOp>());
-    }
-  }
+  TORCH_CHECK(
+      t4.allclose(cg_output_tv4),
+      "tv4 error of: ",
+      t4.sub(cg_output_tv4).abs().max());
 }
 
-TEST(NVFuserTest, FusionIssue970_CUDA) {
+TEST(NVFuserTest, FusionTraversalOrder7_CUDA) {
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  const int nelm = 10;
-
-  // tv3 = tv0 + sum(tv0)
-  auto tv0 = makeConcreteTensor({nelm, nelm});
+  TensorView* tv0 = makeDummyTensor(1);
   fusion.addInput(tv0);
-  auto tv1 = sum(tv0, {1});
-  auto tv2 = broadcast(tv1, {false, true});
-  auto tv3 = add(tv2, tv0);
-  fusion.addOutput(tv3);
+  TensorView* tv1 = add(tv0, new Float(1));
+  TensorView* tv2 = add(tv1, new Float(2));
+  TensorView* tv3 = add(tv0, new Float(3));
+  TensorView* tv4 = add(tv3, new Float(4));
+  TensorView* tv5 = add(tv2, tv4);
 
-  tv1->split(1, 4);
+  fusion.addOutput(tv5);
+
+  TensorView* tvs[] = {tv1, tv2, tv3, tv4, tv5};
+  for (auto tv : tvs) {
+    tv->split(0, 2);
+    tv->split(0, 4);
+    tv->split(0, 8);
+  }
+
+  // computeAt into inner loop nests
+  tv1->computeAt(tv2, -1);
+  tv3->computeAt(tv4, -2);
+
+  tv2->computeAt(tv5, -4);
+  tv4->computeAt(tv5, -3);
 
   FusionExecutor fe;
   fe.compileFusion(&fusion);
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor t0 = at::randn({nelm, nelm}, options);
-
-  auto outputs = fe.runFusion({t0});
+  at::Tensor t0 = at::rand({100}, options);
+  at::Tensor cg_output_tv5 = at::empty_like(t0, options);
+  fe.runFusion({t0}, {cg_output_tv5});
 
-  auto ref = sum(t0, {1}).unsqueeze(-1).expand({nelm, nelm}) + t0;
+  auto t1 = t0 + 1;
+  auto t2 = t1 + 2;
+  auto t3 = t0 + 3;
+  auto t4 = t3 + 4;
+  auto t5 = t2 + t4;
 
-  testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__);
+  TORCH_CHECK(
+      t5.allclose(cg_output_tv5),
+      "tv5 error of: ",
+      t5.sub(cg_output_tv5).abs().max());
 }
 
-// Reproducer of #1016
-TEST(NVFuserTest, FusionIssue1016_CUDA) {
+// Test predication of grid reduction
+TEST(NVFuserTest, FusionThreadPredicate_CUDA) {
+  const int gdimx = 4;
+  const int bdimx = 128;
+
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(2);
+  TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
 
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(2));
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  TensorView* tv2 = unaryOp(UnaryOpType::Neg, tv1);
+  TensorView* tv3 = add(tv0, new Float(2));
 
+  fusion.addOutput(tv3);
   fusion.addOutput(tv2);
 
-  tv1->setMemoryType(MemoryType::Shared);
+  tv1->split(1, bdimx);
+  tv1->split(1, gdimx);
+  tv3->split(1, bdimx);
+  tv3->split(1, gdimx);
+
+  TensorView* tv1_rf = tv1->rFactor({1});
 
-  tv2->split(-1, 8);
+  tv1->computeAt(tv2, -1);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
+  tv1->axis(0)->parallelize(ParallelType::BIDy);
+  tv1_rf->axis(0)->parallelize(ParallelType::BIDy);
+  tv2->axis(0)->parallelize(ParallelType::BIDy);
+  tv1->axis(-2)->parallelize(ParallelType::BIDx);
+  tv1_rf->axis(-2)->parallelize(ParallelType::BIDx);
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv1_rf->axis(-1)->parallelize(ParallelType::TIDx);
+
+  tv3->axis(3)->parallelize(ParallelType::TIDx);
+  tv3->axis(2)->parallelize(ParallelType::BIDx);
+  tv3->axis(0)->parallelize(ParallelType::BIDy);
 
-  int numel_x = 10;
-  int numel_y = 11;
+  int numel_x = 100;
+  int numel_y = 1000;
 
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+  at::Tensor cg_output_tv2 = at::empty({numel_x}, options);
+  at::Tensor cg_output_tv3 = at::empty_like(input, options);
 
-  auto ref = t0 + 1 + 2;
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  fe.runFusion({input}, {cg_output_tv3, cg_output_tv2});
 
-  testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__);
+  auto aten_output_tv2 = -input.sum({1});
+  TORCH_CHECK(aten_output_tv2.allclose(cg_output_tv2));
+  auto aten_output_tv3 = input + 2.0;
+  TORCH_CHECK(aten_output_tv3.allclose(cg_output_tv3));
 }
 
-// Reproducer of #1021
-TEST(NVFuserTest, FusionIssue1021_CUDA) {
+TEST(NVFuserTest, FusionLSTMCell_CUDA) {
+  const int hidden_features = 512;
+  const int batch_size = 64;
+
   Fusion fusion;
   FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = broadcast(tv1, {false, true});
-  fusion.addOutput(tv2);
-
-  auto tv3 = tv2->cache_before();
-
-  tv2->split(0, 2);
-
-  tv1->computeAt(tv2, 1);
-
-  tv2->axis(0)->parallelize(ParallelType::TIDx);
-  tv2->axis(1)->parallelize(ParallelType::Vectorize);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({10}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto ref = (t0 + 1).unsqueeze(-1);
+  TensorView* tvs[16];
+  for (auto& tv : tvs) {
+    tv = makeDummyTensor(2);
+    fusion.addInput(tv);
+  }
 
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
+  auto ingate = unaryOp(
+      UnaryOpType::Sigmoid, add(add(add(tvs[0], tvs[1]), tvs[2]), tvs[3]));
 
-// Reproducer of issue #1053
-TEST(NVFuserTest, FusionNonUniqueThreadDim_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+  auto forgetgate = unaryOp(
+      UnaryOpType::Sigmoid, add(add(add(tvs[4], tvs[5]), tvs[6]), tvs[7]));
 
-  auto tv0 = makeSymbolicTensor(1);
-  fusion->addInput(tv0);
-  auto tv1 = sum(tv0, {0});
-  fusion->addOutput(tv1);
+  auto cellgate = unaryOp(
+      UnaryOpType::Tanh, add(add(add(tvs[8], tvs[9]), tvs[10]), tvs[11]));
 
-  auto tv2 = add(tv0, new Double(1));
-  fusion->addOutput(tv2);
+  auto outgate = unaryOp(
+      UnaryOpType::Sigmoid, add(add(add(tvs[12], tvs[13]), tvs[14]), tvs[15]));
 
-  tv1->split(0, 8);
-  auto tv1_rf = tv1->rFactor({-1});
+  auto cx = makeContigTensor(2);
+  fusion.addInput(cx);
 
-  tv1_rf->computeAt(tv1, 1);
+  auto cy = add(mul(forgetgate, cx), mul(ingate, cellgate));
 
-  tv1_rf->axis(-1)->parallelize(ParallelType::TIDx);
+  auto hy = mul(outgate, unaryOp(UnaryOpType::Tanh, cy));
 
-  tv2->axis(0)->parallelize(ParallelType::TIDx);
+  fusion.addOutput(cy);
+  fusion.addOutput(hy);
 
+  std::vector<c10::IValue> inputs;
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input1 = at::randn({32}, options);
-
-  auto at_tv1 = (input1).sum({0});
-  auto at_tv2 = input1 + 1;
+  at::Tensor large_tensor0 =
+      at::randn({batch_size, hidden_features * 4}, options);
+  at::Tensor large_tensor1 =
+      at::randn({batch_size, hidden_features * 4}, options);
+  at::Tensor large_tensor2 =
+      at::randn({batch_size, hidden_features * 4}, options);
+  at::Tensor large_tensor3 =
+      at::randn({batch_size, hidden_features * 4}, options);
 
-  FusionExecutor fe;
-  fe.compileFusion(fusion.get());
-  auto outputs = fe.runFusion({input1});
-  testValidate(
-      fusion.get(), outputs, {input1}, {at_tv1, at_tv2}, __LINE__, __FILE__);
-}
+  auto chunked0 = large_tensor0.chunk(4, 1);
+  auto chunked1 = large_tensor1.chunk(4, 1);
+  auto chunked2 = large_tensor2.chunk(4, 1);
+  auto chunked3 = large_tensor3.chunk(4, 1);
 
-TEST(NVFuserTest, FusionParallelDimensionMap1_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+  inputs.insert(inputs.end(), chunked0.begin(), chunked0.end());
+  inputs.insert(inputs.end(), chunked1.begin(), chunked1.end());
+  inputs.insert(inputs.end(), chunked2.begin(), chunked2.end());
+  inputs.insert(inputs.end(), chunked3.begin(), chunked3.end());
 
-  auto tv0 = makeSymbolicTensor(1);
-  fusion->addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv0, new Double(1));
-  fusion->addOutput(tv1);
-  fusion->addOutput(tv2);
-
-  tv1->split(0, 8, false);
-  tv1->axis(1)->parallelize(ParallelType::TIDx);
-  tv2->split(0, 8, false);
-  tv2->axis(1)->parallelize(ParallelType::TIDx);
-
-  // The extents of tv1 and tv2 axes are equal even though their
-  // actual values are not statically known
-  GpuLower gpulw(fusion.get());
-  const auto& pdmap = gpulw.parallelDimensionMap();
-  auto kir_tv1 = gpulw.lowerValue(tv1)->as<kir::TensorView>();
-  auto kir_tv2 = gpulw.lowerValue(tv2)->as<kir::TensorView>();
-  for (size_t i = 0; i < kir_tv1->domain()->domain().size(); ++i) {
-    auto dom1 = kir_tv1->domain()->domain()[i];
-    auto dom2 = kir_tv2->domain()->domain()[i];
-    TORCH_INTERNAL_ASSERT(pdmap.equalDim(dom1->extent(), dom2->extent()));
-  }
+  auto at_ingate =
+      chunked0[0].add(chunked0[1]).add(chunked0[2]).add(chunked0[3]).sigmoid();
+  auto at_forgetgate =
+      chunked1[0].add(chunked1[1]).add(chunked1[2]).add(chunked1[3]).sigmoid();
+  auto at_cellgate =
+      chunked2[0].add(chunked2[1]).add(chunked2[2]).add(chunked2[3]).tanh();
+  auto at_outgate =
+      chunked3[0].add(chunked3[1]).add(chunked3[2]).add(chunked3[3]).sigmoid();
 
-  TORCH_CHECK(pdmap.isExact(ParallelType::TIDx));
-  TORCH_CHECK(
-      pdmap.get(ParallelType::TIDx)->isA<kir::NamedScalar>() &&
-      pdmap.get(ParallelType::TIDx)->as<kir::NamedScalar>()->name() ==
-          "blockDim.x");
+  auto at_cx = at::randn({batch_size, hidden_features}, options);
+  inputs.push_back(at_cx);
+  auto at_cy = at_forgetgate.mul(at_cx).add(at_ingate.mul(at_cellgate));
+  auto at_hy = at_outgate.mul(at_cy.tanh());
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input1 = at::randn({32}, options);
+  scheduleFusion(&fusion, c10::ArrayRef<c10::IValue>(inputs));
 
   FusionExecutor fe;
-  fe.compileFusion(fusion.get());
-  auto outputs = fe.runFusion({input1});
+  fe.compileFusion(&fusion);
+  auto outputs = fe.runFusion(c10::ArrayRef<c10::IValue>(inputs));
 
-  testValidate(
-      fusion.get(),
-      outputs,
-      {input1},
-      {input1 + 1, input1 + 1},
-      __LINE__,
-      __FILE__);
+  TORCH_CHECK(at_cy.allclose(outputs[0], 1e-4, 1e-7));
+  TORCH_CHECK(at_hy.allclose(outputs[1], 1e-4, 1e-7));
 }
 
-TEST(NVFuserTest, FusionParallelDimensionMap2_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+TEST(NVFuserTest, FusionComputeAtMultiBCast_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  auto tv0 = makeSymbolicTensor(1);
-  fusion->addInput(tv0);
-  auto tv1 = makeSymbolicTensor(2);
-  fusion->addInput(tv1);
-  auto tv2 = broadcast(tv0, {false, true});
-  auto tv3 = add(tv1, tv2);
-  fusion->addOutput(tv3);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(1);
+  fusion.addInput(tv0);
 
-  tv3->split(-1, 8, false);
-  tv2->computeAt(tv3, -1);
+  TensorView* tv1 = mul(tv0, new Float(0.5));
+  TensorView* tv2 = broadcast(tv1, {true, false});
+  TensorView* tv3 = broadcast(tv1, {false, true});
+  TensorView* tv4 = add(tv2, tv3);
+  fusion.addOutput(tv4);
 
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  // This is not supported and should throw an exception.
+  ASSERT_ANY_THROW(tv1->computeAt(tv3, -1));
+}
 
-  GpuLower gpulw(fusion.get());
-  const auto& pdmap = gpulw.parallelDimensionMap();
-  TORCH_CHECK(pdmap.isExact(ParallelType::TIDx));
-  TORCH_CHECK(
-      pdmap.get(ParallelType::TIDx)->isA<kir::NamedScalar>() &&
-      pdmap.get(ParallelType::TIDx)->as<kir::NamedScalar>()->name() ==
-          "blockDim.x");
+TEST(NVFuserTest, FusionReductionHalf_CUDA) {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input1 = at::randn({11}, options);
-  at::Tensor input2 = at::randn({11, 13}, options);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(3, DataType::Half);
+  fusion.addInput(tv0);
 
-  FusionExecutor fe;
-  fe.compileFusion(fusion.get());
-  auto outputs = fe.runFusion({input1, input2});
+  auto tv1 = castOp(DataType::Float, tv0);
+  auto tv2 = add(tv1, new Float(1.0));
+  auto tv3 = sum(tv2, {2});
+  auto tv4 = castOp(DataType::Half, tv3);
 
-  auto ref = input1.unsqueeze(-1) + input2;
+  fusion.addOutput(tv4);
 
-  testValidate(
-      fusion.get(), outputs, {input1, input2}, {ref}, __LINE__, __FILE__);
-}
+  const auto options =
+      at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
+  at::Tensor input = at::randn({8, 8, 16}, options);
 
-// Mix symbolic and concrete tensors
-TEST(NVFuserTest, FusionParallelDimensionMap3_CUDA) {
-  auto fusion = std::make_unique<Fusion>();
-  FusionGuard fg(fusion.get());
+  auto reduction_tv = tv3;
 
-  auto tv0 = makeSymbolicTensor(1);
-  fusion->addInput(tv0);
+  auto outputsOfReduction = DependencyCheck::getAllOutputsOf({reduction_tv});
 
-  auto tv2 = add(tv0, new Double(1));
-  fusion->addOutput(tv2);
-  auto tv3 = add(tv0, new Double(1));
-  fusion->addOutput(tv3);
+  // Grab only tensor views, though there shouldn't be any other type
+  auto tv_entries = ir_utils::filterByType<TensorView>(outputsOfReduction);
 
-  tv2->split(0, 10);
-  tv3->split(0, 20);
+  std::vector<TensorView*> tvOutputsOfReduction(
+      tv_entries.begin(), tv_entries.end());
 
-  auto tv4 = add(tv0, new Double(1));
-  fusion->addOutput(tv4);
-  auto tv5 = add(tv0, new Double(1));
-  fusion->addOutput(tv5);
+  auto reduction_params =
+      getReductionHeuristics(&fusion, {input}, reduction_tv);
+  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
+  scheduleReduction(
+      &fusion, reduction_params.value(), reduction_tv, tvOutputsOfReduction);
 
-  // Not mapped but equal extent
-  tv4->split(0, 10);
-  tv5->split(0, 10);
+  TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
 
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  FusionExecutor fe;
+  fe.compileFusion(&fusion);
+  // no broadcasting needed, omitting the last optional argument;
+  auto outputs = fe.runFusion({input}, reduction_params.value().lparams);
 
-  tv4->axis(-1)->parallelize(ParallelType::TIDy);
-  tv5->axis(-1)->parallelize(ParallelType::TIDy);
+  auto aten_output = input.to(c10::ScalarType::Float)
+                         .add(1.0)
+                         .sum({2})
+                         .to(c10::ScalarType::Half);
 
-  GpuLower gpulw(fusion.get());
-  const auto& pdmap = gpulw.parallelDimensionMap();
-  TORCH_CHECK(!pdmap.isExact(ParallelType::TIDx));
-  TORCH_CHECK(
-      pdmap.get(ParallelType::TIDx)->isA<kir::NamedScalar>() &&
-      pdmap.get(ParallelType::TIDx)->as<kir::NamedScalar>()->name() ==
-          "blockDim.x");
-  TORCH_CHECK(pdmap.isExact(ParallelType::TIDy));
   TORCH_CHECK(
-      pdmap.get(ParallelType::TIDy)->isConst() &&
-      pdmap.get(ParallelType::TIDy)->as<kir::Int>()->value().value() == 10);
+      aten_output.allclose(outputs[0], 1e-04, 1e-04),
+      "Error of: ",
+      aten_output.sub(outputs[0]).abs().max());
+}
 
+TEST(NVFuserTest, FusionInputsIdLookup_CUDA) {
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input1 = at::randn({13}, options);
+  at::Tensor t0 = at::randn({16, 8, 8}, options);
+  at::Tensor t1 = at::randn({8, 8}, options);
+  at::Tensor t2 = at::randn({6, 4}, options);
 
-  FusionExecutor fe;
-  fe.compileFusion(fusion.get());
-  auto outputs = fe.runFusion({input1});
+  // create a cache with max size 2;
+  auto inputs_id_lookup = InputsIdLookup(2);
 
-  testValidate(
-      fusion.get(),
-      outputs,
-      {input1},
-      {input1 + 1, input1 + 1, input1 + 1, input1 + 1},
-      __LINE__,
-      __FILE__);
-}
+  // testing basic function, same encoding for identical inputs
+  auto id_0 = inputs_id_lookup.lookupId({t0, t1, 5.0});
+  auto id_0_lookup = inputs_id_lookup.lookupId({t0, t1, 2.5});
+  TORCH_CHECK(id_0.id == id_0_lookup.id);
+  TORCH_CHECK(inputs_id_lookup.size() == 1);
+  TORCH_CHECK(id_0.eviction == false);
 
-// Parallelizing merged broadcast domains
-TEST(NVFuserTest, FusionParallelDimensionMap4_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+  // new input (even tho same shape, but we have different signature because of
+  // missing scalar input
+  auto id_1 = inputs_id_lookup.lookupId({t0, t1});
+  auto id_1_lookup = inputs_id_lookup.lookupId({t0, t1});
+  TORCH_CHECK(id_1.id == id_1_lookup.id);
+  TORCH_CHECK(inputs_id_lookup.size() == 2);
+  TORCH_CHECK(id_1.eviction == false);
 
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  auto tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
-  auto tv2 = add(tv0, new Double(1));
-  auto tv3 = broadcast(tv2, {true, false});
-  auto tv4 = add(tv3, tv1);
-  fusion.addOutput(tv4);
+  // eviction should happen at this point
+  auto id_2 = inputs_id_lookup.lookupId({t2, t1});
+  TORCH_CHECK(id_2.id != id_0.id);
+  TORCH_CHECK(id_2.id != id_1.id);
+  TORCH_CHECK(inputs_id_lookup.size() == 2);
+  TORCH_CHECK(id_2.eviction == true);
+  TORCH_CHECK(id_2.evict_id == id_0.id);
 
-  tv4->split(1, 4);
-  tv4->reorder({{1, 2}, {2, 1}});
-  tv4->merge(0);
-  tv0->computeAt(tv4, 1);
-  tv1->computeAt(tv4, 1);
+  // look at input 1 again
+  auto id_1_relook = inputs_id_lookup.lookupId({t0, t1});
+  TORCH_CHECK(id_1_relook.id == id_1.id);
+  TORCH_CHECK(id_1_relook.eviction == false);
+}
 
-  // TIDx is mapped to tv4.axis(0) as well as tv2.axis(0), so it's not
-  // exact.
-  tv4->axis(0)->parallelize(ParallelType::TIDx);
+TEST(NVFuserTest, FusionGroupGuardSimpleTensor_CUDA) {
+  std::vector<int64_t> sizes_vec({16, 8, 8});
+  std::vector<int64_t> strides_vec({64, 8, 1});
+  auto tensor_type = TensorType::create(
+      at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  tv2->setMemoryType(MemoryType::Shared);
-  tv3->setMemoryType(MemoryType::Shared);
+  // pass with identical shape
+  auto t0 = at::randn({16, 8, 8}, options);
+  TORCH_CHECK(complyWith(t0, tensor_type));
 
-  GpuLower gpulw(&fusion);
-  const auto& pdmap = gpulw.parallelDimensionMap();
-  TORCH_CHECK(!pdmap.isExact(ParallelType::TIDx));
-  TORCH_CHECK(
-      pdmap.get(ParallelType::TIDx)->isA<kir::NamedScalar>() &&
-      pdmap.get(ParallelType::TIDx)->as<kir::NamedScalar>()->name() ==
-          "blockDim.x");
+  // pass with dynamic shape
+  auto t1 = at::randn({16, 16, 8}, options);
+  TORCH_CHECK(complyWith(t1, tensor_type));
 
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input1 = at::randn({13}, options);
-  at::Tensor input2 = at::randn({15, 13}, options);
+  // rank failure
+  auto t5 = at::randn({16, 8, 8, 8}, options);
+  TORCH_CHECK(!complyWith(t5, tensor_type));
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion({input1, input2});
+  // broadcasting semantic change failure
+  auto t2 = at::randn({16, 1, 8}, options);
+  TORCH_CHECK(!complyWith(t2, tensor_type));
 
-  auto ref = (input1 + 1).unsqueeze(0) + input2;
+  // contiguity failure via slicing
+  auto t3 = t0.slice(1, 0, 8, 2);
+  TORCH_CHECK(!complyWith(t3, tensor_type));
 
-  testValidate(&fusion, outputs, {input1, input2}, {ref}, __LINE__, __FILE__);
+  // contiguity failure via slicing
+  auto t4 = t0.slice(2, 0, 8, 2);
+  TORCH_CHECK(!complyWith(t4, tensor_type));
 }
 
-TEST(NVFuserTest, FusionParallelDimensionMap5_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
+TEST(NVFuserTest, FusionGroupGuardBroadcastTensor_CUDA) {
+  std::vector<int64_t> sizes_vec({16, 1, 8});
+  std::vector<int64_t> strides_vec({8, 8, 1});
+  auto tensor_type = TensorType::create(
+      at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  auto tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
-  auto tv3 = broadcast(tv0, {false, true});
-  auto tv4 = add(tv3, tv1);
-  fusion.addOutput(tv4);
+  // broadcasting semantic change
+  auto t0 = at::randn({16, 8, 8}, options);
+  TORCH_CHECK(!complyWith(t0, tensor_type));
 
-  tv4->split(1, 4);
-  tv0->computeAt(tv4, -1);
-  tv1->computeAt(tv4, -1);
+  // dtype failure
+  auto t1 = at::randn({16, 1, 8}, options.dtype(at::kHalf));
+  TORCH_CHECK(!complyWith(t1, tensor_type));
 
-  tv4->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv4->axis(-2)->parallelize(ParallelType::TIDy);
-  tv3->axis(-2)->parallelize(ParallelType::TIDy);
+  // dtype failure
+  auto t2 = at::randn({16, 1, 8}, options);
+  TORCH_CHECK(complyWith(t2, tensor_type));
 
-  GpuLower gpulw(&fusion);
-  const auto& pdmap = gpulw.parallelDimensionMap();
-  TORCH_CHECK(pdmap.isExact(ParallelType::TIDx));
-  TORCH_CHECK(pdmap.isExact(ParallelType::TIDy));
-  TORCH_CHECK(
-      pdmap.get(ParallelType::TIDx)->isConst() &&
-      pdmap.get(ParallelType::TIDx)->as<kir::Int>()->value().value() == 4);
-  TORCH_CHECK(
-      pdmap.get(ParallelType::TIDy)->isA<kir::NamedScalar>() &&
-      pdmap.get(ParallelType::TIDy)->as<kir::NamedScalar>()->name() ==
-          "blockDim.y");
+  // device inconsistency shouldn't fail
+  auto t3 = at::randn({16, 1, 8}, options.device(at::kCPU, 0));
+  TORCH_CHECK(complyWith(t3, tensor_type));
+}
 
+TEST(NVFuserTest, FusionGroupGuardPermutedTensor_CUDA) {
+  std::vector<int64_t> sizes_vec({16, 8, 8});
+  std::vector<int64_t> strides_vec({64, 1, 8});
+  auto tensor_type = TensorType::create(
+      at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
   auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor input1 = at::randn({13}, options);
-  at::Tensor input2 = at::randn({13, 15}, options);
 
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion({input1, input2});
+  // failing permutation
+  auto t0 = at::randn({16, 8, 8}, options);
+  TORCH_CHECK(!complyWith(t0, tensor_type));
+
+  // passing with dynamic shape
+  auto t1 = t0.permute({0, 2, 1});
+  TORCH_CHECK(complyWith(t1, tensor_type));
+}
+
+TEST(NVFuserTest, FusionGroupGuardRelaxedCheck_CUDA) {
+  std::vector<int64_t> sizes_vec({16, 8, 8});
+  std::vector<int64_t> strides_vec({128, 16, 1});
+  auto tensor_type = TensorType::create(
+      at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
 
-  auto ref = (input1).unsqueeze(-1) + input2;
+  // contiguity check passes although it differs
+  auto t0 = at::randn({16, 16, 8}, options);
+  TORCH_CHECK(complyWith(t0, tensor_type));
 
-  testValidate(&fusion, outputs, {input1, input2}, {ref}, __LINE__, __FILE__);
+  // passing with dynamic shape
+  auto t1 = t0.slice(1, 0, 16, 2);
+  TORCH_CHECK(complyWith(t1, tensor_type));
 }
 
 } // namespace jit
 } // namespace torch
+
 #endif // #if defined(USE_CUDA)
diff --git a/test/cpp/jit/test_gpu_shift.cpp b/test/cpp/jit/test_gpu_shift.cpp
deleted file mode 100644 (file)
index 72a3b8b..0000000
+++ /dev/null
@@ -1,2870 +0,0 @@
-#if defined(USE_CUDA)
-#include <gtest/gtest.h>
-
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/codegen.h>
-#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
-#include <torch/csrc/jit/codegen/cuda/interface.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_graphviz.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/mutator.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
-
-// fuser and IR parser
-#include "test_gpu_validator.h"
-
-#include <ATen/cuda/Exceptions.h>
-#include <c10/cuda/CUDAStream.h>
-
-#include <algorithm>
-#include <iostream>
-
-// Tests go in torch::jit
-namespace torch {
-namespace jit {
-
-using namespace torch::jit::fuser::cuda;
-using namespace at::indexing;
-
-namespace {
-
-// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
-// but unknown sizes
-TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) {
-  return TensorViewBuilder()
-      .ndims(ndims)
-      .dtype(dtype)
-      .contiguity(std::vector<bool>(ndims, true))
-      .build();
-}
-
-// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
-// but unknown sizes
-TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
-  return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
-}
-
-// Make a non-contiguous tensor of compile-time known sizes
-TensorView* makeConcreteTensor(
-    std::vector<int64_t> shape,
-    DataType dtype = DataType::Float) {
-  return TensorViewBuilder().shape(shape).dtype(dtype).build();
-}
-
-void checkIntValue(
-    ExpressionEvaluator& evaluator,
-    Val* val,
-    Int::ScalarType expected_value) {
-  TORCH_CHECK(val->isAnInt());
-  const auto actual_value = evaluator.evaluate(val);
-  TORCH_CHECK(actual_value.has_value());
-  TORCH_CHECK(actual_value.value() == expected_value);
-}
-
-void checkIntValue(
-    kir::ExpressionEvaluator& evaluator,
-    const kir::Val* val,
-    kir::Int::ScalarType expected_value) {
-  const auto actual_value = evaluator.evaluate(val);
-  TORCH_CHECK(actual_value.has_value());
-  TORCH_CHECK(actual_value.value() == expected_value);
-}
-
-// ATen version of tensor shifting
-auto shift(at::Tensor tensor, const std::vector<int>& offsets) {
-  TORCH_INTERNAL_ASSERT(tensor.ndimension() == offsets.size());
-  at::Tensor t = tensor;
-  for (size_t i = 0; i < offsets.size(); ++i) {
-    const auto offset = offsets[i];
-    if (offset == 0) {
-      continue;
-    }
-    t = t.roll(offsets[i], i);
-    std::vector<at::indexing::TensorIndex> indices(
-        tensor.ndimension(), at::indexing::Slice(0, at::indexing::None));
-    if (offset > 0) {
-      indices[i] = at::indexing::Slice(0, offset);
-    } else {
-      indices[i] = at::indexing::Slice(offset, at::indexing::None);
-    }
-    t.index(indices) = 0;
-  }
-  return t;
-}
-
-// ATen version of tensor shifting
-auto gather(
-    at::Tensor tensor,
-    const std::vector<int>& window_shape,
-    const std::vector<std::vector<int>>& pad_width) {
-  TORCH_CHECK(
-      tensor.ndimension() == window_shape.size(),
-      "Invalid window shape: ",
-      window_shape,
-      ". Size of the window shape is different from the tensor dimension.");
-  TORCH_CHECK(
-      tensor.ndimension() == pad_width.size(),
-      "Invalid pad width: ",
-      pad_width,
-      ". Size of the pad width is different from the tensor dimension.");
-  at::Tensor t = tensor;
-  for (size_t i = 0; i < window_shape.size(); ++i) {
-    const auto w_size = window_shape[i];
-    TORCH_CHECK(w_size != 0);
-    const auto& pad = pad_width[i];
-    TORCH_CHECK(pad.size() == 2);
-    at::Tensor concat_tensor;
-    for (int w = 0; w < w_size; ++w) {
-      std::vector<int> shift_offsets(t.ndimension(), 0);
-      shift_offsets[i] = pad[0] - w;
-      auto shifted = shift(t, shift_offsets);
-      shifted = shifted.unsqueeze(-1);
-      if (w == 0) {
-        concat_tensor = shifted;
-      } else {
-        concat_tensor = at::cat({concat_tensor, shifted}, -1);
-      }
-    }
-    t = concat_tensor;
-  }
-  return t;
-}
-
-} // namespace
-
-// Shift an input tensor
-TEST(NVFuserTest, FusionShift1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  auto tv1 = shift(tv0, {-1, 0});
-  fusion.addOutput(tv1);
-
-  auto tv2 = shift(tv0, {0, 1});
-  fusion.addOutput(tv2);
-
-  auto tv3 = shift(tv0, {2, 2});
-  fusion.addOutput(tv3);
-
-  auto tv4 = shift(tv0, {-2, -2});
-  fusion.addOutput(tv4);
-
-  int numel_x = 9;
-  int numel_y = 11;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = shift(t0, {-1, 0});
-  TORCH_CHECK(t1.equal(outputs[0]));
-
-  auto t2 = shift(t0, {0, 1});
-  TORCH_CHECK(t2.equal(outputs[1]));
-
-  auto t3 = shift(t0, {2, 2});
-  TORCH_CHECK(t3.equal(outputs[2]));
-
-  auto t4 = shift(t0, {-2, -2});
-  TORCH_CHECK(t4.equal(outputs[3]));
-}
-
-// Shifts an intermediate tensor
-TEST(NVFuserTest, FusionShift2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = shift(tv1, {-1, 0});
-  fusion.addOutput(tv2);
-
-  // make it a little more complex
-  auto tv3 = add(tv0, new Double(3));
-  auto tv4 = add(tv3, new Double(4));
-  auto tv5 = shift(tv4, {-1, 0});
-  auto tv6 = shift(tv4, {0, -1});
-  auto tv7 = shift(tv4, {1, 0});
-  auto tv8 = shift(tv4, {0, 0});
-  auto tv9 = add(tv5, tv6);
-  auto tv10 = add(tv9, tv7);
-  auto tv11 = add(tv10, tv8);
-  fusion.addOutput(tv11);
-
-  for (auto tv : {tv1, tv2, tv3, tv4, tv5, tv6, tv7, tv8, tv9, tv10, tv11}) {
-    tv->setMemoryType(MemoryType::Global);
-  }
-
-  // t1 allocation: (t1.size[0] + 1) * (t1.size[1])
-  // t3 allocation: (t3.size[0] + 2) * (t3.size[1] + 1)
-  // t4 allocation: (t3.size[0] + 2) * (t3.size[1] + 1)
-  GpuLower gpulw(&fusion);
-
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == 1 || tensor_name == 3 || tensor_name == 4) {
-        TORCH_CHECK(alloc->shape().size() == 2);
-        for (int i = 0; i < 2; ++i) {
-          if (tensor_name == 1 && i == 1) {
-            TORCH_CHECK(alloc->shape().at(i)->isA<kir::NamedScalar>());
-            continue;
-          }
-          auto def =
-              dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
-          TORCH_CHECK(def != nullptr && def->operation() == BinaryOpType::Add);
-          TORCH_CHECK(def->as<kir::BinaryOp>()->lhs()->isA<kir::NamedScalar>());
-          auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-          TORCH_CHECK(rhs != nullptr && rhs->isConst());
-          int rhs_value = *rhs->value();
-          if (tensor_name == 1) {
-            TORCH_CHECK(i == 0);
-            TORCH_CHECK(rhs_value == 1);
-          } else {
-            if (i == 0) {
-              TORCH_CHECK(rhs_value == 2);
-            } else {
-              TORCH_CHECK(rhs_value == 1);
-            }
-          }
-        }
-      }
-    }
-  }
-
-  int numel_x = 9;
-  int numel_y = 11;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = shift(t1, {-1, 0});
-
-  auto t3 = t0 + 3;
-  auto t4 = t3 + 4;
-  auto t5 = shift(t4, {-1, 0});
-  auto t6 = shift(t4, {0, -1});
-  auto t7 = shift(t4, {1, 0});
-  auto t8 = shift(t4, {0, 0});
-  auto t9 = t5 + t6;
-  auto t10 = t9 + t7;
-  auto t11 = t10 + t8;
-
-  testValidate(&fusion, outputs, inputs, {t2, t11}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftRightOfCA_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = shift(tv1, {0, 1});
-  fusion.addOutput(tv2);
-
-  tv0->computeAt(tv2, -2);
-
-  tv1->setMemoryType(MemoryType::Global);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 100;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = shift(t1, {0, 1});
-
-  TORCH_CHECK(t2.allclose(outputs[0]));
-}
-
-TEST(NVFuserTest, FusionShiftLeftOfCA_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(1));
-  auto tv3 = shift(tv2, {-1, 0});
-  auto tv4 = add(tv3, new Double(1));
-  fusion.addOutput(tv4);
-
-  tv0->computeAt(tv4, -1);
-
-  // Lowering should trigger an assertion failure as a shifted axis is
-  // found inside an allocation position.
-  ASSERT_ANY_THROW(fusion.printKernel());
-}
-
-TEST(NVFuserTest, FusionShiftSplit1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = shift(tv1, {0, 1});
-  auto tv3 = shift(tv1, {0, -2});
-  fusion.addOutput(tv2);
-  fusion.addOutput(tv3);
-
-  int split_factor = 4;
-  tv2->split(-1, split_factor);
-  tv3->split(-1, split_factor);
-
-  tv0->computeAt(tv2, -2);
-  tv0->computeAt(tv3, -2);
-
-  // t1 allocation: (4 + 3)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == 1) {
-        TORCH_CHECK(alloc->shape().size() == 1);
-        auto def =
-            dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition());
-        auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-        TORCH_CHECK(lhs != nullptr && lhs->isConst());
-        int lhs_value = *lhs->value();
-        auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-        TORCH_CHECK(rhs != nullptr && rhs->isConst());
-        int rhs_value = *rhs->value();
-        TORCH_CHECK(lhs_value == split_factor && rhs_value == 3);
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 9;
-  int numel_y = 11;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = shift(t1, {0, 1});
-  auto t3 = shift(t1, {0, -2});
-
-  testValidate(&fusion, outputs, inputs, {t2, t3}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftSplit2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(1));
-  auto tv3 = shift(tv2, {0, -1});
-  auto tv4 = shift(tv2, {0, 1});
-  auto tv5 = add(tv3, tv4);
-  fusion.addOutput(tv5);
-
-  auto tv6 = add(tv0, new Double(1));
-  auto tv7 = shift(tv6, {0, 0});
-  auto tv8 = add(tv7, new Double(1));
-  fusion.addOutput(tv8);
-
-  int split_factor = 4;
-
-  tv5->split(-1, split_factor);
-  tv8->split(-1, split_factor);
-
-  tv0->computeAt(tv5, -2);
-  tv0->computeAt(tv8, -2);
-
-  // t1 and t2 allocation: (4 + 2)
-  // t4 allocation: (4)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == 1 || tensor_name == 2) {
-        TORCH_CHECK(alloc->shape().size() == 1);
-        auto def =
-            dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition());
-        auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-        TORCH_CHECK(lhs != nullptr && lhs->isConst());
-        int lhs_value = *lhs->value();
-        auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-        TORCH_CHECK(rhs != nullptr && rhs->isConst());
-        int rhs_value = *rhs->value();
-        TORCH_CHECK(lhs_value == split_factor && rhs_value == 2);
-      } else if (tensor_name == 4) {
-        TORCH_CHECK(alloc->shape().size() == 1);
-        auto size = dynamic_cast<kir::Int*>(alloc->shape().at(0));
-        TORCH_CHECK(size != nullptr && size->isConst());
-        int size_value = *size->value();
-        TORCH_CHECK(size_value == split_factor);
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 9;
-  int numel_y = 11;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 2;
-  auto t3 = shift(t1, {0, -1});
-  auto t4 = shift(t1, {0, 1});
-  auto t5 = t3 + t4;
-
-  auto t6 = t0 + 1;
-  auto t7 = t6;
-  auto t8 = t7 + 1;
-
-  testValidate(&fusion, outputs, inputs, {t5, t8}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftDoubleSplit_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(2));
-  auto tv3 = shift(tv2, {0, 1});
-  fusion.addOutput(tv3);
-
-  int split_factor1 = 8;
-  int split_factor2 = 4;
-
-  tv3->split(-1, split_factor1);
-
-  tv0->computeAt(tv3, -2);
-
-  tv1->split(-1, split_factor2);
-
-  // t1: [i1, i2/8, 8/4, 4]
-  // t2: [i1, i2/8, 8]
-  // t3: [i1, i2/8, 8]
-
-  // t1 and t2 allocation: (split_factor1 + 1)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == 1 || tensor_name == 2) {
-        TORCH_CHECK(alloc->shape().size() == 1);
-        auto def =
-            dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition());
-        auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-        TORCH_CHECK(lhs != nullptr && lhs->isConst());
-        int lhs_value = *lhs->value();
-        auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-        TORCH_CHECK(rhs != nullptr && rhs->isConst());
-        int rhs_value = *rhs->value();
-        TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1);
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 3;
-  auto ref = shift(t1, {0, 1});
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift3ptStencil_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // 3-pt stencil
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-
-  std::vector<std::vector<int>> offsets = {{-1}, {1}};
-
-  std::vector<TensorView*> tvs;
-  for (const auto& offset : offsets) {
-    tvs.push_back(shift(tv0, offset));
-  }
-
-  auto tv_out = tv0;
-
-  for (auto tv : tvs) {
-    tv_out = add(tv_out, tv);
-  }
-
-  tv_out = div(tv_out, new Double(tvs.size() + 1));
-
-  fusion.addOutput(tv_out);
-
-  int split_factor = 4;
-
-  tv_out->split(0, split_factor);
-
-  // This seems fine but not verified yet
-  // tv_out->axis(-1)->parallelize(ParallelType::Unswitch);
-
-  auto cache = tv0->cache_after();
-
-  tv0->computeAt(tv_out, 1);
-
-  // Inline completely except for the cache
-  for (auto tv : tvs) {
-    tv->computeAt(tv_out, -1);
-  }
-
-  // cache allocation: (split_factor + 2)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == cache->name()) {
-        TORCH_CHECK(alloc->shape().size() == 1);
-        auto def =
-            dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition());
-        auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-        TORCH_CHECK(lhs != nullptr && lhs->isConst());
-        int lhs_value = *lhs->value();
-        auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-        TORCH_CHECK(rhs != nullptr && rhs->isConst());
-        int rhs_value = *rhs->value();
-        TORCH_CHECK(lhs_value == split_factor && rhs_value == 2);
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto ref = (t0 + shift(t0, {-1}) + shift(t0, {1})) / 3;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift5ptStencil_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // 5-pt stencil
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};
-
-  std::vector<TensorView*> tvs;
-  for (const auto& offset : offsets) {
-    tvs.push_back(shift(tv0, offset));
-  }
-
-  auto tv_out = tv0;
-
-  for (auto tv : tvs) {
-    tv_out = add(tv_out, tv);
-  }
-
-  tv_out = div(tv_out, new Double(tvs.size() + 1));
-
-  fusion.addOutput(tv_out);
-
-  std::vector<int> split_factor({4, 8});
-
-  tv_out->split(-1, split_factor[1]);
-  tv_out->split(0, split_factor[0]);
-  tv_out->reorder({{1, 2}, {2, 1}});
-
-  auto cache = tv0->cache_after();
-
-  tv0->computeAt(tv_out, 2);
-
-  // Inline completely except for the cache
-  for (auto tv : tvs) {
-    tv->computeAt(tv_out, -1);
-  }
-
-  // cache allocation: (split_factor + 2) * (split_factor + 2)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == cache->name()) {
-        TORCH_CHECK(alloc->shape().size() == 2);
-        for (int i = 0; i < 2; ++i) {
-          auto def =
-              dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
-          auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-          TORCH_CHECK(lhs != nullptr && lhs->isConst());
-          int lhs_value = *lhs->value();
-          auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-          TORCH_CHECK(rhs != nullptr && rhs->isConst());
-          int rhs_value = *rhs->value();
-          TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2);
-        }
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto ref = t0;
-  for (const auto& offset : offsets) {
-    ref = ref + shift(t0, offset);
-  }
-  ref = ref / int(offsets.size() + 1);
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift9ptStencil_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // 9-pt stencil
-  std::vector<std::vector<int>> offsets;
-  for (int i = -1; i < 2; ++i) {
-    for (int j = -1; j < 2; ++j) {
-      if (i == 0 && j == 0) {
-        continue;
-      }
-      offsets.push_back({i, j});
-    }
-  }
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  std::vector<TensorView*> tvs;
-  for (const auto& offset : offsets) {
-    tvs.push_back(shift(tv0, offset));
-  }
-
-  auto tv_out = tv0;
-
-  for (auto tv : tvs) {
-    tv_out = add(tv_out, tv);
-  }
-
-  tv_out = div(tv_out, new Double(tvs.size() + 1));
-
-  fusion.addOutput(tv_out);
-
-  std::vector<int> split_factor({4, 8});
-  tv_out->split(-1, split_factor[1]);
-  tv_out->split(0, split_factor[0]);
-  tv_out->reorder({{1, 2}, {2, 1}});
-
-  auto cache = tv0->cache_after();
-
-  tv0->computeAt(tv_out, 2);
-
-  // Inline completely except for the cache
-  for (auto tv : tvs) {
-    tv->computeAt(tv_out, -1);
-  }
-
-  // This seems fine but not yet verified
-  // tv_out->axis(-1)->parallelize(ParallelType::Unswitch);
-
-  // cache allocation: (split_factor + 2) * (split_factor + 2)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == cache->name()) {
-        TORCH_CHECK(alloc->shape().size() == 2);
-        for (int i = 0; i < 2; ++i) {
-          auto def =
-              dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
-          auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-          TORCH_CHECK(lhs != nullptr && lhs->isConst());
-          int lhs_value = *lhs->value();
-          auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-          TORCH_CHECK(rhs != nullptr && rhs->isConst());
-          int rhs_value = *rhs->value();
-          TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2);
-        }
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto ref = t0;
-  for (const auto& offset : offsets) {
-    ref = ref + shift(t0, offset);
-  }
-  ref = ref / int(offsets.size() + 1);
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftSmemBlocking_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = shift(tv1, {0, 1});
-  fusion.addOutput(tv2);
-
-  int smem_block_factor = 32;
-
-  tv2->split(-1, smem_block_factor);
-
-  tv0->computeAt(tv2, -2);
-
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
-  tv1->setMemoryType(MemoryType::Shared);
-
-  // tv1 allocation: (split_factor + 1)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == tv1->name()) {
-        TORCH_CHECK(alloc->shape().size() == 1);
-        for (int i = 0; i < 1; ++i) {
-          auto def =
-              dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
-          auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-          TORCH_CHECK(lhs != nullptr && lhs->isConst());
-          int lhs_value = *lhs->value();
-          auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-          TORCH_CHECK(rhs != nullptr && rhs->isConst());
-          int rhs_value = *rhs->value();
-          TORCH_CHECK(lhs_value == smem_block_factor && rhs_value == 1);
-        }
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 100;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = shift(t1, {0, 1});
-  auto ref = t2;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift3ptStencilParallel_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // 3-pt stencil
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  std::vector<TensorView*> tvs;
-  tvs.push_back(shift(tv0, {-1}));
-  tvs.push_back(shift(tv0, {1}));
-
-  auto tv_out = tv0;
-
-  for (auto tv : tvs) {
-    tv_out = add(tv_out, tv);
-  }
-
-  tv_out = div(tv_out, new Double(tvs.size() + 1));
-
-  fusion.addOutput(tv_out);
-
-  int smem_block_factor = 32;
-
-  tv_out->split(0, smem_block_factor);
-  // tv_out->axis(-1)->parallelize(ParallelType::Unswitch);
-
-  auto tv0_cache = tv0->cache_after();
-
-  tv0->computeAt(tv_out, 1);
-
-  for (auto tv : tvs) {
-    tv->computeAt(tv_out, -1);
-  }
-
-  tv0_cache->setMemoryType(MemoryType::Shared);
-  tv_out->axis(-1)->parallelize(ParallelType::TIDx);
-  tv0_cache->axis(-1)->parallelize(ParallelType::TIDx);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto ref = (t0 + shift(t0, {-1}) + shift(t0, {1})) / 3;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift5ptStencilParallel_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // 5-pt stencil
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};
-
-  std::vector<TensorView*> tvs;
-  for (const auto& offset : offsets) {
-    tvs.push_back(shift(tv0, offset));
-  }
-
-  auto tv_out = tv0;
-
-  for (auto tv : tvs) {
-    tv_out = add(tv_out, tv);
-  }
-
-  tv_out = div(tv_out, new Double(tvs.size() + 1));
-
-  fusion.addOutput(tv_out);
-
-  int smem_block_factor = 32;
-
-  tv_out->split(-1, smem_block_factor);
-  tv_out->split(0, smem_block_factor);
-
-  tv_out->reorder({{1, 2}, {2, 1}});
-
-  auto tv0_cache = tv0->cache_after();
-
-  tv0->computeAt(tv_out, 2);
-
-  for (auto tv : tvs) {
-    tv->computeAt(tv_out, -1);
-  }
-
-  tv_out->axis(-1)->parallelize(ParallelType::TIDx);
-  tv_out->axis(-2)->parallelize(ParallelType::TIDy);
-  tv_out->axis(-3)->parallelize(ParallelType::BIDx);
-  tv_out->axis(-4)->parallelize(ParallelType::BIDy);
-
-  tv0_cache->setMemoryType(MemoryType::Shared);
-  tv0_cache->axis(-1)->parallelize(ParallelType::TIDx);
-  tv0_cache->axis(-2)->parallelize(ParallelType::TIDy);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto ref = t0;
-  for (const auto& offset : offsets) {
-    ref = ref + shift(t0, offset);
-  }
-  ref = ref / int(offsets.size() + 1);
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftMerge1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = shift(tv1, {-1, 1});
-  fusion.addOutput(tv2);
-
-  int split_factor = 4;
-
-  tv2->split(-1, split_factor);
-  tv2->split(0, split_factor);
-  tv2->reorder({{1, 2}, {2, 1}});
-  tv2->merge(2, 3);
-
-  tv0->computeAt(tv2, 2);
-
-  // t1 allocation: (split_factor + 1) * (split_factor + 1)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == 1) {
-        TORCH_CHECK(alloc->shape().size() == 2);
-        for (int i = 0; i < 2; ++i) {
-          auto def =
-              dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
-          auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-          TORCH_CHECK(lhs != nullptr && lhs->isConst());
-          int lhs_value = *lhs->value();
-          auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-          TORCH_CHECK(rhs != nullptr && rhs->isConst());
-          int rhs_value = *rhs->value();
-          TORCH_CHECK(lhs_value == split_factor && rhs_value == 1);
-        }
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = shift(t1, {-1, 1});
-  auto ref = t2;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftMerge2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = shift(tv1, {1, -1});
-  auto tv3 = shift(tv1, {-1, 1});
-  auto tv4 = add(tv2, tv3);
-  fusion.addOutput(tv4);
-
-  int split_factor = 4;
-
-  tv4->split(-1, split_factor);
-  tv4->split(0, split_factor);
-  tv4->reorder({{1, 2}, {2, 1}});
-  tv4->merge(2, 3);
-
-  tv0->computeAt(tv4, -2);
-
-  // t1 allocation: (split_factor + 2) * (split_factor + 2)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == 1) {
-        TORCH_CHECK(alloc->shape().size() == 2);
-        for (int i = 0; i < 2; ++i) {
-          auto def =
-              dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
-          auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-          TORCH_CHECK(lhs != nullptr && lhs->isConst());
-          int lhs_value = *lhs->value();
-          auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-          TORCH_CHECK(rhs != nullptr && rhs->isConst());
-          int rhs_value = *rhs->value();
-          TORCH_CHECK(lhs_value == split_factor && rhs_value == 2);
-        }
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = shift(t1, {1, -1});
-  auto t3 = shift(t1, {-1, 1});
-  auto t4 = t2 + t3;
-
-  TORCH_CHECK(t4.allclose(outputs[0]));
-}
-
-TEST(NVFuserTest, FusionShiftGlobal_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = shift(tv1, {0, 1});
-  auto tv3 = shift(tv1, {-1, 0});
-  auto tv4 = add(tv2, tv3);
-  fusion.addOutput(tv4);
-
-  tv1->split(-1, 4);
-  tv2->split(-1, 8);
-  tv3->split(-1, 2);
-  tv4->split(-1, 3);
-
-  tv1->merge(-2, -1);
-
-  tv1->setMemoryType(MemoryType::Global);
-  tv2->setMemoryType(MemoryType::Global);
-  tv3->setMemoryType(MemoryType::Global);
-
-  // t1 allocation: (t1.size[0] + 1) * (t1.size[1] + 1)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == 1) {
-        TORCH_CHECK(alloc->shape().size() == 2);
-        for (int i = 0; i < 2; ++i) {
-          auto def =
-              dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
-          TORCH_CHECK(def != nullptr && def->operation() == BinaryOpType::Add);
-          TORCH_CHECK(def->as<kir::BinaryOp>()->lhs()->isA<kir::NamedScalar>());
-          auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-          TORCH_CHECK(rhs != nullptr && rhs->isConst());
-          int rhs_value = *rhs->value();
-          TORCH_CHECK(rhs_value == 1);
-        }
-      }
-    }
-  }
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = shift(t1, {0, 1});
-  auto t3 = shift(t1, {-1, 0});
-  auto t4 = t2 + t3;
-  auto ref = t4;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftDoubleSplitMerge1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(2));
-  auto tv3 = shift(tv2, {0, 1});
-  fusion.addOutput(tv3);
-
-  int split_factor1 = 8;
-  int split_factor2 = 4;
-
-  tv3->split(-1, split_factor1);
-
-  tv0->computeAt(tv3, -2);
-
-  tv1->split(-1, split_factor2);
-  tv1->merge(-2, -1);
-
-  // t1 and t2 allocation: (split_factor1 + 1)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == 1 || tensor_name == 2) {
-        TORCH_CHECK(alloc->shape().size() == 1);
-        auto def =
-            dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition());
-        auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-        TORCH_CHECK(lhs != nullptr && lhs->isConst());
-        int lhs_value = *lhs->value();
-        auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-        TORCH_CHECK(rhs != nullptr && rhs->isConst());
-        int rhs_value = *rhs->value();
-        TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1);
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 3;
-  auto ref = shift(t1, {0, 1});
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftDoubleSplitMerge2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(2));
-  auto tv3 = shift(tv2, {1, 1});
-  fusion.addOutput(tv3);
-
-  auto out = tv3;
-
-  int split_factor1 = 32;
-  int split_factor2 = 4;
-
-  out->split(-1, split_factor1);
-  out->split(-1, split_factor2);
-  out->split(0, split_factor1);
-  out->split(1, split_factor2);
-  out->reorder({{3, 1}, {1, 2}, {4, 3}, {2, 4}});
-  out->merge(2, 3);
-  out->merge(2, 3);
-  out->merge(2, 3);
-  out->merge(0, 1);
-
-  TransformPropagator::from(out);
-
-  tv0->computeAt(out, 1);
-
-  out->axis(0)->parallelize(ParallelType::BIDx);
-  out->axis(1)->parallelize(ParallelType::TIDx);
-
-  scheduler_utils::parallelizeAllLike(out, {tv1, tv2});
-
-  for (auto tv : {tv1, tv2}) {
-    tv->setMemoryType(MemoryType::Shared);
-  }
-
-  // t1 and t2 allocation: (split_factor1 + 1) * (split_factor1 + 1)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == 1 || tensor_name == 2) {
-        TORCH_CHECK(alloc->shape().size() == 2);
-        for (int i = 0; i < 2; ++i) {
-          auto def =
-              dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
-          auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-          TORCH_CHECK(lhs != nullptr && lhs->isConst());
-          int lhs_value = *lhs->value();
-          auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-          TORCH_CHECK(rhs != nullptr && rhs->isConst());
-          int rhs_value = *rhs->value();
-          TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1);
-        }
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto ref = shift(t0 + 1 + 2, {1, 1});
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // 5-pt stencil
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};
-
-  std::vector<TensorView*> tvs;
-  for (const auto& offset : offsets) {
-    tvs.push_back(shift(tv0, offset));
-  }
-
-  auto tv_out = tv0;
-
-  for (auto tv : tvs) {
-    tv_out = add(tv_out, tv);
-  }
-
-  tv_out = div(tv_out, new Double(tvs.size() + 1));
-
-  fusion.addOutput(tv_out);
-
-  std::vector<int> split_factor({4, 32});
-
-  tv_out->split(-1, split_factor[1]);
-  tv_out->split(0, split_factor[0]);
-  tv_out->reorder({{1, 2}, {2, 1}});
-
-  auto tv0_cache = tv0->cache_after();
-
-  // Merge the inner-most two axes and create
-  // a 1D thread block of split_factor1*split_factor2 threads
-  tv_out->merge(-2, -1);
-
-  tv0->computeAt(tv_out, 2);
-
-  // Inline completely except for the cache
-  for (auto tv : tvs) {
-    tv->computeAt(tv_out, -1);
-  }
-
-  tv0_cache->merge(-2, -1);
-
-  tv_out->axis(-1)->parallelize(ParallelType::TIDx);
-  tv_out->axis(1)->parallelize(ParallelType::BIDx);
-  tv_out->axis(0)->parallelize(ParallelType::BIDy);
-
-  tv0_cache->setMemoryType(MemoryType::Shared);
-  tv0_cache->axis(-1)->parallelize(ParallelType::TIDx);
-
-  // cache allocation: (split_factor1 + 2) * (split_factor2 + 2)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == tv0_cache->name()) {
-        TORCH_CHECK(alloc->shape().size() == 2);
-        for (int i = 0; i < 2; ++i) {
-          auto def =
-              dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
-          auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-          TORCH_CHECK(lhs != nullptr && lhs->isConst());
-          int lhs_value = *lhs->value();
-          auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-          TORCH_CHECK(rhs != nullptr && rhs->isConst());
-          int rhs_value = *rhs->value();
-          TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2);
-        }
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto ref = t0;
-  for (const auto& offset : offsets) {
-    ref = ref + shift(t0, offset);
-  }
-  ref = ref / int(offsets.size() + 1);
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftChain1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = shift(tv0, {0, 1});
-  auto tv2 = shift(tv1, {0, 1});
-  fusion.addOutput(tv2);
-
-  int split_factor = 4;
-  tv2->split(-1, split_factor);
-
-  tv0->computeAt(tv2, -2);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto ref = shift(shift(t0, {0, 1}), {0, 1});
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftChain2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = shift(tv0, {0, 1});
-  auto tv2 = shift(tv1, {0, -1});
-  fusion.addOutput(tv2);
-
-  tv2->split(-1, 4);
-
-  tv0->computeAt(tv2, -2);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto ref = shift(shift(t0, {0, 1}), {0, -1});
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftChain3_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = shift(tv1, {0, 1});
-  auto tv3 = shift(tv2, {0, 1});
-  fusion.addOutput(tv3);
-
-  int split_factor = 4;
-  tv3->split(-1, split_factor);
-
-  tv0->computeAt(tv3, -2);
-
-  // Halo size of tv1 is 2 as it needs to account for both of the two
-  // shift operations , while that of tv2 is still just 1
-
-  // tv1: (split_factor + 2)
-  // tv2: (split_factor + 1)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == 1 || tensor_name == 2) {
-        TORCH_CHECK(alloc->shape().size() == 1);
-        for (int i = 0; i < 1; ++i) {
-          auto def =
-              dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
-          auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-          TORCH_CHECK(lhs != nullptr && lhs->isConst());
-          int lhs_value = *lhs->value();
-          auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-          TORCH_CHECK(rhs != nullptr && rhs->isConst());
-          int rhs_value = *rhs->value();
-          TORCH_CHECK(lhs_value == split_factor);
-          if (tensor_name == 1) {
-            TORCH_CHECK(rhs_value == 2);
-          } else if (tensor_name == 2) {
-            TORCH_CHECK(rhs_value == 1);
-          }
-        }
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = shift(t1, {0, 1});
-  auto t3 = shift(t2, {0, 1});
-  auto ref = t3;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftChain4_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = shift(tv0, {1, -1});
-  auto tv2 = shift(tv1, {2, -2});
-  auto tv3 = shift(tv2, {3, -3});
-  auto tv4 = shift(tv3, {4, -4});
-  auto tv_out = tv4;
-
-  fusion.addOutput(tv_out);
-
-  int split_factor = 4;
-
-  tv_out->split(-1, split_factor);
-  tv_out->split(0, split_factor);
-  tv_out->reorder({{1, 2}, {2, 1}});
-
-  tv0->computeAt(tv_out, 2);
-
-  tv1->merge(-2, -1);
-  tv2->merge(-2, -1);
-  tv3->merge(-2, -1);
-
-  // tv1: (split_factor + 9) * (split_factor + 9)
-  // tv2: (split_factor + 7) * (split_factor + 7)
-  // tv3: (split_factor + 4) * (split_factor + 4)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == 1 || tensor_name == 2) {
-        TORCH_CHECK(alloc->shape().size() == 2);
-        for (int i = 0; i < 2; ++i) {
-          auto def =
-              dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
-          auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-          TORCH_CHECK(lhs != nullptr && lhs->isConst());
-          int lhs_value = *lhs->value();
-          auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-          TORCH_CHECK(rhs != nullptr && rhs->isConst());
-          int rhs_value = *rhs->value();
-          TORCH_CHECK(lhs_value == split_factor);
-          if (tensor_name == 1) {
-            TORCH_CHECK(rhs_value == 9);
-          } else if (tensor_name == 2) {
-            TORCH_CHECK(rhs_value == 7);
-          } else if (tensor_name == 3) {
-            TORCH_CHECK(rhs_value == 4);
-          }
-        }
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = shift(t0, {1, -1});
-  auto t2 = shift(t1, {2, -2});
-  auto t3 = shift(t2, {3, -3});
-  auto t4 = shift(t3, {4, -4});
-  auto ref = t4;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift5ptStencilChain_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};
-
-  // First stencil: 5pt stencil
-  // stencil1 = (tv0 + tv0[+1][0] + tv0[-1][0] + tv0[0][+1] + tv0[0][-1]) / 5
-  std::vector<TensorView*> tv_stencil1_shifts;
-  for (const auto& offset : offsets) {
-    tv_stencil1_shifts.push_back(shift(tv0, offset));
-  }
-
-  auto tv_stencil1 = tv0;
-  for (auto tv : tv_stencil1_shifts) {
-    tv_stencil1 = add(tv_stencil1, tv);
-  }
-
-  tv_stencil1 = div(tv_stencil1, new Double(tv_stencil1_shifts.size() + 1));
-
-  // Second stencil: Same 5pt stencil
-  std::vector<TensorView*> tv_stencil2_shifts;
-  for (const auto& offset : offsets) {
-    tv_stencil2_shifts.push_back(shift(tv_stencil1, offset));
-  }
-
-  auto tv_stencil2 = tv_stencil1;
-  for (auto tv : tv_stencil2_shifts) {
-    tv_stencil2 = add(tv_stencil2, tv);
-  }
-
-  tv_stencil2 = div(tv_stencil2, new Double(tv_stencil2_shifts.size() + 1));
-
-  auto tv_out = tv_stencil2;
-
-  fusion.addOutput(tv_out);
-
-  auto tv0_cache = tv0->cache_after();
-
-  std::vector<int> split_factor({16, 16});
-
-  tv_out->split(-1, split_factor[1]);
-  tv_out->split(0, split_factor[0]);
-  tv_out->reorder({{1, 2}, {2, 1}});
-
-  tv0->computeAt(tv_out, 2);
-
-  // Inline completely all inputs to the first stencil output, except for the
-  // tv0 cache
-  for (auto tv : tv_stencil1_shifts) {
-    tv->computeAt(tv_stencil1, -1);
-  }
-
-  // Inline completely all inputs to the second stencil output, except
-  // for the first stencil output
-  for (auto tv : tv_stencil2_shifts) {
-    tv->computeAt(tv_stencil2, -1);
-  }
-
-  tv_out->axis(1)->parallelize(ParallelType::BIDx);
-  tv_out->axis(0)->parallelize(ParallelType::BIDy);
-
-  auto all_values = DependencyCheck::getAllValsBetween(
-      {fusion.inputs().begin(), fusion.inputs().end()}, fusion.outputs());
-  for (auto tv : ir_utils::filterByType<TensorView>(all_values)) {
-    tv->axis(-1)->parallelize(ParallelType::TIDx);
-    tv->axis(-2)->parallelize(ParallelType::TIDy);
-  }
-
-  tv0_cache->setMemoryType(MemoryType::Shared);
-  tv_stencil1->setMemoryType(MemoryType::Shared);
-
-  // tv0_cache: (split_factor + 4) * (split_factor + 4)
-  // tv_stencil1: (split_factor + 2) * (split_factor + 2)
-  GpuLower gpulw(&fusion);
-  for (const auto& kir_node : gpulw.kernel()->irNodes()) {
-    if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
-      auto tensor_name = alloc->buffer()->name();
-      if (tensor_name == tv0_cache->name() ||
-          tensor_name == tv_stencil1->name()) {
-        TORCH_CHECK(alloc->shape().size() == 2);
-        for (int i = 0; i < 2; ++i) {
-          auto def =
-              dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
-          auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
-          TORCH_CHECK(lhs != nullptr && lhs->isConst());
-          int lhs_value = *lhs->value();
-          auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
-          TORCH_CHECK(rhs != nullptr && rhs->isConst());
-          int rhs_value = *rhs->value();
-          TORCH_CHECK(lhs_value == split_factor[i]);
-          if (tensor_name == tv0_cache->name()) {
-            TORCH_CHECK(rhs_value == 4);
-          } else if (tensor_name == tv_stencil1->name()) {
-            TORCH_CHECK(rhs_value == 2);
-          }
-        }
-      }
-    }
-  }
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto stencil1 = t0;
-  for (const auto& offset : offsets) {
-    stencil1 = stencil1 + shift(t0, offset);
-  }
-  stencil1 = stencil1 / int(offsets.size() + 1);
-  auto stencil2 = stencil1;
-  for (const auto& offset : offsets) {
-    stencil2 = stencil2 + shift(stencil1, offset);
-  }
-  stencil2 = stencil2 / int(offsets.size() + 1);
-  auto ref = stencil2;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-// Shift a reduced tensor
-TEST(NVFuserTest, FusionShiftReduction1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = sum(tv1, {1});
-  auto tv3 = shift(tv2, {1});
-  fusion.addOutput(tv3);
-
-  tv3->split(0, 4);
-  tv0->computeAt(tv3, 1);
-  tv0->computeAt(tv2, -1);
-
-  const int numel_x = 9;
-  const int numel_y = 11;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = sum(t1, {1});
-  auto t3 = shift(t2, {1});
-  auto ref = t3;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-// Parallelized version of FusionShiftReduction1
-TEST(NVFuserTest, FusionShiftReduction2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = sum(tv1, {1});
-  auto tv3 = shift(tv2, {1});
-  fusion.addOutput(tv3);
-
-  tv3->split(0, 4);
-  tv0->computeAt(tv3, 1);
-
-  tv2->split(-1, 32);
-  tv0->computeAt(tv2, -1);
-
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
-  tv2->setMemoryType(MemoryType::Shared);
-
-  const int numel_x = 201;
-  const int numel_y = 301;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = sum(t1, {1});
-  auto t3 = shift(t2, {1});
-  auto ref = t3;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftRfactor1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = sum(tv1, {1});
-  auto tv3 = shift(tv2, {1});
-  fusion.addOutput(tv3);
-
-  tv3->split(0, 4);
-  tv0->computeAt(tv3, 1);
-
-  tv2->split(-1, 32);
-  auto rf = tv2->rFactor({-2});
-  tv0->computeAt(tv2, -1);
-  tv0->computeAt(rf, -1);
-
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
-  tv2->setMemoryType(MemoryType::Shared);
-
-  const int numel_x = 201;
-  const int numel_y = 301;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = sum(t1, {1});
-  auto t3 = shift(t2, {1});
-  auto ref = t3;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftBcast1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  auto tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
-  auto tv2 = broadcast(tv0, {false, true});
-  auto tv3 = shift(tv2, {0, 1});
-  auto tv4 = add(tv3, tv1);
-  fusion.addOutput(tv4);
-
-  tv0->computeAt(tv4, -1);
-  tv1->computeAt(tv4, -1);
-
-  const int numel_x = 9;
-  const int numel_y = 11;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x}, options);
-  at::Tensor t1 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0, t1};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(inputs);
-
-  auto t4 = t0.unsqueeze(-1).expand({numel_x, numel_y}) + t1;
-  auto ref = t4;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftBcast2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  auto tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
-  auto tv2 = broadcast(tv0, {false, true});
-  auto tv3 = shift(tv2, {1, 0});
-  auto tv4 = add(tv3, tv1);
-  fusion.addOutput(tv4);
-
-  tv4->split(0, 4);
-  tv0->computeAt(tv4, 1);
-
-  const int numel_x = 9;
-  const int numel_y = 11;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x}, options);
-  at::Tensor t1 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0, t1};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(inputs);
-
-  auto t2 = t0.unsqueeze(-1).expand({numel_x, numel_y});
-  auto t3 = shift(t2, {1, 0});
-  auto ref = t3 + t1;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-// Combine ShiftBcast1 and ShiftBcast2 with parallelization
-TEST(NVFuserTest, FusionShiftBcast3_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  auto tv1 = makeSymbolicTensor(2);
-  fusion.addInput(tv1);
-  auto tv2 = broadcast(tv0, {false, true});
-  auto tv3 = shift(tv2, {1, 0});
-  auto tv4 = shift(tv2, {0, 1});
-  auto tv5 = shift(tv2, {-1, -1});
-  auto tv6 = add(tv3, tv4);
-  auto tv7 = add(tv6, tv5);
-  auto tv8 = add(tv7, tv1);
-  fusion.addOutput(tv8);
-
-  tv8->split(0, 4);
-  tv8->split(-1, 4);
-  tv0->computeAt(tv8, 1);
-
-  tv8->axis(-1)->parallelize(ParallelType::TIDx);
-  for (auto tv : {tv8, tv7, tv6, tv5, tv4, tv3, tv2}) {
-    tv->axis(1)->parallelize(ParallelType::TIDy);
-  }
-
-  tv2->setMemoryType(MemoryType::Shared);
-
-  const int numel_x = 101;
-  const int numel_y = 201;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x}, options);
-  at::Tensor t1 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0, t1};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(inputs);
-
-  auto t2 = t0.unsqueeze(-1).expand({numel_x, numel_y});
-  auto t3 = shift(t2, {1, 0});
-  auto t4 = t2;
-  auto t5 = shift(t2, {-1, 0});
-  auto ref = t3 + t4 + t5 + t1;
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-// See issue #893
-TEST(NVFuserTest, FusionShiftSyncPlacement1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv0, new Double(2));
-  auto tv3 = add(tv1, tv2);
-  auto tv4 = shift(tv3, {0, 1});
-  fusion.addOutput(tv4);
-
-  tv4->split(1, 8);
-  tv0->computeAt(tv4, 2);
-
-  tv2->computeAt(tv3, -1);
-
-  tv1->setMemoryType(MemoryType::Shared);
-  tv3->setMemoryType(MemoryType::Shared);
-
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv4->axis(-1)->parallelize(ParallelType::TIDx);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-  int numel_y = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x, numel_y}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = t0 + 2;
-  auto t3 = add(t1, t2);
-  auto t4 = shift(t3, {0, 1});
-
-  testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__);
-}
-
-// See issue #893. Top-level placement.
-TEST(NVFuserTest, FusionShiftSyncPlacement2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv0, new Double(2));
-  auto tv3 = add(tv1, tv2);
-  auto tv4 = shift(tv3, {1});
-  fusion.addOutput(tv4);
-
-  tv2->computeAt(tv3, -1);
-
-  tv1->setMemoryType(MemoryType::Shared);
-  tv3->setMemoryType(MemoryType::Shared);
-
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-  tv4->axis(-1)->parallelize(ParallelType::TIDx);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 99;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({numel_x}, options);
-  std::vector<IValue> inputs = {t0};
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = t0 + 2;
-  auto t3 = add(t1, t2);
-  auto t4 = shift(t3, {1});
-
-  testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftSyncPlacement3_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(1);
-  fusion.addInput(tv0);
-  auto tv1 = add(tv0, new Double(1));
-  auto tv2 = add(tv1, new Double(2));
-  auto tv3 = shift(tv2, {1});
-  fusion.addOutput(tv3);
-
-  // This doesn't work. syncthreads is needed between tv1 and tv2, but
-  // both the loop extent of both tv1 and tv2 has halo, so the loop is
-  // not eliminated even though it is parallelized. Moving syncthreads
-  // out of the loop would make it placed before tv1, which would make
-  // it meaningless.
-  // Ideally, an exception should be thrown at this computeAt, but at
-  // this point, the fusion is not yet parallelized, nor memory type
-  // is set, so this computeAt itself is not an error yet.
-  tv1->computeAt(tv2, -1);
-
-  tv1->setMemoryType(MemoryType::Shared);
-  tv2->setMemoryType(MemoryType::Shared);
-
-  tv1->axis(-1)->parallelize(ParallelType::TIDx);
-  tv2->axis(-1)->parallelize(ParallelType::TIDx);
-  tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
-  // The error should be detected when the fusion is lowered.
-  ASSERT_ANY_THROW(fusion.printKernel());
-}
-
-// Based on original CUDA provided by Vishal Mehta.
-// Major differences with the original version:
-// - Boundary processing. We always pad by zero. The original version
-//   is only defined for the interior domain.
-// - The original version uses additional 2 warps to load the halos
-//   along the Y dimension. The other 10 warps are used to load a 32x10
-//   tile, and all warps will do coalesced loads. No such optimization
-//   is done in the fuser version.
-TEST(NVFuserTest, FusionHorizontalDiffusion_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto inp = makeSymbolicTensor(3);
-  fusion.addInput(inp);
-  auto coeff = makeSymbolicTensor(3);
-  fusion.addInput(coeff);
-
-  std::vector<std::vector<int>> offsets{
-      {0, 1, 0}, {0, -1, 0}, {0, 0, 1}, {0, 0, -1}};
-
-  // T2, T3, T4, T5
-  std::vector<TensorView*> inp_neighbors;
-  for (const auto& offset : offsets) {
-    inp_neighbors.push_back(shift(inp, offset));
-  }
-
-  // T8
-  TensorView* sum_of_neighbors = nullptr;
-  for (auto inp_neighbor : inp_neighbors) {
-    if (sum_of_neighbors == nullptr) {
-      sum_of_neighbors = inp_neighbor;
-    } else {
-      sum_of_neighbors = add(sum_of_neighbors, inp_neighbor);
-    }
-  }
-
-  // T9 = T0 * 4
-  // T10 = T9 - T8
-  auto lap = sub(mul(inp, new Double(4)), sum_of_neighbors);
-
-  // T11 = shift(T10)
-  // T12 = T11 - T10
-  auto flx = sub(shift(lap, {0, 0, -1}), lap);
-  // T14 = T13 - T0
-  // T15 = T12 * T14
-  // T16 = T15 > 0
-  // T17 = T16 ? 0 : T12
-  auto flx_cond = gt(mul(flx, sub(shift(inp, {0, 0, -1}), inp)), new Double(0));
-  auto flx0 = where(flx_cond, new Double(0), flx);
-
-  // T18 = shift(T10)
-  // T19 = T18 - T10
-  auto fly = sub(shift(lap, {0, -1, 0}), lap);
-  // T20 = shift(T0)
-  // T21 = T20 - T0
-  // T22 = T19 * T21
-  // T23 = T22 > 0
-  auto fly_cond = gt(mul(fly, sub(shift(inp, {0, -1, 0}), inp)), new Double(0));
-  // T24 = T23 ? 0 : T19
-  auto fly0 = where(fly_cond, new Double(0), fly);
-
-  // T25 = shift(flx0)
-  // T26 = T17 - T25
-  // T27 = shift(fly0)
-  // T28 = T24 - T27
-  // T29 = T26 + T28
-  // T30 = T1 * T29
-  // T31 = T0 - T30
-  auto out =
-      sub(inp,
-          mul(coeff,
-              add(sub(flx0, shift(flx0, {0, 0, 1})),
-                  sub(fly0, shift(fly0, {0, 1, 0})))));
-
-  fusion.addOutput(out);
-
-  /////////////////////////////////
-  // Scheduling
-  /////////////////////////////////
-
-  // Step 1: 2D Tiling
-
-  const int tile_x = 32;
-  const int tile_y = 8;
-
-  out->split(-1, tile_x);
-  out->split(-3, tile_y);
-  out->reorder({{-2, -3}});
-  inp->computeAt(out, -3);
-  coeff->computeAt(out, -3);
-
-  // Step 2: Inlining
-
-  // Inline inputs to lap
-  auto lap_vals = DependencyCheck::getAllValsBetween({inp}, {lap});
-  for (auto val : ir_utils::filterByType<TensorView>(lap_vals)) {
-    if (val != lap && val != inp) {
-      val->computeAt(lap, -1);
-    }
-  }
-
-  // Inline inputs to flx0
-  auto flx0_vals = DependencyCheck::getAllValsBetween({lap, inp}, {flx0});
-  for (auto val : ir_utils::filterByType<TensorView>(flx0_vals)) {
-    if (val != lap && val != flx0 && val != inp) {
-      val->computeAt(flx0, -1);
-    }
-  }
-
-  // Inline inputs to fly0
-  auto flxy_vals = DependencyCheck::getAllValsBetween({lap, inp}, {fly0});
-  for (auto val : ir_utils::filterByType<TensorView>(flxy_vals)) {
-    if (val != lap && val != fly0 && val != inp) {
-      val->computeAt(fly0, -1);
-    }
-  }
-
-  // Inline inputs to out
-  auto out_vals = DependencyCheck::getAllValsBetween({flx0, fly0}, {out});
-  for (auto val : ir_utils::filterByType<TensorView>(out_vals)) {
-    if (val != flx0 && val != fly0 && val != out) {
-      val->computeAt(out, -1);
-    }
-  }
-
-  // Step 3: Parallelization
-
-  // Block parallelization
-  out->axis(0)->parallelize(ParallelType::BIDz);
-  out->axis(1)->parallelize(ParallelType::BIDy);
-  out->axis(2)->parallelize(ParallelType::BIDx);
-
-  // Thread parallelization
-  for (auto tv : {out, flx0, fly0, lap}) {
-    tv->axis(3)->parallelize(ParallelType::TIDy);
-    tv->axis(4)->parallelize(ParallelType::TIDx);
-    if (tv != out) {
-      tv->setMemoryType(MemoryType::Shared);
-    }
-  }
-
-  /////////////////////////////////
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  int numel_x = 101;
-  int numel_y = 99;
-  int numel_z = 10;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor inp_at = at::randn({numel_z, numel_y, numel_x}, options);
-  at::Tensor coeff_at = at::randn({numel_z, numel_y, numel_x}, options);
-  std::vector<IValue> inputs = {inp_at, coeff_at};
-  auto outputs = fe.runFusion(inputs);
-
-  {
-    at::Tensor zeros = at::zeros({numel_z, numel_y, numel_x}, options);
-    auto lap = inp_at * 4 -
-        (shift(inp_at, {0, 1, 0}) + shift(inp_at, {0, -1, 0}) +
-         shift(inp_at, {0, 0, 1}) + shift(inp_at, {0, 0, -1}));
-    auto flx = shift(lap, {0, 0, -1}) - lap;
-    auto flx_cond = (flx * (shift(inp_at, {0, 0, -1}) - inp_at)) > 0;
-    auto flx0 = at::where(flx_cond, zeros, flx);
-    auto fly = shift(lap, {0, -1, 0}) - lap;
-    auto fly_cond = (fly * (shift(inp_at, {0, -1, 0}) - inp_at)) > 0;
-    auto fly0 = at::where(fly_cond, zeros, fly);
-
-    auto ref = inp_at -
-        coeff_at *
-            ((flx0 - shift(flx0, {0, 0, 1})) + (fly0 - shift(fly0, {0, 1, 0})));
-
-    testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-  }
-}
-
-// 3x3 max pooling
-TEST(NVFuserTest, FusionMaxPooling_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Format: CHW
-  auto inp = makeSymbolicTensor(3);
-  fusion.addInput(inp);
-
-  // 3x3 pooling of the HW spatial domain
-  std::vector<std::vector<int>> offsets;
-  for (int i = -1; i <= 1; ++i) {
-    for (int j = -1; j <= 1; ++j) {
-      if (i == 0 && j == 0) {
-        continue;
-      }
-      offsets.push_back({i, j});
-    }
-  }
-
-  std::vector<TensorView*> inp_tile({inp});
-  for (auto offset : offsets) {
-    offset.insert(offset.begin(), 0);
-    inp_tile.push_back(shift(inp, offset));
-  }
-
-  TensorView* max_tensor = nullptr;
-  for (auto tv : inp_tile) {
-    if (max_tensor == nullptr) {
-      max_tensor = tv;
-    } else {
-      max_tensor = binaryOp(BinaryOpType::Max, max_tensor, tv);
-    }
-  }
-
-  fusion.addOutput(max_tensor);
-
-  ////////////////////////////////////
-
-  // Cache the input and weight tensors
-  auto inp_cache = inp->cache_after();
-
-  // Tiling the spatial domain
-  const int tile_x = 32;
-  const int tile_y = 8;
-
-  max_tensor->split(-2, tile_y);
-  max_tensor->axis(-2)->parallelize(ParallelType::TIDy);
-  max_tensor->split(-1, tile_x);
-  max_tensor->axis(-1)->parallelize(ParallelType::TIDx);
-  max_tensor->reorder({{-3, -2}});
-
-  inp_cache->computeAt(max_tensor, 3);
-  inp_cache->axis(-2)->parallelize(ParallelType::TIDy);
-  inp_cache->axis(-1)->parallelize(ParallelType::TIDx);
-  inp_cache->setMemoryType(MemoryType::Shared);
-
-  auto max_tensor_dep =
-      DependencyCheck::getAllValsBetween({inp_cache}, {max_tensor});
-  for (auto tv : ir_utils::filterByType<TensorView>(max_tensor_dep)) {
-    if (tv == inp_cache || tv == max_tensor) {
-      continue;
-    }
-    tv->computeAt(max_tensor, -1);
-  }
-
-  max_tensor->axis(0)->parallelize(ParallelType::BIDx);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  const int hw = 50;
-  const int num_channels = 20;
-  const int pooling_window = 3;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor aten_inp = at::randn({num_channels, hw, hw}, options);
-  // shift always pads by zero, so if all surrounding values are
-  // negative, max pooling would pick a padded value, which isn't the
-  // correct behavior. We need to be able to choose the value of
-  // padding. In this case, padding by the minimum value would not
-  // have this problem. For now, avoid the problem by making sure all
-  // values are not negative.
-  aten_inp = at::abs(aten_inp);
-  std::vector<IValue> inputs = {aten_inp};
-
-  auto outputs = fe.runFusion(inputs);
-
-  auto ref = at::max_pool2d(
-      aten_inp, {pooling_window, pooling_window}, {1, 1}, {1, 1});
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionGatherPadding1_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  const std::vector<int> window_shape = {1, 3};
-  const std::vector<std::vector<int>> padding_width = {{0, 0}, {1, 1}};
-
-  auto tv1 = gather(tv0, window_shape, padding_width);
-
-  fusion.addOutput(tv1);
-
-  const int s1 = 11;
-  const int s2 = 13;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({s1, s2}, options);
-
-  auto ref = gather(t0, window_shape, padding_width);
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion({t0});
-
-  TORCH_CHECK(ref.equal(outputs[0]));
-}
-
-TEST(NVFuserTest, FusionGatherPadding2_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  const std::vector<int> window_shape = {1, 3};
-  const std::vector<std::vector<int>> padding_width = {{0, 0}, {1, 1}};
-
-  auto tv0 = makeSymbolicTensor(2);
-  fusion.addInput(tv0);
-
-  auto tv1 = add(tv0, new Double(1));
-
-  auto tv2 = gather(tv1, window_shape, padding_width);
-
-  auto tv3 = sum(tv2, {-1});
-
-  fusion.addOutput(tv3);
-
-  tv3->split(1, 32);
-  tv0->computeAt(tv3, 2);
-  tv2->computeAt(tv3, -1);
-
-  tv3->axis(0)->parallelize(ParallelType::BIDy);
-  tv3->axis(1)->parallelize(ParallelType::BIDx);
-  tv3->axis(2)->parallelize(ParallelType::TIDx);
-  tv1->axis(2)->parallelize(ParallelType::TIDx);
-
-  tv1->setMemoryType(MemoryType::Shared);
-
-  const int s1 = 99;
-  const int s2 = 101;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::Tensor t0 = at::randn({s1, s2}, options);
-  std::vector<IValue> inputs = {t0};
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-  auto outputs = fe.runFusion(inputs);
-
-  auto t1 = t0 + 1;
-  auto t2 = gather(t1, window_shape, padding_width);
-  auto ref = sum(t2, {-1});
-
-  testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionConv2DStatic_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Input: [C, H, W]
-  auto inp = makeSymbolicTensor(3);
-  fusion.addInput(inp);
-
-  // Weights: [K, C, 3, 3]
-  auto w = makeSymbolicTensor(4);
-  fusion.addInput(w);
-
-  // Gather a neighbor tile of [3, 3] with padding size of 1 for each
-  // side of the spatial dimensions
-  auto inp_tile = gather(inp, {1, 3, 3}, {{0, 0}, {1, 1}, {1, 1}});
-  // inp_tile: [C, H, W, 1, 3, 3]
-
-  auto inp_bc =
-      broadcast(inp_tile, {true, false, false, false, false, false, false});
-  auto w_bc = broadcast(w, {false, false, true, true, true, false, false});
-
-  auto inp_times_w = mul(inp_bc, w_bc);
-
-  // Reduce the channel and neighbor tile dimensions
-  auto out = sum(inp_times_w, {1, 4, 5, 6});
-
-  fusion.addOutput(out);
-
-  ////////////////////////////////////
-
-  // Cache the input and weight tensors
-  auto inp_cache = inp->cache_after();
-
-  // Blocking the spatial dimensions
-  const int block_w = 16;
-  const int block_h = 4;
-  // Blocking the channel dimension
-  const int block_c = 8;
-
-  out->split(2, block_h);
-  out->split(4, block_w);
-  out->reorder({{3, 4}});
-  // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3]
-
-  out->split(1, block_c);
-  // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3]
-
-  auto out_rf = out->rFactor({1, -3, -2, -1});
-  // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3]
-  // out_rf: [K, Ci, Ho, Wo, Hi, Wi]
-
-  // Create a [block_x, block_y] tile on smem
-  inp_cache->computeAt(out, 4);
-  // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi]
-  inp_cache->setMemoryType(MemoryType::Shared);
-
-  // Move Ci forward
-  out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}});
-  inp_cache->computeAt(out_rf, 5);
-
-  inp_tile->computeAt(out_rf, -1);
-  w->computeAt(out_rf, -1);
-
-  out->axis(0)->parallelize(ParallelType::BIDx);
-  out->axis(1)->parallelize(ParallelType::TIDz);
-  out->axis(4)->parallelize(ParallelType::TIDy);
-  out->axis(5)->parallelize(ParallelType::TIDx);
-
-  scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf});
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  const int dim_h = 99;
-  const int dim_w = 101;
-  const int dim_c = 10;
-  const int dim_f = 20;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options);
-  at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options);
-  std::vector<IValue> inputs = {at_inp, at_w};
-
-  auto cg_outputs = fe.runFusion(inputs);
-
-  at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis
-  auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1);
-  at_out = at_out.squeeze(0); // drop the N axis
-
-  testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__);
-}
-
-// Mostly the same as the static conv test, but the shape of the weights,
-// 3x3 in this case, is given dynamically
-TEST(NVFuserTest, FusionConv2DDynamic_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Input: [C, H, W]
-  auto inp = makeSymbolicTensor(3);
-  fusion.addInput(inp);
-
-  // Weights: [K, C, S, T]
-  auto w = makeSymbolicTensor(4);
-  fusion.addInput(w);
-
-  auto w_h = new Int();
-  fusion.addInput(w_h);
-  auto w_w = new Int();
-  fusion.addInput(w_w);
-
-  auto pad_h = new Int();
-  fusion.addInput(pad_h);
-  auto pad_w = new Int();
-  fusion.addInput(pad_w);
-
-  // Gather a neighbor tile of [w_dim_h, w_dim_w] with padding
-  auto inp_tile = gather(
-      inp,
-      {new Int(1), w_h, w_w},
-      {{new Int(0), new Int(0)}, {pad_h, pad_h}, {pad_w, pad_w}});
-  // inp_tile: [C, 1, H - w_h + 1, W - w_w + 1, w_h, w_w]
-
-  auto inp_bc =
-      broadcast(inp_tile, {true, false, false, false, false, false, false});
-  auto w_bc = broadcast(w, {false, false, true, true, true, false, false});
-
-  auto inp_times_w = mul(inp_bc, w_bc);
-
-  // Reduce the channel and neighbor tile dimensions
-  auto out = sum(inp_times_w, {1, 4, 5, 6});
-
-  fusion.addOutput(out);
-
-  ////////////////////////////////////
-  // Cache the input and weight tensors
-  auto inp_cache = inp->cache_after();
-
-  // Blocking the spatial dimensions
-  const int block_w = 16;
-  const int block_h = 4;
-  // Blocking the channel dimension
-  const int block_c = 8;
-
-  out->split(2, block_h);
-  out->split(4, block_w);
-  out->reorder({{3, 4}});
-  // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3]
-
-  out->split(1, block_c);
-  // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3]
-
-  auto out_rf = out->rFactor({1, -3, -2, -1});
-  // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3]
-  // out_rf: [K, Ci, Ho, Wo, Hi, Wi]
-
-  // Create a [block_x, block_y] tile on smem
-  inp_cache->computeAt(out, 4);
-  // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi]
-  inp_cache->setMemoryType(MemoryType::Shared);
-
-  // Move Ci forward
-  out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}});
-  inp_cache->computeAt(out_rf, 5);
-
-  inp_tile->computeAt(out_rf, -1);
-  w->computeAt(out_rf, -1);
-
-  out->axis(0)->parallelize(ParallelType::BIDx);
-  out->axis(1)->parallelize(ParallelType::TIDz);
-  out->axis(4)->parallelize(ParallelType::TIDy);
-  out->axis(5)->parallelize(ParallelType::TIDx);
-
-  scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf});
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  const int dim_h = 99;
-  const int dim_w = 101;
-  const int dim_c = 10;
-  const int dim_f = 20;
-  const int dim_w_h = 3;
-  const int dim_w_w = 3;
-  const int dim_pad_h = (dim_w_h - 1) / 2;
-  const int dim_pad_w = (dim_w_w - 1) / 2;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options);
-  at::Tensor at_w = at::randn({dim_f, dim_c, dim_w_h, dim_w_w}, options);
-  std::vector<IValue> inputs = {
-      at_inp, at_w, dim_w_h, dim_w_w, dim_pad_h, dim_pad_w};
-
-  auto cg_outputs = fe.runFusion(inputs);
-
-  at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis
-  auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1);
-  at_out = at_out.squeeze(0); // drop the N axis
-
-  testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__);
-}
-
-// 5x5 followed by 3x3
-TEST(NVFuserTest, FusionConv2DDynamicChain_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Input: [K1, H, W]
-  auto inp = makeSymbolicTensor(3);
-  fusion.addInput(inp);
-
-  // Weights: [K2, K1, S1, T1]
-  auto w1 = makeSymbolicTensor(4);
-  fusion.addInput(w1);
-
-  // Weights: [K3, K2, S2, T2]
-  auto w2 = makeSymbolicTensor(4);
-  fusion.addInput(w2);
-
-  auto w1_h = new Int();
-  fusion.addInput(w1_h);
-  auto w1_w = new Int();
-  fusion.addInput(w1_w);
-
-  auto w2_h = new Int();
-  fusion.addInput(w2_h);
-  auto w2_w = new Int();
-  fusion.addInput(w2_w);
-
-  auto pad_h1 = new Int();
-  fusion.addInput(pad_h1);
-  auto pad_w1 = new Int();
-  fusion.addInput(pad_w1);
-
-  auto pad_h2 = new Int();
-  fusion.addInput(pad_h2);
-  auto pad_w2 = new Int();
-  fusion.addInput(pad_w2);
-
-  // Gather a neighbor tile of [w1_h, w1_w] with padding
-  auto inp_tile = gather(
-      inp,
-      {new Int(1), w1_h, w1_w},
-      {{new Int(0), new Int(0)}, {pad_h1, pad_h1}, {pad_w1, pad_w1}});
-  // inp_tile: [C, 1, H - w1_h + 1, W - w1_w + 1, w1_h, w1_w]
-
-  auto inp_bc =
-      broadcast(inp_tile, {true, false, false, false, false, false, false});
-  auto w1_bc = broadcast(w1, {false, false, true, true, true, false, false});
-
-  auto inp_times_w1 = mul(inp_bc, w1_bc);
-
-  // Reduce the channel and neighbor tile dimensions
-  auto out1 = sum(inp_times_w1, {1, 4, 5, 6});
-
-  // Second conv
-  auto out1_tile = gather(
-      out1,
-      {new Int(1), w2_h, w2_w},
-      {{new Int(0), new Int(0)}, {pad_h2, pad_h2}, {pad_w2, pad_w2}});
-
-  auto out1_bc =
-      broadcast(out1_tile, {true, false, false, false, false, false, false});
-  auto w2_bc = broadcast(w2, {false, false, true, true, true, false, false});
-
-  auto out1_times_w2 = mul(out1_bc, w2_bc);
-
-  auto out2 = sum(out1_times_w2, {1, 4, 5, 6});
-
-  fusion.addOutput(out2);
-
-  ////////////////////////////////////
-  // Cache the input and weight tensors
-  auto inp_cache = inp->cache_after();
-
-  // Blocking the spatial dimensions
-  const int block_w = 16;
-  const int block_h = 4;
-
-  out2->split(2, block_h);
-  out2->split(4, block_w);
-  out2->reorder({{3, 4}});
-  // out2: [K3, K2, Ho, Wo, Hi, Wi, 1, 3, 3]
-
-  // Create a [block_x, block_y] tile on smem
-  inp_cache->computeAt(out2, 4);
-  // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi]
-  inp_cache->setMemoryType(MemoryType::Shared);
-
-  // Move Ci forward
-  out1->reorder({{5, 3}, {3, 4}, {4, 5}});
-  out1->setMemoryType(MemoryType::Shared);
-
-  inp_cache->computeAt(out1, 4);
-
-  inp_tile->computeAt(out1, -1);
-  w1->computeAt(out1, -1);
-
-  out1_tile->computeAt(out2, -1);
-  w2->computeAt(out2, -1);
-
-  out2->axis(0)->parallelize(ParallelType::BIDx);
-  out2->axis(4)->parallelize(ParallelType::TIDy);
-  out2->axis(5)->parallelize(ParallelType::TIDx);
-
-  scheduler_utils::parallelizeAllLike(out2, {inp_cache, out1});
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  const int dim_h = 99;
-  const int dim_w = 101;
-  const int dim_k1 = 3;
-  const int dim_k2 = 5;
-  const int dim_k3 = 7;
-  const int dim_w1_h = 5;
-  const int dim_w1_w = 5;
-  const int dim_pad1_h = (dim_w1_h - 1) / 2;
-  const int dim_pad1_w = (dim_w1_w - 1) / 2;
-  const int dim_w2_h = 3;
-  const int dim_w2_w = 3;
-  const int dim_pad2_h = (dim_w2_h - 1) / 2;
-  const int dim_pad2_w = (dim_w2_w - 1) / 2;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor at_inp = at::randn({dim_k1, dim_h, dim_w}, options);
-  at::Tensor at_w1 = at::randn({dim_k2, dim_k1, dim_w1_h, dim_w1_w}, options);
-  at::Tensor at_w2 = at::randn({dim_k3, dim_k2, dim_w2_h, dim_w2_w}, options);
-  std::vector<IValue> inputs = {
-      at_inp,
-      at_w1,
-      at_w2,
-      dim_w1_h,
-      dim_w1_w,
-      dim_w2_h,
-      dim_w2_w,
-      dim_pad1_h,
-      dim_pad1_w,
-      dim_pad2_h,
-      dim_pad2_w};
-
-  auto cg_outputs = fe.runFusion(inputs);
-
-  at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis
-  auto at_out1 = at::conv2d(at_inp, at_w1, {}, 1, 2);
-  auto at_out2 = at::conv2d(at_out1, at_w2, {}, 1, 1);
-  at_out2 = at_out2.squeeze(0); // drop the N axis
-
-  testValidate(&fusion, cg_outputs, inputs, {at_out2}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) {
-  Fusion fusion;
-  FusionGuard fg(&fusion);
-
-  // Input: [C, H, W]
-  auto inp = makeSymbolicTensor(3);
-  fusion.addInput(inp);
-
-  // Weights: [K, C, 2, 2]
-  auto w = makeSymbolicTensor(4);
-  fusion.addInput(w);
-
-  // Gather a neighbor tile of [2, 2] with padding size of 1 only for
-  // the right side of the spatial dimensions. The left padding is
-  // zero so that the output axis stays the same.
-  auto inp_tile = gather(inp, {1, 2, 2}, {{0, 0}, {0, 1}, {0, 1}});
-  // inp_tile: [C, H, W, 1, 2, 2]
-
-  auto inp_bc =
-      broadcast(inp_tile, {true, false, false, false, false, false, false});
-  auto w_bc = broadcast(w, {false, false, true, true, true, false, false});
-
-  auto inp_times_w = mul(inp_bc, w_bc);
-
-  // Reduce the channel and neighbor tile dimensions
-  auto out = sum(inp_times_w, {1, 4, 5, 6});
-
-  fusion.addOutput(out);
-
-  ////////////////////////////////////
-
-  // Cache the input and weight tensors
-  auto inp_cache = inp->cache_after();
-
-  // Blocking the spatial dimensions
-  const int block_w = 16;
-  const int block_h = 4;
-  // Blocking the channel dimension
-  const int block_c = 8;
-
-  out->split(2, block_h);
-  out->split(4, block_w);
-  out->reorder({{3, 4}});
-  // out: [K, C, Ho, Wo, Hi, Wi, 1, 2, 2]
-
-  out->split(1, block_c);
-  // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 2, 2]
-
-  auto out_rf = out->rFactor({1, -3, -2, -1});
-  // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 2, 2]
-  // out_rf: [K, Ci, Ho, Wo, Hi, Wi]
-
-  // Create a [block_x, block_y] tile on smem
-  inp_cache->computeAt(out, 4);
-  // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi]
-  inp_cache->setMemoryType(MemoryType::Shared);
-
-  // Move Ci forward
-  out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}});
-  inp_cache->computeAt(out_rf, 5);
-
-  inp_tile->computeAt(out_rf, -1);
-  w->computeAt(out_rf, -1);
-
-  out->axis(0)->parallelize(ParallelType::BIDx);
-  out->axis(1)->parallelize(ParallelType::TIDz);
-  out->axis(4)->parallelize(ParallelType::TIDy);
-  out->axis(5)->parallelize(ParallelType::TIDx);
-
-  scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf});
-
-  FusionExecutor fe;
-  fe.compileFusion(&fusion);
-
-  const int dim_h = 99;
-  const int dim_w = 101;
-  const int dim_c = 10;
-  const int dim_f = 20;
-
-  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-  at::manual_seed(0);
-  at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options);
-  at::Tensor at_w = at::randn({dim_f, dim_c, 2, 2}, options);
-  std::vector<IValue> inputs = {at_inp, at_w};
-
-  auto cg_outputs = fe.runFusion(inputs);
-
-  at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis
-  auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1);
-  at_out = at_out.squeeze(0); // drop the N axis
-  // The shape of the spatial domain is (dim_h+1)x(dim_w+1), whereas
-  // the fuser output has dim_h*dim_w. Drop the first elements to make
-  // it match with the fuser output.
-  std::vector<at::indexing::TensorIndex> indices{
-      at::indexing::Slice(0, at::indexing::None),
-      at::indexing::Slice(1, at::indexing::None),
-      at::indexing::Slice(1, at::indexing::None)};
-  ;
-  at_out = at_out.index(indices);
-
-  testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__);
-}
-
-} // namespace jit
-} // namespace torch
-#endif // #if defined(USE_CUDA)
diff --git a/test/cpp/jit/test_gpu_validator.h b/test/cpp/jit/test_gpu_validator.h
deleted file mode 100644 (file)
index dee05ea..0000000
+++ /dev/null
@@ -1,387 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-
-#include <unordered_map>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-struct ValidationConstants {
-  // Tolerances generated from randn + add + sum fusion
-  // compared against double precision
-  std::array<std::array<double, 2>, 20> sum_tolerances_float = {
-      {{4, 1.51992e-06},      {8, 2.23704e-06},      {16, 2.95788e-06},
-       {32, 4.4778e-06},      {64, 6.75395e-06},     {128, 8.57934e-06},
-       {256, 1.30594e-05},    {512, 2.19122e-05},    {1024, 3.3451e-05},
-       {2048, 5.78476e-05},   {4096, 0.000108292},   {8192, 0.00012207},
-       {16384, 0.000136882},  {32768, 0.000248561},  {65536, 0.000407594},
-       {131072, 0.000500901}, {262144, 0.000923019}, {524288, 0.00156909},
-       {1048576, 0.00223107}, {2097152, 0.00343043}}};
-
-  // Tolerances generated from randn + add + sum fusion
-  // compared against double precision
-  std::array<std::array<double, 2>, 20> sum_tolerances_half = {
-      {{4, 0.00390625},    {8, 0.0078125},    {16, 0.0078125},
-       {32, 0.0155334},    {64, 0.0156269},   {128, 0.0312042},
-       {256, 0.0312548},   {512, 0.0619979},  {1024, 0.0625103},
-       {2048, 0.124686},   {4096, 0.12501},   {8192, 0.24945},
-       {16384, 0.250049},  {32768, 0.498946}, {65536, 0.500071},
-       {131072, 0.985087}, {262144, 1.00006}, {524288, 1.99234},
-       {1048576, 2.00032}, {2097152, 3.99073}}};
-
-  double base_half_abs_tol = -1;
-  double base_half_rel_tol = -1;
-  double base_float_abs_tol = -1;
-  double base_float_rel_tol = -1;
-};
-
-namespace {
-
-// Returns abs and relative values to use for validation
-std::pair<double, double> getTolerance(
-    DataType dtype,
-    int64_t reduction_size,
-    const ValidationConstants& tolerances) {
-  switch (dtype) {
-    case DataType::Float:
-    // TODO: Pull new tolerances for Double, for now we will just use float
-    // tolerances as it should be no worse.
-    case DataType::Double: {
-      const auto& sum_tolerance_entry = tolerances.sum_tolerances_float;
-      const auto& base_abs = tolerances.base_float_abs_tol;
-      const auto& base_rel = tolerances.base_float_rel_tol;
-
-      if (reduction_size <= 1) {
-        // No reduction case
-        if (base_abs == -1 || base_rel == -1) {
-          return {sum_tolerance_entry[0][1], sum_tolerance_entry[1][1]};
-        } else {
-          return {base_abs, base_rel};
-        }
-      } else {
-        // Reduction case
-        size_t entry = 0;
-        while (sum_tolerance_entry[entry][0] < reduction_size &&
-               entry < sum_tolerance_entry.size()) {
-          entry++;
-        }
-        double abs_tol = 0.0;
-        if (entry + 1 < sum_tolerance_entry.size()) {
-          // Grab the next entry up so we have some margin
-          abs_tol = sum_tolerance_entry[entry + 1][1];
-        } else {
-          // If we hit the end of the list, return twice the max error we
-          // measured
-          abs_tol = sum_tolerance_entry[sum_tolerance_entry.size() - 1][1] * 2.;
-        }
-        // Relative tol we're going to set to 1% of abs tol just for
-        // a small margin of rel error.
-        return {abs_tol, abs_tol * 0.01};
-      }
-    }
-    case DataType::Half: {
-      // Copied from float case
-      const auto& sum_tolerance_entry = tolerances.sum_tolerances_half;
-      const auto& base_abs = tolerances.base_half_abs_tol;
-      const auto& base_rel = tolerances.base_half_rel_tol;
-
-      if (reduction_size <= 1) {
-        // No reduction case
-        if (base_abs == -1 || base_rel == -1) {
-          return {sum_tolerance_entry[0][1], sum_tolerance_entry[1][1]};
-        } else {
-          return {base_abs, base_rel};
-        }
-      } else {
-        // Reduction case
-        size_t entry = 0;
-        while (sum_tolerance_entry[entry][0] < reduction_size &&
-               entry < sum_tolerance_entry.size()) {
-          entry++;
-        }
-        double abs_tol = 0.0;
-        if (entry + 1 < sum_tolerance_entry.size()) {
-          // Grab the next entry up so we have some margin
-          abs_tol = sum_tolerance_entry[entry + 1][1];
-        } else {
-          // If we hit the end of the list, return twice the max error we
-          // measured
-          abs_tol = sum_tolerance_entry[sum_tolerance_entry.size() - 1][1] * 2.;
-        }
-        // Relative tol we're going to set to 1% of abs tol just for
-        // a small margin of rel error.
-        return {abs_tol, abs_tol * 0.01};
-      }
-    }
-    case DataType::Int:
-      return {0.0, 0.0};
-    case DataType::Int32:
-      return {0.0, 0.0};
-    case DataType::Bool:
-      return {0.0, 0.0};
-    default:
-      TORCH_INTERNAL_ASSERT(
-          false, "Do not have tolerance computation for type ", dtype, ".");
-  }
-}
-
-class ReductionSizeMapper : private IterVisitor {
- public:
-  //! Runs through the fusion and determines how many reductions were performed
-  //! to compute each tensorview.
-  static std::unordered_map<TensorView*, int64_t> computeReductionSizes(
-      Fusion* fusion,
-      ExpressionEvaluator& expr_eval) {
-    ReductionSizeMapper mapper(fusion, expr_eval);
-    return mapper.reduction_map;
-  }
-
- private:
-  ReductionSizeMapper(Fusion* fusion, ExpressionEvaluator& expr_eval)
-      : expr_eval_(expr_eval) {
-    // Initialize input values
-    for (auto inp : fusion->inputs()) {
-      if (inp->isA<TensorView>()) {
-        auto tv = inp->as<TensorView>();
-        // Shouldn't have any reductions, but run it through analysis anyways.
-        reduction_map[tv] = getReductionSize(tv);
-      }
-    }
-
-    IterVisitor::traverse(fusion);
-
-    // catch up with dangling outputs;
-    for (auto out : fusion->outputs()) {
-      if (out->isA<TensorView>()) {
-        auto tv = out->as<TensorView>();
-        // possible that we have a dangling output that's not generated by any
-        // expression. e.g. 0 workspace or null tensor
-        if (reduction_map.count(tv) == 0) {
-          // Shouldn't have any reductions, but run it through analysis anyways.
-          reduction_map[tv] = getReductionSize(tv);
-        }
-      }
-    }
-  }
-
-  int64_t getReductionSize(const TensorView* tv) {
-    int64_t reduction_elements = 1;
-    for (auto id : tv->getMaybeRFactorDomain()) {
-      if (id->isReduction()) {
-        auto inferred_extent = expr_eval_.evaluate(id->extent());
-        TORCH_INTERNAL_ASSERT(
-            inferred_extent.has_value(),
-            "Couldn't figure out what the dimensions of a tensorview is in evaluation for validation. ",
-            id,
-            " in ",
-            tv);
-        reduction_elements = reduction_elements * inferred_extent.value();
-      }
-    }
-    return reduction_elements;
-  }
-
-  void handle(Expr* expr) override {
-    if (!ir_utils::isTVOp(expr)) {
-      return;
-    }
-
-    int64_t inp_reduction_elements = 1;
-    for (auto inp : expr->inputs()) {
-      if (inp->isA<TensorView>()) {
-        if (auto tv = inp->as<TensorView>()) {
-          inp_reduction_elements =
-              std::max(inp_reduction_elements, reduction_map.at(tv));
-        }
-      }
-    }
-
-    for (auto out : expr->outputs()) {
-      if (out->isA<TensorView>()) {
-        auto tv = out->as<TensorView>();
-        reduction_map[tv] = getReductionSize(tv) * inp_reduction_elements;
-      }
-    }
-  }
-
- private:
-  using IterVisitor::handle;
-
-  std::unordered_map<TensorView*, int64_t> reduction_map;
-  ExpressionEvaluator& expr_eval_;
-};
-
-ExpressionEvaluator bindInputsAndLaunchParams(
-    Fusion* fusion,
-    const at::ArrayRef<IValue>& aten_inputs,
-    const LaunchParams& launch_constraints) {
-  auto expr_eval = executor_utils::bindFusionInputs(aten_inputs, fusion);
-  for (auto val : fusion->vals()) {
-    if (!val->isA<TensorView>()) {
-      continue;
-    }
-
-    // Roughly taken from executor.cpp/computeLaunchParams
-    auto tv = val->as<TensorView>();
-    for (auto id : tv->domain()->domain()) {
-      if (!(id->isThread() && id->extent()->definition() == nullptr)) {
-        continue;
-      }
-
-      if (id->isBroadcast()) {
-        continue;
-      }
-
-      auto extent = id->extent();
-      auto inferred_extent = expr_eval.evaluate(extent);
-      auto p_type = id->getParallelType();
-
-      if (inferred_extent.has_value()) {
-        // This value could have been inferred, make sure it was set right.
-        TORCH_CHECK(
-            inferred_extent.value() == launch_constraints.getDim(p_type) ||
-                launch_constraints.getRawVal(p_type) == -1,
-            "inferred that ",
-            p_type,
-            " should be set to ",
-            inferred_extent.value(),
-            " but launch constraints specified ",
-            launch_constraints.getRawVal(p_type));
-      } else {
-        // Bind the launch constraint into our evaluation context
-        if (launch_constraints.hasDim(id->getParallelType())) {
-          expr_eval.bind(extent, launch_constraints.getDim(p_type));
-        }
-      }
-    }
-  }
-  return expr_eval;
-}
-
-} // namespace
-
-// Validation will look through the fusion and figure out how many elements were
-// reduced to create each output. It will then compute a tolernace to use for
-// allclose based on experimental results. The experimental results were based
-// on adding two tensors then summing them. This of course has an assumption
-// that we're always summing values between -2 and 2. If we start summing values
-// larger than that this approach might not hold.
-inline void testValidate(
-    Fusion* fusion,
-    const std::vector<at::Tensor>& fusion_outputs,
-    const at::ArrayRef<IValue>& aten_inputs,
-    const std::vector<at::Tensor>& aten_outputs,
-    int line_number,
-    const char* file_name,
-    std::string err_msg = "",
-    const LaunchParams& lparams = LaunchParams(),
-    const ValidationConstants& tolerances = ValidationConstants()) {
-  FusionGuard fg(fusion);
-
-  auto expr_eval = bindInputsAndLaunchParams(fusion, aten_inputs, lparams);
-
-  auto reduction_sizes =
-      ReductionSizeMapper::computeReductionSizes(fusion, expr_eval);
-
-  TORCH_INTERNAL_ASSERT(
-      fusion_outputs.size() == aten_outputs.size() &&
-          aten_outputs.size() == fusion->outputs().size(),
-      "Number of outputs don't match.");
-
-  TORCH_INTERNAL_ASSERT(
-      fusion->inputs().size() == aten_inputs.size(),
-      "Number of inputs don't match.");
-
-  for (size_t i = 0; i < fusion->inputs().size(); i++) {
-    if (fusion->inputs()[i]->isA<TensorView>()) {
-      TORCH_INTERNAL_ASSERT(
-          aten_inputs[i].isTensor(), "Mismatch of tensor inputs.");
-
-      auto fusion_input_tv = fusion->inputs()[i]->as<TensorView>();
-      auto at_tensor = aten_inputs[i].toTensor();
-
-      TORCH_INTERNAL_ASSERT(
-          at_tensor.dim() ==
-              TensorDomain::noReductions(
-                  fusion_input_tv->getMaybeRFactorDomain())
-                  .size(),
-          "Dimensionality mismatch in inputs.");
-    }
-  }
-
-  for (size_t i = 0; i < fusion->outputs().size(); i++) {
-    TORCH_INTERNAL_ASSERT(
-        fusion->outputs()[i]->isA<TensorView>(), "Mismatch of tensor outputs.");
-
-    auto fusion_output_tensor = fusion_outputs[i];
-    auto fusion_output_tv = fusion->outputs()[i]->as<TensorView>();
-    auto aten_output_tensor = aten_outputs[i];
-
-    TORCH_INTERNAL_ASSERT(
-        reduction_sizes.count(fusion_output_tv),
-        "Missed reduction size count on fusion output at index: ",
-        i);
-
-    int64_t reduction_size = reduction_sizes.at(fusion_output_tv);
-
-    TORCH_INTERNAL_ASSERT(
-        aten_output_tensor.dim() == fusion_output_tensor.dim() &&
-            fusion_outputs[i].dim() ==
-                TensorDomain::noReductions(
-                    fusion_output_tv->getMaybeRFactorDomain())
-                    .size(),
-        "Dimensionality mismatch in inputs.");
-
-    auto tolerance_values = getTolerance(
-        fusion_output_tv->getDataType().value(), reduction_size, tolerances);
-
-    if (aten_output_tensor.is_floating_point()) {
-      TORCH_INTERNAL_ASSERT(
-          aten_output_tensor.allclose(
-              fusion_output_tensor.to(aten_output_tensor.dtype()),
-              tolerance_values.second,
-              tolerance_values.first),
-          "\n",
-          err_msg,
-          "\nValidation error in output ",
-          i,
-          " on line ",
-          line_number,
-          " in file ",
-          file_name,
-          ".\n  Detected abs error of: ",
-          aten_output_tensor.sub(fusion_output_tensor)
-              .abs()
-              .max()
-              .item()
-              .to<double>(),
-          "\n    absolute tolerance was set to ",
-          tolerance_values.first,
-          "\n    and relative tolerance set to ",
-          tolerance_values.second);
-    } else {
-      TORCH_INTERNAL_ASSERT(
-          aten_output_tensor.equal(
-              fusion_output_tensor.to(aten_output_tensor.dtype())),
-          "\n",
-          err_msg,
-          ".\n  Validation error in output ",
-          i,
-          " on line ",
-          line_number,
-          " in file ",
-          file_name,
-          ".\n Values are not equal and are not a floating type.");
-    }
-  }
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
index adb59be..41bc7d5 100644 (file)
@@ -1,14 +1,10 @@
 import unittest
 import os
-import random
 
 import torch
-from torch.nn import functional
 
-from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR  # TEST_WITH_ROCM
-from torch.testing._internal.common_cuda import TEST_MULTIGPU
+from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR
 from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed
-from torch.testing import FileCheck
 
 from test_jit import JitTestCase, RUN_CUDA
 
@@ -16,17 +12,10 @@ from jit.test_fuser_common import TestFuserCommon  # noqa: F401
 
 import itertools
 import numpy as np
-import math
 
-from typing import List
-
-CUDA_MAJOR, CUDA_MINOR = (int(x) for x in torch.version.cuda.split('.'))
-
-os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1'
-os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1'
-os.environ['PYTORCH_NVFUSER_DISABLE_FASTMATH'] = '1'
-os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0'
-os.environ['PYTORCH_NVFUSER_DISABLE_RNG_UNROLL'] = '1'
+os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1'
+os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1'
+os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0'
 
 if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
     torch._C._jit_set_texpr_fuser_enabled(False)
@@ -36,35 +25,8 @@ if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
 FUSION_GROUP = 'prim::CudaFusionGroup'
 FUSION_GUARD = 'prim::CudaFusionGuard'
 
-def is_pre_volta():
-    prop = torch.cuda.get_device_properties(torch.cuda.current_device())
-    return prop.major < 7
-
 class TestCudaFuser(JitTestCase):
 
-    special_values = torch.tensor(
-        [float("-inf"), -10, -math.pi,
-            -1, -0.5, 0, 1, 0.5,
-            math.pi, 10, float("inf"),
-            float("nan")], dtype=torch.float, device='cuda')
-
-    int_types = [
-        torch.int8,
-        torch.uint8,
-        torch.int16,
-        torch.int32,
-        torch.int64
-    ]
-
-    support_tensor_dtypes = [
-        torch.int32,
-        torch.int64,
-        torch.float16,
-        torch.float32,
-        torch.float64,
-        torch.bool
-    ]
-
     def _getSubgraphInFusion(self, graph):
         num_node = 0
         subgraph = None
@@ -89,7 +51,6 @@ class TestCudaFuser(JitTestCase):
         torch._C._jit_override_can_fuse_on_cpu(False)
         torch._C._jit_override_can_fuse_on_gpu(False)
         self.old_guard = torch._C._jit_set_nvfuser_guard_mode(False)
-        torch._C._debug_set_autodiff_subgraph_inlining(False)
 
         if(RUN_CUDA):
             self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True)
@@ -100,7 +61,6 @@ class TestCudaFuser(JitTestCase):
         torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse)
         torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse)
         torch._C._jit_set_nvfuser_guard_mode(self.old_guard)
-        torch._C._debug_set_autodiff_subgraph_inlining(True)
         super(TestCudaFuser, self).tearDown()
 
     def _run_helper(self, jit_op, op, *args):
@@ -111,7 +71,7 @@ class TestCudaFuser(JitTestCase):
         torch.cuda.manual_seed_all(123)
         o = op(*args)
         self.assertEqual(o, jit_o)
-        self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, 1, consider_subgraphs=True)
+        self.assertGraphContains(jit_op.graph_for(*args), FUSION_GUARD)
 
     def _run_training_helper(self, jit_op, op, grads, *args):
         torch.cuda.manual_seed_all(123)
@@ -200,31 +160,6 @@ class TestCudaFuser(JitTestCase):
         self.assertEqual(o, jit_o)
         self.assertGraphContains(t_jit.graph_for(x, y, z, q), FUSION_GUARD)
 
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_reduction_dtypes(self):
-
-        for op in [torch.sum, torch.mean]:
-            for dtype in [torch.float16, torch.float32, torch.double]:
-                def make_func(op):
-                    def func(x: torch.Tensor):
-                        o = torch.mul(x, 1.0)
-                        o = op(o, dim=[2])
-                        return o
-                    return func
-
-                x = torch.randn(8, 4, 16, dtype=dtype, device="cuda")
-                t = make_func(op)
-                t_jit = torch.jit.trace(t, x)
-                jit_o = t_jit(x)
-                jit_o = t_jit(x)
-                o = t(x)
-                self.assertEqual(o.dtype, jit_o.dtype)
-                self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
-                self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
-
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
                      "Requires fusion optimization pass to be effective")
@@ -364,7 +299,7 @@ class TestCudaFuser(JitTestCase):
         o = t(x, y, z)
         self.assertEqual(o, jit_o)
         subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z))
-        self.assertGraphContainsExactly(subgraph, 'aten::add', 4, consider_subgraphs=False)
+        self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
 
     @unittest.skipIf(True, "Broadcast with different output not supported yet")
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
@@ -410,6 +345,20 @@ class TestCudaFuser(JitTestCase):
         # Currently cannot fuse this
         self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
 
+    def _binary_test_helper(self, operation):
+        def t(x: torch.Tensor, y: torch.Tensor, z: float):
+            o = x + z
+            o = operation(o, y)
+            return o
+        t_jit = torch.jit.script(t)
+        x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
+        y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
+        jit_o = t_jit(x, y, 2.0)
+        jit_o = t_jit(x, y, 2.0)
+        o = t(x, y, 2.0)
+        self.assertEqual(o, jit_o)
+        self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
+
     def _unary_test_helper(self, operation):
         def t(x: torch.Tensor, z: float):
             o = x + z
@@ -456,217 +405,14 @@ class TestCudaFuser(JitTestCase):
                       torch.relu,
                       torch.sigmoid,
                       torch.tanh,
-                      torch.nn.functional.silu]
+                      torch.nn.functional.gelu]
         for op in operations:
             self._unary_test_helper(op)
 
-    def _unary_type_test_helper(self, operation, dtype, random_data=True):
-        shape = (4, 8, 32, 32)
-
-        # need additional def of t for boolean ops
-        def t(x: torch.Tensor, y: torch.Tensor):
-            o = x * y
-            o = operation(o)
-            return o
-
-        y = torch.tensor([1], device="cuda").to(dtype)
-
-        if random_data:
-            x = torch.randn(shape, dtype=torch.float32, device="cuda")
-            if dtype in self.int_types:
-                # prefer a larger variance for integer types
-                x *= 5
-            x = x.to(dtype=dtype)
-        else:
-            x = self.special_values.to(dtype=dtype)
-        try:
-            ref = t(x, y)
-        except Exception:
-            # same way as TE checker, if eager mode throws, ignore this test
-            return
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x, y)
-        jit_o = t_jit(x, y)
-        if dtype in self.support_tensor_dtypes:
-            self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
-        o = t(x, y)
-        self.assertEqual(o, jit_o, msg=f"""
-        failing case:
-            {dtype} {operation} {x}
-        """)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_data_compatibility(self):
-        dtypes = [
-            *self.int_types,
-            torch.float16,
-            torch.float32,
-            torch.float64
-        ]
-        operations = [torch.neg,
-                      torch.abs,
-                      torch.log,
-                      torch.log10,
-                      torch.log1p,
-                      torch.log2,
-                      torch.lgamma,
-                      torch.exp,
-                      torch.expm1,
-                      torch.erf,
-                      torch.erfc,
-                      torch.cos,
-                      torch.acos,
-                      torch.cosh,
-                      torch.sin,
-                      torch.asin,
-                      torch.tan,
-                      torch.atan,
-                      torch.sqrt,
-                      torch.rsqrt,
-                      torch.ceil,
-                      torch.floor,
-                      torch.round,
-                      torch.trunc,
-                      torch.frac,
-                      torch.reciprocal,
-                      torch.relu,
-                      torch.sigmoid,
-                      torch.tanh,
-                      torch.nn.functional.silu]
-        prev_fallback = os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK']
-        os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '0'
-        for op, dtype in itertools.product(operations, dtypes):
-            self._unary_type_test_helper(op, dtype, False)  # test special numbers
-            self._unary_type_test_helper(op, dtype)  # test random data
-        os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = prev_fallback
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_category_rule(self):
-        def run_tensor(x, z):
-            def t(x: torch.Tensor, z: torch.Tensor):
-                o = x + z
-                o = torch.abs(o)
-                return o
-            t_jit = torch.jit.script(t)
-            jit_o = t_jit(x, z)
-            jit_o = t_jit(x, z)
-            o = t(x, z)
-            self.assertEqual(o.dtype, jit_o.dtype)
-            self.assertEqual(o, jit_o)
-            self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD)
-
-        def run_scalar(x, z):
-            def t(x: torch.Tensor, z: float):
-                o = x + z
-                o = torch.abs(o)
-                return o
-            t_jit = torch.jit.script(t)
-            jit_o = t_jit(x, z)
-            jit_o = t_jit(x, z)
-            o = t(x, z)
-            self.assertEqual(o.dtype, jit_o.dtype)
-            self.assertEqual(o, jit_o)
-            self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD)
-
-        # n-dim with 0-dim (no type-promote)
-        x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
-        z = torch.tensor(2.0, dtype=torch.double, device="cuda")
-        run_tensor(x, z)
-
-        # n-dim with 0-dim (type-promote)
-        x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long)
-        z = torch.tensor(2.0, dtype=torch.double, device="cuda")
-        run_tensor(x, z)
-
-        # n-dim with n-dim (type-promote)
-        x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
-        z = torch.randn(4, 8, 32, 32, dtype=torch.double, device="cuda")
-        run_tensor(x, z)
-
-        # n-dim with scalar (no type-promote)
-        x = torch.randn(4, 8, 32, 32, dtype=torch.float16, device="cuda")
-        z = torch.tensor(3., dtype=torch.double)
-        run_scalar(x, z)
-
-        # n-dim with scalar (type-promote)
-        x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long)
-        z = torch.tensor(3., dtype=torch.double)
-        run_scalar(x, z)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_unary_bitwise(self):
-        def bit_not(x: torch.Tensor):
-            return ~(x + 0)
-
-        jitted = torch.jit.script(bit_not)
-        x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(torch.long)
-        jit_o = bit_not(x)
-        jit_o = bit_not(x)
-        o = bit_not(x)
-        self.assertEqual(o, jit_o)
-        jitted.graph_for(x)  # Shows up in second instance, not first
-        self.assertGraphContains(jitted.graph_for(x), FUSION_GUARD)
-
-        def bool_not(x: torch.Tensor, y: torch.Tensor):
-            return ~(x & y)
-
-        jitted = torch.jit.script(bool_not)
-        x = torch.rand(4, 8, 32, 32, dtype=torch.float, device="cuda").round().to(torch.bool)
-        y = torch.rand(4, 8, 32, 32, dtype=torch.float, device="cuda").round().to(torch.bool)
-        jit_o = bool_not(x, y)
-        jit_o = bool_not(x, y)
-        o = bool_not(x, y)
-        self.assertEqual(o, jit_o)
-        jitted.graph_for(x, y)  # Shows up in second instance, not first
-        self.assertGraphContains(jitted.graph_for(x, y), FUSION_GUARD)
-
-    def _binary_test_helper(self, operation, dtype):
-        def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
-            o = x + z
-            o = operation(o, y)
-            return o
-        x = (torch.randn(4, 32, 32, dtype=torch.float, device="cuda") * 5).to(dtype)
-        y = (torch.randn(4, 32, 32, dtype=torch.float, device="cuda") * 5).to(dtype)
-        # Avoid division by zero for integer tensors
-        div_like = [torch.div, torch.fmod, torch.remainder]
-        if operation in div_like and (dtype == torch.int32 or dtype == torch.int64):
-            y[y == 0] = 1
-        z = torch.tensor([2], device="cuda").to(dtype)
-        o = t(x, y, z)
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x, y, z)
-        jit_o = t_jit(x, y, z)
-
-        self.assertEqual(o, jit_o)
-        self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
-
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
                      "Requires fusion optimization pass to be effective")
     def test_binary_ops(self):
-        data_types = [
-            torch.float32,
-            torch.float64,
-            torch.int32,
-            torch.int64
-        ]
-        # need some extra support
-        # to handle below with integer inputs, and they
-        # don't look like popular integer ops in models
-        # , TODO: insert assertions in cpp
-        # if decide not to fuse these on int
-        skip_for_integer = [
-            torch.atan2,
-            torch.fmod,
-            torch.pow,
-            torch.div
-        ]
         operations = [torch.div,
                       torch.mul,
                       torch.atan2,
@@ -681,73 +427,8 @@ class TestCudaFuser(JitTestCase):
                       torch.gt,
                       torch.le,
                       torch.lt]
-        for op, dtype in itertools.product(operations, data_types):
-            if (dtype not in self.int_types) or (op not in skip_for_integer):
-                self._binary_test_helper(op, dtype)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_binary_bitwise(self):
-        def jit_or(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
-            return (x & y) | z
-
-        def jit_xor(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
-            return (x & y) ^ z
-
-        def jit_lshift(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
-            return (x & y) << z
-
-        def jit_rshift(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
-            return (x & y) >> z
-
-        for jit_func in [jit_or, jit_xor, jit_lshift, jit_rshift]:
-            x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(torch.long)
-            y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(torch.long)
-            z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(2).to(torch.long)
-
-            jitted = torch.jit.script(jit_func)
-            jit_o = jitted(x, y, z)
-            jit_o = jitted(x, y, z)
-            o = jit_func(x, y, z)
-            self.assertEqual(o, jit_o)
-            self.assertGraphContains(jitted.graph_for(x, y, z), FUSION_GUARD)
-
-        # We shouldn't need this redefinition of the function, but otherwise it won't recompile for a new type
-        def jit_or(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
-            return (x & y) | z
-
-        def jit_xor(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
-            return (x & y) ^ z
-
-        for jit_func in [jit_or, jit_xor]:
-            x = torch.rand(4, 2, dtype=torch.float, device="cuda").round().to(torch.bool)
-            y = torch.rand(4, 2, dtype=torch.float, device="cuda").round().to(torch.bool)
-            z = torch.rand(4, 2, dtype=torch.float, device="cuda").round().to(torch.bool)
-
-            jitted = torch.jit.script(jit_func)
-            jit_o = jitted(x, y, z)
-            jit_o = jitted(x, y, z)
-            o = jit_func(x, y, z)
-            self.assertEqual(o, jit_o)
-            self.assertGraphContains(jitted.graph_for(x, y, z), FUSION_GUARD)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_type_as_op(self):
-        def t(x: torch.Tensor, y: torch.Tensor, z: float):
-            o = torch.lt(x, z)
-            o = o.type_as(y)
-            return o
-        t_jit = torch.jit.script(t)
-        x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
-        y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
-        jit_o = t_jit(x, y, 0.5)
-        jit_o = t_jit(x, y, 0.5)
-        o = t(x, y, 0.5)
-        self.assertEqual(o, jit_o)
-        self.assertGraphContains(t_jit.graph_for(x, y, 0.5), FUSION_GUARD)
+        for op in operations:
+            self._binary_test_helper(op)
 
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     # legacy fuser does not work for rand_like, see issue #34361
@@ -877,7 +558,7 @@ class TestCudaFuser(JitTestCase):
 
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     def test_random_topo(self):
-        os.environ["PYTORCH_NVFUSER_DISABLE_FALLBACK"] = "1"
+        os.environ["PYTORCH_CUDA_FUSER_DISABLE_FALLBACK"] = "1"
         self.assertTrue(runDefaultTestWithSeed(28449))
 
     def _compare(self, desc, inp1, inp2, error):
@@ -904,12 +585,10 @@ class TestCudaFuser(JitTestCase):
             o = torch.relu(o)
             return o
 
-        x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute(
-            [perm0.index(i) for i in range(len(sizes))])
+        x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute([perm0.index(i) for i in range(len(sizes))])
         if broadcast_axis >= 0:
             sizes[broadcast_axis] = 1
-        y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute(
-            [perm1.index(i) for i in range(len(sizes))])
+        y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute([perm1.index(i) for i in range(len(sizes))])
         t_jit = torch.jit.script(t)
         jit_o = t_jit(x, y)
         jit_o = t_jit(x, y)
@@ -936,26 +615,23 @@ class TestCudaFuser(JitTestCase):
                     x = [7, 8, 12]
                     self._permutation_helper(x, b_axis, torch.float32, "cuda", perm0, perm1)
 
-    def _reduction_helper(self, sizes, reduction_axis, dtype, device, perm0, perm1, keepdim=False):
+    def _reduction_helper(self, sizes, reduction_axis, dtype, device, perm0, perm1):
         class MyReduction(torch.nn.Module):
-            __constants__ = ['reduction_axis', 'keepdim']
+            __constants__ = ['reduction_axis']
 
             def __init__(self):
                 super(MyReduction, self).__init__()
                 self.reduction_axis = reduction_axis
-                self.keepdim = keepdim
 
             def forward(self, x: torch.Tensor, y: torch.Tensor):
                 o = torch.add(x, y)
-                o = torch.sum(o, dim=self.reduction_axis, keepdim=self.keepdim)
+                o = torch.sum(o, dim=self.reduction_axis)
                 return o
 
         t = MyReduction()
 
-        x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute(
-            [perm0.index(i) for i in range(len(sizes))])
-        y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute(
-            [perm1.index(i) for i in range(len(sizes))])
+        x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute([perm0.index(i) for i in range(len(sizes))])
+        y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute([perm1.index(i) for i in range(len(sizes))])
         t_jit = torch.jit.script(t)
         jit_o = t_jit(x, y)
         jit_o = t_jit(x, y)
@@ -966,7 +642,6 @@ class TestCudaFuser(JitTestCase):
         self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
         self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
 
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
                      "Requires fusion optimization pass to be effective")
@@ -976,329 +651,10 @@ class TestCudaFuser(JitTestCase):
             # to single element (codegen limitation at this moment)
             for num_reduce_dim in range(1, len(x)):
                 for axes in itertools.combinations(range(len(x)), num_reduce_dim):
-                    for keepdim in (True, False):
-                        perm0 = range(len(x))
-                        perm1 = range(len(x))
-                        self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1, keepdim)
-
-    def _layer_norm_autodiff_helper(self, model, grad, shapes, args):
-        jit_model = torch.jit.script(model)
-
-        eps = np.random.random() * 1e-4
-        use_cudnn = bool(np.random.randint(0, 2))
-
-        # profile/optimization runs
-        for i in range(3):
-            jit_o = jit_model(shapes, *args, eps, use_cudnn)
-            jit_o.backward(grad)
-
-        ref_args = [t.detach().clone().requires_grad_() for t in args]
-        [t.grad.zero_() for t in args]
-        jit_o = jit_model(shapes, *args, eps, use_cudnn)
-        jit_o.backward(grad)
-
-        o = model(shapes, *ref_args, eps, use_cudnn)
-        o.backward(grad)
-        self.assertEqual(jit_o, o)
-        for arg, ref_arg in zip(args, ref_args):
-            self.assertEqual(arg.grad, ref_arg.grad)
-
-        # check fusion in fw & bw
-        g = jit_model.graph_for(shapes, *args, eps, use_cudnn)
-        for node in g.nodes():
-            n = node
-        dbg_state = jit_model.get_debug_state()
-        for val in dbg_state.execution_plans.values():
-            v = val
-        state2 = v.code.grad_executor_states()
-        for val in state2[0].execution_plans.values():
-            v2 = val
-        FileCheck().check(FUSION_GUARD).run(g)
-        FileCheck().check(FUSION_GUARD).run(v2.graph)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_layer_norm_autodiff(self):
-        def t_wb(shapes: List[int], x, w, b, eps: float, cudnn: bool):
-            o = torch.layer_norm(x, shapes, w, b, eps, cudnn)
-            o = torch.relu(o)
-            return o
-
-        def t_w(shapes: List[int], x, w, eps: float, cudnn: bool):
-            o = torch.layer_norm(x, shapes, w, None, eps, cudnn)
-            o = torch.relu(o)
-            return o
-
-        def t_b(shapes: List[int], x, b, eps: float, cudnn: bool):
-            o = torch.layer_norm(x, shapes, None, b, eps, cudnn)
-            o = torch.relu(o)
-            return o
-
-        def t(shapes: List[int], x, eps: float, cudnn: bool):
-            o = torch.layer_norm(x, shapes, None, None, eps, cudnn)
-            o = torch.relu(o)
-            return o
-
-        model = {3: t_wb, 2: t_w, 1: t_b, 0: t}
+                    perm0 = range(len(x))
+                    perm1 = range(len(x))
+                    self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1)
 
-        for w, b in itertools.product([True, False], repeat=2):
-            batch = [4]
-            shapes = [2, 3, 4]
-            m = model[w * 2 + b]
-
-            grad = torch.randn(batch + shapes, dtype=torch.float32, device="cuda")
-            args = [torch.randn(batch + shapes, dtype=torch.float32, device="cuda").requires_grad_()]
-            if w:
-                args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
-            if b:
-                args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
-            self._layer_norm_autodiff_helper(m, grad, shapes, args)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_layer_norm_parser(self):
-        dtype = torch.float32
-        device = "cuda"
-        x = torch.randn([4, 4, 2], dtype=dtype, device=device)
-        w = torch.randn([4, 2], dtype=dtype, device=device)
-        b = torch.randn([4, 2], dtype=dtype, device=device)
-
-        def t(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor):
-            o = torch.relu(x)
-            o = torch.layer_norm(o, [4, 2], w, b, 1e-5)
-            return o
-
-        o = t(x, w, b)
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x, w, b)
-        jit_o = t_jit(x, w, b)
-        o = t(x, w, b)
-        self.assertGraphContains(t_jit.graph_for(x, w, b), FUSION_GUARD)
-
-    def _native_layer_norm_helper(self, shape, norm_shape, dtype, device, error, affine=True):
-        class MyLayerNorm(torch.nn.Module):
-            __constants__ = ['norm_shape']
-
-            def __init__(self, elementwise_affine=True):
-                super(MyLayerNorm, self).__init__()
-                self.norm_shape = norm_shape
-                if elementwise_affine:
-                    self.weight = torch.randn(norm_shape, dtype=dtype, device=device)
-                    self.bias = torch.randn(norm_shape, dtype=dtype, device=device)
-                    with torch.no_grad():
-                        self.weight.fill_(1)
-                        self.bias.fill_(0)
-                else:
-                    self.weight = None
-                    self.bias = None
-
-            def forward(self, x: torch.Tensor):
-                o = torch.relu(x)
-                o = torch.native_layer_norm(o, self.norm_shape, self.weight, self.bias, 1e-5)
-                return o
-
-        t = MyLayerNorm(affine)
-
-        x = torch.randn(shape, dtype=dtype, device=device)
-        t_jit = torch.jit.script(t)
-        jit_o, jit_mean, jit_rstd = t_jit(x)
-        jit_o, jit_mean, jit_rstd = t_jit(x)
-        o, mean, rstd = t(x)
-        self.assertEqual(o.dtype, jit_o.dtype)
-        # numerical issues here due to our scheduling.
-        # can't use `self.assertEqual(o, jit_o)`
-        self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
-        self.assertTrue(self._compare("comparing mean failed", mean, jit_mean, error))
-        self.assertTrue(self._compare("comparing rstd failed", rstd, jit_rstd, error))
-        self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_native_layer_norm(self):
-        dims = 4
-        rnds = 3
-        for idx in range(rnds):
-            for offset in range(1, dims):
-                for affine in (True, False):
-                    input_shape = [random.randint(10, 30) for idx in range(dims)]
-                    norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)]
-                    self._native_layer_norm_helper(input_shape, norm_shape, torch.float32, "cuda", 1e-4, affine)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_native_layer_norm_half(self):
-        dims = 4
-        rnds = 3
-        for idx in range(rnds):
-            for offset in range(1, dims):
-                input_shape = [random.randint(10, 30) for idx in range(dims)]
-                norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)]
-                self._native_layer_norm_helper(input_shape, norm_shape, torch.float16, "cuda", 5e-3)
-
-    def _norm_helper(self, shape, dtype, device, error, is_batch_norm_else_instance_norm):
-        class MyBatchNorm(torch.nn.Module):
-            def __init__(self):
-                super(MyBatchNorm, self).__init__()
-
-            def forward(self, x: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
-                o = torch.nn.functional.batch_norm(x, r_mean, r_var, training=True)
-                o = torch.relu(o)
-                return o
-
-        class MyInstanceNorm(torch.nn.Module):
-            def __init__(self):
-                super(MyInstanceNorm, self).__init__()
-
-            def forward(self, x: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
-                o = torch.nn.functional.instance_norm(x, r_mean, r_var, use_input_stats=True)
-                o = torch.relu(o)
-                return o
-
-        t = MyBatchNorm() if is_batch_norm_else_instance_norm else MyInstanceNorm()
-
-        x = torch.randn(shape, dtype=dtype, device=device)
-        running_mean = torch.zeros(shape[1], dtype=torch.float32, device=device)
-        running_var = torch.ones(shape[1], dtype=torch.float32, device=device)
-        t_jit = torch.jit.script(t)
-
-        eager_running_mean = running_mean.clone()
-        eager_running_var = running_var.clone()
-        jit_running_mean = running_mean.clone()
-        jit_running_var = running_var.clone()
-
-        jit_o = t_jit(x, running_mean.clone(), running_var.clone())
-
-        self.assertTrue(self._compare("prerun comparing running_mean failed", eager_running_mean, jit_running_mean, error))
-        self.assertTrue(self._compare("prerun comparing running_var failed", eager_running_var, jit_running_var, error))
-
-        jit_o = t_jit(x, jit_running_mean, jit_running_var)
-        o = t(x, eager_running_mean, eager_running_var)
-        self.assertEqual(o.dtype, jit_o.dtype)
-        # numerical issues here due to our scheduling.
-        # can't use `self.assertEqual(o, jit_o)`
-        self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
-        self.assertTrue(self._compare("comparing running_mean failed", eager_running_mean, jit_running_mean, error))
-        self.assertTrue(self._compare("comparing running_var failed", eager_running_var, jit_running_var, error))
-        self.assertGraphContains(t_jit.graph_for(x, running_mean, running_var), FUSION_GUARD)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_norm(self):
-        output_elements = 10000
-        channel_sizes = [67, 457, 1024, 4096]
-
-        with torch.backends.cudnn.flags(enabled=False):
-            for is_batch_norm_else_instance_norm in [False, True]:
-                for dims in range(3, 6):
-                    output_size = int(pow(output_elements, 1. / (dims - 1)))
-                    for C in channel_sizes:
-                        x = [output_size for idx in range(dims)]
-                        x[1] = C
-                        self._norm_helper(x, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_norm_large(self):
-        output_elements = 262144
-        channel_sizes = 67, 457, 1024
-
-        for is_batch_norm_else_instance_norm in [True, False]:
-            for dims in range(3, 6):
-                output_size = int(pow(output_elements, 1. / (dims - 1)))
-                for C in channel_sizes:
-                    x = [output_size for idx in range(dims)]
-                    x[1] = C
-                    self._norm_helper(x, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_norm_half(self):
-        output_elements = 10000
-        channel_sizes = [67, 457, 1024, 4096]
-
-        with torch.backends.cudnn.flags(enabled=False):
-            for is_batch_norm_else_instance_norm in [False, True]:
-                for dims in range(3, 6):
-                    output_size = int(pow(output_elements, 1. / (dims - 1)))
-                    for C in channel_sizes:
-                        x = [output_size for idx in range(dims)]
-                        x[1] = C
-                        self._norm_helper(x, torch.float16, "cuda", 5e-3, is_batch_norm_else_instance_norm)
-
-    def _softmax_helper(self, shape, reduction_axis, dtype, device, error):
-        class MySoftmax(torch.nn.Module):
-            __constants__ = ['reduction_axis']
-
-            def __init__(self):
-                super(MySoftmax, self).__init__()
-                self.reduction_axis = reduction_axis
-
-            def forward(self, x: torch.Tensor, y: torch.Tensor):
-                o = torch.add(x, y)
-                o = torch.nn.functional.softmax(o, dim=self.reduction_axis)
-                return o
-
-        t = MySoftmax()
-
-        x = torch.randn(shape, dtype=dtype, device=device)
-        y = torch.randn(shape, dtype=dtype, device=device)
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x, y)
-        jit_o = t_jit(x, y)
-        o = t(x, y)
-        self.assertEqual(o.dtype, jit_o.dtype)
-        # numerical issues here due to our scheduling.
-        # can't use `self.assertEqual(o, jit_o)`
-        self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
-        self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_softmax(self):
-        output_size = 10000
-        dims = 4
-        output_size = int(pow(output_size, 1. / dims))
-        reduction_sizes = [67, 256, 1024, 4096]
-
-        for reduction_dim in range(dims):
-            for reduction_size in reduction_sizes:
-                x = [output_size for idx in range(dims)]
-                x[reduction_dim] = reduction_size
-                self._softmax_helper(x, reduction_dim, torch.float32, "cuda", 1e-4)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_softmax_half(self):
-        output_size = 10000
-        dims = 4
-        output_size = int(pow(output_size, 1. / dims))
-        reduction_sizes = [67, 256, 1024, 4096]
-
-        for reduction_dim in range(dims):
-            for reduction_size in reduction_sizes:
-                x = [output_size for idx in range(dims)]
-                x[reduction_dim] = reduction_size
-                self._softmax_helper(x, reduction_dim, torch.float16, "cuda", 5e-3)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
                      "Requires fusion optimization pass to be effective")
@@ -1312,7 +668,6 @@ class TestCudaFuser(JitTestCase):
                     for perm1 in itertools.permutations(range(len(x))):
                         self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1)
 
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
                      "Requires fusion optimization pass to be effective")
@@ -1355,119 +710,49 @@ class TestCudaFuser(JitTestCase):
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
                      "Requires fusion optimization pass to be effective")
-    def test_channels_last_with_broadcast(self):
-        # setting this true forces a new graph to be generated with a new
-        # input a different broadcast shape
-        torch._C._jit_set_nvfuser_guard_mode(True)
-
-        def t(x: torch.Tensor, y: torch.Tensor):
-            o = torch.mul(x, y)
-            o = o + 2.0
+    def test_reduction_dtype(self):
+        def t(x: torch.Tensor):
+            o = torch.mul(x, 1.0)
+            o = torch.sum(o, dim=[2], dtype=torch.float32)
             return o
         t_jit = torch.jit.script(t)
 
-        # Single Channel broadcasts
-        # Test 1
-        x = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
-        x = x.to(memory_format=torch.channels_last)
-
-        y = torch.randn(8, 4, 10, 1, dtype=torch.float, device="cuda")
-        y = y.to(memory_format=torch.channels_last)
-
-        jit_o = t_jit(x, y)
-        jit_o = t_jit(x, y)
-        o = t(x, y)
-
+        x = torch.randn(8, 4, 16, dtype=torch.float, device="cuda")
+        jit_o = t_jit(x)
+        jit_o = t_jit(x)
+        o = t(x)
         self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
-                         jit_o.is_contiguous(memory_format=torch.channels_last))
-        self.assertEqual(o, jit_o)
-
-        # Test 2
-        y = torch.randn(8, 4, 1, 16, dtype=torch.float, device="cuda")
-        y = y.to(memory_format=torch.channels_last)
+        self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
+        self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
 
-        jit_o = t_jit(x, y)
-        jit_o = t_jit(x, y)
-        o = t(x, y)
+    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
+                     "Requires fusion optimization pass to be effective")
+    def test_reduction_half(self):
+        def t(x: torch.Tensor):
+            o = torch.mul(x, 1.0)
+            o = torch.sum(o, dim=[2])
+            return o
 
+        t_jit = torch.jit.script(t)
+        x = torch.randn(8, 4, 16, dtype=torch.float16, device="cuda")
+        jit_o = t_jit(x)
+        jit_o = t_jit(x)
+        o = t(x)
         self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
-                         jit_o.is_contiguous(memory_format=torch.channels_last))
-        self.assertEqual(o, jit_o)
+        self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
+        self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
 
-        # Test 3
-        y = torch.randn(8, 1, 10, 16, dtype=torch.float, device="cuda")
-        y = y.to(memory_format=torch.channels_last)
-
-        jit_o = t_jit(x, y)
-        jit_o = t_jit(x, y)
-        o = t(x, y)
-
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
-                         jit_o.is_contiguous(memory_format=torch.channels_last))
-        self.assertEqual(o, jit_o)
-
-        # Test 3
-        y = torch.randn(1, 4, 10, 16, dtype=torch.float, device="cuda")
-        y = y.to(memory_format=torch.channels_last)
-
-        jit_o = t_jit(x, y)
-        jit_o = t_jit(x, y)
-        o = t(x, y)
-
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
-                         jit_o.is_contiguous(memory_format=torch.channels_last))
-        self.assertEqual(o, jit_o)
-
-        '''
-        Currently, the JIT doesn't have tensor merge logic to handle adding
-        a broadcast tensor with more than one broadcast into a non-broadcast
-        tensor.  Therefore, either of these tests can fail depending on the
-        sort implementation.  The second test is known to fail.
-
-        # Two Channel broadcasts
-        # Test 1
-        y = torch.randn(8, 4, 1, 1, dtype=torch.float, device="cuda")
-        y = y.to(memory_format=torch.channels_last)
-
-        jit_o = t_jit(x, y)
-        jit_o = t_jit(x, y)
-        o = t(x, y)
-
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
-                         jit_o.is_contiguous(memory_format=torch.channels_last))
-        self.assertEqual(o, jit_o)
-
-        # Test 2
-        y = torch.randn(8, 4, 1, 1, dtype=torch.float, device="cuda")
-        y = y.to(memory_format=torch.channels_last).transpose(2,3)
-        x = x.transpose(2,3)
-
-        jit_o = t_jit(x, y)
-        jit_o = t_jit(x, y)
-        o = t(x, y)
-
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
-                         jit_o.is_contiguous(memory_format=torch.channels_last))
-        self.assertEqual(o, jit_o)
-        '''
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_pw_single_reduction_partition(self):
-        sizes = [2, 2, 2]
-        dtype = torch.float
-        device = "cuda"
-        x = torch.randn(sizes, dtype=dtype, device=device)
-        y = torch.randn(sizes, dtype=dtype, device=device)
-        z = torch.randn(sizes, dtype=dtype, device=device)
+    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
+                     "Requires fusion optimization pass to be effective")
+    def test_pw_single_reduction_partition(self):
+        sizes = [8, 8, 8]
+        dtype = torch.float
+        device = "cuda"
+        x = torch.randn(sizes, dtype=dtype, device=device)
+        y = torch.randn(sizes, dtype=dtype, device=device)
+        z = torch.randn(sizes, dtype=dtype, device=device)
 
         def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
             o = torch.add(x, y)
@@ -1482,94 +767,6 @@ class TestCudaFuser(JitTestCase):
         self.assertEqual(o, jit_o)
         self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
 
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_permutation_preservation(self):
-        sizes = [2, 2, 2, 2]
-        dtype = torch.float
-        device = "cuda"
-        x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
-
-        def t(x: torch.Tensor):
-            o = torch.relu(x)
-            o = torch.sum(o, dim=[0])
-            return o
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x)
-        jit_o = t_jit(x)
-        o = t(x)
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o, jit_o)
-        self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
-        # we should preserve permutation to inputs
-        self.assertEqual(jit_o.stride(), (1, 4, 2))
-
-        def t(x: torch.Tensor):
-            o = torch.relu(x)
-            o = torch.add(o, 1.0)
-            return o
-
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x)
-        jit_o = t_jit(x)
-        o = t(x)
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o, jit_o)
-        self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
-        self.assertTrue(jit_o.is_contiguous(memory_format=torch.channels_last))
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_normalization_partition(self):
-        sizes = [8, 8, 8]
-        dtype = torch.float
-        device = "cuda"
-        x = torch.randn(sizes, dtype=dtype, device=device)
-        y = torch.randn(sizes, dtype=dtype, device=device)
-        z = torch.randn(sizes, dtype=dtype, device=device)
-        r_m = torch.randn(8, dtype=dtype, device=device)
-        r_v = torch.randn(8, dtype=dtype, device=device)
-
-        def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
-            o = torch.add(x, y)
-            o = torch.nn.functional.softmax(o, dim=0)
-            o = torch.add(o, z)
-            o = torch.nn.functional.batch_norm(o, r_mean, r_var, training=True)
-            return o
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x, y, z, r_m, r_v)
-        jit_o = t_jit(x, y, z, r_m, r_v)
-        o = t(x, y, z, r_m, r_v)
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o, jit_o)
-        self.assertGraphContains(t_jit.graph_for(x, y, z, r_m, r_v), FUSION_GUARD)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_sum_to_one(self):
-        dtype = torch.float
-        device = "cuda"
-        x = torch.randn([4, 5, 6], dtype=dtype, device=device)
-
-        def t(x: torch.Tensor):
-            o = torch.add(x, 0)
-            o = torch.sum(o, dim=[0, 1, 2])
-            return o
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x)
-        jit_o = t_jit(x)
-        o = t(x)
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o, jit_o)
-        self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
                      "Requires fusion optimization pass to be effective")
@@ -1593,28 +790,6 @@ class TestCudaFuser(JitTestCase):
         self.assertEqual(o, jit_o)
         self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
 
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_trivial_reduction(self):
-        dtype = torch.float
-        device = "cuda"
-        x = torch.randn([1, 4, 8], dtype=dtype, device=device)
-
-        def t(x: torch.Tensor):
-            o = torch.add(x, 0)
-            o = torch.sum(o, dim=[0])
-            o = torch.sum(o, dim=[0])
-            return o
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x)
-        jit_o = t_jit(x)
-        o = t(x)
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o, jit_o)
-        self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
-
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
                      "Requires fusion optimization pass to be effective")
@@ -1630,7 +805,6 @@ class TestCudaFuser(JitTestCase):
         repro_jit = torch.jit.script(repro)
         self._run_helper(repro_jit, repro, x, 0.6)
 
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
                      "Requires fusion optimization pass to be effective")
@@ -1654,809 +828,23 @@ class TestCudaFuser(JitTestCase):
         # have been optimized away
         self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 0)
 
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_profile_ivalue(self):
-        dtype = torch.float
-        device = "cuda"
-        x = torch.randn([7, 4, 7], dtype=dtype, device=device)
-        y = torch.randn([7, 4, 7], dtype=dtype, device=device)
-
-        def t(x: torch.Tensor, y: torch.Tensor, dim: List[int], keepdim: bool):
-            o = torch.add(x, y)
-            o = o.sum(dim, keepdim=keepdim)
-            return o
-
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x, y, (0, 1), False)
-        jit_o = t_jit(x, y, (0, 1), False)
-        o = t(x, y, (0, 1), False)
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o, jit_o)
-        self.assertGraphContains(t_jit.graph_for(x, y, (0, 1), False), FUSION_GUARD)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_sum_to_size(self):
-        dtype = torch.float
-        device = "cuda"
-        x = torch.randn([2, 4, 4], dtype=dtype, device=device)
-        y = torch.randn([2, 4, 4], dtype=dtype, device=device)
-
-        def t(x: torch.Tensor, y: torch.Tensor, new_size: List[int]):
-            o = torch.add(x, y)
-            o = o.sum_to_size(new_size)
-            return o
-
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x, y, (4, 1))
-        jit_o = t_jit(x, y, (4, 1))
-        o = t(x, y, (4, 1))
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o, jit_o)
-        self.assertGraphContains(t_jit.graph_for(x, y, (4, 1)), FUSION_GUARD)
-
-        # update shape: old kernel should handle dynamic shape well without
-        # recompilation
-        x = torch.randn([2, 5, 8], dtype=dtype, device=device)
-        y = torch.randn([2, 5, 8], dtype=dtype, device=device)
-        # (TODO) check executed kernel, should extend autograd.profiler to fused
-        # kernels
-        jit_o = t_jit(x, y, (5, 1))
-        o = t(x, y, (5, 1))
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o, jit_o)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_grad_sum_to_size(self):
-        dtype = torch.float
-        device = "cuda"
-        x = torch.randn([2, 4, 4], dtype=dtype, device=device).requires_grad_()
-        y = torch.randn([4], dtype=dtype, device=device).requires_grad_()
-        grad = torch.randn([2, 4, 4], dtype=dtype, device=device)
-
-        ref_x = x.detach().clone().requires_grad_()
-        ref_y = y.detach().clone().requires_grad_()
-
-        def t(x: torch.Tensor, y: torch.Tensor):
-            o = torch.add(x, y)
-            o = torch.relu(o)
-            return o
-
-        # profiling runs for forward & backward
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x, y)
-        jit_o.backward(grad)
-        jit_o = t_jit(x, y)
-        jit_o.backward(grad)
-
-        x.grad = None
-        y.grad = None
-        jit_o = t_jit(x, y)
-        jit_o.backward(grad)
-        o = t(ref_x, ref_y)
-        o.backward(grad)
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o, jit_o)
-        self.assertEqual(x.grad, ref_x.grad)
-        self.assertEqual(y.grad, ref_y.grad)
-        bwd_graph = list(
-            list(t_jit.get_debug_state().execution_plans.values())[
-                0].code.grad_executor_states()[0].execution_plans.values()
-        )[0].graph
-        FileCheck().check(FUSION_GUARD).run(bwd_graph)
-
-        # update shape: old kernel should handle dynamic shape well without
-        # recompilation
-        x = torch.randn([2, 5, 8], dtype=dtype, device=device).requires_grad_()
-        y = torch.randn([8], dtype=dtype, device=device).requires_grad_()
-        ref_x = x.detach().clone().requires_grad_()
-        ref_y = y.detach().clone().requires_grad_()
-        grad = torch.randn([2, 5, 8], dtype=dtype, device=device)
-        jit_o = t_jit(x, y)
-        # (TODO) check executed kernel, should extend autograd.profiler to fused
-        # kernels
-        jit_o.backward(grad)
-        o = t(ref_x, ref_y)
-        o.backward(grad)
-        self.assertEqual(o.dtype, jit_o.dtype)
-        self.assertEqual(o, jit_o)
-        self.assertEqual(x.grad, ref_x.grad)
-        self.assertEqual(y.grad, ref_y.grad)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_add_backward_with_alpha(self):
-        x = torch.randn(4, 2, dtype=torch.float32, device='cuda', requires_grad=True)
-        y = torch.randn(4, 2, dtype=torch.float32, device='cuda', requires_grad=True)
-        grad = torch.randn(4, 2, dtype=torch.float32, device='cuda')
-
-        # Test that a mul is not generated when not needed
-        # Alpha=1.0 or is not used
-        def test1(x: torch.Tensor, y: torch.Tensor):
-            o = torch.add(x, y, alpha=1.0)
-            o = o + 1.0
-            return o
-
-        test1_jit = torch.jit.script(test1)
-        for i in range(3):
-            jit_o = test1_jit(x, y)
-            jit_o.backward(grad)
-
-        bwd1_graph = list(
-            list(test1_jit.get_debug_state().execution_plans.values())[
-                0].code.grad_executor_states()[0].execution_plans.values()
-        )[0].graph
-        FileCheck().check_not("aten::mul_").run(bwd1_graph)
-
-        # Alpha is set to something other than 1.0
-        def test2(x: torch.Tensor, y: torch.Tensor):
-            o = torch.add(x, y, alpha=2.0)
-            o = o + 1.0
-            return o
-
-        test2_jit = torch.jit.script(test2)
-        for i in range(3):
-            jit_o = test2_jit(x, y)
-            jit_o.backward(grad)
-
-        bwd2_graph = list(
-            list(test2_jit.get_debug_state().execution_plans.values())[
-                0].code.grad_executor_states()[0].execution_plans.values()
-        )[0].graph
-        FileCheck().check("aten::mul_").run(bwd2_graph)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_dropout_inference_fusion(self):
-        dtype = torch.float
-        device = "cuda"
-        x = torch.randn([10, 4, 8], dtype=dtype, device=device)
-
-        def t(x: torch.Tensor, p: float, train: bool):
-            o = torch.nn.functional.dropout(x, p, training=train)
-            o = o + 1.0
-            return o
-
-        t_jit = torch.jit.script(t)
-
-        self._run_helper(t_jit, t, x, 0.15, False)
-
     @unittest.skipIf(not RUN_CUDA, "requires CUDA")
     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
                      "Requires fusion optimization pass to be effective")
-    def test_dropout_train_nograd_fusion(self):
+    def test_gelu_fusion(self):
         dtype = torch.float
         device = "cuda"
-        x = torch.randn([10, 4, 8], dtype=dtype, device=device)
-
-        def t(x: torch.Tensor, p: float, train: bool):
-            o = torch.nn.functional.dropout(x, p, training=train)
-            o = o + 1.0
-            return o
-
-        t_jit = torch.jit.script(t)
-
-        self._run_helper(t_jit, t, x, 0.0, True)
+        x = torch.randn([64, 128, 1024], dtype=dtype, device=device, requires_grad=True)
+        grads = torch.randn([64, 128, 1024], dtype=dtype, device=device)
 
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_dropout_train_nograd_prob_check(self):
-        dtype = torch.float
-        device = "cuda"
-        x = torch.randn([1024, 1024], dtype=dtype, device=device)
-
-        def t(x: torch.Tensor, p: float, train: bool):
-            o = torch.nn.functional.dropout(x, p, training=train)
-            o = o + 0.0
-            return o
-
-        t_jit = torch.jit.script(t)
-
-        for prob in [0.0, 0.15, 0.5, 0.85, 1.]:
-            torch.cuda.manual_seed_all(123)
-            jit_o = t_jit(x, prob, True)
-            torch.cuda.manual_seed_all(123)
-            jit_o = t_jit(x, prob, True)
-
-            self.assertTrue(jit_o.detach().isfinite().all().item())
-
-            num_elems = x.numel()
-            num_zeros = num_elems - jit_o.detach().count_nonzero().item()
-            percent_zeros = num_zeros / num_elems
-
-            self.assertTrue((percent_zeros >= (prob - 0.01)) and (percent_zeros <= (prob + 0.01)))
-            self.assertGraphContainsExactly(t_jit.graph_for(x, prob, True), FUSION_GUARD, 1, consider_subgraphs=True)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_dropout_training_fusion(self):
-        dtype = torch.float
-        device = "cuda"
-        x = torch.randn([10, 4, 8], dtype=dtype, device=device, requires_grad=True)
-        grads = torch.randn([10, 4, 8], dtype=dtype, device=device)
-
-        def t(x: torch.Tensor, p: float, train: bool):
-            o = torch.nn.functional.dropout(x, p, training=train)
-            o = o * 1.0
-            return o
-
-        t_jit = torch.jit.script(t)
-
-        # The drop probability needs to be set to zero given that the order of picking random
-        # numbers between eager mode and the jit is different
-        self._run_training_helper(t_jit, t, grads, x, 0.0, True)
-
-        def t2(x: torch.Tensor, p: float, train: bool):
-            o = torch.nn.functional.softmax(x, dim=-1)
-            o = torch.nn.functional.dropout(o, p, training=train)
-            return o
-
-        t2_jit = torch.jit.script(t2)
-
-        # The drop probability needs to be set to zero given that the order of picking random
-        # numbers between eager mode and the jit is different
-        self._run_training_helper(t2_jit, t2, grads, x, 0.0, True)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_gelu(self):
-        dtype = torch.float
-        device = "cuda"
-        x = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=True)
-        grads = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=False)
-
-        def t(x: torch.Tensor, fast : bool):
-            o = torch.nn.functional.gelu(x, fast)
-            o = o * 1.0
-            return o
-
-        t_jit = torch.jit.script(t)
-
-        for approximate in [False, True]:
-            self._run_training_helper(t_jit, t, grads, x, approximate)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_dropout_training_prob_check(self):
-        dtype = torch.float
-        device = "cuda"
-        x = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=True)
-        x_nograd = torch.randn([1024, 1024], dtype=dtype, device=device)
-
-        def t(x: torch.Tensor, p: float, train: bool):
-            o = torch.nn.functional.dropout(x, p, training=train)
-            o = o + 0.0
-            return o
-
-        t_jit = torch.jit.script(t)
-
-        for prob in [0.0, 0.15, 0.5, 0.85, 1.]:
-            torch.cuda.manual_seed_all(123)
-            jit_o = t_jit(x, prob, True)
-            torch.cuda.manual_seed_all(123)
-            jit_o = t_jit(x, prob, True)
-            torch.cuda.manual_seed_all(123)
-            jit_o = t_jit(x, prob, True)
-
-            self.assertTrue(jit_o.detach().isfinite().all().item())
-
-            num_elems = x.numel()
-            num_zeros = num_elems - jit_o.detach().count_nonzero().item()
-            percent_zeros = num_zeros / num_elems
-
-            self.assertTrue((percent_zeros >= (prob - 0.01)) and (percent_zeros <= (prob + 0.01)))
-            self.assertGraphContainsExactly(t_jit.graph_for(x, prob, True), FUSION_GUARD, 1, consider_subgraphs=True)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_linear(self):
-        in_feature = 2
-        out_feature = 8
-        x = torch.randn(4, in_feature, dtype=torch.float32, device='cuda')
-        weight = torch.randn(out_feature, in_feature, dtype=torch.float32, device='cuda')
-        bias = torch.randn(out_feature, dtype=torch.float32, device='cuda')
-
-        def t(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
-            o = torch.nn.functional.linear(x, weight, bias)
-            o = torch.relu(o)
-            return o
-
-        # bias set to true.
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x, weight, bias)
-        jit_o = t_jit(x, weight, bias)
-        o = t(x, weight, bias)
-        self.assertEqual(o, jit_o)
-        # since the output value is not used at all, the fusion operator should
-        # have been optimized away
-        self.assertGraphContainsExactly(t_jit.graph_for(x, weight, bias), FUSION_GUARD, 1)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_backward_type(self):
-        # not super useful to check gradient of integer/bool, so skipping here
-        type_pairs = [
-            (torch.float, torch.half),
-            (torch.double, torch.half),
-            (torch.float, torch.double),
-        ]
-        for x_type, y_type in type_pairs:
-            x = torch.randn(4, 2, dtype=x_type, device='cuda', requires_grad=True)
-            y = torch.randn(4, 2, dtype=y_type, device='cuda', requires_grad=True)
-            grad = torch.randn(4, 2, dtype=torch.float, device='cuda')
-
-            def test1(x: torch.Tensor, y: torch.Tensor):
-                o = torch.add(x, y)
-                o = torch.add(o, y)
-                o = torch.add(o, y)
-                o = torch.add(o, y)
-                o = o + 1.0
-                return o
-
-            test1_jit = torch.jit.script(test1)
-            for i in range(3):
-                jit_o = test1_jit(x, y)
-                jit_o.backward(grad)
-
-            bwd_graph = list(
-                list(test1_jit.get_debug_state().execution_plans.values())[
-                    0].code.grad_executor_states()[0].execution_plans.values()
-            )[0].graph
-
-            FileCheck().check(FUSION_GROUP).run(bwd_graph)
-            self.assertEqual(x.grad.dtype, x.dtype)
-            self.assertEqual(y.grad.dtype, y.dtype)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_autocast_1(self):
-        def t(x: torch.Tensor, y: torch.Tensor):
-            o = x * 2.0
-            o = torch.softmax(o, dim=-1)
-            o = o * 3.0
-            o = torch.matmul(o, y)
-            return o
-
-        x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True)
-        y = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
-        grad = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=False)
-        t_jit = torch.jit.script(t)
-
-        for i in range(3):
-            with torch.cuda.amp.autocast():
-                jit_o = t_jit(x, y)
-                if i == 2 :
-                    fwd_graph = t_jit.graph_for(x, y)
-            jit_o.backward(grad)
-
-        self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
-
-        with torch.cuda.amp.autocast():
-            bwd_graph = list(
-                list(t_jit.get_debug_state().execution_plans.values())[
-                    0].code.grad_executor_states()[0].execution_plans.values()
-            )[0].graph
-        FileCheck().check(FUSION_GROUP).run(bwd_graph)
-
-        self.assertEqual(jit_o.dtype, torch.half)
-        self.assertEqual(x.grad.dtype, x.dtype)
-        self.assertEqual(y.grad.dtype, y.dtype)
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_autocast_2(self):
         def t(x: torch.Tensor):
-            o = x * 2.0
-            o = torch.softmax(o, dim=-1)
-            o = o * 3.0
-            o = torch.softmax(o, dim=-1)
-            o = o * 4.0
-            return o
-
-        x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True)
-        grad = torch.randn(8, 4, dtype=torch.float, device='cuda', requires_grad=False)
-        t_jit = torch.jit.script(t)
-
-        for i in range(3):
-            with torch.cuda.amp.autocast() :
-                jit_o = t_jit(x)
-                if i == 2 :
-                    fwd_graph = t_jit.graph_for(x)
-            jit_o.backward(grad)
-
-        self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
-
-        with torch.cuda.amp.autocast():
-            bwd_graph = list(
-                list(t_jit.get_debug_state().execution_plans.values())[
-                    0].code.grad_executor_states()[0].execution_plans.values()
-            )[0].graph
-        FileCheck().check(FUSION_GROUP).run(bwd_graph)
-
-        self.assertEqual(jit_o.dtype, torch.float)
-        self.assertEqual(x.grad.dtype, x.dtype)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_to_dtype_fp32_to_fp16(self):
-        def t(x: torch.Tensor):
-            o = x * 2.0
-            o = o.to(dtype=torch.half)
-            o = o * 3.0
-            return o
-
-        x = torch.randn(8, 4, dtype=torch.float, device='cuda')
-        t_jit = torch.jit.script(t)
-
-        for i in range(3):
-            jit_o = t_jit(x)
-
-        self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
-        self.assertEqual(jit_o.dtype, torch.half)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_to_dtype_fp16_to_fp32(self):
-        def t(x: torch.Tensor):
-            o = x * 2.0
-            o = o.to(dtype=torch.float)
-            o = o * 3.0
-            return o
-
-        x = torch.randn(8, 4, dtype=torch.half, device='cuda')
-        t_jit = torch.jit.script(t)
-
-        for i in range(3):
-            jit_o = t_jit(x)
-
-        self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
-        self.assertEqual(jit_o.dtype, torch.float)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_to_dtype_fp16_to_fp16(self):
-        def t(x: torch.Tensor):
-            o = x * 2.0
-            o = o.to(dtype=torch.half)
-            o = o * 3.0
-            return o
-
-        x = torch.randn(8, 4, dtype=torch.half, device='cuda')
-        t_jit = torch.jit.script(t)
-
-        for i in range(3):
-            jit_o = t_jit(x)
-
-        self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
-        self.assertEqual(jit_o.dtype, torch.half)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(not TEST_MULTIGPU, "requires multiple CUDA device")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_multiple_device_pw(self):
-
-        def t(x):
-            o = x + 1.0
-            o = torch.relu(o)
-            return o
-
-        x = torch.randn(2, dtype=torch.float32, device="cuda")
-        t_jit = torch.jit.script(t)
-
-        for i in range(3):
-            jit_o = t_jit(x)
-
-        self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
-        torch.cuda.device(1)
-        x = x.to("cuda:1")
-        jit_o = t_jit(x)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_graph_for_with_missing_optimized_engine(self):
-        x = torch.randn(8, 4, 2, dtype=torch.float, device="cuda").requires_grad_()
-
-        def t(x: torch.Tensor, flag: bool):
-            x = x + 1.0
-            x = torch.relu(x)
-            if flag:
-                o = x + 1.0
-                o = torch.relu(o)
-            else:
-                o = x + 2.0
-                o = torch.relu(o)
-            return o
-
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x, False)
-        jit_o = t_jit(x, False)
-        jit_o = t_jit(x, True)
-        o = t(x, True)
-        self.assertEqual(o, jit_o)
-        # since the output value is not used at all, the fusion operator should
-        # have been optimized away
-        self.assertGraphContainsExactly(t_jit.graph_for(x, True), FUSION_GUARD, 1, True)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_branches(self):
-        in_feature = 2
-        out_feature = 4
-        x = torch.randn(4, in_feature, dtype=torch.float32, device='cuda')
-        weight = torch.randn(out_feature, in_feature, dtype=torch.float32, device='cuda')
-        bias = torch.randn(out_feature, dtype=torch.float32, device='cuda')
-
-        def t(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, flag: bool):
-            if flag:
-                o = torch.nn.functional.linear(x, weight, bias)
-                o = o + 1.0
-                o = torch.relu(o)
-            else:
-                o = x.sum()
-                o = o + 2.0
-                o = torch.relu(o)
-            return o
-
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x, weight, bias, True)
-        jit_o = t_jit(x, weight, bias, True)
-        o = t(x, weight, bias, True)
-        self.assertEqual(o, jit_o)
-        # since the output value is not used at all, the fusion operator should
-        # have been optimized away
-        self.assertGraphContainsExactly(t_jit.graph_for(x, weight, bias, True), FUSION_GUARD, 1)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_scalar_tensor(self):
-        x = torch.empty([], device="cuda", dtype=torch.float32)
-
-        def t(x: torch.Tensor):
-            o = x + 1.0
-            o = torch.nn.functional.relu(o)
-            return o
-
-        # bias set to true.
-        t_jit = torch.jit.script(t)
-        jit_o = t_jit(x)
-        jit_o = t_jit(x)
-        o = t(x)
-        self.assertEqual(o, jit_o)
-        # since the output value is not used at all, the fusion operator should
-        # have been optimized away
-        self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
-
-    @unittest.skipIf(os.environ.get('PYTORCH_NO_CUDA_MEMORY_CACHING') is not None,
-                     "skipping graph_rng when caching allocator is disabled")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(CUDA_MAJOR < 11, "requires CUDA11 or above")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_graph_rng(self):
-        self.assertTrue(torch._C._jit_nvfuser_enabled())
-        size = 10000
-        a = torch.randn((size,), device="cuda", dtype=torch.float)
-
-        def t(x):
-            o = x + 1.0
-            o = torch.nn.functional.dropout(o, p=0.1)
-            o = o + 1.0
-            o = torch.nn.functional.dropout(o, p=0.1)
+            o = torch.nn.functional.gelu(x)
+            o = o * 1.0
             return o
 
         t_jit = torch.jit.script(t)
 
-        for _ in range(3):
-            t_jit(a)
-
-        self.assertGraphContainsExactly(t_jit.graph_for(a), FUSION_GUARD, 1)
-
-        # Control (jitted, ungraphed)
-        torch.cuda.manual_seed(5)
-        eager_out = a.clone()
-        for _ in range(3):
-            eager_out = t_jit(eager_out)
-
-        graph_in = a.clone()
-        g = torch.cuda._Graph()
-        s = torch.cuda.Stream()
-        s.wait_stream(torch.cuda.current_stream())
-        with torch.cuda.stream(s):
-            torch.cuda.manual_seed(5)
-            g.capture_begin()
-            graph_out = t_jit(graph_in)
-            g.capture_end()
-        torch.cuda.current_stream().wait_stream(s)
-        # g is now a jitted, graphed version of t.
-
-        # Runs a (jitted, graphed) -> (jitted, ungraphed) -> (jitted, graphed) sequence.
-        # The ops in the overall sequence should be the same as Control.
-        g.replay()
-        # graph_out is now filled with g's result. Use it as ungraphed input.
-        out = t_jit(graph_out)
-        graph_in.copy_(out)
-        g.replay()
-
-        # If replay() updated RNG state correctly, graph_out should now equal eager_out
-        self.assertEqual(graph_out, eager_out)
-
-    def _test_batch_norm_impl_index_helper(self, batch, c, hw, affine=True, track_running_stats=True, train=True):
-        # enabling inlining to avoid counter increment in BN forward
-        torch._C._debug_set_autodiff_subgraph_inlining(True)
-        dtype = torch.float32
-
-        class MyModule(torch.nn.Module):
-            def __init__(self, num_features=10, affine=True, track_running_stats=True):
-                super(MyModule, self).__init__()
-                self.bn = torch.nn.BatchNorm2d(num_features,
-                                               1e-5,
-                                               affine=affine,
-                                               track_running_stats=track_running_stats).to(dtype=dtype)
-
-            def forward(self, x):
-                o = x * 1.0
-                o = self.bn(o)
-                return o
-
-        x = torch.randn(batch, c, hw, hw, dtype=torch.float, device="cuda").to(dtype=dtype).requires_grad_()
-        grad = torch.randint(-20, 20, (batch, c, hw, hw), device="cuda").to(dtype=dtype).div(-10)
-
-        my_module = MyModule(c, affine, track_running_stats).cuda()
-        ref_module = MyModule(c, affine, track_running_stats).cuda()
-
-        if not train:
-            my_module.eval()
-            ref_module.eval()
-
-        t_jit = torch.jit.script(my_module)
-        ref_module.load_state_dict(my_module.state_dict())
-
-        ref_x = x.detach().requires_grad_()
-
-        for i in range(0, 3):
-            jit_o = t_jit(x)
-            jit_o.backward(grad)
-
-        # TODO: remove this run?
-        o = ref_module(ref_x)
-        o.backward(grad)
-
-        has_affine = ref_module.bn.weight is not None
-        has_running_stats = ref_module.bn.running_mean is not None
-
-        if has_running_stats:
-            my_module.bn.running_mean.zero_()
-            my_module.bn.running_var.fill_(1.0)
-            ref_module.bn.running_mean.zero_()
-            ref_module.bn.running_var.fill_(1.0)
-
-        # Verify that when train is False, we don't have grad for weight/bias.
-        if has_affine and train:
-            my_module.bn.weight.grad.zero_()
-            my_module.bn.bias.grad.zero_()
-            ref_module.bn.weight.grad.zero_()
-            ref_module.bn.bias.grad.zero_()
-
-        x.grad.zero_()
-        ref_x.grad.zero_()
-
-        # real runs
-        jit_o = t_jit(x)
-        jit_o.backward(grad)
-
-        o = ref_module(ref_x)
-        o.backward(grad)
-
-        # assert forward graph fusion
-        self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1, consider_subgraphs=True)
-        # assert backward graph fusion
-        bwd_graph = list(
-            list(t_jit.get_debug_state().execution_plans.values())[0].code.grad_executor_states()[0]
-            .execution_plans.values())[0].graph
-        self.assertGraphContainsExactly(bwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
-
-        self.assertTrue(self._compare("comparing output failed", jit_o, o, 1e-5))
-        self.assertTrue(self._compare("comparing input grad failed", x.grad, ref_x.grad, 1e-4))
-        # TODO: switch to welford and reduce this to 1e-5
-        # The 1e-3 looks bad, but we don't have welford in codegen, so numeric
-        # is very different between reference and codegen.
-        if has_affine and train:
-            self.assertTrue(self._compare("comparing weight grad failed",
-                                          my_module.bn.weight.grad,
-                                          ref_module.bn.weight.grad,
-                                          1e-3))
-            self.assertTrue(self._compare("comparing bias grad failed",
-                                          my_module.bn.bias.grad,
-                                          ref_module.bn.bias.grad,
-                                          1e-4))
-        if has_running_stats:
-            self.assertTrue(self._compare("comparing running_mean failed",
-                                          my_module.bn.running_mean,
-                                          ref_module.bn.running_mean,
-                                          1e-5))
-            self.assertTrue(self._compare("comparing running_var failed",
-                                          my_module.bn.running_var,
-                                          ref_module.bn.running_var,
-                                          1e-5))
-
-    @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_batch_norm_impl_index_correctness(self):
-        with torch.backends.cudnn.flags(enabled=True):
-            batch = [2, 7, 16]
-            channels = [4, 89, 19, 32]
-            hw = [1, 8, 17, 32]
-
-            # avoid tolerance failure in CI
-            torch.cuda.manual_seed_all(211)
-
-            # failing sizes (2, 1, 1, 1)
-            # failing sizes (2, 89, 8, 8) training False, track True, affine: False
-            for b, c, hw in itertools.product(batch, channels, hw):
-                setups = [
-                    [True, True],
-                    [False, False],
-                    [True, False],
-                    [False, True]]
-                for training_and_track, affine in itertools.product(setups, [True, False]):
-                    training, track_running_stats = training_and_track
-                    self._test_batch_norm_impl_index_helper(b, c, hw, affine, track_running_stats, training)
-
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
-                     "Requires fusion optimization pass to be effective")
-    def test_softplus_fuser(self):
-        def shifted_softplus(x: torch.Tensor, shift: float):
-            return functional.softplus(x) - shift
-
-        jitted = torch.jit.script(shifted_softplus)
-        inp = torch.randn(4, 2, dtype=torch.float32, device="cuda").requires_grad_()
-        inp_ref = inp.detach().clone().requires_grad_()
-        grad = torch.randn(4, 2, dtype=torch.float32, device="cuda")
-
-        aten_o = shifted_softplus(inp_ref, 0.693147)
-        aten_o.backward(grad)
-        aten_grad = inp_ref.grad
-
-        for i in range(3):
-            jit_o = jitted(inp, 0.693147)
-            inp.grad = None         # avoid accumulation on grad
-            jit_o.backward(grad)
-            jit_grad = inp.grad
-
-        assert torch.allclose(jit_o, aten_o)
-        assert torch.allclose(jit_grad, aten_grad)
-        self.assertGraphContains(jitted.graph_for(inp, 0.693147), FUSION_GROUP, True)
+        self._run_training_helper(t_jit, t, grads, x)
 
 class TestPassManagerCudaFuser(JitTestCase):
 
index 1d99fa8..a139515 100644 (file)
@@ -32,15 +32,12 @@ GENERATED_CPP = [
 # NVFuser runtime library
 libtorch_nvfuser_runtime_sources = [
     "torch/csrc/jit/codegen/cuda/runtime/block_reduction.cu",
-    "torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu",
-    "torch/csrc/jit/codegen/cuda/runtime/block_sync_default.cu",
     "torch/csrc/jit/codegen/cuda/runtime/broadcast.cu",
     "torch/csrc/jit/codegen/cuda/runtime/fp16_support.cu",
     "torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu",
     "torch/csrc/jit/codegen/cuda/runtime/helpers.cu",
     "torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu",
     "torch/csrc/jit/codegen/cuda/runtime/tensor.cu",
-    "torch/csrc/jit/codegen/cuda/runtime/welford.cu",
     "aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh",
     "aten/src/ATen/cuda/detail/UnpackRaw.cuh",
 ]
@@ -502,7 +499,6 @@ libtorch_cuda_core_sources = [
     "torch/csrc/autograd/functions/comm.cpp",
     "torch/csrc/jit/codegen/cuda/arith.cpp",
     "torch/csrc/jit/codegen/cuda/compute_at.cpp",
-    "torch/csrc/jit/codegen/cuda/compute_at_map.cpp",
     "torch/csrc/jit/codegen/cuda/codegen.cpp",
     "torch/csrc/jit/codegen/cuda/dispatch.cpp",
     "torch/csrc/jit/codegen/cuda/expr_evaluator.cpp",
@@ -513,61 +509,40 @@ libtorch_cuda_core_sources = [
     "torch/csrc/jit/codegen/cuda/fusion.cpp",
     "torch/csrc/jit/codegen/cuda/graph_fuser.cpp",
     "torch/csrc/jit/codegen/cuda/index_compute.cpp",
-    "torch/csrc/jit/codegen/cuda/index_reference_replay.cpp",
     "torch/csrc/jit/codegen/cuda/instrumentation.cpp",
     "torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp",
     "torch/csrc/jit/codegen/cuda/ir_cloner.cpp",
     "torch/csrc/jit/codegen/cuda/ir_graphviz.cpp",
     "torch/csrc/jit/codegen/cuda/ir_nodes.cpp",
     "torch/csrc/jit/codegen/cuda/ir_iostream.cpp",
-    "torch/csrc/jit/codegen/cuda/ir_utils.cpp",
     "torch/csrc/jit/codegen/cuda/iter_visitor.cpp",
     "torch/csrc/jit/codegen/cuda/kernel.cpp",
     "torch/csrc/jit/codegen/cuda/kernel_cache.cpp",
-    "torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp",
     "torch/csrc/jit/codegen/cuda/kernel_ir.cpp",
     "torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp",
     "torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp",
-    "torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp",
-    "torch/csrc/jit/codegen/cuda/lower_allocation.cpp",
-    "torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp",
     "torch/csrc/jit/codegen/cuda/lower_index.cpp",
-    "torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp",
     "torch/csrc/jit/codegen/cuda/lower_loops.cpp",
-    "torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp",
-    "torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp",
-    "torch/csrc/jit/codegen/cuda/lower_predicate.cpp",
-    "torch/csrc/jit/codegen/cuda/lower_shift.cpp",
-    "torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp",
-    "torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp",
+    "torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp",
+    "torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp",
     "torch/csrc/jit/codegen/cuda/lower_unroll.cpp",
+    "torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp",
     "torch/csrc/jit/codegen/cuda/lower_utils.cpp",
     "torch/csrc/jit/codegen/cuda/lower_validation.cpp",
     "torch/csrc/jit/codegen/cuda/lower2device.cpp",
     "torch/csrc/jit/codegen/cuda/manager.cpp",
     "torch/csrc/jit/codegen/cuda/mutator.cpp",
-    "torch/csrc/jit/codegen/cuda/ops/composite.cpp",
-    "torch/csrc/jit/codegen/cuda/ops/normalization.cpp",
-    "torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp",
-    "torch/csrc/jit/codegen/cuda/parallel_type_bitmap.cpp",
     "torch/csrc/jit/codegen/cuda/parser.cpp",
     "torch/csrc/jit/codegen/cuda/partition.cpp",
     "torch/csrc/jit/codegen/cuda/predicate_compute.cpp",
     "torch/csrc/jit/codegen/cuda/register_interface.cpp",
-    "torch/csrc/jit/codegen/cuda/root_domain_map.cpp",
-    "torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp",
-    "torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp",
-    "torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp",
-    "torch/csrc/jit/codegen/cuda/scheduler/registry.cpp",
-    "torch/csrc/jit/codegen/cuda/scheduler/utils.cpp",
+    "torch/csrc/jit/codegen/cuda/scheduler.cpp",
     "torch/csrc/jit/codegen/cuda/shape_inference.cpp",
-    "torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp",
     "torch/csrc/jit/codegen/cuda/tensor_view.cpp",
     "torch/csrc/jit/codegen/cuda/transform_iter.cpp",
     "torch/csrc/jit/codegen/cuda/transform_replay.cpp",
     "torch/csrc/jit/codegen/cuda/transform_rfactor.cpp",
     "torch/csrc/jit/codegen/cuda/type.cpp",
-    "torch/csrc/jit/codegen/cuda/utils.cpp",
     "torch/csrc/jit/tensorexpr/cuda_codegen.cpp",
     "torch/csrc/jit/runtime/register_cuda_ops.cpp",
 ]
index 2c14f5a..fab206a 100644 (file)
@@ -3,9 +3,7 @@
 #include <c10/util/Exception.h>
 #include <c10/util/irange.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
 #include <torch/csrc/jit/codegen/cuda/type.h>
-#include <cfloat>
 
 namespace torch {
 namespace jit {
@@ -22,10 +20,10 @@ Val* newScalar(ValType vtype, DataType dtype) {
       switch (dtype) {
         case DataType::Bool:
           return new Bool();
-        case DataType::Double:
         case DataType::Float:
+          return new Float();
         case DataType::Half:
-          return new Double();
+          return new Half();
         case DataType::Int:
           return new Int();
         default:
@@ -37,11 +35,10 @@ Val* newScalar(ValType vtype, DataType dtype) {
 
   TORCH_CHECK(
       false,
-      "Cannot handle ValType: ",
+      "Was expecting a scalar type, but received ValType: ",
       vtype,
       " with DataType:",
-      dtype,
-      " in newScalar.");
+      dtype);
 }
 
 TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype) {
@@ -70,7 +67,7 @@ TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype) {
         continue;
       if (dom[i]->isBroadcast())
         continue;
-      out_domain[i] = dom[i]->clone();
+      out_domain[i] = new IterDomain(dom[i]->start(), dom[i]->extent());
     }
   }
   for (const auto dim_i : c10::irange(out_domain.size())) {
@@ -127,11 +124,31 @@ std::vector<Val*> maybeBroadcast(const std::vector<Val*>& vals) {
   return out_vals;
 }
 
+Val* newOutputVal(const std::vector<Val*>& vals) {
+  ValType out_vtype = vals[0]->getValType().value();
+  DataType out_dtype = vals[0]->getDataType().value();
+
+  for (auto val : vals) {
+    TORCH_CHECK(val->isVal(), "Invalid statement found during promotion.");
+    TORCH_CHECK(
+        val->getDataType().value() != DataType::Null,
+        "Invalid datatype found during prmotion.");
+    out_vtype = promote_type(out_vtype, val->getValType().value());
+    out_dtype = promote_type(out_dtype, val->getDataType().value());
+  }
+
+  if (out_vtype == ValType::TensorView)
+    return newOutputTV(vals, out_dtype);
+
+  return newScalar(out_vtype, out_dtype);
+}
+
 Val* newValLike(Val* val, DataType dtype) {
+  TORCH_CHECK(val->isVal(), "Invalid statement provided to create new value.");
   TORCH_CHECK(
       dtype != DataType::Null, "Invalid datatype provided for new value.");
 
-  const ValType vtype = val->getValType().value();
+  ValType vtype = val->getValType().value();
 
   if (vtype == ValType::TensorView)
     return newOutputTV({val}, dtype);
@@ -167,20 +184,7 @@ TensorView* castOp(DataType dtype, TensorView* v1) {
 // UNARY OPERATIONS
 
 Val* unaryOp(UnaryOpType type, Val* v1) {
-  TORCH_INTERNAL_ASSERT(
-      type != UnaryOpType::Address,
-      "The reference operator & is not accessible in the Fusion IR");
-  Val* out = newValLike(v1, v1->getDataType().value());
-  // TODO: We should add the following, but we need to go through shchedulers
-  // and make sure all calls to "fusion->inputs" includes the output of RandLike
-  //
-  //  If rand like, there isn't a real dependency on the input value, so map it
-  //  to a dummy scalar. if
-  //
-  // (type == UnaryOpType::RandLike) {
-  //   v1 = new NamedScalar("__rnd", v1->getDataType().value());
-  // }
-
+  Val* out = newOutputVal({v1});
   new UnaryOp(type, out, v1);
   return out;
 }
@@ -192,7 +196,6 @@ TensorView* unaryOp(UnaryOpType type, TensorView* v1) {
 Val* neg(Val* v) {
   return unaryOp(UnaryOpType::Neg, v);
 }
-
 TensorView* neg(TensorView* v) {
   return unaryOp(UnaryOpType::Neg, v);
 }
@@ -206,13 +209,11 @@ TensorView* arithOpOverloads(Val* (*func)(Val*, Val*), T1* v1, T2* v2) {
   return func(v1->template as<Val>(), v2->template as<Val>())
       ->template as<TensorView>();
 }
-
 template <typename T1, typename T2>
 TensorView* arithOpOverloads(BinaryOpType type, T1* v1, T2* v2) {
   return binaryOp(type, v1->template as<Val>(), v2->template as<Val>())
       ->template as<TensorView>();
 }
-
 template <typename T1, typename T2, typename T3>
 TensorView* arithOpOverloads(
     Val* (*func)(Val*, Val*, Val*),
@@ -226,7 +227,6 @@ TensorView* arithOpOverloads(
              vals[2]->template as<Val>())
       ->template as<TensorView>();
 }
-
 template <typename T1, typename T2, typename T3, typename T4>
 TensorView* arithOpOverloads(
     Val* (*func)(Val*, Val*, Val*, Val*),
@@ -242,126 +242,28 @@ TensorView* arithOpOverloads(
              vals[3]->template as<Val>())
       ->template as<TensorView>();
 }
-
-namespace {
-enum class Category { Scalar, ZeroDimTensor, DimTensor };
-
-inline Category getCategory(const Val* v) {
-  if (v->isA<TensorView>()) {
-    if (v->as<TensorView>()->nDims() > 0) {
-      return Category::DimTensor;
-    } else {
-      return Category::ZeroDimTensor;
-    }
-  } else {
-    return Category::Scalar;
-  }
-}
-
-// replicated logic from Aten/native/TypeProperties.cpp, minus complex support
-DataType getCommonType(DataType higher, DataType lower) {
-  if (isFloatingPointType(higher)) {
-    return higher;
-  }
-  if (higher == DataType::Bool || isFloatingPointType(lower)) {
-    return promote_type(higher, lower);
-  }
-  if (higher != DataType::Null) {
-    return higher;
-  }
-  return lower;
-}
-} // namespace
-
-// Type promotion logic for binary operators
-DataType getOutputType(BinaryOpType op_type, Val* v1, Val* v2) {
-  DataType v1_dtype = v1->getDataType().value();
-  DataType v2_dtype = v2->getDataType().value();
-
-  const bool floating_input =
-      isFloatingPointType(v1_dtype) || isFloatingPointType(v2_dtype);
-
-  const bool integer_input =
-      isIntegralType(v1_dtype) || isIntegralType(v2_dtype);
-
-  const bool all_integer_input =
-      isIntegralType(v1_dtype) && isIntegralType(v2_dtype);
-
-  if (all_integer_input) {
-    TORCH_INTERNAL_ASSERT(
-        !(noFullIntegerSupport(op_type)) || (v1->isScalar() && v2->isScalar()),
-        "unsupported op with all integer tensor inputs");
-  }
-
-  // Combine categories
-  const auto v1_cat = getCategory(v1);
-  const auto v2_cat = getCategory(v2);
-  if (v1_cat != v2_cat) {
-    const DataType higher = v1_cat > v2_cat ? v1_dtype : v2_dtype;
-    const DataType lower = v1_cat > v2_cat ? v2_dtype : v1_dtype;
-    const DataType common_type = getCommonType(higher, lower);
-    v1_dtype = common_type;
-    v2_dtype = common_type;
-  }
-
-  if (isIntegerOp(op_type) || (alsoBooleanOperator(op_type) && integer_input)) {
-    // If integer op or maybe bool op with integer inputs meaning binary op
-    if (integer_input && all_integer_input) {
-      return promote_type(v1_dtype, v2_dtype);
-    } else if (integer_input && !all_integer_input) {
-      TORCH_CHECK(
-          !floating_input,
-          "Operator ",
-          op_type,
-          " not supported with floating point inputs.");
-      return isIntegralType(v1_dtype) ? v1_dtype : v2_dtype;
-    } else {
-      TORCH_INTERNAL_ASSERT(
-          false,
-          "Currently no support for float inputs to int operations. ",
-          "Inputs should be manually casted first.");
-    }
-  } else if (isLogicalOp(op_type)) {
-    return DataType::Bool;
-  } else if (alsoBooleanOperator(op_type)) {
-    // If boolean op that can't have floating inputs (& or |)
-    TORCH_CHECK(
-        !floating_input,
-        "Operator ",
-        op_type,
-        " not supported with floating point inputs.");
-    return DataType::Bool;
-  } else {
-    // Otherwise do normal type promotion
-    return promote_type(v1_dtype, v2_dtype);
-  }
-}
-
 } // namespace
 
 TORCH_CUDA_CU_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) {
-  const auto out_dtype = getOutputType(type, v1, v2);
-  const auto out_vtype =
-      promote_type(v1->getValType().value(), v2->getValType().value());
   auto vals = maybeBroadcast({v1, v2});
-  Val* out = nullptr;
-  if (out_vtype == ValType::TensorView) {
-    out = newOutputTV(vals, out_dtype);
-  } else {
-    out = newScalar(out_vtype, out_dtype);
+  Val* out = newOutputVal({vals[0], vals[1]});
+  if (is_logical_op(type)) {
+    if (out->getDataType().value() != DataType::Bool)
+      out = newValLike(out, DataType::Bool);
+  } else if (type >= BinaryOpType::Mod) {
+    if (out->getDataType().value() != DataType::Int)
+      out = newValLike(out, DataType::Int);
   }
+
   new BinaryOp(type, out, vals[0], vals[1]);
   return out;
 }
-
 TensorView* binaryOp(BinaryOpType type, TensorView* v1, Val* v2) {
   return arithOpOverloads(type, v1, v2);
 }
-
 TensorView* binaryOp(BinaryOpType type, Val* v1, TensorView* v2) {
   return arithOpOverloads(type, v1, v2);
 }
-
 TensorView* binaryOp(BinaryOpType type, TensorView* v1, TensorView* v2) {
   return arithOpOverloads(type, v1, v2);
 }
@@ -379,7 +281,6 @@ TensorView* add(Val* v1, TensorView* v2) {
 TensorView* add(TensorView* v1, TensorView* v2) {
   return arithOpOverloads(add, v1, v2);
 }
-
 // sub
 Val* sub(Val* v1, Val* v2) {
   return binaryOp(BinaryOpType::Sub, v1, v2);
@@ -393,7 +294,6 @@ TensorView* sub(Val* v1, TensorView* v2) {
 TensorView* sub(TensorView* v1, TensorView* v2) {
   return arithOpOverloads(sub, v1, v2);
 }
-
 // mul
 Val* mul(Val* v1, Val* v2) {
   return binaryOp(BinaryOpType::Mul, v1, v2);
@@ -407,7 +307,6 @@ TensorView* mul(Val* v1, TensorView* v2) {
 TensorView* mul(TensorView* v1, TensorView* v2) {
   return arithOpOverloads(mul, v1, v2);
 }
-
 // div
 Val* div(Val* v1, Val* v2) {
   return binaryOp(BinaryOpType::Div, v1, v2);
@@ -421,7 +320,6 @@ TensorView* div(Val* v1, TensorView* v2) {
 TensorView* div(TensorView* v1, TensorView* v2) {
   return arithOpOverloads(div, v1, v2);
 }
-
 // mod
 Val* mod(Val* v1, Val* v2) {
   return binaryOp(BinaryOpType::Mod, v1, v2);
@@ -435,7 +333,6 @@ TensorView* mod(Val* v1, TensorView* v2) {
 TensorView* mod(TensorView* v1, TensorView* v2) {
   return arithOpOverloads(mod, v1, v2);
 }
-
 // lt
 Val* lt(Val* v1, Val* v2) {
   return binaryOp(BinaryOpType::LT, v1, v2);
@@ -449,20 +346,6 @@ TensorView* lt(Val* v1, TensorView* v2) {
 TensorView* lt(TensorView* v1, TensorView* v2) {
   return arithOpOverloads(lt, v1, v2);
 }
-
-// gt
-Val* gt(Val* v1, Val* v2) {
-  return binaryOp(BinaryOpType::GT, v1, v2);
-}
-TensorView* gt(TensorView* v1, Val* v2) {
-  return arithOpOverloads(gt, v1, v2);
-}
-TensorView* gt(Val* v1, TensorView* v2) {
-  return arithOpOverloads(gt, v1, v2);
-}
-TensorView* gt(TensorView* v1, TensorView* v2) {
-  return arithOpOverloads(gt, v1, v2);
-}
 // eq
 Val* eq(Val* v1, Val* v2) {
   return binaryOp(BinaryOpType::Eq, v1, v2);
@@ -476,7 +359,6 @@ TensorView* eq(Val* v1, TensorView* v2) {
 TensorView* eq(TensorView* v1, TensorView* v2) {
   return arithOpOverloads(eq, v1, v2);
 }
-
 // ceilDiv
 Val* ceilDiv(Val* v1, Val* v2) {
   return binaryOp(BinaryOpType::CeilDiv, v1, v2);
@@ -490,16 +372,15 @@ TensorView* ceilDiv(Val* v1, TensorView* v2) {
 TensorView* ceilDiv(TensorView* v1, TensorView* v2) {
   return arithOpOverloads(ceilDiv, v1, v2);
 }
-
 // andOp
 Val* andOp(Val* v1, Val* v2) {
   TORCH_CHECK(
-      !isFloatingPointType(v1->getDataType().value()),
-      "Input1 should not be a floating point type, but received: ",
+      v1->getDataType().value() == DataType::Bool,
+      "Input1 should be of type bool, not ",
       v1->getDataType().value());
   TORCH_CHECK(
-      !isFloatingPointType(v2->getDataType().value()),
-      "Input2 should not be a floating point type, but received: ",
+      v2->getDataType().value() == DataType::Bool,
+      "Input2 should be of type bool, not ",
       v2->getDataType().value());
   return binaryOp(BinaryOpType::And, v1, v2);
 }
@@ -518,8 +399,7 @@ TensorView* andOp(TensorView* v1, TensorView* v2) {
 // TODO: How do we adjust this so we can reduce to a single scalar value?
 static TensorView* newForReduction(
     TensorView* tv,
-    const std::vector<unsigned int>& axes,
-    DataType data_type = DataType::Null) {
+    const std::vector<unsigned int>& axes) {
   auto orig_domain = TensorDomain::noReductions(tv->getRootDomain());
   std::set<unsigned int> axes_set(axes.begin(), axes.end());
 
@@ -544,7 +424,7 @@ static TensorView* newForReduction(
     const IterDomain* id = orig_domain[dim];
 
     TORCH_CHECK(
-        !(isReduction && id->isBroadcast() && !id->isImplicitBroadcast()),
+        !(isReduction && id->isBroadcast()),
         "Cannot reduce an axis that is marked as broadcasted as it has an undetermined size. Tried to reduce ID = ",
         id,
         " of tensor ",
@@ -559,18 +439,14 @@ static TensorView* newForReduction(
 
   TensorDomain* td =
       new TensorDomain(new_domain, std::vector<bool>(new_domain.size(), true));
-
-  data_type =
-      data_type == DataType::Null ? tv->getDataType().value() : data_type;
-  return new TensorView(td, data_type);
+  return new TensorView(td, tv->getDataType().value());
 }
 
 TensorView* reductionOp(
     BinaryOpType reduction_op_type,
     const std::vector<int>& axes,
     Val* init,
-    TensorView* tv,
-    bool keep_dim /*=false*/) {
+    TensorView* tv) {
   TORCH_CHECK(
       init->isConstScalar(),
       "Cannot create a reduction operation where the initial value is not a const scalar.");
@@ -600,98 +476,30 @@ TensorView* reductionOp(
   }
 
   TensorView* out = newForReduction(tv, uint_axes);
-  const auto out_type = out->getDataType().value();
-  const auto init_type = init->getDataType().value();
-  TORCH_CHECK(
-      (isFloatingPointType(out_type) && isFloatingPointType(init_type)) ||
-          (isIntegralType(out_type) && isIntegralType(init_type)) ||
-          (out_type == DataType::Bool && init_type == DataType::Bool),
-      "Types should match for reduction ops but received: ",
-      out_type,
-      " and ",
-      init_type);
+  if (init->getDataType().value() != tv->getDataType().value())
+    init = castOp(tv->getDataType().value(), init);
   new ReductionOp(reduction_op_type, init, out, tv);
-
-  if (keep_dim) {
-    auto tv_root = TensorDomain::noReductions(tv->getRootDomain());
-    std::vector<bool> is_broadcast(tv_root.size(), false);
-    for (int axis : axes) {
-      is_broadcast[axis] = true;
-    }
-
-    out = broadcast(out, is_broadcast);
-  }
   return out;
 }
 
-TensorView* sum(
-    TensorView* v1,
-    const std::vector<int>& axes,
-    bool keep_dim /*=false*/) {
-  Val* init = nullptr;
-  auto dtype = v1->getDataType().value();
-  if (isFloatingPointType(dtype)) {
-    init = new Double(0.0);
-  } else if (isIntegralType(dtype)) {
-    init = new Int(0);
-  } else {
-    TORCH_CHECK(
-        false,
-        "Could not generate a sum op for tensor with type: ",
-        v1->getDataType().value());
-  }
-
-  return reductionOp(BinaryOpType::Add, axes, init, v1, keep_dim);
-}
-
-TensorView* max(
-    TensorView* v1,
-    const std::vector<int>& axes,
-    bool keep_dim /*=false*/) {
-  Val* init = nullptr;
-  switch (v1->getDataType().value()) {
-    case (DataType::Double):
-      init = new Double(DBL_MIN);
-      break;
-    case (DataType::Float):
-      init = new Double(FLT_MIN);
-      break;
-    case (DataType::Int):
-      init = new Int(INT_MIN);
-      break;
-    default:
-      TORCH_CHECK(
-          false,
-          "Could not generate a max op for tensor with type: ",
-          v1->getDataType().value());
-  }
-
-  return reductionOp(BinaryOpType::Max, axes, init, v1, keep_dim);
-}
-
-TensorView* min(
-    TensorView* v1,
-    const std::vector<int>& axes,
-    bool keep_dim /*=false*/) {
-  Val* init = nullptr;
+TensorView* sum(TensorView* v1, const std::vector<int>& axes) {
+  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+  Val* init;
   switch (v1->getDataType().value()) {
-    case (DataType::Double):
-      init = new Double(DBL_MAX);
-      break;
     case (DataType::Float):
-      init = new Double(FLT_MAX);
+      init = new Float(0.0);
       break;
     case (DataType::Int):
-      init = new Int(INT_MAX);
+      init = new Int(0);
       break;
     default:
       TORCH_CHECK(
           false,
-          "Could not generate a min op for tensor with type: ",
+          "Could not generate a sum op for tensor with type: ",
           v1->getDataType().value());
   }
 
-  return reductionOp(BinaryOpType::Min, axes, init, v1, keep_dim);
+  return reductionOp(BinaryOpType::Add, axes, init, v1);
 }
 
 TensorView* broadcast(
@@ -720,7 +528,6 @@ TensorView* broadcast(
   }
 
   std::vector<IterDomain*> out_domain;
-  // Don't propagate reduction IDs through arith ops.
   auto inp_domain = TensorDomain::noReductions(inp->getRootDomain());
   size_t iinp = 0, ibdim = 0;
   while (ibdim < is_broadcast_dim.size()) {
@@ -731,7 +538,8 @@ TensorView* broadcast(
           ParallelType::Serial,
           IterType::BroadcastWithoutStride));
     } else {
-      out_domain.push_back(inp_domain[iinp]->clone());
+      // Don't propagate reduction IDs through arith ops.
+      out_domain.push_back(inp_domain[iinp]);
       iinp++;
     }
     ibdim++;
@@ -740,114 +548,7 @@ TensorView* broadcast(
   TensorView* out_tensor = new TensorView(
       new TensorDomain(out_domain, std::vector<bool>(out_domain.size(), true)),
       inp->getDataType().value());
-  new BroadcastOp(out_tensor, inp, is_broadcast_dim);
-  return out_tensor;
-}
-
-WelfordResult Welford(
-    TensorView* tv,
-    const std::vector<int>& axes,
-    TensorView* init_avg,
-    TensorView* init_var,
-    Int* init_N) {
-  TORCH_CHECK(
-      TensorDomain::sameAs(tv->getRootDomain(), tv->domain()->domain()),
-      "Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt.");
-
-  TORCH_CHECK(tv->nDims() > 0, "Tried to reduce a 0-dim tensor");
-  TORCH_CHECK(axes.size() > 0, "No reduction axis specified");
-
-  // Initial values for welford op are tensors, so their dims have to match the
-  // output dim,
-  // i.e. original_dims - dims_to_be_reduced
-  Val* init_avg_val = nullptr;
-  Val* init_var_val = nullptr;
-  if (!init_N->isZeroInt()) {
-    TORCH_CHECK(
-        init_avg != nullptr && init_var != nullptr && init_N != nullptr,
-        "welford op: all init values need to be provided");
-    TORCH_CHECK(
-        (axes.size() + init_avg->getRootDomain().size()) ==
-            tv->getRootDomain().size(),
-        "welford op: initial tensor mismatch");
-    TORCH_CHECK(
-        (axes.size() + init_var->getRootDomain().size()) ==
-            tv->getRootDomain().size(),
-        "welford op: initial tensor mismatch");
-    init_avg_val = init_avg;
-    init_var_val = init_var;
-  } else {
-    init_avg_val = new Double(0);
-    init_var_val = new Double(0);
-  }
-
-  // Check and collect reduction axes
-  std::vector<unsigned int> uint_axes;
-  for (int axis : axes) {
-    if (axis < 0)
-      axis += int(tv->nDims());
-
-    TORCH_CHECK(
-        axis >= 0 && (unsigned int)axis < tv->nDims(),
-        "Reduction on invalid axis, recieved: ",
-        axis,
-        " however tensor view only has ",
-        tv->nDims(),
-        " dims.");
-
-    uint_axes.push_back((unsigned int)axis);
-  }
-
-  // Create tensor outputs
-  TensorView* out_avg = newForReduction(tv, uint_axes);
-  TensorView* out_var = newForReduction(tv, uint_axes);
-  TensorView* out_N = newForReduction(tv, uint_axes, DataType::Int);
-
-  new WelfordOp(
-      out_avg,
-      out_var,
-      out_N, /*out var/avg/count */
-      init_avg_val,
-      init_var_val,
-      init_N, /*init var/avg/count */
-      tv,
-      nullptr,
-      new Int(1)); /*in var/avg/count */
-
-  return WelfordResult(out_avg, out_var, out_N);
-}
-
-WelfordResult::WelfordResult(
-    TensorView* in_avg,
-    TensorView* in_var_sum,
-    TensorView* in_n)
-    : avg(in_avg), var_sum(in_var_sum), n(in_n) {
-  TORCH_INTERNAL_ASSERT(avg->definition()->sameAs(var_sum->definition()));
-  TORCH_INTERNAL_ASSERT(avg->definition()->sameAs(n->definition()));
-}
-
-WelfordResult WelfordResult::rFactor(const std::vector<int>& axes) {
-  auto o_tv = avg->definition()->as<WelfordOp>()->out()->as<TensorView>();
-  return o_tv->rFactor(axes, avg, var_sum, n);
-}
-
-TensorView* transpose(
-    TensorView* inp,
-    const std::unordered_map<int, int>& old2new) {
-  auto inp_domain = TensorDomain::noReductions(inp->getRootDomain());
-  std::vector<IterDomain*> out_domain(inp_domain.size());
-
-  auto new2old = ir_utils::normalizeOld2New(old2new, inp_domain.size());
-
-  for (size_t i = 0; i < out_domain.size(); ++i) {
-    auto in_id = inp_domain[new2old[i]];
-    out_domain[i] = in_id->clone();
-  }
-
-  TensorView* out_tensor = new TensorView(
-      new TensorDomain(out_domain, std::vector<bool>(out_domain.size(), true)),
-      inp->getDataType().value());
-  new TransposeOp(out_tensor, inp, new2old);
+  new BroadcastOp(out_tensor, inp);
   return out_tensor;
 }
 
@@ -956,28 +657,18 @@ TensorView* addcmul(TensorView* v1, TensorView* v2, TensorView* v3, Val* v4) {
 }
 
 // TERNARY OPERATIONS
-// where (c ? v1 : v2)
+// where
 Val* where(Val* c, Val* v1, Val* v2) {
   TORCH_CHECK(
       c->getDataType().value() == DataType::Bool,
       "Condition should be of DataType Bool, not ",
       c->getDataType().value());
 
-  // Not actually an add, but need to send a binary op to get output type
-  auto out_dtype = getOutputType(BinaryOpType::Add, v1, v2);
-  auto out_vtype =
-      promote_type(v1->getValType().value(), v2->getValType().value());
   auto vals = maybeBroadcast({c, v1, v2});
-  Val* out = nullptr;
-  if (out_vtype == ValType::TensorView) {
-    out = newOutputTV(vals, out_dtype);
-  } else {
-    out = newScalar(out_vtype, out_dtype);
-  }
+  Val* out = newOutputVal({vals[1], vals[2]});
   new TernaryOp(TernaryOpType::Where, out, vals[0], vals[1], vals[2]);
   return out;
 }
-
 TensorView* where(TensorView* v1, Val* v2, Val* v3) {
   return arithOpOverloads(where, v1, v2, v3);
 }
@@ -1003,36 +694,17 @@ TensorView* where(TensorView* v1, TensorView* v2, TensorView* v3) {
 // TERNARY OPERATIONS
 
 Val* threshold(Val* in, Val* thresh, Val* value) {
-  const auto in_type = in->getDataType().value();
-  const auto thresh_type = thresh->getDataType().value();
-  const auto value_type = value->getDataType().value();
-  if (isFloatingPointType(in_type)) {
-    TORCH_CHECK(
-        isFloatingPointType(thresh_type) && isFloatingPointType(value_type),
-        "All input DataType values should match the input type ",
-        in_type,
-        " vs ",
-        thresh_type,
-        " and ",
-        value_type);
-  } else if (isIntegralType(in_type)) {
-    TORCH_CHECK(
-        isIntegralType(thresh_type) && isIntegralType(value_type),
-        "All input DataType values should match the input ",
-        in_type,
-        " vs ",
-        thresh_type,
-        " and ",
-        value_type);
-  }
   TORCH_CHECK(
-      (thresh->getValType().value() == ValType::Scalar ||
-       thresh->getValType().value() == ValType::NamedScalar) &&
-          (value->getValType().value() == ValType::Scalar ||
-           value->getValType().value() == ValType::NamedScalar),
-      "For Threshold operation: Thresh and Value values should be Scalars.");
+      in->getDataType().value() == thresh->getDataType().value() &&
+          in->getDataType().value() == value->getDataType().value(),
+      "All input DataType values should match the input ",
+      in->getDataType().value());
+  TORCH_CHECK(
+      thresh->getValType().value() == ValType::Scalar &&
+          value->getValType().value() == ValType::Scalar,
+      "Thresh and Value values should be Scalars");
 
-  Val* out = newValLike(in, in_type);
+  Val* out = newOutputVal({in});
 
   new TernaryOp(TernaryOpType::Threshold, out, in, thresh, value);
   return out;
@@ -1043,36 +715,17 @@ TensorView* threshold(TensorView* in, Val* thresh, Val* value) {
 }
 
 Val* clamp(Val* in, Val* min_val, Val* max_val) {
-  const auto in_type = in->getDataType().value();
-  const auto min_type = min_val->getDataType().value();
-  const auto max_type = max_val->getDataType().value();
-  if (isFloatingPointType(in_type)) {
-    TORCH_CHECK(
-        isFloatingPointType(min_type) && isFloatingPointType(max_type),
-        "All input DataType values should match the input type ",
-        in_type,
-        " vs ",
-        min_type,
-        " and ",
-        max_type);
-  } else if (isIntegralType(in_type)) {
-    TORCH_CHECK(
-        isIntegralType(min_type) && isIntegralType(max_type),
-        "All input DataType values should match the input ",
-        in_type,
-        " vs ",
-        min_type,
-        " and ",
-        max_type);
-  }
   TORCH_CHECK(
-      (min_val->getValType().value() == ValType::Scalar ||
-       min_val->getValType().value() == ValType::NamedScalar) &&
-          (max_val->getValType().value() == ValType::Scalar ||
-           max_val->getValType().value() == ValType::NamedScalar),
-      "For Threshold operation: Thresh and Value values should be Scalars.");
+      in->getDataType().value() == min_val->getDataType().value() &&
+          in->getDataType().value() == max_val->getDataType().value(),
+      "All input DataType values should match the input ",
+      in->getDataType().value());
+  TORCH_CHECK(
+      min_val->getValType().value() == ValType::Scalar &&
+          max_val->getValType().value() == ValType::Scalar,
+      "Min and Max values should be Scalars");
 
-  Val* out = newValLike(in, in_type);
+  Val* out = newOutputVal({in});
 
   new TernaryOp(TernaryOpType::Clamp, out, in, min_val, max_val);
   return out;
@@ -1082,205 +735,6 @@ TensorView* clamp(TensorView* in, Val* min_val, Val* max_val) {
   return clamp(in->as<Val>(), min_val, max_val)->as<TensorView>();
 }
 
-// sum_to operator
-
-TensorView* sum_to(TensorView* in, const std::vector<Int*>& sum_to_size) {
-  const auto& root = TensorDomain::noReductions(in->getRootDomain());
-
-  TORCH_CHECK(
-      root.size() >= sum_to_size.size(),
-      "sum_to: Error trying to reduce",
-      in,
-      "into a shape of size",
-      sum_to_size.size());
-
-  // If no reduction is needed sum_to returns the input tv
-  TensorView* out = in;
-
-  const int64_t leading_dims = root.size() - sum_to_size.size();
-
-  // Generate reduction axes for leading dims
-  std::vector<int> reduce_dims(leading_dims);
-  std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
-
-  // Generate reduction axes for dims within sum_to_size
-  std::vector<bool> inner_red_dims(sum_to_size.size(), false);
-  bool reduction_within_shape = false;
-
-  // Reduce rest of the dims with keep_dim
-  for (int i = leading_dims; i < int(root.size()); i++) {
-    if (sum_to_size[i - leading_dims]->isOneInt() &&
-        !root[i]->extent()->isOneInt()) {
-      inner_red_dims[i - leading_dims] = true;
-      reduce_dims.push_back(i);
-      reduction_within_shape = true;
-    }
-  }
-
-  // Reduction step
-  if (!reduce_dims.empty()) {
-    out = sum(in, reduce_dims);
-  }
-
-  // Broadcast back reduced dims within shape
-  if (reduction_within_shape) {
-    out = broadcast(out, inner_red_dims);
-  }
-
-  return out;
-}
-
-TensorView* sum_to(TensorView* in, const std::vector<int64_t>& sum_to_size) {
-  const auto& root = TensorDomain::noReductions(in->getRootDomain());
-
-  TORCH_CHECK(
-      root.size() >= sum_to_size.size(),
-      "sum_to: Error trying to reduce",
-      in,
-      "into a shape of size",
-      sum_to_size.size());
-
-  // If no reduction is needed sum_to returns the input tv
-  TensorView* out = in;
-
-  const int64_t leading_dims = root.size() - sum_to_size.size();
-
-  // Generate reduction axes for leading dims
-  std::vector<int> reduce_dims(leading_dims);
-  std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
-
-  // Generate reduction axes for dims within sum_to_size
-  std::vector<bool> inner_red_dims(sum_to_size.size(), false);
-  bool reduction_within_shape = false;
-
-  // Reduce rest of the dims with keep_dim
-  for (int i = leading_dims; i < int(root.size()); i++) {
-    if (sum_to_size[i - leading_dims] == 1 && !root[i]->extent()->isOneInt()) {
-      inner_red_dims[i - leading_dims] = true;
-      reduce_dims.push_back(i);
-      reduction_within_shape = true;
-    }
-  }
-
-  // Reduction step
-  if (!reduce_dims.empty()) {
-    out = sum(in, reduce_dims);
-  }
-
-  // Broadcast back reduced dims within shape
-  if (reduction_within_shape) {
-    out = broadcast(out, inner_red_dims);
-  }
-
-  return out;
-}
-
-TensorView* shift(TensorView* inp, const std::vector<int>& offsets) {
-  TORCH_CHECK(
-      TensorDomain::noReductions(inp->getRootDomain()).size() == offsets.size(),
-      "Invalid shift offsets, number of entries in offsets expected to be ",
-      TensorDomain::noReductions(inp->getRootDomain()).size(),
-      " but received ",
-      offsets.size());
-
-  auto out = newValLike(inp, inp->getDataType().value())->as<TensorView>();
-  new ShiftOp(out, inp, offsets);
-  return out;
-}
-
-namespace {
-std::vector<Int*> convertToIntVector(const std::vector<int>& x) {
-  std::vector<Int*> converted;
-  std::transform(x.begin(), x.end(), std::back_inserter(converted), [](int x) {
-    return new Int(x);
-  });
-  return converted;
-}
-} // namespace
-
-TensorView* gather(
-    TensorView* inp,
-    const std::vector<int>& window_shape,
-    const std::vector<std::vector<int>>& pad_width) {
-  std::vector<Int*> window_shape_int = convertToIntVector(window_shape);
-  std::vector<std::vector<Int*>> pad_width_int;
-  std::transform(
-      pad_width.begin(),
-      pad_width.end(),
-      std::back_inserter(pad_width_int),
-      [](const std::vector<int>& x) { return convertToIntVector(x); });
-  return gather(inp, window_shape_int, pad_width_int);
-}
-
-TensorView* gather(
-    TensorView* inp,
-    const std::vector<Int*>& window_shape,
-    const std::vector<std::vector<Int*>>& pad_width) {
-  auto inp_dom = TensorDomain::noReductions(inp->getRootDomain());
-  const auto ndims = inp_dom.size();
-
-  TORCH_CHECK(
-      ndims == window_shape.size(),
-      "Invalid window shape: number of entries expected to be ",
-      ndims,
-      " but received ",
-      window_shape.size());
-
-  TORCH_CHECK(
-      ndims == pad_width.size(),
-      "Invalid pad width: number of entries expected to be ",
-      ndims,
-      " but received ",
-      pad_width.size());
-
-  std::for_each(pad_width.begin(), pad_width.end(), [](const auto& p) {
-    TORCH_CHECK(
-        p.size() == 2,
-        "Each entry of pad_width must have two non-negative integers.");
-  });
-
-  std::vector<IterDomain*> out_dom;
-  std::vector<IterDomain*> out_gather_dom;
-
-  for (size_t i = 0; i < ndims; ++i) {
-    const auto inp_axis = inp_dom[i];
-    const auto window_dim = window_shape[i];
-    const auto pad_left = pad_width[i][0];
-    const auto pad_right = pad_width[i][1];
-    TORCH_INTERNAL_ASSERT(inp_axis->start()->isZeroInt());
-    Val* out_axis_dim = nullptr;
-    if (window_dim->isConst() && pad_left->isConst() && pad_right->isConst()) {
-      const int64_t extent_adjustment =
-          -(-window_dim->value().value() + 1 + pad_left->value().value() +
-            pad_right->value().value());
-      out_axis_dim = extent_adjustment == 0
-          ? inp_axis->extent()
-          : sub(inp_axis->extent(), new Int(extent_adjustment));
-    } else {
-      out_axis_dim =
-          add(add(sub(inp_axis->extent(), window_dim), new Int(1)),
-              add(pad_left, pad_right));
-    }
-    out_dom.push_back(new IterDomain(
-        new Int(0),
-        out_axis_dim,
-        ParallelType::Serial,
-        inp_axis->getIterType()));
-    // create a new axis for the gathered domain
-    out_gather_dom.push_back(new IterDomain(
-        new Int(0), window_dim, ParallelType::Serial, IterType::Gather));
-  }
-
-  out_dom.insert(out_dom.end(), out_gather_dom.begin(), out_gather_dom.end());
-
-  auto out = new TensorView(
-      new TensorDomain(out_dom, std::vector<bool>(out_dom.size(), true)),
-      inp->getDataType().value());
-
-  new GatherOp(out, inp, window_shape, pad_width);
-  return out;
-}
-
 } // namespace cuda
 } // namespace fuser
 } // namespace jit
index c8df67a..c8f8177 100644 (file)
@@ -49,34 +49,7 @@ TORCH_CUDA_CU_API TensorView* reductionOp(
     BinaryOpType reduction_op_type,
     const std::vector<int>& axes,
     Val* init,
-    TensorView* v1,
-    bool keep_dim = false);
-
-//! Auxiliary Struct holding result of
-//! a single welford op in ternsorview
-class TORCH_CUDA_CU_API WelfordResult {
- public:
-  TensorView* avg;
-  TensorView* var_sum;
-  TensorView* n;
-
-  explicit WelfordResult(
-      TensorView* in_avg,
-      TensorView* in_var_sum,
-      TensorView* in_n);
-
-  WelfordResult rFactor(const std::vector<int>& axes);
-};
-
-//! Welford operator on specified axes. This is currently the only scan op with
-//! multiple outputs that is supported. May consider generalization if more scan
-//! ops are added.
-TORCH_CUDA_CU_API WelfordResult Welford(
-    TensorView* tv,
-    const std::vector<int>& axes,
-    TensorView* init_avg = nullptr,
-    TensorView* init_var = nullptr,
-    Int* init_N = new Int(0));
+    TensorView* v1);
 
 // UNARY OPERATIONS
 TORCH_CUDA_CU_API Val* neg(Val* v);
@@ -90,18 +63,6 @@ TORCH_CUDA_CU_API TensorView* broadcast(
     TensorView* inp,
     const std::vector<bool>& is_broadcast_dim);
 
-//! Transpose a tensor as specified by axis mappings.
-//!
-//! The transposition mapping is specified with a list of pairs from
-//! old to new positions. Positions are relative to the noReduction
-//! domain.
-//!
-//! \param inp Tensor to transpose
-//! \param old2new Pairs of mapping from old to new positions.
-TORCH_CUDA_CU_API TensorView* transpose(
-    TensorView* inp,
-    const std::unordered_map<int, int>& old2new);
-
 // BINARY OPERATIONS
 // add
 TORCH_CUDA_CU_API Val* add(Val* v1, Val* v2);
@@ -133,11 +94,6 @@ TORCH_CUDA_CU_API Val* lt(Val* v1, Val* v2);
 TORCH_CUDA_CU_API TensorView* lt(TensorView* v1, Val* v2);
 TORCH_CUDA_CU_API TensorView* lt(Val* v1, TensorView* v2);
 TORCH_CUDA_CU_API TensorView* lt(TensorView* v1, TensorView* v2);
-// gt
-TORCH_CUDA_CU_API Val* gt(Val* v1, Val* v2);
-TORCH_CUDA_CU_API TensorView* gt(TensorView* v1, Val* v2);
-TORCH_CUDA_CU_API TensorView* gt(Val* v1, TensorView* v2);
-TORCH_CUDA_CU_API TensorView* gt(TensorView* v1, TensorView* v2);
 // eq
 TORCH_CUDA_CU_API Val* eq(Val* v1, Val* v2);
 TORCH_CUDA_CU_API TensorView* eq(TensorView* v1, Val* v2);
@@ -157,18 +113,7 @@ TORCH_CUDA_CU_API TensorView* andOp(TensorView* v1, TensorView* v2);
 // REDUCTION OPERATIONS
 TORCH_CUDA_CU_API TensorView* sum(
     TensorView* v1,
-    const std::vector<int>& reduction_axes,
-    bool keep_dim = false);
-
-TORCH_CUDA_CU_API TensorView* max(
-    TensorView* v1,
-    const std::vector<int>& reduction_axes,
-    bool keep_dim = false);
-
-TORCH_CUDA_CU_API TensorView* min(
-    TensorView* v1,
-    const std::vector<int>& reduction_axes,
-    bool keep_dim = false);
+    const std::vector<int>& reduction_axes);
 
 // COMPOUND OPERATIONS
 // add_alpha
@@ -251,67 +196,6 @@ TORCH_CUDA_CU_API TensorView* threshold(
 TORCH_CUDA_CU_API Val* clamp(Val* in, Val* min_val, Val* max_val);
 TORCH_CUDA_CU_API TensorView* clamp(TensorView* in, Val* min_val, Val* max_val);
 
-//! Internal operator for supporting backward graphs
-//!
-//! example:
-//!   v1 = T1 [I0(10),I1(20),I2(30),I3(40)]
-//!   v2 = sum_to(v1,{30,1}) ------> v2 = T2[I2,R3 (keep_dim)]
-//!
-//!  This operator will return v1* directly if sizes of v1 root domain
-//!  is already the same as shape.
-//!
-//!  Name of sum_to is different from NV fuser naming,
-//!  this is to align with the operator name of at::sum_to.
-
-TORCH_CUDA_CU_API TensorView* sum_to(
-    TensorView* v1,
-    const std::vector<Int*>& sum_to_size);
-
-TORCH_CUDA_CU_API TensorView* sum_to(
-    TensorView* v1,
-    const std::vector<int64_t>& sum_to_size);
-
-//! Shift a tensor to a direction specified by offsets.
-//!
-//! Example:
-//!   t0: 2D tensor of size N by M
-//!   t1 = shift(t0, {1, -1});
-//!
-//!   then:
-//!     t1[i, j] = t0[i-1, j+1] for 1 <= i < N and 0 <= j < M-1.
-//!     t1[i, j] = 0, otherwise
-TORCH_CUDA_CU_API TensorView* shift(
-    TensorView* inp,
-    const std::vector<int>& offsets);
-
-//! Gather a window of nearby elements for each element.
-//!
-//! Each window of size window_shape is stored as a additional
-//! innermost domain, meaning that the number of dimensions of the
-//! output tensor doubles. The pad_width parameter specifies the
-//! padding width of each side of each axis.
-//!
-//! Example:
-//!   t0: 2D tensor of [N, M]
-//!   t1 = gather(t0, {1, 3}, {{0, 0}, {1, 1}});
-//!
-//!   then:
-//!     t1: [N, M, 1, 3]
-//!     t1[i, j, k, l] = The value at the window position of [k, l]
-//!                      for t0[i, j]
-TORCH_CUDA_CU_API TensorView* gather(
-    TensorView* inp,
-    const std::vector<int>& window_shape,
-    const std::vector<std::vector<int>>& pad_width);
-
-//! Gather a window of nearby elements for each element.
-//!
-//! Same as the another gather interface but with Int* parameters.
-TORCH_CUDA_CU_API TensorView* gather(
-    TensorView* inp,
-    const std::vector<Int*>& window_shape,
-    const std::vector<std::vector<Int*>>& pad_width);
-
 } // namespace cuda
 } // namespace fuser
 } // namespace jit
index 00d7ce0..d68bade 100644 (file)
@@ -1,13 +1,11 @@
 #include <c10/util/irange.h>
 #include <torch/csrc/jit/codegen/cuda/codegen.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
 #include <torch/csrc/jit/codegen/cuda/type.h>
 #include <torch/csrc/jit/codegen/cuda/utils.h>
 
-#include <array>
 #include <sstream>
 #include <vector>
 
@@ -19,12 +17,12 @@ namespace codegen {
 
 namespace {
 
-class CudaKernelGenerator : private kir::IrVisitor {
-  static constexpr const char* kTab = "  ";
+class CudaKernelGenerator : private OptInConstDispatch {
+  static constexpr char* kTab = "  ";
 
  public:
   static std::string generateKernelDefinition(
-      const kir::Kernel* kernel,
+      const Kernel* kernel,
       const std::string& kernel_name) {
     CudaKernelGenerator codegen(kernel);
     codegen.genDeclaration(kernel_name);
@@ -37,7 +35,7 @@ class CudaKernelGenerator : private kir::IrVisitor {
   }
 
  private:
-  explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) {}
+  explicit CudaKernelGenerator(const Kernel* kernel) : kernel_(kernel) {}
 
   // Generates the kernel function declaration
   void genDeclaration(const std::string& kernel_name) {
@@ -45,28 +43,41 @@ class CudaKernelGenerator : private kir::IrVisitor {
 
     code_ << "__global__ void " << kernel_name << "(";
 
-    std::vector<kir::Val*> params;
+    std::vector<Val*> params;
 
-    // Inputs & Outputs
+    // Inputs
     for (auto val : kernel_->inputs()) {
       params.push_back(val);
     }
+
+    // Outputs
     for (auto val : kernel_->outputs()) {
       params.push_back(val);
     }
 
+    // Global buffers
+    for (auto allocate : kernel_summary.global_allocations) {
+      params.push_back(allocate->buffer());
+    }
+
     // Generate parameter declarations
-    for (kir::Val* val : params) {
-      if (const auto tv = dynamic_cast<kir::TensorView*>(val)) {
-        code_ << "Tensor<" << val->dtype() << ", "
-              << TensorDomain::noReductions(
-                     tv->fuserTv()->getMaybeRFactorDomain())
-                     .size()
-              << "> " << varName(tv);
-      } else {
-        TORCH_INTERNAL_ASSERT(val->isScalar()); // NOLINT (LLVM bug 48525)
-        TORCH_INTERNAL_ASSERT(val->definition() == nullptr);
-        code_ << val->dtype() << " " << gen(val);
+    for (Val* val : params) {
+      switch (val->getValType().value()) {
+        case ValType::KirTensorView: {
+          // TODO(kir): review this
+          const auto tv = val->as<kir::TensorView>();
+          code_ << "Tensor<" << val->getDataType().value() << ", "
+                << TensorDomain::noReductions(
+                       tv->fuserTv()->getMaybeRFactorDomain())
+                       .size()
+                << "> " << gen(tv);
+          break;
+        }
+        case ValType::KirScalar:
+          code_ << val->getDataType().value() << " " << gen(val);
+          break;
+        default:
+          TORCH_CHECK(!"Unexpected parameter type");
       }
 
       if (val != params.back()) {
@@ -74,27 +85,9 @@ class CudaKernelGenerator : private kir::IrVisitor {
       }
     }
 
-    // Global buffers
-    for (auto allocate : kernel_summary.global_allocations) {
-      TORCH_INTERNAL_ASSERT(allocate->buffer()->isA<kir::TensorView>());
-      const auto tv = allocate->buffer()->as<kir::TensorView>();
-      const auto& maybe_rfactor_domain = tv->domain()->hasRFactor()
-          ? tv->domain()->rfactorDomain()
-          : tv->domain()->rootDomain();
-      const auto nDims = std::count_if(
-          maybe_rfactor_domain.begin(),
-          maybe_rfactor_domain.end(),
-          [](const kir::IterDomain* id) {
-            return !id->isReduction() &&
-                id->iterType() != IterType::BroadcastWithoutStride;
-          });
-      code_ << ", Tensor<" << tv->dtype() << ", " << nDims << "> "
-            << varName(tv);
-    }
-
     // Kernels generating random numbers take extra (seed, offset) arguments
     if (kernel_summary.is_stochastic) {
-      code_ << ", at::PhiloxCudaState philox_args";
+      code_ << ", unsigned long long seed, unsigned long long offset";
     }
 
     code_ << ") ";
@@ -107,11 +100,7 @@ class CudaKernelGenerator : private kir::IrVisitor {
     // Random number generator (optional)
     if (kernel_summary.is_stochastic) {
       indent() << "const int idx = blockIdx.x*blockDim.x + threadIdx.x;\n";
-      indent() << "auto offset = philox_args.captured_ ?\n";
-      indent()
-          << "  static_cast<uint64_t>(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :\n";
-      indent() << "  philox_args.offset_.val;\n";
-      indent() << "Philox rnd(philox_args.seed_, idx, offset);\n";
+      indent() << "Philox rnd(seed, idx, offset);\n";
     }
 
     // Do we have any dynamic shared memory buffers?
@@ -120,13 +109,10 @@ class CudaKernelGenerator : private kir::IrVisitor {
 
     // Do we have any reductions?
     const bool has_reductions = kernel_summary.has_block_reductions ||
-        kernel_summary.number_of_grid_reductions > 0;
-
-    const bool has_parallel_welford =
-        kernel_summary.has_block_welford || kernel_summary.has_grid_welford;
+        kernel_summary.has_grid_reductions;
 
     // Shared memory
-    if (has_dynamic_smem || has_reductions || has_parallel_welford) {
+    if (has_dynamic_smem || has_reductions) {
       indent() << "alignas("
 #ifndef __HIP_PLATFORM_HCC__
                << dataTypeSize(kernel_summary.largest_smem_data_type)
@@ -139,45 +125,20 @@ class CudaKernelGenerator : private kir::IrVisitor {
         indent() << "unsigned offset = 0;\n";
       }
 
-      if (has_reductions || has_parallel_welford) {
+      if (has_reductions) {
         indent() << "void* shared_mem = array;\n";
         if (has_dynamic_smem) {
-          if (has_parallel_welford) {
-            indent() << "offset += "
-                     << "((blockDim.x * blockDim.y * blockDim.z) * 3 * sizeof("
-                     << kernel_summary.largest_smem_data_type << "));\n";
-          } else {
-            indent() << "offset += "
-                     << "((blockDim.x * blockDim.y * blockDim.z) * sizeof("
-                     << kernel_summary.largest_smem_data_type << "));\n";
-          }
-        }
-
-        if (has_parallel_welford) {
-          // Unpack shared mem pointer
-          auto space_type = kernel_summary.largest_smem_data_type;
-          indent()
-              << "nvfuser_index_t block_size = blockDim.x*blockDim.y*blockDim.z;\n";
-          indent() << space_type << " *shared_mem_var = "
-                   << "static_cast<" << space_type << "*>("
-                   << "shared_mem);\n";
-          indent() << space_type
-                   << " *shared_mem_avg = shared_mem_var + block_size;\n";
-          indent() << space_type
-                   << " *shared_mem_n = shared_mem_avg + block_size;\n";
+          indent() << "offset += "
+                   << "((blockDim.x * blockDim.y * blockDim.z) * sizeof("
+                   << kernel_summary.largest_smem_data_type << "));\n";
         }
       }
     }
-
-    // Call the initialization function if using a custom block sync
-    if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
-      indent() << "block_sync::init();\n";
-    }
   }
 
   void genBody() {
     for (auto expr : kernel_->topLevelExprs()) {
-      expr->accept(this);
+      OptInConstDispatch::handle(expr);
     }
   }
 
@@ -204,97 +165,92 @@ class CudaKernelGenerator : private kir::IrVisitor {
     return code_;
   }
 
-  std::string gen(const kir::Node* node) {
+  std::string gen(const Statement* stmt) {
     std::stringstream tmp_code;
     std::swap(tmp_code, code_);
-    auto replacement = replacement_map_.find(node);
-    if (replacement != replacement_map_.end()) {
-      node = replacement->second;
-    }
-    node->accept(this);
+    handle(stmt);
     std::swap(tmp_code, code_);
     return tmp_code.str();
   }
 
-  // TODO(kir): consider automatic var naming
-  std::string varName(const kir::Val* val) {
-    std::string prefix = "";
-    if (val->isA<kir::TensorView>()) {
-      prefix = "T";
-    } else {
-      prefix = typePrefix(val->dtype());
-    }
-
-    std::stringstream value_name;
-    if (val->name() != kInvalidStmName) {
-      value_name << prefix << val->name();
-    } else {
-      value_name << "k" << prefix << val->id();
-    }
-    return value_name.str();
+  std::string gen(const kir::TensorView* tv) {
+    std::stringstream tv_name;
+    tv_name << "T" << tv->name();
+    return tv_name.str();
   }
 
-  std::string genInline(const kir::Node* node) {
+  std::string genInline(const Statement* stmt) {
     const bool saved_inline = print_inline_;
     print_inline_ = true;
-    auto result = gen(node);
+    const auto result = gen(stmt);
     print_inline_ = saved_inline;
     // NOLINTNEXTLINE(performance-no-automatic-move)
     return result;
   }
 
-  void visit(const kir::Predicate* node) final {
-    TORCH_INTERNAL_ASSERT(node->hasValue());
-    code_ << gen(node->value());
+  void handle(const Statement* node) final {
+    OptInConstDispatch::handle(node);
+  }
+
+  void handle(const Expr* node) final {
+    OptInConstDispatch::handle(node);
   }
 
-  void visit(const kir::Bool* node) final {
-    const auto def = node->definition();
+  void handle(const Val* node) final {
+    OptInConstDispatch::handle(node);
+  }
+
+  void handle(const kir::Bool* node) final {
+    const auto def = node->getOrigin();
     if (print_inline_ && def != nullptr) {
       code_ << "(" << gen(def) << ")";
-    } else if (node->isConst()) {
-      code_ << (*node->value() ? "true" : "false");
+    } else if (node->isSymbolic()) {
+      code_ << "b" << node->name();
     } else {
-      code_ << varName(node);
+      code_ << *node->value();
     }
   }
 
-  void visit(const kir::Double* node) final {
-    const auto def = node->definition();
+  void handle(const kir::Float* node) final {
+    const auto def = node->getOrigin();
     if (print_inline_ && def != nullptr) {
       code_ << "(" << gen(def) << ")";
-    } else if (node->isConst()) {
-      const int digits = std::numeric_limits<Double::ScalarType>::max_digits10;
-      code_ << std::setprecision(digits) << *node->value();
+    } else if (node->isSymbolic()) {
+      code_ << "f" << node->name();
     } else {
-      code_ << varName(node);
+      const int digits = std::numeric_limits<Float::ScalarType>::max_digits10;
+      code_ << "float(" << std::setprecision(digits) << *node->value() << ")";
     }
   }
 
-  void visit(const kir::Int* node) final {
-    const auto def = node->definition();
+  void handle(const kir::Half* node) final {
+    const auto def = node->getOrigin();
     if (print_inline_ && def != nullptr) {
       code_ << "(" << gen(def) << ")";
-    } else if (node->isConst()) {
-      code_ << *node->value();
+    } else if (node->isSymbolic()) {
+      code_ << "h" << node->name();
     } else {
-      code_ << varName(node);
+      code_ << "__float2half(" << *node->value() << ")";
     }
   }
 
-  void visit(const kir::NamedScalar* node) final {
-    // dim3 components are unsigned int. Cast to signed integer to
-    // support negative indexing
-    if (node->getParallelIndex().has_value() ||
-        node->getParallelDim().has_value()) {
-      code_ << "((nvfuser_index_t)" << node->name() << ")";
+  void handle(const kir::Int* node) final {
+    const auto def = node->getOrigin();
+    if (print_inline_ && def != nullptr) {
+      code_ << "(" << gen(def) << ")";
+    } else if (node->isSymbolic()) {
+      code_ << "i" << node->name();
     } else {
-      code_ << node->name();
+      code_ << *node->value();
     }
   }
 
-  void visit(const kir::TensorIndex* node) final {
-    code_ << varName(node->view()) << "[";
+  void handle(const kir::NamedScalar* node) final {
+    code_ << node->name();
+  }
+
+  void handle(const kir::TensorIndex* node) final {
+    code_ << gen(node->view()) << "[";
 
     bool first = true;
     for (auto* ind : node->indices()) {
@@ -314,96 +270,19 @@ class CudaKernelGenerator : private kir::IrVisitor {
     code_ << "]";
   }
 
-  void visit(const kir::IterDomain* node) final {
+  void handle(const kir::IterDomain* node) final {
     TORCH_INTERNAL_ASSERT(!"Unreachable");
   }
 
-  void visit(const kir::TensorDomain* node) final {
+  void handle(const kir::TensorDomain* node) final {
     TORCH_INTERNAL_ASSERT(!"Unreachable");
   }
 
-  void visit(const kir::TensorView* tv) final {
+  void handle(const kir::TensorView* node) final {
     TORCH_INTERNAL_ASSERT(!"Unreachable");
   }
 
-  void visit(const kir::UnaryOp* node) final {
-    bool is_vector_op = false;
-    size_t vector_word_size = 1;
-
-    if (vectorize_scope_ && node->out()->isA<kir::TensorIndex>()) {
-      auto ti = node->out()->as<kir::TensorIndex>();
-
-      bool vectorize_op = false;
-      bool misaligned_op = false;
-
-      for (auto id : ti->view()->fuserTv()->domain()->domain()) {
-        if (!isParallelTypeVectorize(id->getParallelType())) {
-          continue;
-        }
-
-        ExpressionEvaluator expr_eval(id->fusion());
-        auto vector_size_optional = expr_eval.evaluate(id->extent());
-
-        TORCH_INTERNAL_ASSERT(
-            vector_size_optional.has_value(),
-            "Could not evaluate constant value bound to vectorized dim.");
-
-        vector_word_size = vector_size_optional.value();
-
-        vectorize_op = id->getParallelType() == ParallelType::Vectorize;
-        misaligned_op =
-            id->getParallelType() == ParallelType::MisalignedVectorize;
-        break;
-      }
-
-      if (vectorize_op) {
-        TORCH_INTERNAL_ASSERT(
-            node->operation() == UnaryOpType::Set,
-            "Cannot vectorize operations that are not sets. ",
-            "Use cache_before and cache_after to store/load with vectorized reads into buffers.");
-        is_vector_op = true;
-      }
-
-      if (misaligned_op) {
-        is_vector_op = (node->operation() == UnaryOpType::Set);
-      }
-
-      if (is_vector_op && !node->in()->isScalar()) {
-        TORCH_INTERNAL_ASSERT(
-            node->out()->dtype() == node->in()->dtype(),
-            "Vectorized store/load requires input and output datatypes match.");
-      }
-    }
-
-    if (is_vector_op) {
-      if (node->in()->isScalar()) {
-        indent() << "reinterpret_cast<"
-                 << "Array<" << node->out()->dtype() << ", " << vector_word_size
-                 << ">*>"
-                 << "(&" << gen(node->out()) << ")->set(" << gen(node->in())
-                 << ");\n";
-      } else {
-        indent() << "*reinterpret_cast<"
-                 << "Array<" << node->out()->dtype() << ", " << vector_word_size
-                 << ">*>"
-                 << "(&" << gen(node->out()) << ")"
-                 << " = *reinterpret_cast<"
-                 << "Array<" << node->in()->dtype() << ", " << vector_word_size
-                 << ">*>"
-                 << "(&" << gen(node->in()) << ");\n";
-      }
-      return;
-    }
-
-    if (node->out()->isA<kir::NamedScalar>()) {
-      const auto op_type = node->operation();
-      if (auto op = inline_op_str(op_type)) {
-        indent() << gen(node->out()) << " = " << *op << genInline(node->in())
-                 << ";\n";
-      }
-      return;
-    }
-
+  void handle(const kir::UnaryOp* node) final {
     if (!print_inline_) {
       indent() << gen(node->out());
       if (!node->out()->isScalar() && !node->in()->isScalar()) {
@@ -413,35 +292,20 @@ class CudaKernelGenerator : private kir::IrVisitor {
       code_ << " = ";
     }
 
-    const auto op_type = node->operation();
-    if (auto op = inline_op_str(op_type)) {
-      if (alsoBooleanOperator(op_type) &&
-          node->out()->dtype() == DataType::Bool) {
-        code_ << stringifyBooleanOp(op_type) << gen(node->in());
-      } else {
-        code_ << *op << gen(node->in());
-      }
+    if (auto op = inline_op_str(node->getUnaryOpType())) {
+      code_ << *op << gen(node->in());
     } else {
-      if (op_type == UnaryOpType::Cast) {
-        const auto cast_str =
-            cast_func_str({node->in()->dtype(), node->out()->dtype()});
-        TORCH_INTERNAL_ASSERT(
-            cast_str.has_value(),
-            "Invalid cast. Input type: ",
-            node->in()->dtype(),
-            ", output type: ",
-            node->out()->dtype());
+      if (node->getUnaryOpType() == UnaryOpType::Cast) {
+        const auto cast_str = cast_func_str(
+            {node->in()->getDataType().value(),
+             node->out()->getDataType().value()});
         code_ << cast_str.value();
       } else {
-        code_ << op_type;
-        if (needFloatSuffix(op_type) &&
-            node->out()->dtype() == DataType::Float) {
-          code_ << "f";
-        }
+        code_ << node->getUnaryOpType();
       }
 
       code_ << "(";
-      if (op_type == UnaryOpType::RandLike) {
+      if (node->getUnaryOpType() == UnaryOpType::RandLike) {
         code_ << "rnd";
       } else {
         code_ << gen(node->in());
@@ -456,77 +320,28 @@ class CudaKernelGenerator : private kir::IrVisitor {
 
   std::string genBinaryOp(
       BinaryOpType op_type,
-      kir::Val* out,
       const std::string& lhs,
       const std::string& rhs) {
     std::stringstream expr;
     if (auto op = inline_op_str(op_type)) {
-      expr << lhs << " ";
-      if (alsoBooleanOperator(op_type) && out->dtype() == DataType::Bool) {
-        expr << stringifyBooleanOp(op_type);
-      } else {
-        expr << *op;
-      }
-      expr << " " << rhs;
+      expr << lhs << " " << *op << " " << rhs;
     } else {
-      expr << op_type;
-      if (needFloatSuffix(op_type) && out->dtype() == DataType::Float) {
-        expr << "f";
-      }
-      expr << "(" << lhs << ", " << rhs << ")";
+      expr << op_type << "(" << lhs << ", " << rhs << ")";
     }
     return expr.str();
   }
 
-  // If one argument is a tensorview and the other is a scalar, make sure we
-  // cast the scalar to the tensorview type
-  std::string scalarCast(kir::Val* lhs, kir::Val* rhs) {
-    // If neither are scalars return
-    if (!((lhs->isScalar() || rhs->isScalar()) &&
-          (lhs->isA<kir::TensorIndex>() || rhs->isA<kir::TensorIndex>()))) {
-      return "";
-    }
-
-    // Looking for mixed tensorview scalar options where types don't match
-    // but are either both floating or both int types. We should cast
-    // scalar to tensorview type in these instances.
-    auto lhs_t = lhs->dtype();
-    auto rhs_t = rhs->dtype();
-
-    // If same type, don't cast anything
-    if (lhs_t == rhs_t) {
-      return "";
-    }
-
-    // Don't do anything when dealing with bools
-    if (lhs_t == DataType::Bool || rhs_t == DataType::Bool) {
-      return "";
-    }
-
-    // Mixing floating and int combination
-    if ((isFloatingPointType(lhs_t) != isFloatingPointType(rhs_t)) ||
-        (isIntegralType(lhs_t) != isIntegralType(rhs_t))) {
-      return "";
-    }
-
-    std::stringstream cast;
-    cast << "(" << (lhs->isA<kir::TensorIndex>() ? lhs_t : rhs_t) << ") ";
-    return cast.str();
-  }
-
-  void visit(const kir::BinaryOp* node) final {
-    const auto op_type = node->operation();
+  void handle(const kir::BinaryOp* node) final {
+    const auto op_type = node->getBinaryOpType();
     if (print_inline_) {
       // Inline expression: `lhs op rhs`
-      code_ << genBinaryOp(
-          op_type, node->out(), gen(node->lhs()), gen(node->rhs()));
+      code_ << genBinaryOp(op_type, gen(node->lhs()), gen(node->rhs()));
     } else {
       indent() << gen(node->out());
       if (node->out()->isScalar()) {
         // Single line: `out = lhs op rhs;`
         code_ << " = "
-              << genBinaryOp(
-                     op_type, node->out(), gen(node->lhs()), gen(node->rhs()));
+              << genBinaryOp(op_type, gen(node->lhs()), gen(node->rhs()));
       } else {
         // Split TensorView expressions across multiple lines:
         //
@@ -534,39 +349,21 @@ class CudaKernelGenerator : private kir::IrVisitor {
         //    =  lhs
         //    op rhs;
         //
-
-        auto cast = scalarCast(node->lhs(), node->rhs());
         if (auto op = inline_op_str(op_type)) {
           code_ << "\n";
-          indent() << kTab << "= " << (node->lhs()->isScalar() ? cast : "")
-                   << gen(node->lhs()) << "\n";
-          indent() << kTab;
-          if (alsoBooleanOperator(op_type) &&
-              node->out()->dtype() == DataType::Bool) {
-            code_ << stringifyBooleanOp(op_type);
-          } else {
-            code_ << *op;
-          }
-          code_ << " " << (node->rhs()->isScalar() ? cast : "")
-                << gen(node->rhs());
+          indent() << kTab << "= " << gen(node->lhs()) << "\n";
+          indent() << kTab << *op << " " << gen(node->rhs());
         } else {
-          if (integer_op_str(op_type) && isIntegralType(node->out()->dtype())) {
-            auto int_op = integer_op_str(op_type);
-            code_ << " = " << *int_op << "(\n";
-          } else {
-            code_ << " = " << op_type << "(\n";
-          }
-          indent() << kTab << (node->lhs()->isScalar() ? cast : "")
-                   << gen(node->lhs()) << ",\n";
-          indent() << kTab << (node->rhs()->isScalar() ? cast : "")
-                   << gen(node->rhs()) << ")";
+          code_ << " = " << op_type << "(\n";
+          indent() << kTab << gen(node->lhs()) << ",\n";
+          indent() << kTab << gen(node->rhs()) << ")";
         }
       }
       code_ << ";\n";
     }
   }
 
-  void visit(const kir::TernaryOp* node) final {
+  void handle(const kir::TernaryOp* node) final {
     if (!print_inline_) {
       indent() << gen(node->out());
       if (!node->out()->isScalar()) {
@@ -576,39 +373,25 @@ class CudaKernelGenerator : private kir::IrVisitor {
       code_ << " = ";
     }
 
-    code_ << node->operation() << "(" << gen(node->in1()) << ", ";
-
-    // Make sure the two operands of where has the same
-    // type. Note that compiling "where(0.0f, 0.0)" fails because of
-    // the overloading ambiguity.
-    if (node->operation() == TernaryOpType::Where) {
-      auto cast = scalarCast(node->in2(), node->in3());
-      code_ << (node->in2()->isScalar() ? cast : "") << gen(node->in2()) << ", "
-            << (node->in3()->isScalar() ? cast : "") << gen(node->in3()) << ")";
-    } else {
-      code_ << gen(node->in2()) << ", " << gen(node->in3()) << ")";
-    }
+    code_ << node->getTernaryOpType() << "(" << gen(node->in1()) << ", "
+          << gen(node->in2()) << ", " << gen(node->in3()) << ")";
 
     if (!print_inline_) {
       code_ << ";\n";
     }
   }
 
-  std::string genReductionOp(BinaryOpType op_type, kir::Val* out) {
+  std::string genReductionOp(BinaryOpType op_type, DataType data_type) {
     std::stringstream lambda;
-    DataType data_type = out->dtype();
     lambda << "[](" << data_type << " &a, " << data_type << " b) "
-           << "{ a = " << genBinaryOp(op_type, out, "a", "b") << "; }";
+           << "{ a = " << genBinaryOp(op_type, "a", "b") << "; }";
     return lambda.str();
   }
 
-  void visit(const kir::BroadcastOp* node) final {
-    TORCH_INTERNAL_ASSERT(node->out()->isA<kir::TensorIndex>());
-    const auto tensor_index = node->out()->as<kir::TensorIndex>();
-
-    const ParallelTypeBitmap domains =
-        kernel_->predicateMap().getParallelBroadcastDomains(
-            tensor_index->view()->fuserTv());
+  void handle(const kir::BroadcastOp* node) final {
+    const ir_utils::ParallelTypeBitmap domains =
+        ir_utils::getParallelBroadcastDomains(
+            node->out(), kernel_->predicateMap());
 
     const bool thread_x = domains.get(ParallelType::TIDx);
     const bool thread_y = domains.get(ParallelType::TIDy);
@@ -625,24 +408,21 @@ class CudaKernelGenerator : private kir::IrVisitor {
         "Parallel broadcast across blocks not supported");
 
     if (block_broadcast_needed) {
-      const auto data_type = node->out()->dtype();
+      const auto data_type = node->out()->getDataType().value();
       indent() << "broadcast::blockBroadcast<" << (thread_x ? "true" : "false")
                << ", " << (thread_y ? "true" : "false") << ", "
                << (thread_z ? "true" : "false") << ">(\n";
       indent() << kTab << gen(node->out()) << ",\n";
       indent() << kTab << gen(node->in()) << ",\n";
-      indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
-      TORCH_INTERNAL_ASSERT(
-          node->predicate() != nullptr && node->predicate()->hasValue());
-      indent() << kTab << genInline(node->predicate()) << ");\n";
+      indent() << kTab << "static_cast<" << data_type << "*>(shared_mem));\n";
     } else {
       indent() << gen(node->out()) << "\n";
       indent() << kTab << " = " << gen(node->in()) << ";\n";
     }
   }
 
-  void visit(const kir::ReductionOp* node) final {
-    TORCH_INTERNAL_ASSERT(node->out()->isA<kir::TensorIndex>());
+  void handle(const kir::ReductionOp* node) final {
+    TORCH_CHECK(node->out()->getValType() == ValType::TensorIndex);
 
     const auto out = node->out()->as<kir::TensorIndex>();
     const auto domain = out->view()->domain();
@@ -652,9 +432,9 @@ class CudaKernelGenerator : private kir::IrVisitor {
 
     if (!has_block_reduce && !has_grid_reduce) {
       const auto gen_out = gen(out);
-      const auto op_type = node->operation();
+      const auto op_type = node->getReductionOpType();
       indent() << gen_out << " = "
-               << genBinaryOp(op_type, out, gen_out, gen(node->in())) << ";\n";
+               << genBinaryOp(op_type, gen_out, gen(node->in())) << ";\n";
       return;
     }
 
@@ -663,455 +443,184 @@ class CudaKernelGenerator : private kir::IrVisitor {
     const bool tidy = par_domains.find(ParallelType::TIDy) != par_domains.end();
     const bool tidz = par_domains.find(ParallelType::TIDz) != par_domains.end();
 
-    const auto data_type = node->out()->dtype();
-    const auto op_type = node->operation();
+    const auto data_type = node->out()->getDataType().value();
+    const auto op_type = node->getReductionOpType();
 
     if (has_block_reduce) {
       if (has_grid_reduce) {
         indent() << data_type << " "
-                 << "block_result_" << block_reduce_name_ << "="
-                 << gen(node->init()) << ";\n";
+                 << "block_result"
+                 << ";\n";
       }
       indent() << "blockReduce<" << (tidx ? "true" : "false") << ", "
                << (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false")
                << ">(\n";
       if (has_grid_reduce) {
-        indent() << kTab << "block_result_" << block_reduce_name_ << ",\n";
+        indent() << kTab << "block_result"
+                 << ",\n";
       } else {
         indent() << kTab << gen(node->out()) << ",\n";
       }
       indent() << kTab << gen(node->in()) << ",\n";
-      indent() << kTab << genReductionOp(op_type, node->out()) << ",\n";
+      indent() << kTab << genReductionOp(op_type, data_type) << ",\n";
       indent() << kTab << "threadIdx,\n";
       indent() << kTab << "blockDim,\n";
       indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
-      TORCH_INTERNAL_ASSERT(
-          node->predicate() != nullptr && node->predicate()->hasValue());
-      auto read_pred = genInline(node->predicate());
-      indent() << kTab << read_pred << ",\n";
-      // Pass the write predicate if available and different from the
-      // default predicate. The blockReduce runtime function uses the
-      // default predicate for both read and write when only the
-      // default one is given.
-      if (node->writePredicate() != nullptr) {
-        TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue());
-        auto write_pred = genInline(node->writePredicate());
-        indent() << kTab << write_pred << ",\n";
-      }
-      indent() << kTab << data_type << "(" << genInline(node->init())
-               << "));\n";
-    }
-  }
-
-  void visit(const kir::WelfordOp* node) final {
-    TORCH_INTERNAL_ASSERT(node->out()->isA<kir::TensorIndex>());
-
-    const auto out = node->out()->as<kir::TensorIndex>();
-    const auto domain = out->view()->domain();
-
-    const auto out_var = node->outVar();
-    const auto out_avg = node->outAvg();
-    const auto out_N = node->outN();
-
-    const auto in_var = node->inVar();
-    const auto in_avg = node->inAvg();
-    const auto in_N = node->inN();
-
-    const bool has_block_reduce = domain->hasBlockReduction();
-    const bool has_grid_reduce = domain->hasGridReduction();
-
-    // Serial WelfordOp generation
-    if (!has_block_reduce && !has_grid_reduce) {
-      indent() << "welfordCombine ("
-               << "\n";
-      indent() << " " << gen(out_avg) << ",\n";
-      indent() << " " << gen(out_var) << ",\n";
-      indent() << " " << gen(out_N) << ",\n";
-      indent() << " " << gen(in_avg) << ",\n";
-      if (in_var) {
-        indent() << " " << gen(in_var) << ",\n";
-      } else {
-        indent() << " (" << in_avg->dtype() << ") 0"
-                 << ",\n";
-      }
-      indent() << " (" << out_N->dtype() << ")" << gen(in_N) << ");\n";
-      return;
-    }
-
-    const auto par_domains = node->getParallelReductionDomains();
-    const bool tidx = par_domains.find(ParallelType::TIDx) != par_domains.end();
-    const bool tidy = par_domains.find(ParallelType::TIDy) != par_domains.end();
-    const bool tidz = par_domains.find(ParallelType::TIDz) != par_domains.end();
-
-    const auto data_type = node->out()->dtype();
-
-    if (has_block_reduce) {
-      if (has_grid_reduce) {
-        // allocate block result
-        indent() << data_type << " "
-                 << "block_result_avg_" << block_reduce_name_ << " = "
-                 << gen(node->initAvg()) << ";\n";
-        indent() << data_type << " "
-                 << "block_result_var_" << block_reduce_name_ << " = "
-                 << gen(node->initVar()) << ";\n";
-        indent() << DataType::Int << " "
-                 << "block_result_n_" << block_reduce_name_ << " = "
-                 << gen(node->initN()) << ";\n";
-      }
-      indent() << "blockWelford<" << (tidx ? "true" : "false") << ", "
-               << (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false")
-               << ">(\n";
-      if (has_grid_reduce) {
-        indent() << kTab << "block_result_avg_" << block_reduce_name_ << ",\n"
-                 << kTab << "block_result_var_" << block_reduce_name_ << ",\n"
-                 << kTab << "block_result_n_" << block_reduce_name_ << ",\n";
+      if (node->pred() == nullptr) {
+        indent() << kTab << "true,\n";
       } else {
-        indent() << kTab << gen(node->outAvg()) << ",\n";
-        indent() << kTab << gen(node->outVar()) << ",\n";
-        indent() << kTab << gen(node->outN()) << ",\n";
-      }
-      indent() << " " << gen(in_avg) << ",\n";
-      if (in_var) {
-        indent() << " " << gen(in_var) << ",\n";
-      } else {
-        indent() << " (" << in_avg->dtype() << ") 0"
-                 << ",\n";
+        indent() << kTab << genInline(node->pred()) << ",\n";
       }
-      indent() << out_N->dtype() << "(" << gen(in_N) << "),\n";
-      indent() << kTab << "threadIdx,\n";
-      indent() << kTab << "blockDim,\n";
-      indent() << kTab << "reinterpret_cast<" << data_type
-               << "*>(shared_mem_avg),\n";
-      indent() << kTab << "reinterpret_cast<" << data_type
-               << "*>(shared_mem_var),\n";
-      indent() << kTab << "reinterpret_cast<" << DataType::Int
-               << "*>(shared_mem_n),\n";
-      TORCH_INTERNAL_ASSERT(node->predicate() != nullptr);
-      TORCH_INTERNAL_ASSERT(
-          node->predicate() != nullptr && node->predicate()->hasValue());
-      auto read_pred = genInline(node->predicate());
-      indent() << kTab << read_pred << ",\n";
-      if (node->writePredicate() != nullptr) {
-        TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue());
-        auto write_pred = genInline(node->writePredicate());
-        indent() << kTab << write_pred << ",\n";
-      }
-      indent() << kTab << data_type << "(0));\n";
+      indent() << kTab << genInline(node->init()) << ");\n";
     }
   }
 
-  // Support ReductionOp and WelfordOp
-  template <typename REDUCTION_OP>
-  std::string generateGridReduceTemplateFlags(
-      const REDUCTION_OP* rop,
-      const ParallelTypeBitmap& thread_pred) {
-    const auto par_domains = rop->getParallelReductionDomains();
-    const std::array<ParallelType, 6> ptypes{
-        ParallelType::BIDx,
-        ParallelType::BIDy,
-        ParallelType::BIDz,
-        ParallelType::TIDx,
-        ParallelType::TIDy,
-        ParallelType::TIDz};
-    std::stringstream flags;
-    for (const ParallelType pt : ptypes) {
-      const bool parallel_reduction = par_domains.find(pt) != par_domains.end();
-      const bool pred = thread_pred.get(pt);
-      TORCH_INTERNAL_ASSERT(
-          !(parallel_reduction && pred), "Cannot reduce predicated axis: ", pt);
-      bool flag = false;
-      // Currently assumed that no dimensions parallelized with blocks
-      // are predicated. This assumption may be lifted, but
-      // gridReduction would need some changes.
-      if (isParallelTypeBlockDim(pt)) {
-        TORCH_INTERNAL_ASSERT(
-            !pred, "Predication on block dimensions not allowed: ", pt);
-        flag = parallel_reduction;
-      } else {
-        flag = !pred && !parallel_reduction;
-      }
-      if (pt != ptypes[0]) {
-        flags << ", ";
-      }
-      flags << (flag ? "true" : "false");
-    }
-    return flags.str();
-  }
-
-  void visit(const kir::GridReduction* node) final {
+  void handle(const kir::GridReduction* node) final {
     const auto rop = node->reduction_op();
-    TORCH_INTERNAL_ASSERT(rop->out()->isA<kir::TensorIndex>());
+    TORCH_INTERNAL_ASSERT(rop->out()->getValType() == ValType::TensorIndex);
 
     const auto out = rop->out()->as<kir::TensorIndex>();
     const auto domain = out->view()->domain();
     TORCH_INTERNAL_ASSERT(domain->hasGridReduction());
 
-    const auto data_type = rop->out()->dtype();
-    const auto op_type = rop->operation();
+    const auto par_domains = rop->getParallelReductionDomains();
+    const bool tidx = par_domains.find(ParallelType::TIDx) != par_domains.end();
+    const bool tidy = par_domains.find(ParallelType::TIDy) != par_domains.end();
+    const bool tidz = par_domains.find(ParallelType::TIDz) != par_domains.end();
+    const bool bidx = par_domains.find(ParallelType::BIDx) != par_domains.end();
+    const bool bidy = par_domains.find(ParallelType::BIDy) != par_domains.end();
+    const bool bidz = par_domains.find(ParallelType::BIDz) != par_domains.end();
+
+    const auto data_type = rop->out()->getDataType().value();
+    const auto op_type = rop->getReductionOpType();
 
     TORCH_INTERNAL_ASSERT(
-        node->reduction_buffer()->buffer()->isA<kir::TensorView>());
+        node->reduction_buffer()->buffer()->getValType().value() ==
+        ValType::KirTensorView);
     TORCH_INTERNAL_ASSERT(
-        node->sync_buffer()->buffer()->isA<kir::TensorView>());
+        node->sync_buffer()->buffer()->getValType().value() ==
+        ValType::KirTensorView);
     const auto work_buffer =
         node->reduction_buffer()->buffer()->as<kir::TensorView>();
     const auto sync_buffer =
         node->sync_buffer()->buffer()->as<kir::TensorView>();
 
-    const std::string flags_str =
-        generateGridReduceTemplateFlags(rop, node->threadPredicate());
-
     // Since block-level reduction is already done, those dimensions
     // with tidx/y/z being true do not participate in the grid reduction.
     indent() << kir::GridReduction::getPredicateFlagName(out->view()) << " = "
-             << "reduction::gridReduce<" << flags_str << ">(\n";
+             << "reduction::gridReduce<" << (bidx ? "true" : "false") << ", "
+             << (bidy ? "true" : "false") << ", " << (bidz ? "true" : "false")
+             << ", " << (!tidx ? "true" : "false") << ", "
+             << (!tidy ? "true" : "false") << ", " << (!tidz ? "true" : "false")
+             << ">(\n";
     indent() << kTab << gen(rop->out()) << ",\n";
     if (domain->hasBlockReduction()) {
-      indent() << kTab << "block_result_" << block_reduce_name_ << ",\n";
-      block_reduce_name_++;
+      indent() << kTab << "block_result"
+               << ",\n";
     } else {
       indent() << kTab << gen(rop->in()) << ",\n";
     }
-    indent() << kTab << genReductionOp(op_type, out) << ",\n";
-    indent() << kTab << "&" << varName(work_buffer) << "[0],\n";
-    indent() << kTab << varName(sync_buffer) << ",\n";
+    indent() << kTab << genReductionOp(op_type, data_type) << ",\n";
+    indent() << kTab << "&" << gen(work_buffer) << "[0],\n";
+    indent() << kTab << gen(sync_buffer) << ",\n";
     indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
-    TORCH_INTERNAL_ASSERT(
-        node->predicate() != nullptr && node->predicate()->hasValue());
-    auto read_pred = genInline(node->predicate());
-    indent() << kTab << read_pred << ",\n";
-    if (node->writePredicate() != nullptr) {
-      TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue());
-      auto write_pred = genInline(node->writePredicate());
-      indent() << kTab << write_pred << ",\n";
+    if (node->pred() == nullptr) {
+      indent() << kTab << "true,\n";
     } else {
-      indent() << kTab << read_pred << ",\n";
+      indent() << kTab << genInline(node->pred()) << ",\n";
     }
-    indent() << kTab << data_type << "("
-             << genInline(node->reduction_op()->init()) << "));\n";
+    indent() << kTab << genInline(node->reduction_op()->init()) << ");\n";
   }
 
-  void visit(const kir::GridWelford* node) final {
-    const auto wop = node->welford_op();
-    TORCH_INTERNAL_ASSERT(wop->outAvg()->isA<kir::TensorIndex>());
-
-    const auto out = wop->out()->as<kir::TensorIndex>();
-    const auto domain = out->view()->domain();
-    TORCH_INTERNAL_ASSERT(domain->hasGridReduction());
-
-    const auto data_type = out->dtype();
-
-    TORCH_INTERNAL_ASSERT(node->var_buffer()->buffer()->isA<kir::TensorView>());
-    TORCH_INTERNAL_ASSERT(
-        node->sync_buffer()->buffer()->isA<kir::TensorView>());
-
-    const auto avg_buffer = node->avg_buffer()->buffer()->as<kir::TensorView>();
-    const auto var_buffer = node->var_buffer()->buffer()->as<kir::TensorView>();
-    const auto n_buffer = node->N_buffer()->buffer()->as<kir::TensorView>();
-    const auto sync_buffer =
-        node->sync_buffer()->buffer()->as<kir::TensorView>();
-
-    const std::string flags_str =
-        generateGridReduceTemplateFlags(wop, node->threadPredicate());
-
-    // Since block-level reduction is already done, those dimensions
-    // with tidx/y/z being true do not participate in the grid reduction.
-    indent() << kir::GridWelford::getPredicateFlagName(out->view()) << " = "
-             << "welford::gridWelford<" << flags_str << ">(\n";
-    indent() << kTab << gen(wop->outAvg()) << ",\n"
-             << kTab << gen(wop->outVar()) << ",\n"
-             << kTab << gen(wop->outN()) << ",\n";
-    if (domain->hasBlockReduction()) {
-      indent() << kTab << "block_result_avg_" << block_reduce_name_ << ",\n"
-               << kTab << "block_result_var_" << block_reduce_name_ << ",\n"
-               << kTab << "block_result_n_" << block_reduce_name_ << ",\n";
-      block_reduce_name_++;
-    } else {
-      indent() << kTab << gen(wop->inAvg()) << ",\n";
-      if (wop->inVar() == nullptr) {
-        indent() << kTab << "(" << data_type << ") 0,\n";
-      } else {
-        indent() << kTab << gen(wop->inVar()) << ",\n";
-      }
-      indent() << kTab << "(" << wop->outN()->dtype() << ")" << gen(wop->inN())
-               << ",\n";
-    }
-    indent() << kTab << "&" << varName(avg_buffer) << "[0],\n";
-    indent() << kTab << "&" << varName(var_buffer) << "[0],\n";
-    indent() << kTab << "&" << varName(n_buffer) << "[0],\n";
-    indent() << kTab << varName(sync_buffer) << ",\n";
-    indent() << kTab << "reinterpret_cast<" << data_type
-             << "*>(shared_mem_avg),\n";
-    indent() << kTab << "reinterpret_cast<" << data_type
-             << "*>(shared_mem_var),\n";
-    indent() << kTab << "reinterpret_cast<" << wop->outN()->dtype()
-             << "*>(shared_mem_n),\n";
-    TORCH_INTERNAL_ASSERT(
-        node->predicate() != nullptr && node->predicate()->hasValue());
-    auto read_pred = genInline(node->predicate());
-    indent() << kTab << read_pred << ",\n";
-    if (node->writePredicate() != nullptr) {
-      TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue());
-      auto write_pred = genInline(node->writePredicate());
-      indent() << kTab << write_pred << ",\n";
-    } else {
-      indent() << kTab << read_pred << ",\n";
-    }
-    // TODO : init value support or remove.
-    indent() << kTab << data_type << "(0));\n";
-  }
-
-  void handleScope(const kir::Scope& scope) {
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Woverloaded-virtual"
+  // TODO(Kir): fix me
+  void handle(const kir::Scope& scope) {
     for (auto expr : scope.exprs()) {
-      expr->accept(this);
+      handle(expr);
     }
   }
+#pragma clang diagnostic pop
 
-  void visit(const kir::ForLoop* node) final {
+  void handle(const kir::ForLoop* node) final {
     // TODO(kir): handle this during lowering
-    if (node->iter_domain()->isBroadcast()) {
-      handleScope(node->body());
-      return;
-    } else if (node->vectorize()) {
-      vectorize_scope_ = node->vectorize();
-      handleScope(node->body());
-      vectorize_scope_ = false;
-      return;
-    }
-
-    // By default, a parallelized loop would look like:
-    //
-    //   for (int x = threadIdx.x; x < stop; x += blockDim.x) {
-    //     do_some_comp(x);
-    //   }
-    //
-    // When stop is guaranteed to be smaller or equal to the number of
-    // threads, the for-loop is not necessary. In the above case, we
-    // would just generate the loop body without the for clause but
-    // references to the loop index replaced by the loop start value.
-    //
-    // When the loop end is the same as the IterDomain extent, the
-    // assumption can be safely made. This is more conservative than
-    // necessary since the loop stop value just needs to be <= the
-    // IterDomain extent. However, at this point, this conservative
-    // analysis seems sufficient.
-    if (node->stop() == node->iter_domain()->extent() &&
-        node->iter_domain()->isThread()) {
-      // Register a replacement of references to the loop index with
-      // the loop start value.
-      replacement_map_.insert({node->index(), node->start()});
-      handleScope(node->body());
-      replacement_map_.erase(node->index());
-      return;
-    }
-
-    if (node->start()->isZeroInt() && node->stop()->isOneInt()) {
-      indent() << "constexpr "
-               << "nvfuser_index_t"
-               << " " << gen(node->index()) << " = 0;\n";
-      handleScope(node->body());
+    if (node->iter_domain()->isThread() || node->iter_domain()->isBroadcast()) {
+      handle(node->body());
       return;
     }
 
     const auto gen_index = gen(node->index());
-    const auto gen_start = genInline(node->start());
-    const auto gen_stop = genInline(node->stop());
-    const auto gen_step = genInline(node->step());
+    const auto gen_start = genInline(node->iter_domain()->start());
+    const auto gen_extent = genInline(node->iter_domain()->extent());
+    indent() << "for(size_t " << gen_index << " = " << gen_start << "; "
+             << gen_index << " < " << gen_extent << "; ++" << gen_index << ") ";
 
-    std::stringstream step_code;
-    if (node->step()->isOneInt()) {
-      step_code << "++" << gen_index;
-    } else {
-      step_code << gen_index << " += " << gen_step;
-    }
-    if (node->isUnrollable()) {
-      indent() << "#pragma unroll\n";
-    } else {
-      indent() << "#pragma unroll 1\n";
-    }
-    indent() << "for(nvfuser_index_t " << gen_index << " = " << gen_start
-             << "; " << gen_index << " < " << gen_stop << "; "
-             << step_code.str() << ") ";
     startBlock(true);
-    handleScope(node->body());
+    handle(node->body());
     endBlock();
   }
 
-  void visit(const kir::IfThenElse* node) final {
-    auto conditional = node->predicate()->value();
-    if (conditional->isConst()) {
-      // If the conditional is a constant, then the IfThenElse is not required
-      if (conditional->value().value()) {
-        handleScope(node->thenBody());
-      } else {
-        handleScope(node->elseBody());
-      }
-      return;
-    }
-
-    indent() << "if (" << genInline(conditional) << ") ";
+  void handle(const kir::IfThenElse* node) final {
+    indent() << "if (" << genInline(node->cond()) << ") ";
 
     // "then" block
     startBlock(true);
-    handleScope(node->thenBody());
+    handle(node->thenBody());
 
     // "else" block (optional)
     if (node->hasElse()) {
       endBlock(" else ");
       startBlock(true);
-      handleScope(node->elseBody());
+      handle(node->elseBody());
     }
 
     endBlock();
   }
 
   // TODO(kir): fold initialization into Allocate
-  void visit(const kir::Allocate* node) final {
-    const auto buffer_dtype = node->buffer()->dtype();
-
-    if (!node->buffer()->isA<kir::TensorView>()) {
-      indent() << buffer_dtype << " " << gen(node->buffer()) << ";\n";
+  void handle(const kir::Allocate* node) final {
+    if (node->buffer()->getValType().value() != ValType::KirTensorView) {
+      indent() << node->buffer_type() << " " << gen(node->buffer()) << ";\n";
       return;
     }
 
     const auto tv = node->buffer()->as<kir::TensorView>();
-
-    const auto size = node->size();
-    TORCH_INTERNAL_ASSERT(size != nullptr);
+    TORCH_INTERNAL_ASSERT(tv->domain()->nDims() > 0);
+    TORCH_INTERNAL_ASSERT(node->size() != nullptr);
 
     if (node->alias() != nullptr) {
       // Allocate alias another Allocate node
       const auto alias_tv = node->alias()->buffer()->as<kir::TensorView>();
-      indent() << "// Alias Allocation - " << node->memoryType() << "\n";
-      indent() << buffer_dtype << "* " << varName(tv) << " = "
-               << varName(alias_tv) << ";\n";
+      indent() << "// Alias Allocation - " << node->getMemoryType() << "\n";
+      indent() << node->buffer_type() << "* " << gen(tv) << " = "
+               << gen(alias_tv) << ";\n";
     } else {
       // Standard Memory Allocation
       switch (tv->memoryType()) {
         case MemoryType::Global:
-          indent() << "// Allocate global tensor " << varName(tv) << "\n";
+          indent() << "// Allocate global tensor " << gen(tv) << "\n";
           break;
         case MemoryType::Shared:
-          if (kir::ExpressionEvaluator::isConst(size)) {
+          if (node->size()->isConstScalar()) {
             // Static shared memory
-            indent() << "__shared__ " << buffer_dtype << " " << varName(tv)
-                     << "[" << genInline(size) << "];\n";
+            indent() << "__shared__ " << node->buffer_type() << " " << gen(tv)
+                     << "[" << genInline(node->size()) << "];\n";
           } else {
             // Align Offset Position
             indent() << "offset = alignBufferSize(offset,"
-                     << dataTypeSize(buffer_dtype) << ");\n";
+                     << dataTypeSize(node->buffer_type()) << ");\n";
             // Shared Memory Pointer
-            indent() << buffer_dtype << "* " << varName(tv)
-                     << " = reinterpret_cast<" << buffer_dtype << "*>"
+            indent() << node->buffer_type() << "* " << gen(tv)
+                     << " = reinterpret_cast<" << node->buffer_type() << "*>"
                      << "(array + offset);\n";
             // Increment Offset Position
-            indent() << "offset += (" << genInline(size) << " * sizeof("
-                     << buffer_dtype << "));\n";
+            indent() << "offset += (" << genInline(node->size()) << " * sizeof("
+                     << node->buffer_type() << "));\n";
           }
           break;
         case MemoryType::Local:
-          indent() << buffer_dtype << " " << varName(tv) << "["
-                   << genInline(size) << "];\n";
+          indent() << node->buffer_type() << " " << gen(tv) << "["
+                   << genInline(node->size()) << "];\n";
           break;
         default:
           TORCH_INTERNAL_ASSERT(false, "Unexpected memory type");
@@ -1119,43 +628,23 @@ class CudaKernelGenerator : private kir::IrVisitor {
     }
   }
 
-  void visit(const kir::Sync* node) final {
-    // Use a custom synchronization method if enabled
-    if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
-      indent() << "block_sync::sync();\n";
-    } else {
-      indent() << "__barrier_sync(0);\n";
-    }
-  }
-
-  void visit(const kir::InitMagicZero* node) final {
-    indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n";
-  }
-
-  void visit(const kir::UpdateMagicZero* node) final {
-    indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
+  void handle(const kir::Sync* node) final {
+    indent() << "__syncthreads();\n";
   }
 
  private:
   std::stringstream code_;
-  const kir::Kernel* kernel_;
+  const Kernel* kernel_;
   int block_nest_level_ = 0;
-  int block_reduce_name_ = 0;
 
   // TODO(kir): replace with explicit assignment statements
   bool print_inline_ = false;
-
-  // Mark when we are inside of a vectorized for-loop
-  bool vectorize_scope_ = false;
-
-  //! Holds active replacement mappings during codegen
-  std::unordered_map<const kir::Node*, const kir::Node*> replacement_map_;
 };
 
 } // namespace
 
 std::string generateCudaKernel(
-    const kir::Kernel* kernel,
+    const Kernel* kernel,
     const std::string& kernel_name) {
   FUSER_PERF_SCOPE("generateCudaKernel");
   return CudaKernelGenerator::generateKernelDefinition(kernel, kernel_name);
index 5f9b4f2..8205c85 100644 (file)
@@ -13,7 +13,7 @@ namespace codegen {
 
 //! Generates a CUDA kernel definition for the given kernel
 TORCH_CUDA_CU_API std::string generateCudaKernel(
-    const kir::Kernel* kernel,
+    const Kernel* kernel,
     const std::string& kernel_name = "CUDAGeneratedKernel");
 
 } // namespace codegen
index 265d47f..b1e8d70 100644 (file)
@@ -3,8 +3,6 @@
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
 #include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
 #include <torch/csrc/jit/codegen/cuda/transform_iter.h>
 #include <torch/csrc/jit/codegen/cuda/transform_replay.h>
 
@@ -15,6 +13,95 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
+ComputeAtData::ComputeAtData(TensorView* tv)
+    : tv_ref_(tv),
+      original_has_compute_at_(tv->hasComputeAt()),
+      original_compute_at_position(tv->getThisComputeAtAxis()),
+      original_domain_(tv->domain()),
+      new_compute_at_domain_(tv->domain()) {}
+
+// Clear pass based data
+void ComputeAtData::clearPass() {
+  // If the last pass set a position, update the new_compute_at_position if
+  // latest position would be greater than previously set.
+  if (current_traversal_position_set &&
+      current_traversal_position > new_compute_at_position) {
+    new_compute_at_position = current_traversal_position;
+  }
+
+  current_traversal_position_set = false;
+  current_traversal_position = 0;
+}
+
+void ComputeAtData::setPassPosition(unsigned int pos) {
+  if (current_traversal_position_set) {
+    // A single traversal cannot try to enforce more than one position on a
+    // TensorView as it would produce in incorrect code. If this is hit, then
+    // the given tensor and its production should be duplicated.
+    TORCH_CHECK(
+        pos == current_traversal_position,
+        "Error during computeAt. ComputeAt pass wanted to set position of ",
+        tv_ref_,
+        " at position ",
+        pos,
+        " but was already set to position ",
+        current_traversal_position,
+        ". This tensor would have to be recomputed to satsify the selected computeAt position.");
+  }
+
+  current_traversal_position = pos;
+  touched_ = true;
+  current_traversal_position_set = true;
+}
+
+unsigned int ComputeAtData::getNewPosition() const {
+  // If the last pass set a position, return the latest position if
+  // it would be greater than previously set.
+  if (current_traversal_position_set &&
+      current_traversal_position > new_compute_at_position) {
+    return current_traversal_position;
+  } else {
+    return new_compute_at_position;
+  }
+}
+
+void ComputeAtData::validateNewComputeAt() const {
+  FUSER_PERF_SCOPE("validateNewComputeAt");
+
+  TORCH_INTERNAL_ASSERT(
+      getNewPosition() >= original_compute_at_position,
+      "Invalid computeAt detected. This computeAt would invalidate the set computeAt on ",
+      tv_ref_,
+      " as the new computeAt position was found to be ",
+      getNewPosition(),
+      ".");
+  auto mismatch = BestEffortReplay::findFirstMismatchedID(
+      tv_ref_->domain(), original_domain_);
+  TORCH_CHECK(
+      mismatch >= (int)original_compute_at_position,
+      "Invalid computeAt detected. This computeAt call would invalidate the set computeAt on ",
+      tv_ref_,
+      " as the previous set computeAt was on the domain ",
+      original_domain_,
+      " with a computeAt position of ",
+      original_compute_at_position,
+      ".");
+}
+
+void ComputeAtData::setComputeAtDomain(TensorDomain* td) {
+  if (new_compute_at_domain_ != original_domain_) {
+    TORCH_INTERNAL_ASSERT(
+        *new_compute_at_domain_ == *td,
+        "TensorDomain, ",
+        td,
+        ", does not match with the previously set domain of ",
+        tv_ref_,
+        ", which is ",
+        new_compute_at_domain_);
+  }
+  new_compute_at_domain_ = td;
+}
+
 namespace {
 
 // Wrapper around set_intersection
@@ -30,268 +117,38 @@ std::set<T> set_intersection(const std::set<T>& set1, const std::set<T>& set2) {
   return intersection;
 }
 
+// convert an iterable of Val* to be an iterable of TensorView*
+template <typename T1, typename T2>
+T1 tvIterable(const T2& val_iterable) {
+  T1 tv_iterable = T1();
+  std::transform(
+      val_iterable.begin(),
+      val_iterable.end(),
+      std::back_inserter(tv_iterable),
+      [](Val* v) {
+        TORCH_INTERNAL_ASSERT(
+            v->getValType().value() == ValType::TensorView,
+            "When following the computeAt dependency chain, a non TensorView value was found.");
+        return v->as<TensorView>();
+      });
+  return tv_iterable;
+}
+
 std::deque<std::deque<TensorView*>> tvChains(
     std::deque<std::deque<Val*>> val_chains) {
   std::deque<std::deque<TensorView*>> tv_chains(val_chains.size());
   for (const auto i : c10::irange(val_chains.size())) {
-    auto tv_iterable = ir_utils::filterByType<TensorView>(val_chains[i]);
-    tv_chains[i] =
-        std::deque<TensorView*>(tv_iterable.begin(), tv_iterable.end());
+    tv_chains[i] = tvIterable<std::deque<TensorView*>>(val_chains[i]);
   }
   return tv_chains;
 }
 
-bool validateDomain(TensorView* tv, TensorDomain* new_td) {
-  auto first_mismatch =
-      BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td);
-  return first_mismatch >= (int)tv->getMaxProducerPosition() &&
-      first_mismatch >= (int)tv->getComputeAtPosition();
-}
-
-// Return the max position in consumer that producer can be inlined to
-// Cannot inline:
-//   Reduction dimensions in producer
-//   Block broadcast dimensions in producer
-//   Vectorized dimensions in producer or consumer
-//   Unrolled dimensions in producer or consumer
-//   Dimensions derived from root dimensions that exist in both but are
-//   unmappable
-unsigned int getReplayablePosPasC(
-    TensorView* producer,
-    TensorView* consumer,
-    const ComputeAtRootDomainMap& root_map_,
-    ComputeAtMode mode) {
-  // Grab dimensions in producer and consumer that are mappable to eachother
-  // based on the computeAtRootDomainMap. This will tell us which dimensions
-  // can be inlined based on avoiding trying to inline reduction structures.
-  auto mappable_roots =
-      root_map_.getMappableDims(producer->domain(), consumer->domain());
-
-  // Check if any consumer dimensions are marked as vectorize as producer can
-  // not be inlined to vectorized dimensions in consumer.
-  auto c_dom = consumer->domain()->domain();
-  auto vector_dim_it =
-      std::find_if(c_dom.begin(), c_dom.end(), [&mode](IterDomain* id) {
-        return isParallelTypeVectorize(id->getParallelType()) ||
-            ((mode == ComputeAtMode::BestEffort ||
-              mode == ComputeAtMode::MostInlined) &&
-             id->getParallelType() == ParallelType::Unroll);
-      });
-
-  // Limit max position based on vectorized dims in consumer.
-  auto max_consumer_pos = std::distance(c_dom.begin(), vector_dim_it);
-
-  auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer);
-  auto c2p_root_map =
-      PairwiseRootDomainMap(producer, consumer)
-          .mapConsumerToProducer(consumer->domain(), producer->domain());
-
-  auto replay_PasC =
-      BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_root_map);
-
-  // Look for id's that map to a consumer id that's vectorized
-  auto c2p_replay_map = replay_PasC.getReplay();
-
-  for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0;
-       consumer_pos--) {
-    auto map_it = c2p_replay_map.find(consumer->axis((int)consumer_pos - 1));
-    if (map_it != c2p_replay_map.end()) {
-      auto p_id = map_it->second;
-      // If we find a consumer dim that maps to a producer dim that's
-      // vectorized or unrolled limit max compute at by it.
-      if (isParallelTypeVectorize(p_id->getParallelType()) ||
-          ((mode == ComputeAtMode::BestEffort ||
-            mode == ComputeAtMode::MostInlined) &&
-           p_id->getParallelType() == ParallelType::Unroll)) {
-        max_consumer_pos = consumer_pos - 1;
-      }
-    }
-  }
-
-  // Start at max position and work backwards,  try to find a location where
-  // producer can be inlined.
-  for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0;
-       consumer_pos--) {
-    // Grab all root dimensions of consumer as roots must be used to understand
-    // inlining potential.
-    auto consumer_root_dim_vals =
-        IterVisitor::getInputsTo({c_dom.begin(), c_dom.begin() + consumer_pos});
-    // convert to iter domains
-    auto consumer_root_dim_ids =
-        ir_utils::filterByType<IterDomain>(consumer_root_dim_vals);
-    // If any root dimensions cannot be mapped to producer we can't inline. If
-    // any root dimension
-    if (std::any_of(
-            consumer_root_dim_ids.begin(),
-            consumer_root_dim_ids.end(),
-            [&mappable_roots, &c2p_root_map](IterDomain* root_id) {
-              return mappable_roots.find(root_id) == mappable_roots.end() &&
-                  c2p_root_map.find(root_id) != c2p_root_map.end();
-            })) {
-      continue;
-    }
-    return consumer_pos;
-  }
-
-  return 0;
-}
-
-// Return the max position in producer that can be inlined to consumer
-// Cannot inline:
-//   Reduction dimensions in producer
-//   Vectorized dimensions in producer or consumer
-//   Unrolled dimensions in producer or consumer
-//   Dimensions derived from root dimensions that exist in both but are
-//   unmappable
-unsigned int getReplayablePosCasP(
-    TensorView* consumer,
-    TensorView* producer,
-    const ComputeAtRootDomainMap& root_map_,
-    ComputeAtMode mode) {
-  // Grab dimensions in producer and consumer that are mappable to eachother
-  // based on the computeAtRootDomainMap. This will tell us which dimensions
-  // can be inlined based on avoiding trying to inline reduction structures.
-  auto mappable_roots =
-      root_map_.getMappableDims(producer->domain(), consumer->domain());
-
-  auto p_dom = producer->domain()->domain();
-  auto first_reduction =
-      std::find_if(p_dom.begin(), p_dom.end(), [](IterDomain* id) {
-        return id->isReduction();
-      });
-
-  auto first_vectorized_axis =
-      std::find_if(p_dom.begin(), first_reduction, [&mode](IterDomain* id) {
-        return isParallelTypeVectorize(id->getParallelType()) ||
-            ((mode == ComputeAtMode::BestEffort ||
-              mode == ComputeAtMode::MostInlined) &&
-             id->getParallelType() == ParallelType::Unroll);
-      });
-
-  auto max_producer_pos = std::distance(p_dom.begin(), first_vectorized_axis);
-
-  auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer);
-  auto p2c_root_map = pairwise_root_map.mapProducerToConsumer(
-      producer->domain(), consumer->domain());
-
-  auto replay_CasP =
-      BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map);
-
-  // Look for id's that map to a consumer id that's vectorized
-  auto p2c_replay_map = replay_CasP.getReplay();
-
-  for (size_t producer_pos = max_producer_pos; producer_pos > 0;
-       producer_pos--) {
-    auto map_it = p2c_replay_map.find(producer->axis((int)producer_pos - 1));
-    if (map_it != p2c_replay_map.end()) {
-      auto c_id = map_it->second;
-      // If we find a producer dim that maps to a consumer vectorized or
-      // unrolled dim, limit max compute at by it
-      if (isParallelTypeVectorize(c_id->getParallelType()) ||
-          ((mode == ComputeAtMode::BestEffort ||
-            mode == ComputeAtMode::MostInlined) &&
-           c_id->getParallelType() == ParallelType::Unroll)) {
-        max_producer_pos = producer_pos - 1;
-      }
-    }
-  }
-
-  for (size_t producer_pos = max_producer_pos; producer_pos > 0;
-       producer_pos--) {
-    auto all_vals = DependencyCheck::getAllValsBetween(
-        {producer->getMaybeRFactorDomain().begin(),
-         producer->getMaybeRFactorDomain().end()},
-        {p_dom.begin(), p_dom.begin() + producer_pos});
-
-    // If any root dims could have mapped to consumer, but don't, then we can't
-    // compute at this point
-    if (std::any_of(
-            producer->getMaybeRFactorDomain().begin(),
-            producer->getMaybeRFactorDomain().end(),
-            [&mappable_roots, &all_vals](IterDomain* root_id) {
-              return std::find(all_vals.begin(), all_vals.end(), root_id) !=
-                  all_vals.end() &&
-                  mappable_roots.find(root_id) == mappable_roots.end();
-            })) {
-      continue;
-    }
-
-    return producer_pos;
-  }
-  return 0;
-}
-
-unsigned int getInnermostNonBroadcastIdFrom(TensorView* tv) {
-  unsigned int ret = tv->getComputeAtPosition();
-
-  // Still assuming we only have block broadcast for now.
-  //  This part may change
-  while (ret > 0 && tv->axis((int)ret - 1)->isBroadcast()) {
-    ret--;
-  }
-
-  return ret;
-}
-
-// Try to find the aligned position on consumer's domain corresponding to the
-//  compute at position of producer domain. Used in computeAt pass only. No
-//  checking on actual producer-consumer relationship.
-unsigned int getConsumerPosAlignedToProducerCA(
-    TensorView* consumer,
-    TensorView* producer) {
-  // Locate consumer's position that aligns with
-  //  the producer's new compute at axis. We need broadcast axes forwarded so we
-  //  need to replay PasC as CasP will not forward braodcast dims. For example
-  //  if we have:
-  // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 )
-  // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will
-  // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
-  // NVFuserTest.FusionComplexBCast1_CUDA
-
-  auto c2p_map =
-      BestEffortReplay::replayPasC(
-          producer,
-          consumer,
-          -1,
-          // Compute at root domain may not be valid here, as all
-          // producers don't have to be able to map into consumer at
-          // max producer position. Since computeAt should be valid
-          // and this mechanism is only intended to lower produce
-          // position of consumer, we can simply use the pairwise map.
-          PairwiseRootDomainMap(producer, consumer))
-          .getReplay();
-
-  // Find the innermost position of consumer that has
-  //  been mapped within the producer ca axis.
-  unsigned int consumer_pos = consumer->nDims();
-  while (consumer_pos > 0) {
-    auto consumer_id = consumer->axis((int)consumer_pos - 1);
-    auto p_dom = producer->domain()->domain();
-    if (std::any_of(
-            p_dom.begin(),
-            p_dom.begin() + producer->getComputeAtPosition(),
-            [&consumer_id, &c2p_map](IterDomain* p_id) {
-              auto c_id_it = c2p_map.find(consumer_id);
-              if (c_id_it != c2p_map.end()) {
-                return c_id_it->second == p_id;
-              }
-              return false;
-            })) {
-      break;
-    }
-    consumer_pos--;
-  }
-
-  return consumer_pos;
-}
-
 } // namespace
 
-void ComputeAt::runAt(
+void ComputeAt::run(
     TensorView* producer,
     TensorView* consumer,
-    unsigned int consumer_position,
-    ComputeAtMode mode) {
+    unsigned int consumer_position) {
   FUSER_PERF_SCOPE("ComputeAt::run");
 
   // Make sure the correct fusion is setup between this and consumer.
@@ -305,178 +162,101 @@ void ComputeAt::runAt(
   // Make sure Fusion Guard is set appropriately
   FusionGuard fg(producer->fusion());
 
-  TORCH_CHECK(
-      DependencyCheck::isDependencyOf(producer, consumer),
-      "Compute At expects ",
-      producer->name(),
-      " is a dependency of ",
-      consumer->name(),
-      ", however it is not.");
+  std::vector<TensorView*> producers;
 
-  // Run computeAt on our potentially modified producer(s)
-  ComputeAt ca(producer, consumer, consumer, consumer_position, mode);
-  ca.runPass();
-}
+  // It doesn't make sense to set computeAt on an input as it's not generated,
+  // it's provided. If this was called, move the computeAt to users of the
+  // producer that are in a dependency between prodcer and consumer.
+  if (producer->fusion()->hasInput(producer)) {
+    auto all_chains =
+        tvChains(DependencyCheck::getAllDependencyChains(producer, consumer));
 
-void ComputeAt::runWith(
-    TensorView* producer,
-    TensorView* consumer,
-    unsigned int producer_position,
-    ComputeAtMode mode) {
-  FUSER_PERF_SCOPE("ComputeAt::runWith");
-
-  // Make sure the correct fusion is setup between this and consumer.
-  TORCH_CHECK(
-      producer->fusion() == consumer->fusion(),
-      producer,
-      " and ",
-      consumer,
-      " are not in the same fusion.");
-
-  TORCH_CHECK(
-      DependencyCheck::isDependencyOf(producer, consumer),
-      "Compute At expects ",
-      producer->name(),
-      " is a dependency of ",
-      consumer->name(),
-      ", however it is not.");
-
-  // Make sure Fusion Guard is set appropriately
-  FusionGuard fg(producer->fusion());
+    TORCH_CHECK(
+        !all_chains.empty(),
+        "Compute At expects ",
+        producer,
+        " is a dependency of ",
+        consumer,
+        ", however it is not.");
+
+    std::unordered_set<TensorView*> added_producers;
+
+    // Check all dependency chains, select the next TV after producer towards
+    // consumer. These are the TVs we're going to actually call computeAt on.
+    for (const auto& tv_chain : all_chains) {
+      // When a chain only has two tensors, they must be the producer,
+      // which is an input, and the consumer. There is nothing we need
+      // to do for such chains.
+      if (tv_chain.size() > 2) {
+        // Make sure we only add once, but we want to add in a determinsitic
+        // order
+        if (added_producers.find(tv_chain[1]) == added_producers.end()) {
+          producers.push_back(tv_chain[1]);
+          added_producers.emplace(tv_chain[1]);
+        }
+      }
+    }
+  } else {
+    // If producer is not an input, it's the only one.
+    producers.push_back(producer);
+  }
 
-  ComputeAt ca(producer, consumer, producer, producer_position, mode);
-  ca.runPass();
+  // Run computeAt on our potentially modified producer(s)
+  if (!producers.empty()) {
+    for (auto producer_to_run : producers) {
+      ComputeAt ca(producer_to_run, consumer, consumer_position);
+      ca.runPass();
+    }
+  }
 }
 
 // Actually applies transformation
 unsigned int ComputeAt::backwardComputeAt_impl(
     TensorView* producer,
     TensorView* consumer,
-    unsigned int consumer_compute_at_pos) {
+    unsigned int consumer_compute_at_axis) {
   FUSER_PERF_SCOPE("backwardComputeAt_impl");
 
-  auto max_consumer_compute_at_pos =
-      getReplayablePosPasC(producer, consumer, root_map_, mode_);
-  if (mode_ == ComputeAtMode::BestEffort) {
-    consumer_compute_at_pos =
-        std::min(consumer_compute_at_pos, max_consumer_compute_at_pos);
-  } else if (mode_ == ComputeAtMode::MostInlined) {
-    consumer_compute_at_pos = max_consumer_compute_at_pos;
-  } else {
-    TORCH_INTERNAL_ASSERT(
-        consumer_compute_at_pos <= max_consumer_compute_at_pos,
-        "Invalid compute at position detected in compute at when trying to replay producer: ",
-        producer,
-        " as consumer: ",
-        consumer,
-        " tried to do this at position: ",
-        consumer_compute_at_pos,
-        " but max position that's allowed is ",
-        max_consumer_compute_at_pos);
-  }
-
-  auto replay_producer_pair = TransformReplay::replayPasC(
-      producer, consumer, (int)consumer_compute_at_pos, root_map_);
-
-  if (replay_producer_pair.second == 0) {
-    return 0;
-  }
+  auto& producer_entry = tv_data.at(producer);
 
-  if (replay_producer_pair.second >= producer->getComputeAtPosition()) {
-    const TensorDomain* current_domain = producer->domain();
-    TensorDomain* new_domain = replay_producer_pair.first;
+  // Use TensorDomain interface so it doesn't set computeAt automatically
+  auto replay = TransformReplay::replayPasC(
+      producer, consumer, (int)consumer_compute_at_axis);
 
-    TORCH_INTERNAL_ASSERT(
-        validateDomain(producer, new_domain),
-        "Tried to set the domain of ",
-        producer,
-        " to ",
-        new_domain,
-        " but that would invalidate previously compute at position or max producer position.");
+  producer_entry.setPassPosition(replay.second);
 
-    producer->setDomain(new_domain);
-    if (!producer->isFusionInput()) {
-      producer->setComputeAt(replay_producer_pair.second);
-    }
-
-    consumer->setMaxProducer(consumer_compute_at_pos);
-    for (auto other_consumer : ir_utils::consumerTvsOf(producer)) {
-      if (other_consumer != consumer) {
-        auto max_consumer_pos =
-            getConsumerPosAlignedToProducerCA(other_consumer, producer);
-        other_consumer->setMaxProducer(max_consumer_pos);
-      }
-    }
-    root_map_.setAlias(current_domain, new_domain);
+  if (producer_entry.shouldSetComputeAt(replay.second)) {
+    producer->setComputeAt(consumer, (int)consumer_compute_at_axis);
+    producer_entry.setComputeAtDomain(producer->domain());
   }
 
-  return replay_producer_pair.second;
+  return replay.second;
 }
 
-// Actually applies transformation, replay consumer based on producer, set
-// compute at of producer, set pass position of consumer, return position
-// relative to consumer
+// Actually applies transformation
 unsigned int ComputeAt::forwardComputeAt_impl(
     TensorView* producer,
     TensorView* consumer,
-    unsigned int producer_compute_at_pos) {
+    unsigned int producer_compute_at_axis) {
   FUSER_PERF_SCOPE("forwardComputeAt_impl");
 
-  auto max_producer_compute_at_pos =
-      getReplayablePosCasP(consumer, producer, root_map_, mode_);
-
-  if (mode_ == ComputeAtMode::BestEffort) {
-    producer_compute_at_pos =
-        std::min(producer_compute_at_pos, max_producer_compute_at_pos);
-  } else if (mode_ == ComputeAtMode::MostInlined) {
-    producer_compute_at_pos = max_producer_compute_at_pos;
-  } else {
-    TORCH_INTERNAL_ASSERT(
-        producer_compute_at_pos <= max_producer_compute_at_pos,
-        "Invalid compute at position detected in compute at when trying to replay consumer: ",
-        consumer,
-        " as producer: ",
-        producer,
-        " tried to do this at position: ",
-        producer_compute_at_pos,
-        " but max position that's allowed is ",
-        max_producer_compute_at_pos);
-  }
+  auto& consumer_entry = tv_data.at(consumer);
+  const auto& producer_entry = tv_data.at(producer);
 
-  auto replay_consumer_pair = TransformReplay::replayCasP(
-      consumer, producer, (int)producer_compute_at_pos, root_map_);
+  auto replay = TransformReplay::replayCasP(
+      consumer, producer, (int)producer_compute_at_axis);
 
-  if (producer_compute_at_pos > producer->getComputeAtPosition()) {
-    if (!producer->isFusionInput()) {
-      producer->setComputeAt((int)producer_compute_at_pos);
-    }
+  if (producer_entry.shouldSetComputeAt(producer_compute_at_axis)) {
+    producer->setComputeAt(consumer, replay.second);
   }
 
-  if (replay_consumer_pair.second > consumer->getMaxProducerPosition()) {
-    const TensorDomain* current_domain = consumer->domain();
-    TensorDomain* new_domain = replay_consumer_pair.first;
-
-    TORCH_INTERNAL_ASSERT(
-        validateDomain(consumer, new_domain),
-        "Tried to set the domain of ",
-        consumer,
-        " to ",
-        new_domain,
-        " but that would invalidate previously compute at position or max producer position.");
-
-    consumer->setDomain(new_domain);
-    consumer->setMaxProducer(replay_consumer_pair.second);
-    for (auto other_consumer : ir_utils::consumerTvsOf(producer)) {
-      if (other_consumer != consumer) {
-        auto max_consumer_pos =
-            getConsumerPosAlignedToProducerCA(other_consumer, producer);
-        other_consumer->setMaxProducer(max_consumer_pos);
-      }
-    }
-    root_map_.setAlias(current_domain, new_domain);
+  consumer_entry.setPassPosition(replay.second);
+  if (consumer_entry.shouldSetComputeAt(replay.second) &&
+      consumer != consumer_) {
+    consumer_entry.setComputeAtDomain(consumer->domain());
   }
 
-  return replay_consumer_pair.second;
+  return replay.second;
 }
 
 void ComputeAt::setCommonConsumer() {
@@ -502,9 +282,9 @@ void ComputeAt::setCommonConsumer() {
   TORCH_CHECK(
       !all_chains.empty(),
       "Compute At expects ",
-      producer_->name(),
+      producer_,
       " is a dependency of ",
-      consumer_->name(),
+      consumer_,
       ", however it is not.");
 
   // Remove all TVs from producer to consumer as common consumer must be at or
@@ -535,11 +315,6 @@ void ComputeAt::setCommonConsumer() {
 // computeAt if it will increase computeAt positions.
 void ComputeAt::traverseBackward() {
   FUSER_PERF_SCOPE("ComputeAt::traverseBackward");
-  if (reference_ == producer_) {
-    // Forward compute at don't need to run backward traversal
-    producer_position_ = reference_position_;
-    return;
-  }
 
   // propagate *backward* through all *producer* use_chains or from *producer*
   // to common_consumer if common_consumer exists. Only apply transform if
@@ -550,7 +325,7 @@ void ComputeAt::traverseBackward() {
   for (auto tv_chain : chains) {
     TensorView* running_producer = tv_chain.back();
     TensorView* running_consumer = nullptr;
-    unsigned int running_consumer_pos = reference_position_;
+    unsigned int running_consumer_pos = consumer_position_;
     tv_chain.pop_back();
 
     TORCH_INTERNAL_ASSERT(running_producer == consumer_);
@@ -563,11 +338,6 @@ void ComputeAt::traverseBackward() {
       running_consumer_pos = backwardComputeAt_impl(
           running_producer, running_consumer, running_consumer_pos);
     }
-
-    TORCH_INTERNAL_ASSERT(
-        running_producer == producer_,
-        "Compute at backward traversal ended up on something other than the producer.");
-    producer_position_ = running_consumer_pos;
   }
 }
 
@@ -582,12 +352,14 @@ void ComputeAt::traverseForward() {
         DependencyCheck::getAllDependencyChains(producer_, common_consumer_));
   }
 
+  unsigned int producer_pos = tv_data.at(producer_).getNewPosition();
+
   // propagate forward through all chains
   for (auto tv_dep_chain : chains) {
     TensorView* running_producer = nullptr;
     TensorView* running_consumer = tv_dep_chain.front();
     tv_dep_chain.pop_front();
-    unsigned int running_producer_pos = producer_position_;
+    unsigned int running_producer_pos = producer_pos;
 
     TORCH_INTERNAL_ASSERT(running_consumer == producer_);
 
@@ -595,211 +367,109 @@ void ComputeAt::traverseForward() {
       running_producer = running_consumer;
       running_consumer = tv_dep_chain.front();
       tv_dep_chain.pop_front();
+
       running_producer_pos = forwardComputeAt_impl(
           running_producer, running_consumer, running_producer_pos);
     }
   }
 }
 
-void ComputeAt::resetMaxProducerPos(TensorView* consumer_tv) {
-  if (consumer_tv->definition() == nullptr) {
-    consumer_tv->setMaxProducer(0, true);
+void ComputeAt::runPass() {
+  FUSER_PERF_SCOPE("ComputeAt::runPass");
+
+  // Initialize tv_data for all TensorViews we may modify
+  auto chains = producer_use_chains_;
+  if (common_consumer_ != nullptr) {
+    chains = tvChains(
+        DependencyCheck::getAllDependencyChains(producer_, common_consumer_));
   }
 
-  unsigned int new_consummer_pa_pos = 0;
-
-  // Re-compute the max producer position as one or more
-  //  of the producers of this consumer have updated their
-  //  compute at position.
-  for (auto inp : ir_utils::producerTvsOf(consumer_tv)) {
-    if (!inp->isFusionInput()) {
-      // Locate consumer's position that aligns with
-      //  the producer's new compute at axis.
-      unsigned int inp_ca_pos_to_consumer =
-          getConsumerPosAlignedToProducerCA(consumer_tv, inp);
-
-      // Populate the max consumer position required by
-      //  producer compute at.
-      new_consummer_pa_pos =
-          std::max(new_consummer_pa_pos, inp_ca_pos_to_consumer);
+  for (const auto& tv_chain : chains) {
+    for (auto tv : tv_chain) {
+      if (tv_data.find(tv) == tv_data.end()) {
+        tv_data[tv] = ComputeAtData(tv);
+      }
     }
   }
 
-  consumer_tv->setMaxProducer(new_consummer_pa_pos, true);
-}
+  // Traverse backward through all dep chains from producer to consumer
+  traverseBackward();
 
-void ComputeAt::hoistInnermostBroadcast() {
-  auto fusion = producer_->fusion();
+  // Clear data from backward traversal:
+  for (auto& entry : tv_data) {
+    entry.second.clearPass();
+  }
 
-  std::unordered_set<TensorView*> consumers_to_update;
+  // Start at producer and traverse forward through all chains
+  traverseForward();
 
-  auto all_vals = fusion->usedMathVals();
-  auto all_tvs = ir_utils::filterByType<TensorView>(all_vals);
+  setupOutputs();
 
-  for (auto running_producer : all_tvs) {
-    if (!running_producer->isFusionInput()) {
-      auto producer_ca_pos = running_producer->getComputeAtPosition();
-      // Find the innermost iterdomain that is not a broadcast
-      auto new_ca_pos = getInnermostNonBroadcastIdFrom(running_producer);
-      // Update the compute at pos of this producer if the original
-      //  compute at is within inner most broadcast axes
-      if (new_ca_pos < producer_ca_pos) {
-        running_producer->setComputeAt(new_ca_pos, true);
-      }
-      // Mark all consumers of this producer for later produce
-      //  position update.
-      // This is safe with segmented fusion. TV uses will reset
-      //  when FusionSegmentGuard try to change the IO.
-      auto tv_consumers = ir_utils::consumerTvsOf(running_producer);
-      consumers_to_update.insert(tv_consumers.begin(), tv_consumers.end());
-    }
+  for (const auto& entry : tv_data) {
+    entry.first->setDomain(entry.second.getComputeAtDomain());
+    entry.second.validateNewComputeAt();
   }
 
-  // Update the produce positions of all affected consumers
-  for (auto running_consumer : consumers_to_update) {
-    TORCH_INTERNAL_ASSERT(running_consumer->definition() != nullptr);
-    resetMaxProducerPos(running_consumer);
-  }
+  TORCH_INTERNAL_ASSERT(
+      BestEffortReplay::findFirstMismatchedID(
+          consumer_->domain(), tv_data.at(consumer_).getOriginalDomain()) ==
+          (int)consumer_->domain()->nDims(),
+      "ComputeAt logic changed the consumer domain which should not happen. Domain was ",
+      tv_data.at(consumer_).getOriginalDomain(),
+      " but is now: ",
+      consumer_->domain());
 }
 
-void ComputeAt::updateSiblings() {
-  // Track which consumers may have a wrong produce at position to update
-  // later
-  auto updateSiblingsOfTv = [&](TensorView* tv) {
-    if (tv->definition() == nullptr) {
-      return;
-    }
-
-    std::unordered_set<TensorView*> consumers_to_update;
+void ComputeAt::setupOutputs() {
+  FUSER_PERF_SCOPE("ComputeAt::setupOutputs");
 
-    if (tv->definition()->outputs().size() > 1) {
-      auto outs = tv->definition()->outputs();
-      auto out_tvs = ir_utils::filterByType<TensorView>(outs);
-      for (auto sibling_tv : out_tvs) {
-        if (sibling_tv == tv) {
-          continue;
-        }
+  if (common_consumer_ != nullptr)
+    return;
 
-        std::unordered_map<IterDomain*, IterDomain*> tv_to_sibling_map;
-        TORCH_INTERNAL_ASSERT(
-            tv->getRootDomain().size() == sibling_tv->getRootDomain().size(),
-            "Error replaying multiple output expressions in computeAt.");
-
-        // Propagate any root parallelization as fullSelfReplay expects it.
-        for (size_t i = 0; i < sibling_tv->getRootDomain().size(); i++) {
-          auto id = tv->getRootDomain()[i];
-          auto sibling_id = sibling_tv->getRootDomain()[i];
-          if (id->getParallelType() != ParallelType::Serial &&
-              sibling_id->getParallelType() == ParallelType::Serial) {
-            sibling_id->parallelize(id->getParallelType());
-          } else if (
-              id->getParallelType() == ParallelType::Serial &&
-              sibling_id->getParallelType() != ParallelType::Serial) {
-            id->parallelize(sibling_id->getParallelType());
-          }
-        }
-        if (tv->getComputeAtPosition() > sibling_tv->getComputeAtPosition()) {
-          auto sibling_domain = TransformReplay::fullSelfReplay(
-              sibling_tv->domain(), tv->domain());
-          validateDomain(sibling_tv, sibling_domain);
-          sibling_tv->setDomain(sibling_domain);
-          sibling_tv->setComputeAt(tv->getComputeAtPosition());
-          sibling_tv->setMaxProducer(tv->getMaxProducerPosition());
-          auto consumer_tvs = ir_utils::consumerTvsOf(sibling_tv);
-          consumers_to_update.insert(consumer_tvs.begin(), consumer_tvs.end());
+  std::vector<TensorView*> touched_output_order;
+  const auto& terminating_outputs =
+      FusionGuard::getCurFusion()->getTerminatingOutputs();
+
+  for (auto out : ir_utils::filterByType<TensorView>(
+           FusionGuard::getCurFusion()->outputs())) {
+    if (tv_data.find(out) != tv_data.end()) {
+      if (tv_data[out].touched()) {
+        // No need to adjust computeAt when an output is not
+        // a terminating output.
+        if (std::find(
+                terminating_outputs.begin(), terminating_outputs.end(), out) !=
+            terminating_outputs.end()) {
+          touched_output_order.push_back(out);
         }
       }
     }
-    for (auto consumer : consumers_to_update) {
-      this->resetMaxProducerPos(consumer);
-    }
-  };
-
-  // Find all tensor views that may have been modified
-  auto chains = producer_use_chains_;
-  if (common_consumer_ != nullptr) {
-    chains = tvChains(
-        DependencyCheck::getAllDependencyChains(producer_, common_consumer_));
   }
 
-  std::unordered_set<TensorView*> participating_tvs;
-  for (auto chain : chains) {
-    participating_tvs.insert(chain.begin(), chain.end());
-  }
-
-  for (auto tv : participating_tvs) {
-    updateSiblingsOfTv(tv);
-  }
-}
-
-void ComputeAt::updateInputProduceAts() {
-  std::unordered_set<TensorView*> consumers_to_check;
-
-  // Find all tensor views that may have been modified
-  auto chains = producer_use_chains_;
-  if (common_consumer_ != nullptr) {
-    chains = tvChains(
-        DependencyCheck::getAllDependencyChains(producer_, common_consumer_));
-  }
-
-  for (auto chain : chains) {
-    if (chain.size() > 1 && chain[0]->isFusionInput()) {
-      consumers_to_check.emplace(chain[1]);
+  if (touched_output_order.size() > 0) {
+    for (size_t i = 0; i < touched_output_order.size() - 1; i++) {
+      touched_output_order[i]->setComputeAt(
+          touched_output_order[i + 1],
+          (int)tv_data.at(touched_output_order[i]).getNewPosition(),
+          (int)tv_data.at(touched_output_order[i + 1]).getNewPosition());
     }
   }
-
-  for (auto tv : consumers_to_check) {
-    resetMaxProducerPos(tv);
-  }
-}
-
-void ComputeAt::runPass() {
-  FUSER_PERF_SCOPE("ComputeAt::runPass");
-
-  // Traverse backward through all dep chains from producer to consumer
-  traverseBackward();
-
-  // Start at producer and traverse forward through all chains
-  traverseForward();
-
-  // Back off on inlining the inner broadcast axes
-  hoistInnermostBroadcast();
-
-  // Clear max producer position of consumers from fusion inputs.
-  updateInputProduceAts();
-
-  // Update siblings of multi output expressions
-  updateSiblings();
 }
 
 ComputeAt::ComputeAt(
     TensorView* _producer,
     TensorView* _consumer,
-    TensorView* _reference,
-    unsigned int _reference_position,
-    ComputeAtMode _mode)
+    unsigned int _consumer_position)
     : producer_(_producer),
       consumer_(_consumer),
-      reference_(_reference),
-      reference_position_(_reference_position),
-      mode_(_mode) {
+      consumer_position_(_consumer_position) {
   TORCH_INTERNAL_ASSERT(
-      reference_ == producer_ || reference_ == consumer_,
-      "For compute at reference must be producer or consumer, it's neither.",
-      " reference: ",
-      reference_,
-      " consumer: ",
-      consumer_,
-      " producer: ",
-      producer_);
-  TORCH_INTERNAL_ASSERT(
-      reference_position_ >= 0 && reference_position_ <= reference_->nDims(),
+      consumer_position_ >= 0 && consumer_position_ <= consumer_->nDims(),
       "Invalid computeAt axis, received ",
-      reference_position_,
+      _consumer_position,
       " but should be > -",
-      reference_->nDims(),
+      consumer_->nDims(),
       " and <= ",
-      reference_->nDims(),
+      consumer_->nDims(),
       ".");
 
   producer_use_chains_ = tvChains(DependencyCheck::getAllUseChains(producer_));
@@ -808,8 +478,6 @@ ComputeAt::ComputeAt(
   // consumer for all chains at or after the consumer specified in the computeAt
   // call.
   setCommonConsumer();
-
-  root_map_.build();
 }
 
 } // namespace cuda
index 71e3950..d8328c1 100644 (file)
@@ -1,7 +1,5 @@
 #pragma once
 
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-
 #include <c10/util/Exception.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
@@ -17,23 +15,93 @@ namespace cuda {
 class TensorDomain;
 class TensorView;
 
-class ComputeAt {
+// We're going to keep data related to the computeAt pass for each TensorView in
+// this structure, this will allow us to keep a single entry in a map from a
+// TensorView to this one.
+// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+class ComputeAtData {
  public:
-  // Runs the compute at pass making producer look like consumer, computing
-  // producer relative to consumer
-  static void runAt(
-      TensorView* producer,
-      TensorView* consumer,
-      unsigned int consumer_position,
-      ComputeAtMode mode = ComputeAtMode::Standard);
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+  ComputeAtData() = default;
+  ComputeAtData(TensorView* tv);
+
+  // Clear after a given traversal. There will be more than one.
+  void clearPass();
+
+  // Makes sure value matches current_traversal_position if
+  // current_traversal_position_set is true. If this is not the case we're in
+  // an invalid compute_at that would require tensor replication.
+  void setPassPosition(unsigned int pos);
+
+  // Returns if new postion is greater or equal to previous seen, if
+  bool shouldSetComputeAt(unsigned int pos) const {
+    return pos > original_compute_at_position &&
+        pos > new_compute_at_position && pos >= current_traversal_position;
+  }
+
+  // Will return new_compute_at_position, after making sure we cleared out the
+  // last pass
+  unsigned int getNewPosition() const;
+
+  // Will make sure we haven't invalidated previous computeAt calls by
+  // checking that any axes previously in computeAt are still there.
+  void validateNewComputeAt() const;
+
+  // Did we ever compute a value for this TV?
+  bool touched() const {
+    return touched_;
+  }
+
+  TensorDomain* getOriginalDomain() const {
+    return original_domain_;
+  }
+
+  // If we set computeAt, save the domain so we can reset it after traversal.
+  // Traversal state can deviate from the domain we will want to save after the
+  // entire computeAt pass.
+  void setComputeAtDomain(TensorDomain* td);
+
+  // Return domain set in setComputeAtDomain
+  TensorDomain* getComputeAtDomain() const {
+    return new_compute_at_domain_;
+  }
 
-  // Runs the compute with pass making consumer look like producer, computing
-  // producer relative to consumer
-  static void runWith(
-      TensorView* producer,
-      TensorView* consumer,
-      unsigned int producer_position,
-      ComputeAtMode mode = ComputeAtMode::Standard);
+ private:
+  // Was the position ever modified?
+  bool touched_ = false;
+
+  // Hold onto the provided TensorView
+  TensorView* tv_ref_ = nullptr;
+
+  // Did this tv have computeAt set before calling this computeAt pass?
+  bool original_has_compute_at_ = false;
+
+  // What was the computeAt position before the computeAt pass started
+  unsigned int original_compute_at_position = 0;
+
+  // and what was the previous domain that position was set relative to.
+  TensorDomain* original_domain_ = nullptr;
+
+  // Position we can update during a traversal
+  unsigned int current_traversal_position = 0;
+
+  // Did this traversal set a position or not yet
+  bool current_traversal_position_set = false;
+
+  // Position to update after a traversal
+  unsigned int new_compute_at_position = 0;
+
+  // Domain when we actually set computeAt, will set back to this after the
+  // pass.
+  TensorDomain* new_compute_at_domain_;
+};
+
+class ComputeAt {
+ public:
+  static void run(
+      TensorView* _producer,
+      TensorView* _consumer,
+      unsigned int _consumer_position);
 
   ComputeAt() = delete;
   ComputeAt(ComputeAt&) = delete;
@@ -42,26 +110,21 @@ class ComputeAt {
  private:
   TensorView* producer_;
   TensorView* consumer_;
-  TensorView* reference_;
-  unsigned int reference_position_;
-  ComputeAtMode mode_ = ComputeAtMode::Standard;
-
-  unsigned int producer_position_ = 0;
-  ComputeAtRootDomainMap root_map_;
+  unsigned int consumer_position_;
 
   // Runs replayPasC and sets producer computeAt settings. Returns
-  // producer_compute_at_pos.
+  // producer_compute_at_axis.
   unsigned int backwardComputeAt_impl(
       TensorView* producer,
       TensorView* consumer,
-      unsigned int consumer_compute_at_pos);
+      unsigned int consumer_compute_at_axis);
 
   // Runs replayCasP and sets producer computeAt settings. Returns
-  // consumer_compute_at_pos.
+  // consumer_compute_at_axis.
   unsigned int forwardComputeAt_impl(
       TensorView* producer,
       TensorView* consumer,
-      unsigned int producer_compute_at_pos);
+      unsigned int producer_compute_at_axis);
 
   // Look through all the use chains of producer. Check if there's a single
   // consumer for all chains at or after the consumer specified in the computeAt
@@ -76,42 +139,25 @@ class ComputeAt {
   // of producer
   void traverseForward();
 
-  // Looks at producer tensor views of consumer_tv, recomputes its max
-  // producer position, and sets max producer position. This function can
-  // only potentially lower the max producer position of consumer_tv.
-  void resetMaxProducerPos(TensorView* consumer_tv);
-
-  // Undo the inlining of block broadcast at the innermost positions
-  //  to avoid generating repeated block broadcasts
-  void hoistInnermostBroadcast();
-
-  // Update multi-output expressions. If one output is modified, all outputs
-  // should be modified as well. Propagate transformations, compute at, and
-  // produce at from tv to siblings. Run as final pass as it will invalidate the
-  // computeAt map originally computed.
-  void updateSiblings();
-
-  // Compute at pass requires tracking "maxProducerPosition" even if set simply
-  // from input tensor views. However, when lowering, we need a valid produce at
-  // position of all tensors, so inputs should never actually set their
-  // consumers maxProduceAt position.
-  void updateInputProduceAts();
-
   // Run the computeAt pass
   void runPass();
 
+  // Set outputs relative to eachother if there is not a common consumer
+  void setupOutputs();
+
   // Common consumer if it exists
   TensorView* common_consumer_ = nullptr;
 
   // Producer use chains set in, used in a few spots.
   std::deque<std::deque<TensorView*>> producer_use_chains_;
 
+  // All we need to know and keep track of for each TensorView in this pass.
+  std::unordered_map<TensorView*, ComputeAtData> tv_data;
+
   ComputeAt(
       TensorView* _producer,
       TensorView* _consumer,
-      TensorView* _reference,
-      unsigned int _reference_position,
-      ComputeAtMode _mode);
+      unsigned int _consumer_position);
 
   ~ComputeAt() = default;
 };
diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp
deleted file mode 100644 (file)
index a753d3b..0000000
+++ /dev/null
@@ -1,620 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-namespace {
-
-//! Class to figure out how many non-broadcast axes and how many broadcast axes
-//! were used to produce an iter domain. This is important for figuring out what
-//! the correct broadcasted extent is of an iteration domain.
-//!
-//! When GpuLower is available, trivial reductions are not counted as
-//! concrete domains so that they should not be used to generate
-//! for-loops.
-class InputDomainCounter : public IterVisitor {
- public:
-  // Returns number of {non-braodcast non-reduction iteration domains, broadcast
-  // and trivial reduction domains} used to generate the iteration domains in
-  // provided target domain.
-  static std::unordered_map<IterDomain*, std::pair<int, int>> produceCounts(
-      const std::vector<IterDomain*>& domain,
-      GpuLower* gpu_lower) {
-    if (domain.empty()) {
-      return std::unordered_map<IterDomain*, std::pair<int, int>>();
-    }
-
-    InputDomainCounter counter(domain);
-
-    std::unordered_map<IterDomain*, std::pair<int, int>> count_map;
-    for (auto entry : counter.domain_set_) {
-      auto id = entry.first;
-      auto input_id_set = entry.second;
-      int concrete_counts = 0;
-      int broadcast_counts = 0;
-      for (auto input_id : input_id_set) {
-        if (input_id->isBroadcast() ||
-            (gpu_lower &&
-             gpu_lower->trivialReductionInfo().isDerived(input_id))) {
-          broadcast_counts++;
-        } else {
-          concrete_counts++;
-        }
-      }
-      count_map[id] = {concrete_counts, broadcast_counts};
-    }
-
-    // Inputs may be root domains which wouldn't have any entries if no exprs
-    // were traversed, so manually insert their count
-    for (auto id : domain) {
-      if (count_map.find(id) == count_map.end()) {
-        count_map[id] =
-            (id->isBroadcast() ||
-             (gpu_lower && gpu_lower->trivialReductionInfo().isDerived(id)))
-            ? std::make_pair(0, 1)
-            : std::make_pair(1, 0);
-      }
-    }
-    return count_map;
-  }
-
- private:
-  InputDomainCounter(const std::vector<IterDomain*>& domain_) {
-    traverseFrom(
-        domain_[0]->fusion(),
-        std::vector<Val*>(domain_.begin(), domain_.end()));
-  }
-
- private:
-  std::unordered_set<IterDomain*>& getEntry(IterDomain* id) {
-    auto domain_set_it = domain_set_.find(id);
-    if (domain_set_it == domain_set_.end()) {
-      domain_set_it =
-          domain_set_
-              .emplace(std::make_pair(id, std::unordered_set<IterDomain*>()))
-              .first;
-      domain_set_it->second.emplace(id);
-    }
-
-    return domain_set_it->second;
-  }
-
-  void handle(Expr* expr) override {
-    // If we end up moving swizzle to an Expr it would be identity here, instead
-    // of outputs being a function of all inputs
-    switch (expr->getExprType().value()) {
-      case (ExprType::Split):
-      case (ExprType::Merge):
-        break;
-      default:
-        TORCH_INTERNAL_ASSERT(
-            false, "Invalid expr type found in transform traversal.");
-    }
-
-    // Gather all non-broadcast input domains
-    std::unordered_set<IterDomain*> resulting_set;
-    for (auto input_id : ir_utils::filterByType<IterDomain>(expr->inputs())) {
-      auto input_entry = getEntry(input_id);
-      resulting_set.insert(input_entry.begin(), input_entry.end());
-    }
-    for (auto output_id : ir_utils::filterByType<IterDomain>(expr->outputs())) {
-      domain_set_.emplace(std::make_pair(output_id, resulting_set));
-    }
-  }
-
-  std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>> domain_set_;
-};
-
-// Only used once, consider removing.
-template <class T>
-std::deque<T*> deduplicateDeque(const std::deque<T*>& deque) {
-  std::unordered_set<T*> used;
-  std::deque<T*> deduped;
-  for (auto entry : deque) {
-    if (used.find(entry) == used.end()) {
-      deduped.push_back(entry);
-      used.emplace(entry);
-    }
-  }
-  return deduped;
-}
-
-void assertLowered(bool lowered) {
-  TORCH_INTERNAL_ASSERT(
-      lowered,
-      "Tried to accessed lowered values of compute at map,",
-      " however a valid lowering was not set when compute at map was created.");
-}
-
-} // namespace
-
-void ComputeAtMap::mapIds(IterDomain* id0, IterDomain* id1) {
-  auto set_it_0 = disjoint_iter_set_maps_.find(id0);
-  auto set_it_1 = disjoint_iter_set_maps_.find(id1);
-  if (set_it_0 == disjoint_iter_set_maps_.end() &&
-      set_it_1 == disjoint_iter_set_maps_.end()) {
-    // Neither iter domain has been mapped, so make a new disjoint set
-    auto new_set = std::make_shared<std::deque<IterDomain*>>();
-    new_set.get()->push_back(id0);
-    new_set.get()->push_back(id1);
-    disjoint_iter_set_maps_.emplace(std::make_pair(id0, new_set));
-    disjoint_iter_set_maps_.emplace(std::make_pair(id1, new_set));
-    disjoint_iter_sets_.push_back(new_set);
-
-    // Update parallel type map
-    if (mapping_mode_ == MappingMode::PARALLEL) {
-      if (id0->isParallelized() && id1->isParallelized()) {
-        // Both are parallelized, make sure they're the same, set entry for
-        // parallel map
-        TORCH_INTERNAL_ASSERT(id0->getParallelType() == id1->getParallelType());
-        parallel_type_map_[new_set] = id0->getParallelType();
-      } else if (id0->isParallelized() || id1->isParallelized()) {
-        // Only one is parallelized, set entry for parallel map
-        parallel_type_map_[new_set] = id0->isParallelized()
-            ? id0->getParallelType()
-            : id1->getParallelType();
-      }
-    }
-
-  } else if (
-      set_it_0 != disjoint_iter_set_maps_.end() &&
-      set_it_1 != disjoint_iter_set_maps_.end()) {
-    // Both iter domains have been mapped, so join their sets together
-    auto set0_ptr = set_it_0->second;
-    auto set1_ptr = set_it_1->second;
-
-    // If the sets are already the same, do nothing
-    if (set0_ptr == set1_ptr) {
-      return;
-    }
-
-    // Place everything in set1 into set0 and remap all ID's in set1 to set0
-    auto& set1 = *set1_ptr;
-    for (auto id : set1) {
-      set0_ptr->push_back(id);
-      disjoint_iter_set_maps_[id] = set0_ptr;
-    }
-
-    // set1 no longer needed as its IDs are copied into set0
-    disjoint_iter_sets_.erase(std::find(
-        disjoint_iter_sets_.begin(), disjoint_iter_sets_.end(), set1_ptr));
-
-    // Update parallel type map
-    if (mapping_mode_ == MappingMode::PARALLEL) {
-      auto parallel_type_0_it = parallel_type_map_.find(set0_ptr);
-      auto parallel_type_1_it = parallel_type_map_.find(set1_ptr);
-      if (parallel_type_0_it != parallel_type_map_.end() &&
-          parallel_type_1_it != parallel_type_map_.end()) {
-        // If both sets had a parallel type associated with them, make sure they
-        // are the same
-        TORCH_INTERNAL_ASSERT(
-            parallel_type_0_it->second == parallel_type_1_it->second);
-      } else if (parallel_type_1_it != parallel_type_map_.end()) {
-        // Set 1 has a parallel type, set 0 does not, set parallel entry
-        parallel_type_map_[set0_ptr] = parallel_type_1_it->second;
-      }
-      // Else set 0 already has the right parallel type set in the map, if at
-      // all
-
-      // Remove set1 from the parallel type map as it shouldn't exist anymore
-      parallel_type_map_.erase(set1_ptr);
-    }
-
-  } else {
-    auto existing_set = set_it_0 != disjoint_iter_set_maps_.end()
-        ? set_it_0->second
-        : set_it_1->second;
-    auto missing_id = set_it_0 != disjoint_iter_set_maps_.end() ? id1 : id0;
-    existing_set->push_back(missing_id);
-    disjoint_iter_set_maps_[missing_id] = existing_set;
-
-    // Update parallel type map
-    if (mapping_mode_ == MappingMode::PARALLEL) {
-      auto parallel_type_it = parallel_type_map_.find(existing_set);
-      if (parallel_type_it != parallel_type_map_.end() &&
-          missing_id->isParallelized()) {
-        // existing_set has a parallel type already and missing_id has a
-        // parallel type, make sure they match. No need to update map
-        TORCH_INTERNAL_ASSERT(
-            parallel_type_it->second == missing_id->getParallelType());
-      } else if (
-          parallel_type_it == parallel_type_map_.end() &&
-          id1->isParallelized()) {
-        // Set parallel type of existing_set as the newly added missing_id is
-        // parallel
-        parallel_type_map_[existing_set] = missing_id->getParallelType();
-      }
-    }
-  }
-}
-
-void ComputeAtMap::build(Fusion* fusion, GpuLower* gpu_lower) {
-  // Consumers can only show up once in an expression, keep track of all of them
-  std::vector<TensorView*> consumer_tvs;
-
-  for (auto expr : fusion->exprs()) {
-    if (!expr->outputs()[0]->isA<TensorView>()) {
-      continue;
-    }
-
-    auto tv_outputs = ir_utils::filterByType<TensorView>(expr->outputs());
-    TensorView* first_output_tv = nullptr;
-    for (auto c_tv : tv_outputs) {
-      consumer_tvs.push_back(c_tv);
-
-      if (first_output_tv == nullptr) {
-        first_output_tv = c_tv;
-      } else {
-        // Map multi outputs of an expression to eachother. c is current output,
-        // and f as first output. Keep consistent with the later section of
-        // producer and consumers. Which here producer is now "first output",
-        // and consumer is still consumer.
-
-        TORCH_INTERNAL_ASSERT(
-            c_tv->getRootDomain().size() ==
-                first_output_tv->getRootDomain().size(),
-            "Multiple outputs with mismatched dimensions is not supported. ",
-            "Only supported case is welford op where all outputs tvs have idential domains.");
-        // p->f, c->c
-        std::unordered_map<IterDomain*, IterDomain*> c2f_root_map;
-        for (size_t i = 0; i < first_output_tv->getRootDomain().size(); i++) {
-          c2f_root_map.insert(std::make_pair(
-              c_tv->getRootDomain()[i], first_output_tv->getRootDomain()[i]));
-        }
-
-        // Multi output mapping
-        auto replay_FasC = BestEffortReplay(
-            first_output_tv->domain()->domain(),
-            c_tv->domain()->domain(),
-            c2f_root_map);
-
-        auto c2f_map = replay_FasC.getReplay();
-
-        // If we're creating parallel map, only map the leaf
-        // axes. Also, the producer axis must be left of the CA
-        // point.
-        // Otherwise, map the entire replay map.
-        if (mapping_mode_ == MappingMode::PARALLEL) {
-          // Mark axes left of compute at point for parallel type tracking
-          std::unordered_set<IterDomain*> producer_axes_to_map(
-              first_output_tv->domain()->domain().begin(),
-              first_output_tv->domain()->domain().begin() +
-                  first_output_tv->getComputeAtPosition());
-
-          for (auto c_id : c_tv->domain()->domain()) {
-            auto it = c2f_map.find(c_id);
-            if (it == c2f_map.end()) {
-              continue;
-            }
-            auto f_id = it->second;
-            if (producer_axes_to_map.find(f_id) == producer_axes_to_map.end()) {
-              continue;
-            }
-            mapIds(f_id, c_id);
-          }
-        } else {
-          for (auto entry : c2f_map) {
-            auto c_id = entry.first;
-            auto f_id = entry.second;
-            // Map the id's together
-            mapIds(f_id, c_id);
-          }
-        }
-      }
-
-      auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs());
-
-      for (auto p_tv : tv_inputs) {
-        // If outside computeAt axis, we don't want to directly map
-        // consumer/producer as their thread mappings could change as long as
-        // it's across shared/global memory.
-        auto pairwise_map = PairwiseRootDomainMap(p_tv, c_tv);
-        auto c2p_root_map =
-            pairwise_map.mapConsumerToProducer(c_tv->domain(), p_tv->domain());
-
-        // Look for matching ID transformations in producer and consumer, replay
-        // producer as consumer. We want to replay producer as consumer instead
-        // of the other way around since consumer may have some broadcasted axes
-        // producer doesn't have merged into loops producer may use. If we did
-        // consumer as producer we wouldn't have this information in the
-        // mapping. If we're using this map for indexing, we do not want to
-        // propagate broadcast mismatches. If we're using it to identify loop
-        // nests, we do want to propagate mismatches.
-        auto replay_PasC = mapping_mode_ == MappingMode::LOOP ||
-                mapping_mode_ == MappingMode::PARALLEL
-            ? BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map)
-            : BestEffortReplay(
-                  p_tv->domain()->domain(),
-                  c_tv->domain()->domain(),
-                  c2p_root_map);
-
-        auto c2p_map = replay_PasC.getReplay();
-
-        // If we're creating parallel map, only map the leaf
-        // axes. Also, the producer axis must be left of the CA
-        // point.
-        // Otherwise, map the entire replay map.
-        if (mapping_mode_ == MappingMode::PARALLEL) {
-          // Mark axes left of compute at point for parallel type tracking
-          std::unordered_set<IterDomain*> producer_axes_to_map(
-              p_tv->domain()->domain().begin(),
-              p_tv->domain()->domain().begin() + p_tv->getComputeAtPosition());
-
-          for (auto c_id : c_tv->domain()->domain()) {
-            auto it = c2p_map.find(c_id);
-            if (it == c2p_map.end()) {
-              continue;
-            }
-            auto p_id = it->second;
-            if (producer_axes_to_map.find(p_id) == producer_axes_to_map.end()) {
-              continue;
-            }
-            mapIds(p_id, c_id);
-          }
-        } else {
-          for (auto entry : c2p_map) {
-            auto c_id = entry.first;
-            auto p_id = entry.second;
-            // Map the id's together
-            mapIds(p_id, c_id);
-          }
-        }
-      }
-    }
-  }
-
-  // deduplicate iter domain entries in each set
-  for (const auto& iter_set : disjoint_iter_sets_) {
-    *iter_set = deduplicateDeque(*iter_set);
-  }
-
-  // For each IterDomain set we will track how many concrete root domains were
-  // used to generate the IterDomain. Used to populate conrete_id_map. Concrete
-  // ID has maximum of concrete ids, ties are decided based on n_broadcast_ids.
-  // Refer to AdvancedLowering5 for why we need to split ties with broadcast
-  // dims.
-  std::unordered_map<IterDomain*, int> n_concrete_ids_;
-  std::unordered_map<IterDomain*, int> n_broadcast_ids_;
-
-  for (auto c_tv : consumer_tvs) {
-    auto counts =
-        InputDomainCounter::produceCounts(c_tv->domain()->domain(), gpu_lower);
-    std::transform(
-        counts.begin(),
-        counts.end(),
-        std::inserter(n_concrete_ids_, n_concrete_ids_.end()),
-        [](auto counts_entry) {
-          return std::make_pair(counts_entry.first, counts_entry.second.first);
-        });
-    std::transform(
-        counts.begin(),
-        counts.end(),
-        std::inserter(n_broadcast_ids_, n_broadcast_ids_.end()),
-        [](auto counts_entry) {
-          return std::make_pair(counts_entry.first, counts_entry.second.second);
-        });
-  }
-
-  for (auto inp_tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
-    auto counts = InputDomainCounter::produceCounts(
-        inp_tv->domain()->domain(), gpu_lower);
-    std::transform(
-        counts.begin(),
-        counts.end(),
-        std::inserter(n_concrete_ids_, n_concrete_ids_.end()),
-        [](auto counts_entry) {
-          return std::make_pair(counts_entry.first, counts_entry.second.first);
-        });
-    std::transform(
-        counts.begin(),
-        counts.end(),
-        std::inserter(n_broadcast_ids_, n_broadcast_ids_.end()),
-        [](auto counts_entry) {
-          return std::make_pair(counts_entry.first, counts_entry.second.second);
-        });
-  }
-
-  // Populate concrete id map
-  for (const auto& set : disjoint_iter_sets_) {
-    int max_concrete_count = -1;
-    int max_broadcast_count = -1;
-    IterDomain* concrete_id = nullptr;
-    for (auto id : *set) {
-      int concrete_count = n_concrete_ids_.at(id);
-      if (concrete_count >= max_concrete_count) {
-        int broadcast_count = n_broadcast_ids_.at(id);
-        if (concrete_count > max_concrete_count ||
-            broadcast_count > max_broadcast_count) {
-          max_concrete_count = concrete_count;
-          max_broadcast_count = broadcast_count;
-          concrete_id = id;
-        }
-      }
-    }
-
-    TORCH_INTERNAL_ASSERT(
-        concrete_id != nullptr, "Could not concretize an IterDomain set.");
-
-    for (auto id : *set) {
-      concrete_id_map_[id] = concrete_id;
-      if (mapping_mode_ == MappingMode::PARALLEL) {
-        auto parallel_map_it = parallel_type_map_.find(set);
-        // Parallelize all IterDomains to simplify lowering and codegen
-        if (parallel_map_it != parallel_type_map_.end()) {
-          // Don't propogate vectorize like other parallel types
-          if (parallel_map_it->second != ParallelType::Vectorize) {
-            id->parallelize(parallel_map_it->second);
-          }
-        }
-      }
-    }
-  }
-
-  if (gpu_lower != nullptr) {
-    convertToKir(fusion, gpu_lower);
-  }
-}
-
-void ComputeAtMap::convertToKir(Fusion* fusion, GpuLower* gpu_lower) {
-  TORCH_INTERNAL_ASSERT(fusion != nullptr);
-  TORCH_INTERNAL_ASSERT(gpu_lower != nullptr);
-
-  has_lowered_kir_ = true;
-
-  std::unordered_map<
-      std::shared_ptr<std::deque<IterDomain*>>,
-      std::shared_ptr<std::deque<kir::IterDomain*>>>
-      disjoint_set_2_kir;
-
-  for (const auto& disjoint_iter_set : disjoint_iter_set_maps_) {
-    auto fusion_set = disjoint_iter_set.second;
-    auto kir_set_it = disjoint_set_2_kir.find(fusion_set);
-    std::shared_ptr<std::deque<kir::IterDomain*>> kir_set;
-    if (kir_set_it == disjoint_set_2_kir.end()) {
-      kir_set = std::make_shared<std::deque<kir::IterDomain*>>();
-      std::transform(
-          fusion_set->begin(),
-          fusion_set->end(),
-          std::inserter(*kir_set, kir_set->begin()),
-          [&gpu_lower](IterDomain* id) {
-            return gpu_lower->lowerValue(id)->as<kir::IterDomain>();
-          });
-      disjoint_set_2_kir.emplace(std::make_pair(fusion_set, kir_set));
-    } else {
-      kir_set = kir_set_it->second;
-    }
-    kir_disjoint_iter_set_maps_.emplace(std::make_pair(
-        gpu_lower->lowerValue(disjoint_iter_set.first)->as<kir::IterDomain>(),
-        kir_set));
-  }
-
-  for (auto entry : concrete_id_map_) {
-    kir_concrete_id_map_.emplace(std::make_pair(
-        gpu_lower->lowerValue(entry.first)->as<kir::IterDomain>(),
-        gpu_lower->lowerValue(entry.second)->as<kir::IterDomain>()));
-  }
-
-  for (const auto& entry : disjoint_iter_set_maps_) {
-    kir_2_fusion_[gpu_lower->lowerValue(entry.first)->as<kir::IterDomain>()] =
-        entry.first;
-  }
-
-  // Make sure we have all IterDomains that could be used to generate a ForLoop
-  for (auto expr : fusion->exprs()) {
-    if (!expr->outputs()[0]->isA<TensorView>()) {
-      continue;
-    }
-
-    auto tv_outputs = ir_utils::filterByType<TensorView>(expr->outputs());
-
-    for (auto out : tv_outputs) {
-      for (auto entry : out->domain()->domain()) {
-        kir_2_fusion_[gpu_lower->lowerValue(entry)->as<kir::IterDomain>()] =
-            entry;
-      }
-    }
-  }
-}
-
-bool ComputeAtMap::areMapped(IterDomain* id0, IterDomain* id1) const {
-  if (id0 == id1) {
-    return true;
-  }
-  auto set0_it = disjoint_iter_set_maps_.find(id0);
-  auto set1_it = disjoint_iter_set_maps_.find(id1);
-  if (set0_it == disjoint_iter_set_maps_.end() ||
-      set1_it == disjoint_iter_set_maps_.end()) {
-    return false;
-  }
-  return (set0_it->second.get() == set1_it->second.get());
-}
-
-bool ComputeAtMap::areMapped(kir::IterDomain* id0, kir::IterDomain* id1) const {
-  assertLowered(has_lowered_kir_);
-  if (id0 == id1) {
-    return true;
-  }
-  auto set0_it = kir_disjoint_iter_set_maps_.find(id0);
-  auto set1_it = kir_disjoint_iter_set_maps_.find(id1);
-  if (set0_it == kir_disjoint_iter_set_maps_.end() ||
-      set1_it == kir_disjoint_iter_set_maps_.end()) {
-    return false;
-  }
-  return (set0_it->second.get() == set1_it->second.get());
-}
-
-IterDomain* ComputeAtMap::getConcreteMappedID(IterDomain* id) const {
-  auto it = concrete_id_map_.find(id);
-  if (it != concrete_id_map_.end()) {
-    return it->second;
-  }
-  return id;
-}
-
-kir::IterDomain* ComputeAtMap::getConcreteMappedID(kir::IterDomain* id) const {
-  assertLowered(has_lowered_kir_);
-  auto it = kir_concrete_id_map_.find(id);
-  if (it != kir_concrete_id_map_.end()) {
-    return it->second;
-  }
-  return id;
-}
-
-IterDomain* ComputeAtMap::toFusion(kir::IterDomain* kir) const {
-  assertLowered(has_lowered_kir_);
-  auto kir_2_fusion_it = kir_2_fusion_.find(kir);
-  TORCH_INTERNAL_ASSERT(
-      kir_2_fusion_it != kir_2_fusion_.end(),
-      "Kernel ir is not guarneteed to be reversible into fusion ir, could not find fusion entry. ",
-      kir::toString(kir, false));
-  return kir_2_fusion_it->second;
-}
-
-std::string ComputeAtMap::toString() const {
-  std::stringstream ss;
-
-  // We may not have cleaned up non active sets as this is intended for debug,
-  // so first grab unique entries and iterate over them.
-  std::unordered_set<std::shared_ptr<std::deque<IterDomain*>>> disjoint_sets;
-
-  for (const auto& entry : disjoint_iter_set_maps_) {
-    disjoint_sets.emplace(entry.second);
-  }
-
-  for (const auto& disjoint_set : disjoint_sets) {
-    ss << "  disjoint_set{ ";
-    TORCH_INTERNAL_ASSERT(disjoint_set->size() > 0);
-    auto concrete_id = concrete_id_map_.at(disjoint_set->front());
-    for (auto it = disjoint_set->begin(); it != disjoint_set->end(); it++) {
-      if (it != disjoint_set->begin()) {
-        ss << ", ";
-      }
-      ss << (*it);
-      if (*it == concrete_id) {
-        ss << "*";
-      }
-    }
-    ss << " }";
-    if (mapping_mode_ == MappingMode::PARALLEL) {
-      if (parallel_type_map_.find(disjoint_set) != parallel_type_map_.end()) {
-        ss << "  -> " << parallel_type_map_.at(disjoint_set);
-      } else {
-        ss << "  -> " << ParallelType::Serial;
-      }
-    }
-    ss << "\n";
-  }
-  return ss.str();
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h
deleted file mode 100644 (file)
index 6515bc3..0000000
+++ /dev/null
@@ -1,127 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <deque>
-#include <unordered_map>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class GpuLower;
-
-class TORCH_CUDA_CU_API ComputeAtMap {
- public:
-  // There's three modes of these iter domain mappings. For indexing, for loop
-  // nest mapping/generation, and to figure out the parallelization strategy.
-  //
-  // For index/loop mode consider:
-  //
-  // consumer[i0, b1] = producer[i0]
-  // consumer->merge(0) (consumer will now be [i0 * b1])
-  // When producer is replayed as consumer (the direction we use for mapping)
-  // with BestEffortReplay forward_bcast_mismatch = True the producer to
-  // consumer map will have both a mapping of consumer(i0) to producer(i0) as
-  // well as consumer(i0*b1) to producer(i0). This latter mapping is important
-  // for loop nest mappings as the consumer will generate a loop based on i0*b1
-  // and the producer may be computeAt inside this loop nest. However, for
-  // indexing we do not want these two maps as producer may be indexed as i0*i1
-  // depending on the loop nest structure and how it was built. Therefore we
-  // really need to carry two sets of maps around for lowering.
-  //
-  // Parallel mode is important if we have something like:
-  // consumer[i0o, threadIdx.x{i0i}] = producer[i0o, threadIdx.y{i0i}](computeAt
-  // = 1) which can easily happen when using shared memory. We want to make sure
-  // that the iteration domain used for loop construction (concreteId) has the
-  // proper parallelization strategy. In parallel mode we do typical iteration
-  // domain mapping, however we remove from it any iteration domains outside the
-  // computeAt of producer when mapping. This guarentees we won't map
-  // IterDomains that could have different parallelization strategies. We also
-  // propagate the parallel strategy in parallel mode so all mapped IDs that
-  // must have the same parallel type, do.
-  enum class MappingMode { PARALLEL, LOOP, INDEX };
-
-  ComputeAtMap() = default;
-  ComputeAtMap(MappingMode mapping_mode) : mapping_mode_(mapping_mode) {}
-
-  //! Builds all valid mappings. When gpu_lower is not nullptr,
-  //! equivalent mappings for KIR are also created.
-  void build(Fusion* fusion, GpuLower* gpu_lower = nullptr);
-
-  //! Returns if id0 and id1 are mapped to eachother, meaning they represent the
-  //! same loop nest in the lowered code
-  bool areMapped(IterDomain* id0, IterDomain* id1) const;
-
-  bool areMapped(kir::IterDomain* id0, kir::IterDomain* id1) const;
-
-  //! Returns an iter domain that is the maximum expanded size of all iter
-  //! domains the one provided maps to. Useful for opening loops to the correct
-  //! iteration size. Not guarenteed to return the same ID every call, but is
-  //! guarenteed to return iter domains in the same disjoint set.
-  IterDomain* getConcreteMappedID(IterDomain* id) const;
-
-  kir::IterDomain* getConcreteMappedID(kir::IterDomain* id) const;
-
-  // TODO: Would be great if we didn't need this, but we have nice functionality
-  // in iter_visitor that isn't moved over. Use of this is limited to indexing
-  // and this should definitely be removed by building out kernel ir to have
-  // better parity with fusion ir.
-  IterDomain* toFusion(kir::IterDomain* kir) const;
-
-  // Prints mapping information via Fusion IR
-  std::string toString() const;
-
- private:
-  bool has_lowered_kir_ = false;
-
-  void mapIds(IterDomain* id0, IterDomain* id1);
-
-  //! Convert everything to lowered structures (kernel ir), as we will use
-  //! this class frequently during lowering.
-  void convertToKir(Fusion* fusion, GpuLower* gpu_lower);
-
- private:
-  MappingMode mapping_mode_ = MappingMode::LOOP;
-
-  // This is actually only used when mapping mode == LOOP. Only used in expr
-  // sorting, it's actually maximum position where a loop is shared across any
-  // neighbor.
-  std::unordered_map<TensorView*, unsigned int> produce_at_map_;
-
-  // Disjoint sets of iter domains, only defined if iter domain is within
-  // compute at of a tensor view. Maps these iter domains to a set containing
-  // all other iter domains in the fusion that map to the same loop nest.
-  std::unordered_map<IterDomain*, std::shared_ptr<std::deque<IterDomain*>>>
-      disjoint_iter_set_maps_;
-
-  std::unordered_map<
-      kir::IterDomain*,
-      std::shared_ptr<std::deque<kir::IterDomain*>>>
-      kir_disjoint_iter_set_maps_;
-
-  // Keep a list of disjoint_iter_sets that's deterministic to iterate over
-  std::deque<std::shared_ptr<std::deque<IterDomain*>>> disjoint_iter_sets_;
-
-  // Tracks if there's a parallel iter domain associated a disjoint iter domain
-  // set
-  std::unordered_map<std::shared_ptr<std::deque<IterDomain*>>, ParallelType>
-      parallel_type_map_;
-
-  // For each IterDomain set we will track how many concrete root domains were
-  // used to generate the IterDomain
-  std::unordered_map<IterDomain*, IterDomain*> concrete_id_map_;
-
-  std::unordered_map<kir::IterDomain*, kir::IterDomain*> kir_concrete_id_map_;
-
-  // Map kir::IterDomain* back to the fusion IR IterDomain*.
-  // TODO: Would be great if we didn't need this.
-  std::unordered_map<kir::IterDomain*, IterDomain*> kir_2_fusion_;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/disjoint_set.h b/torch/csrc/jit/codegen/cuda/disjoint_set.h
deleted file mode 100644 (file)
index 99647a0..0000000
+++ /dev/null
@@ -1,174 +0,0 @@
-#pragma once
-
-#include <c10/util/Exception.h>
-
-#include <algorithm>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Container class DisjointSet models equivalence relationships
-//!
-//! Each instance of this class keeps a set of equivalent classes
-//! DisjointSet::join(a,b) makes the full class of a and b equivalent
-//! DisjointSet::areEqual(a,b) checks if a and b belong same class
-template <typename T, typename Hash = std::hash<T>>
-class DisjointSet {
- public:
-  DisjointSet() = default;
-
-  //! Joins the equivalent class that a and b belong to
-  //! areEqual(a',b') will be true for each a'=a and b'=b
-  //!
-  //! \param a An element from a equivalent class
-  //!          will create a new equivalent class if a does
-  //!          not belong to any
-  //! \param b An element from another equivalent class
-  //!          will create a new equivalent class if b does
-  //!          not belong to any
-  void join(T a, T b) {
-    // cases where either of the quiv class doesn't exist
-    if (!entry_map.count(a) && !entry_map.count(b)) {
-      createPoint(a);
-      entry_map[b] = fixedPoint(a);
-    } else if (!entry_map.count(a)) {
-      entry_map[a] = fixedPoint(b);
-    } else if (!entry_map.count(b)) {
-      entry_map[b] = fixedPoint(a);
-    } else {
-      // case where both equiv classes exist and need to join
-      const int i0 = fixedPoint(a);
-      const int i1 = fixedPoint(b);
-      int new_parent = 0;
-      int new_child = 0;
-
-      // Either order here is correct but joining larger class to smaller class
-      // tend to be faster
-      std::tie(new_parent, new_child) = (weights[i0] < weights[i1])
-          ? std::make_pair(i0, i1)
-          : std::make_pair(i1, i0);
-      weights[new_parent] += weights[new_child];
-      set_map[new_child] = new_parent;
-    }
-  }
-
-  //! Checks if a and b belong to the same equivalent class
-  //!
-  //! \param a An element from a equivalent class
-  //! \param b An element from another equivalent class
-  //! \returns Boolean value representing if a and b are
-  //!          recorded to be in the same equivalent class
-  //!          will return false if any of a or b doesn't
-  //!          have an equivalent class recorded
-  bool areEquivalent(T a, T b) const {
-    if (!entry_map.count(a) || !entry_map.count(b)) {
-      return false;
-    }
-    return fixedPoint(a) == fixedPoint(b);
-  }
-
-  //! Queries if an element exists in this set
-  bool contains(T a) const {
-    return entry_map.count(a) > 0;
-  }
-
-  //! Returns all elements added to this set
-  std::vector<T> getAllElements() const {
-    std::vector<T> elms(entry_map.size());
-    std::transform(
-        entry_map.begin(),
-        entry_map.end(),
-        elms.begin(),
-        [](const auto& entry_map_kv) { return entry_map_kv.first; });
-    return elms;
-  }
-
-  //! Clears the equivalence relationships
-  void clear() {
-    set_map.clear();
-    weights.clear();
-    entry_map.clear();
-    next_index_ = 0;
-  }
-
-  //! Dumps the equivalent relationships
-  std::ostream& print(std::ostream& os) const {
-    std::unordered_map<int, std::unordered_set<T, Hash>> fixedPointMap;
-    for (const auto& kv : entry_map) {
-      int fixed_point = fixedPoint(kv.first);
-      auto it = fixedPointMap.find(fixed_point);
-      if (it == fixedPointMap.end()) {
-        it = fixedPointMap.insert({fixed_point, {}}).first;
-      }
-      it->second.insert(kv.first);
-    }
-    os << "{\n";
-    for (const auto& kv : fixedPointMap) {
-      os << "\t{ ";
-      for (const auto& val : kv.second) {
-        os << toString(val) << " ";
-      }
-      os << "}\n";
-    }
-    os << "}\n";
-    return os;
-  }
-
- private:
-  // Internal fixed point implementation:
-  //  Returns the equivalent class that e belongs to
-  int getFixedPointForClass(int e) const {
-    TORCH_INTERNAL_ASSERT(static_cast<int>(set_map.size()) > e);
-    while (set_map[e] != e) {
-      // Chasing to fixed point
-      e = set_map[e];
-    }
-    return e;
-  }
-
-  //! Utility to check the class e belongs to:
-  //!
-  //! \param e element e to find the equiv class for
-  //! \returns the equivalent class that e belongs to
-  //!
-  int fixedPoint(T e) const {
-    // Handles case when i doesn't have an equivalence class
-    TORCH_INTERNAL_ASSERT(entry_map.count(e));
-
-    // Use fixed point as a representation for the equiv class
-    return getFixedPointForClass(entry_map.at(e));
-  }
-
-  //! Utility to create a new equiv class for i
-  //
-  //! \param i Element i to create the equiv class for
-  void createPoint(T i) {
-    entry_map[i] = next_index_;
-    set_map.push_back(next_index_++);
-    weights.push_back(1);
-  }
-
- private:
-  // Internal representation of the equivalence class as integers
-  // set_map implements the "parent" relationship
-  std::vector<int> set_map;
-  // Weights is used for preliminary perf optimization
-  std::vector<int> weights;
-
-  // Map the input of type T to its equivalence class
-  std::unordered_map<T, int, Hash> entry_map;
-
-  // Running counter for generating new index when
-  // Creating new equiv classes
-  int next_index_ = 0;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
index b6ba175..f3a8837 100644 (file)
@@ -48,8 +48,11 @@ void Val::dispatch(T handler, Val* val) {
         case DataType::Bool:
           ptr(handler)->handle(val->as<Bool>());
           return;
-        case DataType::Double:
-          ptr(handler)->handle(val->as<Double>());
+        case DataType::Float:
+          ptr(handler)->handle(val->as<Float>());
+          return;
+        case DataType::Half:
+          ptr(handler)->handle(val->as<Half>());
           return;
         case DataType::Int:
           ptr(handler)->handle(val->as<Int>());
@@ -70,6 +73,42 @@ void Val::dispatch(T handler, Val* val) {
     case ValType::NamedScalar:
       ptr(handler)->handle(val->as<NamedScalar>());
       return;
+
+    // TODO: remove once the Kernel IR has its own visitor
+    case ValType::TensorIndex:
+      ptr(handler)->handle(val->as<kir::TensorIndex>());
+      return;
+    case ValType::KirScalar:
+      switch (*(val->getDataType())) {
+        case DataType::Bool:
+          ptr(handler)->handle(val->as<kir::Bool>());
+          return;
+        case DataType::Float:
+          ptr(handler)->handle(val->as<kir::Float>());
+          return;
+        case DataType::Half:
+          ptr(handler)->handle(val->as<kir::Half>());
+          return;
+        case DataType::Int:
+          ptr(handler)->handle(val->as<kir::Int>());
+          return;
+        default:
+          break;
+      }
+      break;
+    case ValType::KirNamedScalar:
+      ptr(handler)->handle(val->as<kir::NamedScalar>());
+      return;
+    case ValType::KirIterDomain:
+      ptr(handler)->handle(val->as<kir::IterDomain>());
+      return;
+    case ValType::KirTensorDomain:
+      ptr(handler)->handle(val->as<kir::TensorDomain>());
+      return;
+    case ValType::KirTensorView:
+      ptr(handler)->handle(val->as<kir::TensorView>());
+      return;
+
     default:
       break;
   }
@@ -97,21 +136,42 @@ void Expr::dispatch(T handler, Expr* expr) {
     case ExprType::ReductionOp:
       ptr(handler)->handle(expr->as<ReductionOp>());
       return;
-    case ExprType::WelfordOp:
-      ptr(handler)->handle(expr->as<WelfordOp>());
-      return;
     case ExprType::BroadcastOp:
       ptr(handler)->handle(expr->as<BroadcastOp>());
       return;
-    case ExprType::TransposeOp:
-      ptr(handler)->handle(expr->as<TransposeOp>());
+
+    case ExprType::KirUnaryOp:
+      ptr(handler)->handle(expr->as<kir::UnaryOp>());
+      return;
+    case ExprType::KirBinaryOp:
+      ptr(handler)->handle(expr->as<kir::BinaryOp>());
+      return;
+    case ExprType::KirTernaryOp:
+      ptr(handler)->handle(expr->as<kir::TernaryOp>());
       return;
-    case ExprType::ShiftOp:
-      ptr(handler)->handle(expr->as<ShiftOp>());
+    case ExprType::KirReductionOp:
+      ptr(handler)->handle(expr->as<kir::ReductionOp>());
       return;
-    case ExprType::GatherOp:
-      ptr(handler)->handle(expr->as<GatherOp>());
+    case ExprType::KirBroadcastOp:
+      ptr(handler)->handle(expr->as<kir::BroadcastOp>());
       return;
+
+    case ExprType::GridReduction:
+      ptr(handler)->handle(expr->as<kir::GridReduction>());
+      return;
+    case ExprType::ForLoop:
+      ptr(handler)->handle(expr->as<kir::ForLoop>());
+      return;
+    case ExprType::IfThenElse:
+      ptr(handler)->handle(expr->as<kir::IfThenElse>());
+      return;
+    case ExprType::Allocate:
+      ptr(handler)->handle(expr->as<kir::Allocate>());
+      return;
+    case ExprType::Sync:
+      ptr(handler)->handle(expr->as<kir::Sync>());
+      return;
+
     default:
       TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
   }
@@ -135,8 +195,11 @@ void Val::constDispatch(T handler, const Val* val) {
         case DataType::Bool:
           ptr(handler)->handle(val->as<Bool>());
           return;
-        case DataType::Double:
-          ptr(handler)->handle(val->as<Double>());
+        case DataType::Float:
+          ptr(handler)->handle(val->as<Float>());
+          return;
+        case DataType::Half:
+          ptr(handler)->handle(val->as<Half>());
           return;
         case DataType::Int:
           ptr(handler)->handle(val->as<Int>());
@@ -157,6 +220,42 @@ void Val::constDispatch(T handler, const Val* val) {
     case ValType::NamedScalar:
       ptr(handler)->handle(val->as<NamedScalar>());
       return;
+
+    // TODO: remove once the Kernel IR has its own visitor
+    case ValType::TensorIndex:
+      ptr(handler)->handle(val->as<kir::TensorIndex>());
+      return;
+    case ValType::KirScalar:
+      switch (*(val->getDataType())) {
+        case DataType::Bool:
+          ptr(handler)->handle(val->as<kir::Bool>());
+          return;
+        case DataType::Float:
+          ptr(handler)->handle(val->as<kir::Float>());
+          return;
+        case DataType::Half:
+          ptr(handler)->handle(val->as<kir::Half>());
+          return;
+        case DataType::Int:
+          ptr(handler)->handle(val->as<kir::Int>());
+          return;
+        default:
+          break;
+      }
+      break;
+    case ValType::KirNamedScalar:
+      ptr(handler)->handle(val->as<kir::NamedScalar>());
+      return;
+    case ValType::KirIterDomain:
+      ptr(handler)->handle(val->as<kir::IterDomain>());
+      return;
+    case ValType::KirTensorDomain:
+      ptr(handler)->handle(val->as<kir::TensorDomain>());
+      return;
+    case ValType::KirTensorView:
+      ptr(handler)->handle(val->as<kir::TensorView>());
+      return;
+
     default:
       break;
   }
@@ -184,21 +283,42 @@ void Expr::constDispatch(T handler, const Expr* expr) {
     case ExprType::ReductionOp:
       ptr(handler)->handle(expr->as<ReductionOp>());
       return;
-    case ExprType::WelfordOp:
-      ptr(handler)->handle(expr->as<WelfordOp>());
-      return;
     case ExprType::BroadcastOp:
       ptr(handler)->handle(expr->as<BroadcastOp>());
       return;
-    case ExprType::TransposeOp:
-      ptr(handler)->handle(expr->as<TransposeOp>());
+
+    case ExprType::KirUnaryOp:
+      ptr(handler)->handle(expr->as<kir::UnaryOp>());
+      return;
+    case ExprType::KirBinaryOp:
+      ptr(handler)->handle(expr->as<kir::BinaryOp>());
+      return;
+    case ExprType::KirTernaryOp:
+      ptr(handler)->handle(expr->as<kir::TernaryOp>());
       return;
-    case ExprType::ShiftOp:
-      ptr(handler)->handle(expr->as<ShiftOp>());
+    case ExprType::KirReductionOp:
+      ptr(handler)->handle(expr->as<kir::ReductionOp>());
       return;
-    case ExprType::GatherOp:
-      ptr(handler)->handle(expr->as<GatherOp>());
+    case ExprType::KirBroadcastOp:
+      ptr(handler)->handle(expr->as<kir::BroadcastOp>());
       return;
+
+    case ExprType::GridReduction:
+      ptr(handler)->handle(expr->as<kir::GridReduction>());
+      return;
+    case ExprType::ForLoop:
+      ptr(handler)->handle(expr->as<kir::ForLoop>());
+      return;
+    case ExprType::IfThenElse:
+      ptr(handler)->handle(expr->as<kir::IfThenElse>());
+      return;
+    case ExprType::Allocate:
+      ptr(handler)->handle(expr->as<kir::Allocate>());
+      return;
+    case ExprType::Sync:
+      ptr(handler)->handle(expr->as<kir::Sync>());
+      return;
+
     default:
       TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
   }
@@ -232,8 +352,10 @@ Statement* Val::mutatorDispatch(T mutator, Val* val) {
       switch (*(val->getDataType())) {
         case DataType::Bool:
           return ptr(mutator)->mutate(val->as<Bool>());
-        case DataType::Double:
-          return ptr(mutator)->mutate(val->as<Double>());
+        case DataType::Float:
+          return ptr(mutator)->mutate(val->as<Float>());
+        case DataType::Half:
+          return ptr(mutator)->mutate(val->as<Half>());
         case DataType::Int:
           return ptr(mutator)->mutate(val->as<Int>());
         default:
@@ -246,6 +368,8 @@ Statement* Val::mutatorDispatch(T mutator, Val* val) {
       return ptr(mutator)->mutate(val->as<TensorDomain>());
     case ValType::TensorView:
       return ptr(mutator)->mutate(val->as<TensorView>());
+    case ValType::TensorIndex:
+      return ptr(mutator)->mutate(val->as<kir::TensorIndex>());
     case ValType::NamedScalar:
       return ptr(mutator)->mutate(val->as<NamedScalar>());
     default:
@@ -269,16 +393,18 @@ Statement* Expr::mutatorDispatch(T mutator, Expr* expr) {
       return ptr(mutator)->mutate(expr->as<TernaryOp>());
     case ExprType::ReductionOp:
       return ptr(mutator)->mutate(expr->as<ReductionOp>());
-    case ExprType::WelfordOp:
-      return ptr(mutator)->mutate(expr->as<WelfordOp>());
+    case ExprType::GridReduction:
+      return ptr(mutator)->mutate(expr->as<kir::GridReduction>());
     case ExprType::BroadcastOp:
       return ptr(mutator)->mutate(expr->as<BroadcastOp>());
-    case ExprType::TransposeOp:
-      return ptr(mutator)->mutate(expr->as<TransposeOp>());
-    case ExprType::ShiftOp:
-      return ptr(mutator)->mutate(expr->as<ShiftOp>());
-    case ExprType::GatherOp:
-      return ptr(mutator)->mutate(expr->as<GatherOp>());
+    case ExprType::ForLoop:
+      return ptr(mutator)->mutate(expr->as<kir::ForLoop>());
+    case ExprType::IfThenElse:
+      return ptr(mutator)->mutate(expr->as<kir::IfThenElse>());
+    case ExprType::Allocate:
+      return ptr(mutator)->mutate(expr->as<kir::Allocate>());
+    case ExprType::Sync:
+      return ptr(mutator)->mutate(expr->as<kir::Sync>());
     default:
       TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
   }
index e83ac4e..18ae573 100644 (file)
@@ -1,48 +1,48 @@
 #pragma once
 
-#include <torch/csrc/jit/codegen/cuda/utils.h>
-
 #include <c10/util/Exception.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <unordered_map>
 
-// dispatch.h prevents the need from adding manual dispatch in every class that
-// wants to define how to process a series of nodes. dispatch.h provides 4
-// classes that can be inherited providing a means to override functions on a
-// per-node basis. There are currently 4 provided dispatch mechanisms:
-//
-// OptOutDispatch:
-//
-// provides the functions:
-// virtual void handle(ValType* irnode){}
-//
-// This provides a mechanisms to override this handle for particular node
-// types. For example if we only wanted to actually run a function on
-// BinaryOps, we could inherit OptOutDispatch and simply override: void
-// handle(BinaryOp*) { doSomething; } Then we could run through all our
-// Statement* and call OptOutDispatch::handle(statement). When a BinaryOp is
-// encountered our override function will be called. For every other node,
-// nothing will be done.
-//
-// OptInDispatch:
-//
-// This class is similar to OptOutDispatch, however if we encounter a node
-// that we haven't specified an override for in the derived class, an error
-// will be thrown. This is useful if we create a class that is expected to
-// handle any type of node it encounters.
-//
-// OptOutMutator:
-//
-// This class is similar to OptOutDispatch except the functions provided are of
-// type: virtual Statement* mutate(Statement*) this is useful for when we want
-// to have an IR node result from our overloaded functions.
-//
-// OptInMutator:
-//
-// This class is similar to OptInDispatch except the functions provided are of
-// type: virtual Statement* mutate(Statement*) this is useful for when we want
-// to have an IR node result from our overloaded functions.
+/*
+ * dispatch.h prevents the need from adding manual dispatch in every class that
+ * wants to define how to process a series of nodes. dispatch.h provides 4
+ * classes that can be inherited providing a means to override functions on a
+ * per-node basis. There are currently 4 provided dispatch mechanisms:
+ *
+ * OptOutDispatch:
+ *
+ * provides the functions:
+ * virtual void handle(ValType* irnode){}
+ *
+ * This provides a mechanisms to override this handle for particular node
+ * types. For example if we only wanted to actually run a function on
+ * BinaryOps, we could inherit OptOutDispatch and simply override: void
+ * handle(BinaryOp*) { doSomething; } Then we could run through all our
+ * Statement* and call OptOutDispatch::handle(statement). When a BinaryOp is
+ * encountered our override function will be called. For every other node,
+ * nothing will be done.
+ *
+ * OptInDispatch:
+ *
+ * This class is similar to OptOutDispatch, however if we encounter a node
+ * that we haven't specified an override for in the derived class, an error
+ * will be thrown. This is useful if we create a class that is expected to
+ * handle any type of node it encounters.
+ *
+ * OptOutMutator:
+ *
+ * This class is similar to OptOutDispatch except the functions provided are of
+ * type: virtual Statement* mutate(Statement*) this is useful for when we want
+ * to have an IR node result from our overloaded functions.
+ *
+ * OptInMutator:
+ *
+ * This class is similar to OptInDispatch except the functions provided are of
+ * type: virtual Statement* mutate(Statement*) this is useful for when we want
+ * to have an IR node result from our overloaded functions.
+ */
 
 namespace torch {
 namespace jit {
@@ -61,7 +61,8 @@ class IterDomain;
 class TensorDomain;
 class TensorView;
 class Bool;
-class Double;
+class Float;
+class Half;
 class Int;
 class NamedScalar;
 
@@ -72,16 +73,51 @@ class UnaryOp;
 class BinaryOp;
 class TernaryOp;
 class ReductionOp;
-class WelfordOp;
 class BroadcastOp;
-class TransposeOp;
-class ShiftOp;
-class GatherOp;
 
-// By default, all IR nodes are handled in this dispatch, and will call an empty
-// function on all nodes.
-class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
+// Kernel IR
+namespace kir {
+
+class Bool;
+class Float;
+class Half;
+class Int;
+class NamedScalar;
+
+class IterDomain;
+class TensorDomain;
+class TensorView;
+
+class UnaryOp;
+class BinaryOp;
+class TernaryOp;
+class ReductionOp;
+class BroadcastOp;
+
+class TensorIndex;
+class Allocate;
+class ForLoop;
+class IfThenElse;
+class GridReduction;
+class Sync;
+
+} // namespace kir
+
+/*
+ * By default, all IR nodes are handled in this dispatch, and will call an empty
+ * function on all nodes.
+ */
+class TORCH_CUDA_CU_API OptOutConstDispatch {
  public:
+  virtual ~OptOutConstDispatch() = default;
+  OptOutConstDispatch() = default;
+
+  OptOutConstDispatch(const OptOutConstDispatch& other) = default;
+  OptOutConstDispatch& operator=(const OptOutConstDispatch& other) = default;
+
+  OptOutConstDispatch(OptOutConstDispatch&& other) = default;
+  OptOutConstDispatch& operator=(OptOutConstDispatch&& other) = default;
+
   // Hierarchal dispatch functions for handle
   virtual void handle(const Statement*);
   virtual void handle(const Expr*);
@@ -92,7 +128,8 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
   virtual void handle(const TensorDomain*) {}
   virtual void handle(const TensorView*) {}
   virtual void handle(const Bool*) {}
-  virtual void handle(const Double*) {}
+  virtual void handle(const Float*) {}
+  virtual void handle(const Half*) {}
   virtual void handle(const Int*) {}
   virtual void handle(const NamedScalar*) {}
 
@@ -103,15 +140,44 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
   virtual void handle(const BinaryOp*) {}
   virtual void handle(const TernaryOp*) {}
   virtual void handle(const ReductionOp*) {}
-  virtual void handle(const WelfordOp*) {}
   virtual void handle(const BroadcastOp*) {}
-  virtual void handle(const TransposeOp*) {}
-  virtual void handle(const ShiftOp*) {}
-  virtual void handle(const GatherOp*) {}
+
+  // Kernel IR nodes
+  virtual void handle(const kir::Bool*) {}
+  virtual void handle(const kir::Float*) {}
+  virtual void handle(const kir::Half*) {}
+  virtual void handle(const kir::Int*) {}
+  virtual void handle(const kir::NamedScalar*) {}
+
+  virtual void handle(const kir::IterDomain*) {}
+  virtual void handle(const kir::TensorDomain*) {}
+  virtual void handle(const kir::TensorView*) {}
+
+  virtual void handle(const kir::UnaryOp*) {}
+  virtual void handle(const kir::BinaryOp*) {}
+  virtual void handle(const kir::TernaryOp*) {}
+  virtual void handle(const kir::ReductionOp*) {}
+  virtual void handle(const kir::BroadcastOp*) {}
+
+  virtual void handle(const kir::TensorIndex*) {}
+  virtual void handle(const kir::GridReduction*) {}
+  virtual void handle(const kir::ForLoop*) {}
+  virtual void handle(const kir::IfThenElse*) {}
+  virtual void handle(const kir::Allocate*) {}
+  virtual void handle(const kir::Sync*) {}
 };
 
-class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
+class TORCH_CUDA_CU_API OptOutDispatch {
  public:
+  virtual ~OptOutDispatch() = default;
+  OptOutDispatch() = default;
+
+  OptOutDispatch(const OptOutDispatch& other) = default;
+  OptOutDispatch& operator=(const OptOutDispatch& other) = default;
+
+  OptOutDispatch(OptOutDispatch&& other) = default;
+  OptOutDispatch& operator=(OptOutDispatch&& other) = default;
+
   // Hierarchal dispatch functions for handle
   virtual void handle(Statement*);
   virtual void handle(Expr*);
@@ -122,7 +188,8 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
   virtual void handle(TensorDomain*) {}
   virtual void handle(TensorView*) {}
   virtual void handle(Bool*) {}
-  virtual void handle(Double*) {}
+  virtual void handle(Float*) {}
+  virtual void handle(Half*) {}
   virtual void handle(Int*) {}
   virtual void handle(NamedScalar*) {}
 
@@ -133,15 +200,44 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
   virtual void handle(BinaryOp*) {}
   virtual void handle(TernaryOp*) {}
   virtual void handle(ReductionOp*) {}
-  virtual void handle(WelfordOp*) {}
   virtual void handle(BroadcastOp*) {}
-  virtual void handle(TransposeOp*) {}
-  virtual void handle(ShiftOp*) {}
-  virtual void handle(GatherOp*) {}
+
+  // Kernel IR nodes
+  virtual void handle(kir::Bool*) {}
+  virtual void handle(kir::Float*) {}
+  virtual void handle(kir::Half*) {}
+  virtual void handle(kir::Int*) {}
+  virtual void handle(kir::NamedScalar*) {}
+
+  virtual void handle(kir::IterDomain*) {}
+  virtual void handle(kir::TensorDomain*) {}
+  virtual void handle(kir::TensorView*) {}
+
+  virtual void handle(kir::UnaryOp*) {}
+  virtual void handle(kir::BinaryOp*) {}
+  virtual void handle(kir::TernaryOp*) {}
+  virtual void handle(kir::ReductionOp*) {}
+  virtual void handle(kir::BroadcastOp*) {}
+
+  virtual void handle(kir::TensorIndex*) {}
+  virtual void handle(kir::GridReduction*) {}
+  virtual void handle(kir::ForLoop*) {}
+  virtual void handle(kir::IfThenElse*) {}
+  virtual void handle(kir::Allocate*) {}
+  virtual void handle(kir::Sync*) {}
 };
 
-class TORCH_CUDA_CU_API OptInConstDispatch : public PolymorphicBase {
+class TORCH_CUDA_CU_API OptInConstDispatch {
  public:
+  virtual ~OptInConstDispatch() = default;
+  OptInConstDispatch() = default;
+
+  OptInConstDispatch(const OptInConstDispatch& other) = default;
+  OptInConstDispatch& operator=(const OptInConstDispatch& other) = default;
+
+  OptInConstDispatch(OptInConstDispatch&& other) = default;
+  OptInConstDispatch& operator=(OptInConstDispatch&& other) = default;
+
   // Hierarchal dispatch functions for handle
   virtual void handle(const Statement*);
   virtual void handle(const Expr*);
@@ -160,8 +256,11 @@ class TORCH_CUDA_CU_API OptInConstDispatch : public PolymorphicBase {
   virtual void handle(const Bool*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool.");
   }
-  virtual void handle(const Double*) {
-    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Double.");
+  virtual void handle(const Float*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Float.");
+  }
+  virtual void handle(const Half*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Half.");
   }
   virtual void handle(const Int*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int.");
@@ -183,9 +282,6 @@ class TORCH_CUDA_CU_API OptInConstDispatch : public PolymorphicBase {
   virtual void handle(const BinaryOp*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp.");
   }
-  virtual void handle(const WelfordOp*) {
-    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for WelfordOp.");
-  }
   virtual void handle(const TernaryOp*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TernaryOp.");
   }
@@ -195,19 +291,87 @@ class TORCH_CUDA_CU_API OptInConstDispatch : public PolymorphicBase {
   virtual void handle(const BroadcastOp*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp.");
   }
-  virtual void handle(const TransposeOp*) {
-    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TransposeOp.");
+
+  // Kernel IR
+  //
+  // TODO: move to a specialized visitor
+  //
+
+  virtual void handle(const kir::Bool*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Bool.");
+  }
+  virtual void handle(const kir::Float*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Float.");
+  }
+  virtual void handle(const kir::Half*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Half.");
+  }
+  virtual void handle(const kir::Int*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Int.");
+  }
+  virtual void handle(const kir::NamedScalar*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::NamedScalar.");
+  }
+
+  virtual void handle(const kir::IterDomain*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::IterDomain.");
+  }
+  virtual void handle(const kir::TensorDomain*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TensorDomain.");
+  }
+  virtual void handle(const kir::TensorView*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TensorView.");
+  }
+
+  virtual void handle(const kir::UnaryOp*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::UnaryOp.");
+  }
+  virtual void handle(const kir::BinaryOp*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::BinaryOp.");
+  }
+  virtual void handle(const kir::TernaryOp*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TernaryOp.");
+  }
+  virtual void handle(const kir::ReductionOp*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::ReductionOp.");
+  }
+  virtual void handle(const kir::BroadcastOp*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::BroadcastOp.");
+  }
+
+  virtual void handle(const kir::GridReduction*) {
+    TORCH_INTERNAL_ASSERT(
+        false, "Handle not overriden for kir::GridReduction.");
+  }
+  virtual void handle(const kir::ForLoop*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::ForLoop.");
   }
-  virtual void handle(const ShiftOp*) {
-    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ShiftOp.");
+  virtual void handle(const kir::Allocate*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Allocate.");
   }
-  virtual void handle(const GatherOp*) {
-    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for GatherOp.");
+  virtual void handle(const kir::Sync*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Sync.");
+  }
+  virtual void handle(const kir::IfThenElse*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::IfThenElse.");
+  }
+
+  virtual void handle(const kir::TensorIndex*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TensorIndex.");
   }
 };
 
-class TORCH_CUDA_CU_API OptInDispatch : public PolymorphicBase {
+class TORCH_CUDA_CU_API OptInDispatch {
  public:
+  virtual ~OptInDispatch() = default;
+  OptInDispatch() = default;
+
+  OptInDispatch(const OptInDispatch& other) = default;
+  OptInDispatch& operator=(const OptInDispatch& other) = default;
+
+  OptInDispatch(OptInDispatch&& other) = default;
+  OptInDispatch& operator=(OptInDispatch&& other) = default;
+
   // Hierarchal dispatch functions for handle
   virtual void handle(Statement* s);
   virtual void handle(Expr* e);
@@ -226,8 +390,11 @@ class TORCH_CUDA_CU_API OptInDispatch : public PolymorphicBase {
   virtual void handle(Bool*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool.");
   }
-  virtual void handle(Double*) {
-    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Double.");
+  virtual void handle(Float*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Float.");
+  }
+  virtual void handle(Half*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Half.");
   }
   virtual void handle(Int*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int.");
@@ -255,26 +422,92 @@ class TORCH_CUDA_CU_API OptInDispatch : public PolymorphicBase {
   virtual void handle(ReductionOp*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp.");
   }
-  virtual void handle(WelfordOp*) {
-    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for WelfordOp.");
-  }
   virtual void handle(BroadcastOp*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp.");
   }
-  virtual void handle(TransposeOp*) {
-    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TransposeOp.");
+
+  // Kernel IR
+  //
+  // TODO: move to a specialized visitor
+  //
+
+  virtual void handle(kir::Bool*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool.");
+  }
+  virtual void handle(kir::Float*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Float.");
+  }
+  virtual void handle(kir::Half*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Half.");
+  }
+  virtual void handle(kir::Int*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int.");
+  }
+  virtual void handle(kir::NamedScalar*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::NamedScalar.");
+  }
+  virtual void handle(kir::TensorIndex*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TensorIndex.");
+  }
+
+  virtual void handle(kir::IterDomain*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::IterDomain.");
+  }
+  virtual void handle(kir::TensorDomain*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TensorDomain.");
+  }
+  virtual void handle(kir::TensorView*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TensorView.");
   }
-  virtual void handle(ShiftOp*) {
-    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ShiftOp.");
+
+  virtual void handle(kir::UnaryOp*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::UnaryOp.");
+  }
+  virtual void handle(kir::BinaryOp*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::BinaryOp.");
+  }
+  virtual void handle(kir::TernaryOp*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TernaryOp.");
+  }
+  virtual void handle(kir::ReductionOp*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::ReductionOp.");
+  }
+  virtual void handle(kir::BroadcastOp*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::BroadcastOp.");
+  }
+
+  virtual void handle(kir::GridReduction*) {
+    TORCH_INTERNAL_ASSERT(
+        false, "Handle not overriden for kir::GridReduction.");
+  }
+  virtual void handle(kir::ForLoop*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::ForLoop.");
   }
-  virtual void handle(GatherOp*) {
-    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for GatherOp.");
+  virtual void handle(kir::Allocate*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Allocate.");
+  }
+  virtual void handle(kir::Sync*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Sync.");
+  }
+  virtual void handle(kir::IfThenElse*) {
+    TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::IfThenElse.");
   }
 };
 
 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
-class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
+class TORCH_CUDA_CU_API OptOutMutator {
  public:
+  virtual ~OptOutMutator() = default;
+  OptOutMutator() = default;
+
+  OptOutMutator(const OptOutMutator& other) = default;
+  OptOutMutator& operator=(const OptOutMutator& other) = default;
+
+  OptOutMutator(OptOutMutator&& other) = default;
+  OptOutMutator& operator=(OptOutMutator&& other) = default;
+
+  virtual void mutate(Fusion* fusion);
+
   // Hierarchal dispatch functions for handle
   virtual Statement* mutate(Statement* s);
   virtual Statement* mutate(Expr* e);
@@ -305,8 +538,10 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
   virtual Statement* mutate(IterDomain*);
   virtual Statement* mutate(TensorDomain*);
   virtual Statement* mutate(TensorView*);
+  virtual Statement* mutate(kir::TensorIndex*);
   virtual Statement* mutate(Bool*);
-  virtual Statement* mutate(Double*);
+  virtual Statement* mutate(Float*);
+  virtual Statement* mutate(Half*);
   virtual Statement* mutate(Int*);
   virtual Statement* mutate(NamedScalar*);
 
@@ -317,19 +552,26 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
   virtual Statement* mutate(BinaryOp*);
   virtual Statement* mutate(TernaryOp*);
   virtual Statement* mutate(ReductionOp*);
-  virtual Statement* mutate(WelfordOp*);
+  virtual Statement* mutate(kir::GridReduction*);
   virtual Statement* mutate(BroadcastOp*);
-  virtual Statement* mutate(TransposeOp*);
-  virtual Statement* mutate(ShiftOp*);
-  virtual Statement* mutate(GatherOp*);
+  virtual Statement* mutate(kir::ForLoop*);
+  virtual Statement* mutate(kir::IfThenElse*);
+  virtual Statement* mutate(kir::Allocate*);
+  virtual Statement* mutate(kir::Sync*);
 };
 
 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
-class TORCH_CUDA_CU_API OptInMutator : public PolymorphicBase {
+class TORCH_CUDA_CU_API OptInMutator {
  public:
-  std::unordered_map<Val*, Val*> mutations;
+  virtual ~OptInMutator() = default;
+  OptInMutator() = default;
+
+  OptInMutator(const OptInMutator& other) = default;
+  OptInMutator& operator=(const OptInMutator& other) = default;
+
+  OptInMutator(OptInMutator&& other) = default;
+  OptInMutator& operator=(OptInMutator&& other) = default;
 
- public:
   void registerMutation(Val* val, Val* mutation) {
     TORCH_INTERNAL_ASSERT(
         mutations.find(val) == mutations.end(),
@@ -338,6 +580,8 @@ class TORCH_CUDA_CU_API OptInMutator : public PolymorphicBase {
     mutations[val] = mutation;
   }
 
+  std::unordered_map<Val*, Val*> mutations;
+
   // Hierarchal dispatch functions for mutate
   virtual Statement* mutate(Statement*);
   virtual Statement* mutate(Expr*);
@@ -353,9 +597,15 @@ class TORCH_CUDA_CU_API OptInMutator : public PolymorphicBase {
   virtual Statement* mutate(TensorView*) {
     TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TensorView.");
   }
+  virtual Statement* mutate(kir::TensorIndex*) {
+    TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TensorIndex.");
+  }
   virtual Statement* mutate(Bool*) {
     TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Bool.");
   }
+  virtual Statement* mutate(Float*) {
+    TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Float.");
+  }
   virtual Statement* mutate(Int*) {
     TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Int.");
   }
@@ -382,20 +632,23 @@ class TORCH_CUDA_CU_API OptInMutator : public PolymorphicBase {
   virtual Statement* mutate(ReductionOp*) {
     TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ReductionOp.");
   }
-  virtual Statement* mutate(WelfordOp*) {
-    TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for WelfordOp.");
+  virtual Statement* mutate(kir::GridReduction*) {
+    TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for GridReduction.");
   }
   virtual Statement* mutate(BroadcastOp*) {
     TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for BroadcastOp.");
   }
-  virtual Statement* mutate(TransposeOp*) {
-    TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TransposeOp.");
+  virtual Statement* mutate(kir::ForLoop*) {
+    TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ForLoop.");
+  }
+  virtual Statement* mutate(kir::Allocate*) {
+    TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Allocate.");
   }
-  virtual Statement* mutate(ShiftOp*) {
-    TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ShiftOp.");
+  virtual Statement* mutate(kir::Sync*) {
+    TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Sync.");
   }
-  virtual Statement* mutate(GatherOp*) {
-    TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for GatherOp.");
+  virtual Statement* mutate(kir::IfThenElse*) {
+    TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for IfThenElse.");
   }
 };
 
index 2eaf01c..776af69 100644 (file)
@@ -1,15 +1,11 @@
-
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-
 #include <torch/csrc/jit/codegen/cuda/codegen.h>
 #include <torch/csrc/jit/codegen/cuda/executor_kernel_arg.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
 #include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/utils.h>
+
+#include <torch/csrc/jit/codegen/cuda/executor.h>
 
 #include <ATen/core/LegacyTypeDispatch.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAStream.h>
 #include <c10/util/irange.h>
 
-#include <fstream>
+#include <cstdlib>
 
 namespace torch {
 namespace jit {
 namespace fuser {
 namespace cuda {
 
-int FusionExecutor::fusion_id_counter_ = 0; // NOLINT
-
-namespace {
-
-static const char* defineIndexMode(KernelIndexMode index_mode) {
-  switch (index_mode) {
-    case KernelIndexMode::INT32:
-      return "typedef int nvfuser_index_t;\n";
-    case KernelIndexMode::INT64:
-      return "typedef int64_t nvfuser_index_t;\n";
-    default:
-      break;
-  }
-
-  TORCH_INTERNAL_ASSERT(false, "unknow indexing mode");
-  return "";
-}
-
-static const char* defineIntegerTypes() {
-  return R"(
-typedef unsigned char uint8_t;
-typedef signed char int8_t;
-typedef short int int16_t;
-typedef unsigned int uint32_t;
-typedef long long int int64_t;
-typedef unsigned long long int uint64_t;
-)";
-}
-
-} // namespace
+int FusionExecutor::fusion_id_counter_ = 0;
 
 std::string FusionExecutor::getStructuredCode(const std::string& kernel) {
   // generating cuda code;
@@ -67,30 +34,20 @@ std::string FusionExecutor::getStructuredCode(const std::string& kernel) {
 #endif
 #endif
   code += std::string("namespace ") + FusionExecutor::kernelNamespace() +
-      " {\n" + defineIntegerTypes() + defineIndexMode(options_.index_mode) +
-      executor_utils::kernelPreamble() + kernel + "}\n";
-
-  if (isDebugDumpEnabled(DebugDumpOption::CudaKernel)) {
-    std::cout << "\n======= Codegen output for kernel: " << kernelName()
-              << " =======\n\n"
-              << kernel << "\n======================================\n\n";
-  } else if (isDebugDumpEnabled(DebugDumpOption::CudaFull)) {
-    std::cout << "\n======= Codegen output for kernel: " << kernelName()
-              << " =======\n\n"
-              << code << "\n======================================\n\n";
-  } else if (isDebugDumpEnabled(DebugDumpOption::CudaToFile)) {
-    std::stringstream file_name;
-    file_name << "__tmp_kernel" << fusion_id_ << ".cu";
-    std::cout << "PRINTING: " << file_name.str() << std::endl;
-    std::ofstream out(file_name.str());
-    out << code << std::endl;
-    out.close();
+      " {\n" + executor_utils::kernelPreamble() + kernel + "}\n";
+
+  const char* debug_env = std::getenv("PYTORCH_CUDA_FUSER_DEBUG");
+  if (debug_env && atoi(debug_env)) {
+    std::cout << "\n==== codegen output for kernel: " << kernelName()
+              << " ====" << std::endl
+              << code << std::endl
+              << "======================================\n"
+              << std::endl;
   }
 
   return code;
 }
 
-// TODO: come up with a more user friendly interface
 void FusionExecutor::debugCompileFusionFromStr(
     Fusion* fusion,
     const std::string& code,
@@ -101,13 +58,8 @@ void FusionExecutor::debugCompileFusionFromStr(
   FusionGuard fg(&fusion_);
   options_ = options;
 
-  if (isDebugDumpEnabled(DebugDumpOption::FusionIr)) {
-    fusion->print();
-  } else if (isDebugDumpEnabled(DebugDumpOption::FusionIrMath)) {
-    fusion->printMath();
-  }
-
-  if (isDebugDumpEnabled(DebugDumpOption::CudaFull)) {
+  const char* debug_env = std::getenv("PYTORCH_CUDA_FUSER_DEBUG");
+  if (debug_env && atoi(debug_env)) {
     std::cout << "\n==== codegen output for kernel: " << kernelName()
               << " ====" << std::endl
               << code << std::endl
@@ -121,16 +73,20 @@ void FusionExecutor::debugCompileFusionFromStr(
   lowered_ = GpuLower(&fusion_);
   const auto kernel = lowered_.kernel();
 
-  if (isDebugDumpEnabled(DebugDumpOption::KernelIr)) {
+  const char* dump_kir_env = std::getenv("PYTORCH_CUDA_FUSER_DUMP_KIR");
+  if (dump_kir_env && atoi(dump_kir_env)) {
     kernel->print();
   }
 
   const auto& kernel_summary = kernel->summary();
+  has_block_reductions = kernel_summary.has_block_reductions;
+  has_grid_reductions = kernel_summary.has_grid_reductions;
+  has_block_broadcasts = kernel_summary.has_block_broadcasts;
 
   if (!kernel_summary.static_smem_allocations.empty()) {
-    kir::ExpressionEvaluator static_evaluator;
+    StatefulExpressionEvaluator static_evaluator(&fusion_);
     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-    const auto static_smem_size = computeSharedMemory(
+    unsigned static_smem_size = computeSharedMemory(
         static_evaluator, kernel_summary.static_smem_allocations);
     TORCH_INTERNAL_ASSERT(
         static_smem_size < max_device_smem,
@@ -142,11 +98,7 @@ void FusionExecutor::debugCompileFusionFromStr(
       fusion_id_ > 0, "assign a fusion_id_ <= 0 is not accepted.");
 }
 
-void FusionExecutor::compileFusion(
-    Fusion* fusion,
-    CompileOptions options,
-    const at::ArrayRef<IValue>& inputs,
-    const LaunchParams& launch_constraints) {
+void FusionExecutor::compileFusion(Fusion* fusion, CompileOptions options) {
   FUSER_PERF_SCOPE("compileFusion");
 
   TORCH_INTERNAL_ASSERT(
@@ -158,17 +110,10 @@ void FusionExecutor::compileFusion(
         "Output types from fusions that are not tensors are not supported at this point.");
   }
 
-  if (isDebugDumpEnabled(DebugDumpOption::FusionIr)) {
-    fusion->print();
-  } else if (isDebugDumpEnabled(DebugDumpOption::FusionIrMath)) {
-    fusion->printMath();
-  }
-
   // Clone the fusion so we can store it
   fusion_ = *fusion;
   FusionGuard fg(&fusion_);
   options_ = options;
-  c10::DeviceGuard dg(options_.device);
 
   TORCH_INTERNAL_ASSERT(
       options.device.is_cuda(), "Provided device to CUDA fuser is the CPU.");
@@ -181,7 +126,8 @@ void FusionExecutor::compileFusion(
   lowered_ = GpuLower(&fusion_);
   const auto kernel = lowered_.kernel();
 
-  if (isDebugDumpEnabled(DebugDumpOption::KernelIr)) {
+  const char* dump_kir_env = std::getenv("PYTORCH_CUDA_FUSER_DUMP_KIR");
+  if (dump_kir_env && atoi(dump_kir_env)) {
     kernel->print();
   }
 
@@ -189,46 +135,24 @@ void FusionExecutor::compileFusion(
   const auto structured_code = getStructuredCode(kernel_code);
 
   const auto& kernel_summary = kernel->summary();
+  has_block_reductions = kernel_summary.has_block_reductions;
+  has_grid_reductions = kernel_summary.has_grid_reductions;
+  has_block_broadcasts = kernel_summary.has_block_broadcasts;
 
   if (!kernel_summary.static_smem_allocations.empty()) {
-    kir::ExpressionEvaluator static_evaluator;
+    StatefulExpressionEvaluator static_evaluator(&fusion_);
     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-    const auto static_smem_size = computeSharedMemory(
+    unsigned static_smem_size = computeSharedMemory(
         static_evaluator, kernel_summary.static_smem_allocations);
     TORCH_INTERNAL_ASSERT(
         static_smem_size < max_device_smem,
         "The static shared memory allocation is larger than available memory.");
   }
 
-  if (kernel_summary.has_dynamic_local_memory_allocations) {
-    std::stringstream ss;
-    ss << "Allocations must be based on constant integers for local memory. However, found: ";
-    for (auto alloc : kernel_summary.dynamic_lmem_allocations) {
-      ss << toString(alloc->buffer(), false) << ", ";
-    }
-    ss << " have dynamic allocations but are placed in local memory.";
-    TORCH_INTERNAL_ASSERT(false, ss.str());
-  }
-
-  TORCH_CHECK(
-      !kernel_summary.has_grid_reduction_in_loop,
-      "Grid reduction must not be placed inside a loop.");
-
-  // TODO: pass block_size here;
-  c10::optional<int> block_size = c10::nullopt;
-  if (!inputs.empty()) {
-    auto expr_eval = executor_utils::bindKernelInputs(inputs, kernel);
-    auto launch_params = computeLaunchParams(launch_constraints, expr_eval);
-    block_size = launch_params.nThreads();
-    TORCH_INTERNAL_ASSERT(
-        block_size > 0, "launch param inferred block size < 0");
-  }
-
   compiled_kernel_ = executor_utils::nvrtcCompile(
       structured_code,
       (kernelNamespace() + "::" + kernelName()).c_str(),
-      fusion_id_,
-      block_size);
+      fusion_id_);
   TORCH_INTERNAL_ASSERT(
       fusion_id_ > 0, "failed to assign a fusion_id_ after compilation.");
 }
@@ -236,36 +160,34 @@ void FusionExecutor::compileFusion(
 namespace {
 
 at::Tensor inferAndAlloc(
-    const kir::TensorView* tv,
-    const std::vector<kir::Val*>& sizes,
-    kir::ExpressionEvaluator& expr_eval,
+    const TensorView* tv,
+    StatefulExpressionEvaluator& see,
     const CompileOptions& options,
     bool zero_init = false) {
   FUSER_PERF_SCOPE("inferAndAlloc");
 
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  std::vector<int64_t> inferred_sizes;
-
-  for (const auto size : sizes) {
-    const auto inferred_val = expr_eval.evaluate(size);
+  std::vector<int64_t> sizes;
+  for (auto id : TensorDomain::noReductions(tv->getMaybeRFactorDomain())) {
+    auto inferred_val = see.inferValue(id->rawExtent());
     TORCH_INTERNAL_ASSERT(
         inferred_val.has_value(),
         "Could not launch kernel as program could not infer ",
-        kir::toString(size),
+        id->rawExtent(),
         " for the buffer ",
-        kir::toString(tv));
-    inferred_sizes.push_back(inferred_val.value());
+        tv);
+    sizes.push_back(inferred_val.value());
   }
 
-  const auto at_type = data_type_to_aten(tv->dtype());
+  auto at_type = data_type_to_aten(tv->getDataType().value());
 
   if (zero_init) {
-    const auto tensor_options =
+    auto tensor_options =
         at::TensorOptions().dtype(at_type).device(options.device);
-    c10::IntArrayRef isizes(inferred_sizes);
+    c10::IntArrayRef isizes(sizes);
     return at::zeros(isizes, tensor_options);
   } else {
-    c10::IntArrayRef isizes(inferred_sizes);
+    c10::IntArrayRef isizes(sizes);
     // Non Variable type guard for empty_cuda call
     at::AutoDispatchBelowADInplaceOrView non_variable_type_mode;
     return at::native::empty_cuda(
@@ -273,33 +195,11 @@ at::Tensor inferAndAlloc(
   }
 }
 
-at::Tensor inferAndAllocOutput(
-    const kir::TensorView* tv,
-    kir::ExpressionEvaluator& expr_eval,
-    const CompileOptions& options,
-    bool zero_init = false) {
-  const auto domain = tv->domain();
-  const auto maybe_rfactor_domain =
-      domain->hasRFactor() ? domain->rfactorDomain() : domain->rootDomain();
-
-  std::vector<kir::Val*> sizes;
-
-  for (const auto id : maybe_rfactor_domain) {
-    if (id->isReduction() ||
-        id->iterType() == IterType::BroadcastWithoutStride) {
-      continue;
-    }
-    sizes.push_back(id->extent());
-  }
-
-  return inferAndAlloc(tv, sizes, expr_eval, options, zero_init);
-}
-
 } // namespace
 
 uint64_t FusionExecutor::computeSharedMemory(
-    kir::ExpressionEvaluator& expr_eval,
-    const std::vector<const kir::Allocate*>& buffers,
+    StatefulExpressionEvaluator& see,
+    const std::vector<kir::Allocate*>& buffers,
     bool align_padding,
     uint64_t total) {
   FUSER_PERF_SCOPE("computeSharedMemory");
@@ -307,9 +207,9 @@ uint64_t FusionExecutor::computeSharedMemory(
     // If this buffer aliases another buffer,
     // then do not allocate memory for this buffer.
     if (smem_alloc->alias() == nullptr) {
-      const auto inferred_val = expr_eval.evaluate(smem_alloc->size());
+      auto inferred_val = see.inferValue(smem_alloc->size());
       if (inferred_val.has_value()) {
-        const uint64_t data_size = dataTypeSize(smem_alloc->buffer()->dtype());
+        const uint64_t data_size = dataTypeSize(smem_alloc->buffer_type());
         // Add padding to align dynamic shared memory
         if (align_padding) {
           total = ceilDiv(total, data_size) * data_size;
@@ -330,25 +230,24 @@ uint64_t FusionExecutor::computeSharedMemory(
 
 LaunchParams FusionExecutor::computeLaunchParams(
     const LaunchParams& launch_constraints,
-    kir::ExpressionEvaluator& expr_eval) {
-  FUSER_PERF_SCOPE("FusionExecutor::ComputeLaunchParams");
+    StatefulExpressionEvaluator& see) {
+  FUSER_PERF_SCOPE("computeLaunchParams");
 
   LaunchParams launch_params;
 
   // Lets collect all IterDomains that are bound to a thread binding
-  std::unordered_map<ParallelType, std::vector<const kir::Val*>, TypeHash>
+  std::unordered_map<ParallelType, std::vector<IterDomain*>, TypeHash>
       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-      parallel_iter_extents;
+      parallel_iter_domains;
   for (auto tv : getUsedTVs()) {
     for (auto id : tv->domain()->domain()) {
       if (id->isThread() && !id->isBroadcast()) {
-        // TODO(kir): we should rewrite this logic based on the Kernel object
-        auto kir_extent = lowered_.lowerValue(id->extent());
-        const auto it = parallel_iter_extents.find(id->getParallelType());
-        if (it != parallel_iter_extents.end()) {
-          it->second.push_back(kir_extent);
+        if (parallel_iter_domains.find(id->getParallelType()) !=
+            parallel_iter_domains.end()) {
+          parallel_iter_domains.at(id->getParallelType()).push_back(id);
         } else {
-          parallel_iter_extents[id->getParallelType()] = {kir_extent};
+          parallel_iter_domains[id->getParallelType()] =
+              std::vector<IterDomain*>({id});
         }
       }
     }
@@ -357,47 +256,50 @@ LaunchParams FusionExecutor::computeLaunchParams(
   // If any dimension was set in launch constraints we need to run through
   // IterDomains that have been parallelized, and bind those values. Or make
   // sure if they could be inferred the inference matches what was set.
-  for (auto& entry : parallel_iter_extents) {
-    auto p_type = entry.first;
-    if (launch_constraints.hasDim(p_type)) {
-      auto parallel_extents = entry.second;
-      for (auto extent : parallel_extents) {
-        auto inferred_val = expr_eval.evaluate(extent);
-        if (inferred_val.has_value()) {
-          // This value could have been inferred, make sure it was set right.
-          bool valid =
-              inferred_val.value() == launch_constraints.getDim(p_type) ||
-              launch_constraints.getRawVal(p_type) == -1;
-          if (!useFallback() && !valid) {
-            TORCH_WARN_ONCE(
-                "Cannot validate parallelization scheme, "
-                "this may be due to mixed broadcast axes that are parallelized.");
+  if (launch_constraints.nBlocks() * launch_constraints.nThreads() != -1) {
+    for (auto& entry : parallel_iter_domains) {
+      auto p_type = entry.first;
+      if (launch_constraints.hasDim(p_type)) {
+        auto parallel_ids = entry.second;
+        for (auto parallel_id : parallel_ids) {
+          auto inferred_val = see.inferValue(parallel_id->rawExtent());
+          if (inferred_val.has_value()) {
+            // This value could have been inferred, make sure it was set right.
+            TORCH_CHECK(
+                inferred_val.value() == launch_constraints.getDim(p_type) ||
+                    launch_constraints.getRawVal(p_type) == -1,
+                "inferred that ",
+                p_type,
+                " should be set to ",
+                inferred_val.value(),
+                " but launch constraints specified ",
+                launch_constraints.getDim(p_type));
+          } else {
+            // Bind the launch constraint into our evaluation context
+            see.safeBind(
+                parallel_id->rawExtent(),
+                launch_constraints.getDim(entry.first),
+                &lowered_);
+            launch_params.bind(launch_constraints.getDim(p_type), p_type);
           }
-        } else {
-          // Bind the launch constraint into our evaluation context
-          expr_eval.bind(extent, launch_constraints.getDim(p_type));
-          launch_params.bind(launch_constraints.getDim(p_type), p_type);
         }
       }
     }
   }
 
   // Run through the rest of the parallel IterDomains and infer their size
-  for (auto& entry : parallel_iter_extents) {
+  for (auto& entry : parallel_iter_domains) {
     auto p_type = entry.first;
-    auto parallel_extents = entry.second;
-    // Select the maxmimum value out of all the parallel extents
-    int64_t maximum_value = std::numeric_limits<int64_t>::min();
-    for (auto extent : parallel_extents) {
-      const auto val = expr_eval.evaluate(extent);
+    auto parallel_ids = entry.second;
+    for (auto parallel_id : parallel_ids) {
+      auto val = see.inferValue(parallel_id->rawExtent());
       TORCH_INTERNAL_ASSERT(
-          val.has_value(),
+          val,
           "Tried to evaluate the extent of ",
-          p_type,
+          parallel_id,
           " to set launch bounds but could not.");
-      maximum_value = std::max(maximum_value, *val);
+      launch_params.bind(val.value(), p_type);
     }
-    launch_params.bind(maximum_value, p_type);
   }
 
   const auto kernel = lowered_.kernel();
@@ -406,34 +308,23 @@ LaunchParams FusionExecutor::computeLaunchParams(
   // Calculate Dynamic Shared Memory Size
   // Add workspace for reduction and broadcast
   uint64_t reduction_broadcast_workspace = 0;
-  const bool has_workspace = kernel_summary.has_block_reductions ||
-      kernel_summary.number_of_grid_reductions > 0 ||
-      kernel_summary.has_block_broadcasts;
-  if (has_workspace &&
-      kernel_summary.largest_smem_data_type != DataType::Null) {
+  if (has_block_reductions || has_grid_reductions || has_block_broadcasts) {
     // Not using nThreads here since it does not handle uninitialized value
-
-    // TODO: here is an optimization opportunity since welford uses int64_t for
-    // N while the data type is not neccessarily double. But it may need more
-    // work on the alignment
-    const int welford_factor =
-        kernel_summary.has_block_welford || kernel_summary.has_grid_welford ? 3
-                                                                            : 1;
     reduction_broadcast_workspace =
-        dataTypeSize(kernel_summary.largest_smem_data_type) * welford_factor *
+        dataTypeSize(kernel_summary.largest_smem_data_type) *
         launch_params.bdimx() * launch_params.bdimy() * launch_params.bdimz();
   }
 
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
   const uint64_t dynamic_smem_size = computeSharedMemory(
-      expr_eval,
+      see,
       kernel_summary.dynamic_smem_allocations,
       true,
       reduction_broadcast_workspace);
 
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
   const uint64_t static_smem_size =
-      computeSharedMemory(expr_eval, kernel_summary.static_smem_allocations);
+      computeSharedMemory(see, kernel_summary.static_smem_allocations);
 
   TORCH_INTERNAL_ASSERT(
       (dynamic_smem_size + static_smem_size) < max_device_smem,
@@ -444,27 +335,26 @@ LaunchParams FusionExecutor::computeLaunchParams(
 }
 
 FusionExecutor::GlobalBuffers FusionExecutor::allocGlobalVals(
-    kir::ExpressionEvaluator& expr_eval) {
-  FUSER_PERF_SCOPE("FusionExecutor::AllocGlobalVals");
+    StatefulExpressionEvaluator& see) {
+  FUSER_PERF_SCOPE("allocGlobalVals");
   GlobalBuffers global_buffers;
-  const auto kernel = lowered_.kernel();
   const auto& kernel_summary = lowered_.kernel()->summary();
   for (auto alloc : kernel_summary.global_allocations) {
     TORCH_INTERNAL_ASSERT(
-        alloc->buffer()->isA<kir::TensorView>(),
+        alloc->buffer()->getValType() == ValType::KirTensorView,
         "Cannot allocate global buffers that are not tensors.");
-    auto tv = alloc->buffer()->as<kir::TensorView>();
-    if (kernel->isOutput(tv)) {
-      continue;
-    }
-    if (alloc->zeroInit()) {
-      global_buffers.buffers.push_back(
-          inferAndAlloc(tv, alloc->shape(), expr_eval, options_, true));
-      global_buffers.zero_init.push_back(true);
+    if (!alloc->zeroInit()) {
+      global_buffers.empty_buffers.push_back(inferAndAlloc(
+          alloc->buffer()->as<kir::TensorView>()->fuserTv(),
+          see,
+          options_,
+          false));
     } else {
-      global_buffers.buffers.push_back(
-          inferAndAlloc(tv, alloc->shape(), expr_eval, options_, false));
-      global_buffers.zero_init.push_back(false);
+      global_buffers.zero_buffers.push_back(inferAndAlloc(
+          alloc->buffer()->as<kir::TensorView>()->fuserTv(),
+          see,
+          options_,
+          true));
     }
   }
 
@@ -472,35 +362,29 @@ FusionExecutor::GlobalBuffers FusionExecutor::allocGlobalVals(
 }
 
 std::vector<at::Tensor> FusionExecutor::allocOutputs(
-    kir::ExpressionEvaluator& expr_eval,
-    const std::unordered_set<int>& alias_indices) {
-  FUSER_PERF_SCOPE("FusionExecutor::AllocOutputs");
-  const auto kernel = lowered_.kernel();
+    StatefulExpressionEvaluator& see) {
+  FUSER_PERF_SCOPE("allocOutputs");
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
   std::vector<at::Tensor> outputs;
-  for (size_t i = 0; i < kernel->outputs().size(); ++i) {
+  for (auto output : fusion_.outputs()) {
     TORCH_INTERNAL_ASSERT(
-        kernel->outputs()[i]->isA<kir::TensorView>(),
+        output->getValType() == ValType::TensorView,
         "Cannot allocate outputs that are not tensors.");
-    auto output = kernel->outputs()[i]->as<kir::TensorView>();
-    if (alias_indices.count(i) == 0) {
-      outputs.push_back(
-          inferAndAllocOutput(output, expr_eval, options_, false));
-    } else {
-      // aliasing to inputs, no need to allocate real output
-      outputs.push_back(inferAndAlloc(output, {}, expr_eval, options_, false));
-    }
+    outputs.push_back(
+        inferAndAlloc(output->as<TensorView>(), see, options_, false));
   }
   return outputs;
 }
 
 void FusionExecutor::setUsedTVs() {
-  auto used_vals = fusion_.usedMathVals();
-  auto used_tvs = ir_utils::filterByType<TensorView>(used_vals);
   used_tvs_.clear();
-
-  for (auto tv : used_tvs)
-    used_tvs_.push_back(tv);
+  auto used_vals = DependencyCheck::getAllValsBetween(
+      {fusion_.inputs().begin(), fusion_.inputs().end()}, fusion_.outputs());
+  for (auto val : used_vals) {
+    if (val->getValType().value() == ValType::TensorView) {
+      used_tvs_.push_back(val->as<TensorView>());
+    }
+  }
 }
 
 std::vector<at::Tensor> FusionExecutor::runFusion(
@@ -508,7 +392,7 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
     const std::vector<at::Tensor>& outputs,
     const LaunchParams& launch_constraints,
     const c10::optional<size_t>& opt_code) {
-  FUSER_PERF_SCOPE("FusionExecutor::RunFusion");
+  FUSER_PERF_SCOPE("runFusion");
 
   TORCH_INTERNAL_ASSERT(
       fusion_id_ > 0, "Cannot run fusion, it was not compiled.");
@@ -527,97 +411,64 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
 
   LaunchParams launch_params;
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  std::vector<at::Tensor> allocated_outputs = outputs;
+  std::vector<at::Tensor> alloced_outputs = outputs;
   GlobalBuffers global_buffers;
   uint64_t rand_offset = 0;
 
   if (executor_entry && executor_entry->init) {
     {
-      // context manager to disable auto grad for `empty_cuda` calls later
+      // context manager to disable auto grad for `empty_cuda` calls later;
       at::AutoDispatchBelowADInplaceOrView non_variable_type_mode;
-      // take the short-cut for launch if we see a recorded input set again
+      // take the short-cut for launch if we see a recorded input set again;
       launch_params = executor_entry->launch_params;
-      // only allocate outputs when not given
-      if (outputs.empty()) {
-        FUSER_PERF_SCOPE("ExecutorRunFusion::OutputAlloc");
-        for (const auto i : c10::irange(executor_entry->output_sizes.size())) {
-          allocated_outputs.push_back(at::native::empty_cuda(
-              executor_entry->output_sizes[i],
-              executor_entry->output_types[i],
-              c10::nullopt,
-              options_.device,
-              c10::nullopt));
-        }
-        for (const auto& entry : executor_entry->io_alias_indices) {
-          TORCH_INTERNAL_ASSERT(
-              inputs[entry.second].isTensor(), "alias io only supports tensor");
-          allocated_outputs[entry.first] = inputs[entry.second].toTensor();
-        }
-      } else {
-        TORCH_INTERNAL_ASSERT(
-            outputs.size() == fusion_.outputs().size(),
-            __func__,
-            " provided number of outputs does match fusion output");
+      for (const auto i : c10::irange(executor_entry->output_sizes.size())) {
+        alloced_outputs.push_back(at::native::empty_cuda(
+            executor_entry->output_sizes[i],
+            executor_entry->output_types[i],
+            c10::nullopt,
+            options_.device,
+            c10::nullopt));
       }
-      {
-        FUSER_PERF_SCOPE("ExecutorRunFusion::IntermediateBufferAlloc");
-        for (const auto i : c10::irange(executor_entry->buffer_sizes.size())) {
-          if (executor_entry->buffer_zero_init[i]) {
-            global_buffers.buffers.push_back(at::zeros(
-                executor_entry->buffer_sizes[i],
-                at::TensorOptions()
-                    .dtype(executor_entry->buffer_types[i])
-                    .device(options_.device)));
-          } else {
-            global_buffers.buffers.push_back(at::native::empty_cuda(
-                executor_entry->buffer_sizes[i],
-                executor_entry->buffer_types[i],
-                c10::nullopt,
-                options_.device,
-                c10::nullopt));
-          }
-        }
+      for (const auto i :
+           c10::irange(executor_entry->empty_buffer_sizes.size())) {
+        global_buffers.empty_buffers.push_back(at::native::empty_cuda(
+            executor_entry->empty_buffer_sizes[i],
+            executor_entry->empty_buffer_types[i],
+            c10::nullopt,
+            options_.device,
+            c10::nullopt));
       }
     }
+    for (const auto i : c10::irange(executor_entry->zero_buffer_sizes.size())) {
+      auto tensor_options = at::TensorOptions()
+                                .dtype(executor_entry->zero_buffer_types[i])
+                                .device(options_.device);
+      global_buffers.zero_buffers.push_back(
+          at::zeros(executor_entry->zero_buffer_sizes[i], tensor_options));
+    }
     rand_offset = executor_entry->rand_offset;
   } else {
-    FUSER_PERF_SCOPE("ExecutorRunFusion::ValidateAndInitialize");
     // code path to take when either:
-    //   1. no opt_code is provided or
+    //   1. no opt_code is provided or;
     //   2. `executor_entry` is not initialized
     executor_utils::validateKernelInputs(&fusion_, inputs, options_.device);
 
-    const auto kernel = lowered_.kernel();
-
-    auto expr_eval = executor_utils::bindKernelInputs(inputs, kernel);
-
-    launch_params = computeLaunchParams(launch_constraints, expr_eval);
-
-    executor_utils::validateVectorizedTensors(
-        &fusion_, inputs, outputs, lowered_, expr_eval);
+    StatefulExpressionEvaluator evaluator =
+        executor_utils::statefulBindInputs(inputs, &fusion_, &lowered_);
 
-    auto alias_indices = fusion_.getInputAliasIndices();
+    launch_params = computeLaunchParams(launch_constraints, evaluator);
 
-    // ditch pre-allocated outputs if the number doesn't match.
     // NOLINTNEXTLINE(bugprone-branch-clone)
-    if (outputs.empty()) {
-      allocated_outputs =
-          allocOutputs(expr_eval, fusion_.getOutputAliasIndices());
-
-      for (const auto& entry : alias_indices) {
-        TORCH_INTERNAL_ASSERT(
-            inputs[entry.second].isTensor(), "alias io only supports tensor");
-        allocated_outputs[entry.first] = inputs[entry.second].toTensor();
-      }
+    if (outputs.empty() || outputs.size() != fusion_.outputs().size()) {
+      alloced_outputs = allocOutputs(evaluator);
     } else {
-      // TODO: Update this as well;
       executor_utils::validateKernelOutputs(
-          &fusion_, allocated_outputs, options_.device);
+          &fusion_, alloced_outputs, options_.device);
     }
 
-    global_buffers = allocGlobalVals(expr_eval);
+    global_buffers = allocGlobalVals(evaluator);
 
-    if (kernel->summary().is_stochastic) {
+    if (lowered_.kernel()->summary().is_stochastic) {
       // NOTE: this is how we map offset to PW kernels in order to have
       // identical random number generator to match native PyTorch results.
       // But it doesn't really work as it takes assumption how threads are
@@ -626,7 +477,7 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
       // works.
       rand_offset = 4 *
           (std::ceil(
-               allocated_outputs[0].numel() /
+               alloced_outputs[0].numel() /
                (4.0 * 128 * launch_params.gdimx())) + // NOLINT
            1);
     }
@@ -634,75 +485,36 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
     // This is the entry when we have provided `opt_code` but the entry has not
     // been initialized yet.
     if (executor_entry) {
-      FUSER_PERF_SCOPE("ExecutorRunFusion::FillCacheEntry");
       // record the the short-cut executor entry for the given input set;
       executor_entry->launch_params = launch_params;
-      executor_entry->io_alias_indices = alias_indices;
-      for (const auto& output : allocated_outputs) {
+      for (const auto& output : alloced_outputs) {
         executor_entry->output_sizes.push_back(output.sizes().vec());
         executor_entry->output_types.push_back(output.scalar_type());
       }
-
-      for (const auto& i : c10::irange(global_buffers.buffers.size())) {
-        executor_entry->buffer_sizes.push_back(
-            global_buffers.buffers[i].sizes().vec());
-        executor_entry->buffer_types.push_back(
-            global_buffers.buffers[i].scalar_type());
-        executor_entry->buffer_zero_init.push_back(global_buffers.zero_init[i]);
+      for (const auto& buffer : global_buffers.empty_buffers) {
+        executor_entry->empty_buffer_sizes.push_back(buffer.sizes().vec());
+        executor_entry->empty_buffer_types.push_back(buffer.scalar_type());
+      }
+      for (const auto& buffer : global_buffers.zero_buffers) {
+        executor_entry->zero_buffer_sizes.push_back(buffer.sizes().vec());
+        executor_entry->zero_buffer_types.push_back(buffer.scalar_type());
       }
       executor_entry->rand_offset = rand_offset;
       executor_entry->init = true;
     }
   }
 
-  KernelArgumentHolder kernel_arguments(options_.index_mode);
-  {
-    FUSER_PERF_SCOPE("ExecutorRunFusion::FillKernelArgStructure");
-    kernel_arguments.push(inputs);
-    kernel_arguments.push(allocated_outputs);
-    kernel_arguments.push(global_buffers.buffers);
-    if (lowered_.kernel()->summary().is_stochastic) {
-      kernel_arguments.appendPhiloxRNGSeed(rand_offset);
-    }
-  }
-
-  if (isDebugDumpEnabled(DebugDumpOption::LaunchParam)) {
-    launch_params.print();
-  }
-
-  if (isDebugDumpEnabled(DebugDumpOption::PrintRuntimeArgs)) {
-    std::cout << "Arguments for kernel" << fusion_id_ << ":" << std::endl
-              << "Inputs:" << std::endl;
-    for (const auto& input : inputs) {
-      if (input.isTensor()) {
-        std::cout << input.toTensor().scalar_type() << " "
-                  << input.toTensor().sizes() << std::endl;
-      }
-    }
-    std::cout << "Outputs:" << std::endl;
-    for (const auto& output : allocated_outputs) {
-      std::cout << "  " << output.scalar_type() << " " << output.sizes()
-                << std::endl;
-    }
-    std::cout << "Reduction and semaphore buffers:" << std::endl;
-    for (const auto& buffer : global_buffers.buffers) {
-      std::cout << "  " << buffer.scalar_type() << " " << buffer.sizes()
-                << std::endl;
-    }
-  }
-
-  cudaEvent_t start_event = {};
-  cudaEvent_t finish_event = {};
-
-  if (measure_kernel_time_ ||
-      isDebugDumpEnabled(DebugDumpOption::EffectiveBandwidth)) {
-    cudaEventCreate(&start_event);
-    cudaEventCreate(&finish_event);
-    cudaEventRecord(start_event);
+  KernelArgumentHolder kernel_arguments;
+  kernel_arguments.push(inputs);
+  kernel_arguments.push(alloced_outputs);
+  kernel_arguments.push(global_buffers.empty_buffers);
+  kernel_arguments.push(global_buffers.zero_buffers);
+  if (lowered_.kernel()->summary().is_stochastic) {
+    kernel_arguments.appendPhiloxRNGSeed(rand_offset);
   }
 
-  if (execute_kernel_) {
-    FUSER_PERF_SCOPE("ExecutorRunFusion::cuLaunchKernel");
+  {
+    FUSER_PERF_SCOPE("cuLaunchKernel");
     AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLaunchKernel(
         compiled_kernel_.function,
         launch_params.gdimx(),
@@ -715,78 +527,10 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
         stream,
         kernel_arguments.getBuffer(),
         nullptr));
+    at::cuda::stream_synchronize(stream);
   }
 
-  if (measure_kernel_time_ ||
-      isDebugDumpEnabled(DebugDumpOption::EffectiveBandwidth)) {
-    cudaEventRecord(finish_event);
-    cudaEventSynchronize(start_event);
-    cudaEventSynchronize(finish_event);
-    cudaEventElapsedTime(&kernel_time_ms_, start_event, finish_event);
-    cudaEventDestroy(start_event);
-    cudaEventDestroy(finish_event);
-
-    if (isDebugDumpEnabled(DebugDumpOption::EffectiveBandwidth)) {
-      size_t bytes = 0;
-      // Figure how many bytes are inputs, outputs, and temporary buffers
-      for (auto input : inputs) {
-        if (input.isTensor()) {
-          bytes += input.toTensor().numel() *
-              dataTypeSize(aten_to_data_type(input.toTensor().scalar_type()));
-        }
-      }
-      for (auto output : allocated_outputs) {
-        bytes += output.numel() *
-            dataTypeSize(aten_to_data_type(output.scalar_type()));
-      }
-      double gb_per_s =
-          ((double)bytes / ((double)kernel_time_ms_ / 1000)) / (double)1.0e9;
-      std::cout << "kernel" << fusion_id_ << " run in " << kernel_time_ms_
-                << " ms, achieved: " << gb_per_s << " GB/s" << std::endl;
-    }
-  }
-
-  return allocated_outputs;
-}
-
-void FusionExecutor::compileRtc(
-    const std::string& code,
-    const std::string& name,
-    bool structured) {
-  FUSER_PERF_SCOPE("ExecutorRunFusion::compileRtc");
-  std::string scode;
-  if (!structured) {
-    scode = getStructuredCode(code);
-  } else {
-    scode = code;
-  }
-  fusion_id_ = 1;
-  options_ = CompileOptions();
-  compiled_kernel_ = executor_utils::nvrtcCompile(scode, name, fusion_id_);
-}
-
-void FusionExecutor::runRtc(
-    const LaunchParams& launch_params,
-    const std::vector<at::Tensor>& args) {
-  FUSER_PERF_SCOPE("runFusion");
-
-  c10::DeviceGuard dg(options_.device);
-  auto stream = at::cuda::getCurrentCUDAStream();
-
-  KernelArgumentHolder kernel_arguments(options_.index_mode);
-  kernel_arguments.push(args);
-  AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLaunchKernel(
-      compiled_kernel_.function,
-      launch_params.gdimx(),
-      launch_params.gdimy(),
-      launch_params.gdimz(),
-      launch_params.bdimx(),
-      launch_params.bdimy(),
-      launch_params.bdimz(),
-      launch_params.smem(),
-      stream,
-      kernel_arguments.getBuffer(),
-      nullptr));
+  return alloced_outputs;
 }
 
 } // namespace cuda
index 084ba59..334c49a 100644 (file)
@@ -1,11 +1,11 @@
 #pragma once
 #include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
 #include <torch/csrc/jit/codegen/cuda/executor_utils.h>
+#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
 #include <torch/csrc/jit/codegen/cuda/ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
 #include <torch/csrc/jit/codegen/cuda/utils.h>
 
@@ -19,7 +19,6 @@ namespace cuda {
 // TODO: Should this actually be in launch params?
 struct TORCH_CUDA_CU_API CompileOptions {
   c10::Device device = c10::Device(c10::DeviceType::CUDA, 0);
-  KernelIndexMode index_mode = KernelIndexMode::INT64;
 };
 
 class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
@@ -33,11 +32,7 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
       int id,
       CompileOptions options = CompileOptions());
 
-  void compileFusion(
-      Fusion* fusion,
-      CompileOptions options = CompileOptions(),
-      const at::ArrayRef<IValue>& inputs = {},
-      const LaunchParams& launch_constraints = LaunchParams());
+  void compileFusion(Fusion* fusion, CompileOptions options = CompileOptions());
 
   std::vector<at::Tensor> runFusion(
       const at::ArrayRef<IValue>& inputs,
@@ -62,66 +57,32 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
     executor_entry_lookup_.erase(cache_id);
   }
 
-  // struct used to hold necessary information to launch compiled kernel on a
-  // given input set.
-  //
   // TODO: strides would also be important when we handle permutations in
   //       codegen.
-  //
+  // struct used to hold necessary information to launch compiled kernel on a
+  // given input set.
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
   struct ExecutorEntry {
     bool init = false;
     LaunchParams launch_params;
-    std::vector<std::pair<int, int>> io_alias_indices;
     std::vector<std::vector<int64_t>> output_sizes;
     std::vector<at::ScalarType> output_types;
-    std::vector<std::vector<int64_t>> buffer_sizes;
-    std::vector<at::ScalarType> buffer_types;
-    std::vector<bool> buffer_zero_init;
+    std::vector<std::vector<int64_t>> empty_buffer_sizes;
+    std::vector<at::ScalarType> empty_buffer_types;
+    std::vector<std::vector<int64_t>> zero_buffer_sizes;
+    std::vector<at::ScalarType> zero_buffer_types;
     uint64_t rand_offset;
   };
 
-  kir::Kernel* kernel() const {
+  Kernel* kernel() const {
     return lowered_.kernel();
   }
 
-  //! Internal knob used for debugging/profiling only
-  void setExecuteKernelFlag(bool execute_kernel) {
-    execute_kernel_ = execute_kernel;
-  }
-
-  //! Internal knob used for debugging/profiling only
-  void setMeasureKernelTimeFlag(bool measure_kernel_time) {
-    measure_kernel_time_ = measure_kernel_time;
-  }
-
-  //! Returns the last kernel execution time, in milliseconds
-  //!
-  //! \note The kernel time is only tracked if enabled by calling
-  //!    setMeasureKernelTimeFlag(true)
-  //!
-  float kernelTimeMs() const {
-    return measure_kernel_time_ ? kernel_time_ms_ : 0;
-  }
-
-  //! Internal tests only. Compiles CUDA code with NVRTC directly from
-  //! string. This util provides a path to test runtime code, i.e. the resource
-  //! strings.
-  void compileRtc(
-      const std::string& code,
-      const std::string& name,
-      bool structured = false);
-
-  //! Internal tests only. Runs the compiled CUDA kernel from compileRtc.
-  void runRtc(
-      const LaunchParams& launch_params,
-      const std::vector<at::Tensor>& args);
-
  private:
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
   struct GlobalBuffers {
-    std::vector<at::Tensor> buffers;
-    std::vector<bool> zero_init;
+    std::vector<at::Tensor> empty_buffers;
+    std::vector<at::Tensor> zero_buffers;
   };
 
   std::string kernelName() const {
@@ -139,24 +100,19 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
 
   LaunchParams computeLaunchParams(
       const LaunchParams& launch_constraints,
-      kir::ExpressionEvaluator& expr_eval);
+      StatefulExpressionEvaluator& see);
 
   uint64_t computeSharedMemory(
-      kir::ExpressionEvaluator& expr_eval,
-      const std::vector<const kir::Allocate*>& buffers,
+      StatefulExpressionEvaluator& see,
+      const std::vector<kir::Allocate*>& buffers,
       bool align_padding = false,
       uint64_t total = 0);
 
   // return a pair of vector of tensors, where tensors in the first vector are
   // not initialized, while the second vector contains zero-initiliazed tensors
-  GlobalBuffers allocGlobalVals(kir::ExpressionEvaluator& expr_eval);
+  GlobalBuffers allocGlobalVals(StatefulExpressionEvaluator& see);
 
-  // alias_index: index of outputs that are aliases to inputs, hence we should
-  // skip allocating real storage for those, but still maintain its spot to
-  // maintain the indexing from output aliases to inputs
-  std::vector<at::Tensor> allocOutputs(
-      kir::ExpressionEvaluator& expr_eval,
-      const std::unordered_set<int>& alias_indices = {});
+  std::vector<at::Tensor> allocOutputs(StatefulExpressionEvaluator& see);
 
   void setUsedTVs();
 
@@ -167,6 +123,11 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
  private:
   Fusion fusion_;
 
+  // TODO(kir): caching the values here is no longer needed
+  bool has_block_reductions = false;
+  bool has_grid_reductions = false;
+  bool has_block_broadcasts = false;
+
   CompileOptions options_;
   size_t max_device_smem = std::numeric_limits<size_t>().max();
   executor_utils::NvrtcFunction compiled_kernel_;
@@ -183,16 +144,6 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
   // lookup table to take short cut to retrieve recorded information in order to
   // launch kernels without re-inference parameters.
   std::unordered_map<size_t, ExecutorEntry> executor_entry_lookup_;
-
-  // Profiling support: knob to control wheter we actually execute the
-  // kernel on the GPU or not
-  bool execute_kernel_ = true;
-
-  // Profiling support: knob to enable measuring kernel execution time
-  bool measure_kernel_time_ = false;
-
-  // The last kernel execution time, if measure_kernel_time_ is true
-  float kernel_time_ms_ = 0;
 };
 
 } // namespace cuda
index c749aa9..230b285 100644 (file)
@@ -11,82 +11,18 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-namespace {
-
-template <typename T, typename nvfuser_index_t>
-std::unique_ptr<TensorArgAbstract> getTensorArg(int nDims) {
-  switch (nDims) {
-    case (0):
-      return std::make_unique<TensorArg<
-          TensorArgCodegen<T, 0, nvfuser_index_t>,
-          nvfuser_index_t>>();
-    case (1):
-      return std::make_unique<TensorArg<
-          TensorArgCodegen<T, 1, nvfuser_index_t>,
-          nvfuser_index_t>>();
-    case (2):
-      return std::make_unique<TensorArg<
-          TensorArgCodegen<T, 2, nvfuser_index_t>,
-          nvfuser_index_t>>();
-    case (3):
-      return std::make_unique<TensorArg<
-          TensorArgCodegen<T, 3, nvfuser_index_t>,
-          nvfuser_index_t>>();
-    case (4):
-      return std::make_unique<TensorArg<
-          TensorArgCodegen<T, 4, nvfuser_index_t>,
-          nvfuser_index_t>>();
-    // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-    case (5):
-      // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-      return std::make_unique<TensorArg<
-          TensorArgCodegen<T, 5, nvfuser_index_t>,
-          nvfuser_index_t>>();
-    // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-    case (6):
-      // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-      return std::make_unique<TensorArg<
-          TensorArgCodegen<T, 6, nvfuser_index_t>,
-          nvfuser_index_t>>();
-    // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-    case (7):
-      // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-      return std::make_unique<TensorArg<
-          TensorArgCodegen<T, 7, nvfuser_index_t>,
-          nvfuser_index_t>>();
-    // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-    case (8):
-      // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-      return std::make_unique<TensorArg<
-          TensorArgCodegen<T, 8, nvfuser_index_t>,
-          nvfuser_index_t>>();
-    default:
-      TORCH_INTERNAL_ASSERT(
-          false,
-          "Tried to gerneate a tensor to run a generated kernel with ",
-          nDims,
-          " dimensions, however it must be a 1-8 dimensional tensor.");
-  }
-  return nullptr;
-}
-
-template <typename INDEX_MODE>
 std::unique_ptr<TensorArgAbstract> getTensorArg(
     c10::ScalarType dtype,
     int nDims) {
   switch (dtype) {
-    case c10::ScalarType::Double:
-      return getTensorArg<double, INDEX_MODE>(nDims);
     case c10::ScalarType::Float:
-      return getTensorArg<float, INDEX_MODE>(nDims);
+      return getTensorArg<float>(nDims);
     case c10::ScalarType::Half:
-      return getTensorArg<at::Half, INDEX_MODE>(nDims);
+      return getTensorArg<at::Half>(nDims);
     case c10::ScalarType::Bool:
-      return getTensorArg<bool, INDEX_MODE>(nDims);
+      return getTensorArg<bool>(nDims);
     case c10::ScalarType::Long:
-      return getTensorArg<int64_t, INDEX_MODE>(nDims);
-    case c10::ScalarType::Int:
-      return getTensorArg<int32_t, INDEX_MODE>(nDims);
+      return getTensorArg<int64_t>(nDims);
     default:
       TORCH_CHECK(
           false,
@@ -96,33 +32,13 @@ std::unique_ptr<TensorArgAbstract> getTensorArg(
   }
 }
 
-} // namespace
-
-std::unique_ptr<TensorArgAbstract> getTensorArg(
-    c10::ScalarType dtype,
-    int nDims,
-    KernelIndexMode index_mode) {
-  switch (index_mode) {
-    case KernelIndexMode::INT32:
-      return getTensorArg<int>(dtype, nDims);
-    case KernelIndexMode::INT64:
-      return getTensorArg<int64_t>(dtype, nDims);
-    default:
-      break;
-  }
-
-  TORCH_INTERNAL_ASSERT(false, "unknown index mode");
-  return nullptr;
-}
-
 // Push a tensor to the arguments
 void KernelArgumentHolder::push(const at::Tensor& tensor) {
   changed_ = true;
   int nDims = tensor.ndimension();
 
   c10::ScalarType dtype = tensor.scalar_type();
-  std::unique_ptr<TensorArgAbstract> tensor_arg =
-      getTensorArg(dtype, nDims, index_mode_);
+  std::unique_ptr<TensorArgAbstract> tensor_arg = getTensorArg(dtype, nDims);
   tensor_arg->setPointer(tensor.data_ptr());
   for (const auto i : c10::irange(nDims)) {
     tensor_arg->setSize(i, tensor.sizes()[i]);
@@ -138,17 +54,13 @@ void KernelArgumentHolder::push(const IValue& val) {
       val.isScalar(),
       "Tried to push an arg to run in a fused kernel, expected a scalar but got, ",
       val);
-  auto scalar_val = val.toScalar();
-  switch (scalar_val.type()) {
+  switch (val.toScalar().type()) {
     // NOLINTNEXTLINE(bugprone-branch-clone)
     case c10::ScalarType::Double:
-      arguments_.push_back(std::make_unique<DoubleArg>(scalar_val.toDouble()));
+      arguments_.push_back(std::make_unique<FloatArg>((float)val.toDouble()));
       return;
     case c10::ScalarType::Long:
-      arguments_.push_back(std::make_unique<LongArg>(scalar_val.toLong()));
-      return;
-    case c10::ScalarType::Bool:
-      arguments_.push_back(std::make_unique<BoolArg>(scalar_val.toBool()));
+      arguments_.push_back(std::make_unique<LongArg>(val.toInt()));
       return;
     default:
       TORCH_INTERNAL_ASSERT(
@@ -160,8 +72,8 @@ void KernelArgumentHolder::push(const IValue& val) {
       " Tried to create argument to send to a fused kernel, but got a non-scalar type.");
 }
 
-void KernelArgumentHolder::push(const at::PhiloxCudaState& val) {
-  arguments_.push_back(std::make_unique<PhiloxCudaStateArg>(val));
+void KernelArgumentHolder::push(const uint64_t& val) {
+  arguments_.push_back(std::make_unique<ULongArg>(val));
 }
 
 // Create buffer, flatten arguments into it, align by 8 Bytes, return pointers
@@ -197,16 +109,17 @@ void KernelArgumentHolder::push(const std::vector<at::Tensor>& tensors) {
 }
 
 void KernelArgumentHolder::appendPhiloxRNGSeed(uint64_t rand_offset) {
-  at::PhiloxCudaState philox_engine_inputs;
+  std::pair<uint64_t, uint64_t> philox_engine_inputs;
   auto gen = at::cuda::detail::getDefaultCUDAGenerator();
   {
     // See Note [Acquire lock when using random generators]
     std::lock_guard<std::mutex> lock(gen.mutex());
     philox_engine_inputs =
-        at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_cuda_state(
+        at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
             rand_offset);
   }
-  push(philox_engine_inputs);
+  push(philox_engine_inputs.first);
+  push(philox_engine_inputs.second);
 }
 
 } // namespace cuda
index 7df1cc4..1c6aaeb 100644 (file)
@@ -1,6 +1,5 @@
 #pragma once
 
-#include <ATen/CUDAGeneratorImpl.h>
 #include <ATen/core/ivalue.h>
 #include <c10/util/Exception.h>
 #include <torch/csrc/jit/ir/ir.h>
@@ -11,31 +10,31 @@ namespace fuser {
 namespace cuda {
 
 // This should match the tensor used in the code generation (almost exactly)
-template <typename T, int N, typename nvfuser_index_t>
+template <typename T, int N>
 struct TensorArgCodegen {
-  T& operator[](nvfuser_index_t ind) {
+  T& operator[](int64_t ind) {
     return data[ind];
   };
 
   T* data;
   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
-  nvfuser_index_t size[N];
+  int64_t size[N];
   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
-  nvfuser_index_t stride[N];
+  int64_t stride[N];
   constexpr int nDims() {
     return N;
   }
-  void setSize(int i, nvfuser_index_t s) {
+  void setSize(int i, int64_t s) {
     size[i] = s;
   }
-  void setStride(int i, nvfuser_index_t s) {
+  void setStride(int i, int64_t s) {
     stride[i] = s;
   }
 };
 
-template <typename T, typename nvfuser_index_t>
-struct TensorArgCodegen<T, 0, nvfuser_index_t> {
-  T& operator[](nvfuser_index_t ind) {
+template <typename T>
+struct TensorArgCodegen<T, 0> {
+  T& operator[](int64_t ind) {
     return data[ind];
   };
 
@@ -43,10 +42,10 @@ struct TensorArgCodegen<T, 0, nvfuser_index_t> {
   constexpr int nDims() {
     return 0;
   }
-  void setSize(int, nvfuser_index_t) {
+  void setSize(int, int64_t) {
     TORCH_INTERNAL_ASSERT(false, "Tried to set size of a 0-dim tensor");
   }
-  void setStride(int, nvfuser_index_t) {
+  void setStride(int, int64_t) {
     TORCH_INTERNAL_ASSERT(false, "Tried to set stride of a 0-dim tensor");
   }
 };
@@ -56,38 +55,34 @@ struct ArgAbstract {
   virtual void* arg() = 0;
 };
 
-struct PhiloxCudaStateArg : public ArgAbstract {
-  at::PhiloxCudaState val_;
-  PhiloxCudaStateArg(at::PhiloxCudaState _val) : val_(_val){};
-  // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
-  void* arg() {
+struct ULongArg : public ArgAbstract {
+  uint64_t val_;
+  ULongArg(uint64_t _val) : val_(_val) {}
+  void* arg() override {
     return &val_;
   }
 };
 
 struct LongArg : public ArgAbstract {
   int64_t val_;
-  explicit LongArg(int64_t _val) : val_(_val){};
-  // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
-  void* arg() {
+  LongArg(int64_t _val) : val_(_val) {}
+  void* arg() override {
     return &val_;
   }
 };
 
-struct DoubleArg : public ArgAbstract {
-  double val_;
-  explicit DoubleArg(double _val) : val_(_val){};
-  // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
-  void* arg() {
+struct IntArg : public ArgAbstract {
+  int val_;
+  IntArg(int _val) : val_(_val) {}
+  void* arg() override {
     return &val_;
   }
 };
 
-struct BoolArg : public ArgAbstract {
-  bool val_;
-  explicit BoolArg(bool _val) : val_(_val){};
-  // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
-  void* arg() {
+struct FloatArg : public ArgAbstract {
+  float val_;
+  FloatArg(float _val) : val_(_val) {}
+  void* arg() override {
     return &val_;
   }
 };
@@ -99,16 +94,16 @@ struct TensorArgAbstract : ArgAbstract {
 };
 
 // This should match the tensor used in the code generation (almost exactly)
-template <typename TENSOR_TYPE, typename nvfuser_index_t>
+template <typename TENSOR_TYPE>
 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
 struct TensorArg : public TensorArgAbstract {
   TENSOR_TYPE instance_;
 
   void setSize(int i, int64_t size) override {
-    instance_.setSize(i, (nvfuser_index_t)size);
+    instance_.setSize(i, size);
   }
   void setStride(int i, int64_t stride) override {
-    instance_.setStride(i, (nvfuser_index_t)stride);
+    instance_.setStride(i, stride);
   }
   void setPointer(void* ptr) override {
     instance_.data = static_cast<decltype(TENSOR_TYPE::data)>(ptr);
@@ -119,22 +114,49 @@ struct TensorArg : public TensorArgAbstract {
   }
 };
 
+template <typename T>
+std::unique_ptr<TensorArgAbstract> getTensorArg(int nDims) {
+  switch (nDims) {
+    case (0):
+      return std::make_unique<TensorArg<TensorArgCodegen<T, 0>>>();
+    case (1):
+      return std::make_unique<TensorArg<TensorArgCodegen<T, 1>>>();
+    case (2):
+      return std::make_unique<TensorArg<TensorArgCodegen<T, 2>>>();
+    case (3):
+      return std::make_unique<TensorArg<TensorArgCodegen<T, 3>>>();
+    case (4):
+      return std::make_unique<TensorArg<TensorArgCodegen<T, 4>>>();
+    case (5):
+      return std::make_unique<TensorArg<TensorArgCodegen<T, 5>>>();
+    case (6):
+      return std::make_unique<TensorArg<TensorArgCodegen<T, 6>>>();
+    case (7):
+      return std::make_unique<TensorArg<TensorArgCodegen<T, 7>>>();
+    case (8):
+      return std::make_unique<TensorArg<TensorArgCodegen<T, 8>>>();
+    default:
+      TORCH_INTERNAL_ASSERT(
+          false,
+          "Tried to gerneate a tensor to run a generated kernel with ",
+          nDims,
+          " dimensions, however it must be a 1-8 dimensional tensor.");
+  }
+}
+
 std::unique_ptr<TensorArgAbstract> getTensorArg(
     c10::ScalarType dtype,
     int nDims);
 
 class KernelArgumentHolder {
  public:
-  explicit KernelArgumentHolder(KernelIndexMode index_mode)
-      : index_mode_(index_mode) {}
-
   // Push a tensor to the arguments
   void push(const at::Tensor& tensor);
 
   // Push a scalar or integer to the arguments
   void push(const IValue& val);
 
-  void push(const at::PhiloxCudaState& val);
+  void push(const uint64_t& val);
 
   // Create buffer, flatten arguments into it, align by 8 Bytes, return pointers
   // in the buffer
@@ -150,7 +172,6 @@ class KernelArgumentHolder {
   std::vector<std::unique_ptr<ArgAbstract>> arguments_;
   std::vector<void*> void_ptrs_;
   bool changed_ = true;
-  KernelIndexMode index_mode_ = KernelIndexMode::INT64;
 };
 
 } // namespace cuda
index 6a2c478..387233c 100644 (file)
@@ -1,34 +1,10 @@
 #include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
 
-#include <ATen/cuda/CUDAContext.h>
-
 namespace torch {
 namespace jit {
 namespace fuser {
 namespace cuda {
 
-void LaunchParams::assertValid() {
-  TORCH_INTERNAL_ASSERT(
-      bdimx() * bdimz() * bdimz() > 0 &&
-          bdimx() * bdimz() * bdimz() <=
-              (int64_t)at::cuda::getCurrentDeviceProperties()
-                  ->maxThreadsPerMultiProcessor,
-      "Selected invalid number of threads for cuda: ",
-      bdimx() * bdimz() * bdimz());
-  TORCH_INTERNAL_ASSERT(
-      gdimx() > 0 && gdimx() < (std::int64_t(1) << 32) - 1,
-      "Invalid number of blocks in x direction: ",
-      gdimx());
-  TORCH_INTERNAL_ASSERT(
-      gdimy() > 0 && gdimy() <= 65535,
-      "Invalid number of blocks in y direction: ",
-      gdimy());
-  TORCH_INTERNAL_ASSERT(
-      gdimz() > 0 && gdimz() <= 65535,
-      "Invalid number of blocks in z direction: ",
-      gdimz());
-}
-
 void LaunchParams::bind(int64_t val, ParallelType p_type) {
   switch (p_type) {
     case ParallelType::TIDx:
@@ -55,7 +31,6 @@ void LaunchParams::bind(int64_t val, ParallelType p_type) {
           "Tried to bind invalid parallel type in launch config: ",
           p_type);
   }
-  assertValid();
 }
 
 int64_t LaunchParams::getDim(ParallelType p_type) const {
@@ -111,23 +86,6 @@ bool LaunchParams::operator==(const LaunchParams& other) const {
       bdimx_ == other.bdimx_ && bdimy_ == other.bdimy_ && smem_ == other.smem_;
 }
 
-void LaunchParams::print() const {
-  std::cout << toString();
-}
-
-std::string LaunchParams::toString() const {
-  std::stringstream ss;
-  ss << "Launch Parameters \n"
-     << "BlockDim.x = " << bdimx() << "\n"
-     << "BlockDim.y = " << bdimy() << "\n"
-     << "BlockDim.z = " << bdimz() << "\n"
-     << "GridDim.x = " << gdimx() << "\n"
-     << "GridDim.y = " << gdimy() << "\n"
-     << "GridDim.z = " << gdimz() << "\n"
-     << "Smem Size = " << smem() << "\n";
-  return ss.str();
-}
-
 } // namespace cuda
 } // namespace fuser
 } // namespace jit
index 66bafb2..97399eb 100644 (file)
@@ -22,11 +22,7 @@ class TORCH_CUDA_CU_API LaunchParams {
         gdimz_(gdimz),
         bdimx_(bdimx),
         bdimy_(bdimy),
-        bdimz_(bdimz) {
-    assertValid();
-  }
-
-  void assertValid();
+        bdimz_(bdimz) {}
 
   void setSmem(int64_t smem) {
     smem_ = smem;
@@ -37,11 +33,11 @@ class TORCH_CUDA_CU_API LaunchParams {
   }
 
   int64_t nBlocks() const {
-    return std::abs(gdimx_ * gdimy_ * gdimz_);
+    return gdimx_ * gdimy_ * gdimz_;
   }
 
   int64_t nThreads() const {
-    return std::abs(bdimx_ * bdimy_ * bdimz_);
+    return bdimx_ * bdimy_ * bdimz_;
   }
 
   int64_t bdimx() const {
@@ -92,7 +88,6 @@ class TORCH_CUDA_CU_API LaunchParams {
     if (class_val == UNINITIALIZED_VAL) {
       class_val = incoming_val;
     }
-    assertValid();
   }
 
   // Binds dim assocaited with p_type to val
@@ -109,10 +104,6 @@ class TORCH_CUDA_CU_API LaunchParams {
 
   bool operator==(const LaunchParams& other) const;
 
-  void print() const;
-
-  std::string toString() const;
-
  private:
   // Spell them out because I want signed ints to know if they were initialized
   // or not.
@@ -129,7 +120,6 @@ class TORCH_CUDA_CU_API LaunchParams {
   // TODO: Fill in output sizes
   std::vector<std::vector<int64_t>> output_sizes;
 };
-
 } // namespace cuda
 } // namespace fuser
 } // namespace jit
index 7efe7cb..db69c28 100644 (file)
@@ -8,26 +8,16 @@
 #include <torch/csrc/jit/codegen/cuda/executor_utils.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
 #include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h>
 #include <torch/csrc/jit/resource_guard.h>
 
-#include <nvfuser_resources/PhiloxCudaStateRaw.h>
 #include <nvfuser_resources/block_reduction.h>
-#include <nvfuser_resources/block_sync_atomic.h>
-#include <nvfuser_resources/block_sync_default.h>
 #include <nvfuser_resources/broadcast.h>
 #include <nvfuser_resources/fp16_support.h>
 #include <nvfuser_resources/grid_reduction.h>
 #include <nvfuser_resources/helpers.h>
 #include <nvfuser_resources/random_numbers.h>
 #include <nvfuser_resources/tensor.h>
-#include <nvfuser_resources/welford.h>
-
-#ifndef USE_ROCM
-#include <cuda_occupancy.h>
-#endif
 
 #include <fstream>
 
@@ -62,16 +52,9 @@ std::string kernelPreamble() {
   ss << nvfuser_resources::tensor_cu;
   ss << nvfuser_resources::random_numbers_cu;
   ss << nvfuser_resources::helpers_cu;
-  if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
-    ss << nvfuser_resources::block_sync_atomic_cu;
-  } else {
-    ss << nvfuser_resources::block_sync_default_cu;
-  }
   ss << nvfuser_resources::block_reduction_cu;
   ss << nvfuser_resources::grid_reduction_cu;
   ss << nvfuser_resources::broadcast_cu;
-  ss << nvfuser_resources::welford_cu;
-  ss << nvfuser_resources::PhiloxCudaStateRaw_cu;
 
   return ss.str();
 }
@@ -117,21 +100,12 @@ bool validateKernelArgTensor(
   DataType param_data_type = *param->getDataType();
   bool match = false;
   switch (arg_data_type) {
-    case at::ScalarType::Double:
-      match = param_data_type == DataType::Double;
-      break;
     case at::ScalarType::Half:
       match = param_data_type == DataType::Half;
       break;
     case at::ScalarType::Float:
       match = param_data_type == DataType::Float;
       break;
-    case at::ScalarType::Long:
-      match = param_data_type == DataType::Int;
-      break;
-    case at::ScalarType::Int:
-      match = param_data_type == DataType::Int32;
-      break;
     case at::ScalarType::Bool:
       match = param_data_type == DataType::Bool;
       break;
@@ -148,33 +122,32 @@ bool validateKernelArgTensor(
 
 // Return false if  arg_type doesn't match the type in param
 bool validateKernelArgScalar(
-    const c10::IValue& arg,
+    const c10::TypePtr& arg_type,
     const Val* param,
     std::stringstream& msg) {
-  if (!arg.isScalar()) {
+  if (!param->isScalar()) {
     msg << "Argument is a scalar, but the parameter is not."
         << "\n";
     return false;
   }
   DataType param_type = *param->getDataType();
   bool match = false;
-  switch (arg.toScalar().type()) {
-    case c10::ScalarType::Long:
-      match = param_type == DataType::Int || param_type == DataType::Int32;
+  switch (arg_type->kind()) {
+    case c10::TypeKind::IntType:
+      match = param_type == DataType::Int;
       break;
-    case c10::ScalarType::Double:
-      match = param_type == DataType::Double || param_type == DataType::Float ||
-          param_type == DataType::Half;
+    case c10::TypeKind::FloatType:
+      match = param_type == DataType::Float;
       break;
-    case c10::ScalarType::Bool:
+    case c10::TypeKind::BoolType:
       match = param_type == DataType::Bool;
       break;
     default:
       match = false;
   }
   if (!match) {
-    msg << "Argument type is " << arg.toScalar().type()
-        << ", but the parameter is " << param_type << "\n";
+    msg << "Argument type is " << *arg_type << ", but the parameter is "
+        << param_type << "\n";
   }
   return match;
 }
@@ -189,80 +162,7 @@ bool validateKernelArg(
   if (arg.isTensor()) {
     return validateKernelArgTensor(arg.toTensor(), param, device, msg);
   } else {
-    return validateKernelArgScalar(arg, param, msg);
-  }
-}
-
-// Return true if all the tensors have the same stride, assumes all tensors are
-// contiguous
-bool checkSameStride(const std::vector<c10::IValue>& tensors) {
-  if (tensors.size() < 2) {
-    return true;
-  }
-  for (size_t idx = 0; idx < tensors.size() - 1; ++idx) {
-    auto current = tensors[idx];
-    auto next = tensors[idx + 1];
-    if (!current.isTensor() || !next.isTensor()) {
-      return false;
-    }
-
-    const auto& current_tensor = current.toTensor();
-    const auto& next_tensor = next.toTensor();
-    if (current_tensor.ndimension() != next_tensor.ndimension()) {
-      return false;
-    }
-
-    for (int64_t i = 0; i < current_tensor.ndimension(); ++i) {
-      if (current_tensor.stride(i) != next_tensor.stride(i)) {
-        return false;
-      }
-    }
-  }
-  return true;
-}
-
-// Return true if all the tensors are contiguous and have the same striding
-bool checkSameContiguity(const std::vector<c10::IValue>& tensors) {
-  auto reference = tensors.front();
-  if (!reference.isTensor()) {
-    return false;
-  }
-
-  // Determine if the reference tensor is contiguous
-  const auto& reference_tensor = reference.toTensor();
-  int64_t expected_stride = 1;
-  for (int64_t i = 1; i <= reference_tensor.ndimension(); ++i) {
-    int64_t ind = reference_tensor.ndimension() - i;
-    if (reference_tensor.size(ind) == 1) {
-      continue;
-    }
-    if (reference_tensor.stride(ind) != expected_stride) {
-      return false;
-    }
-    expected_stride *= reference_tensor.size(ind);
-  }
-
-  // Check if all the tensors have the same contiguity
-  return checkSameStride(tensors);
-}
-
-bool checkValidMisalignedTensors(
-    const std::unordered_set<TensorView*>& inp_tv,
-    const std::unordered_set<TensorView*>& out_tv,
-    const std::vector<c10::IValue>& inp_tensors,
-    const std::vector<c10::IValue>& out_tensors) {
-  if (out_tv.empty()) {
-    // Only check input tensors
-    return checkSameStride(inp_tensors);
-  } else if (!out_tv.empty() && out_tensors.empty()) {
-    // Assume out tensors are contiguous
-    return checkSameContiguity(inp_tensors);
-  } else {
-    // Only check input and output tensors
-    std::vector<c10::IValue> tensors;
-    tensors.insert(tensors.end(), inp_tensors.begin(), inp_tensors.end());
-    tensors.insert(tensors.end(), out_tensors.begin(), out_tensors.end());
-    return checkSameStride(tensors);
+    return validateKernelArgScalar(arg.type(), param, msg);
   }
 }
 
@@ -272,7 +172,7 @@ void validateKernelInputs(
     Fusion* fusion,
     const at::ArrayRef<IValue>& inputs,
     const c10::Device& device) {
-  FUSER_PERF_SCOPE("executor_utils::ValidateKernelInputs");
+  FUSER_PERF_SCOPE("validateKernelInputs");
 
   // This is necessary as we were traversing the fusion graph later in the check
   FusionGuard fg(fusion);
@@ -296,7 +196,7 @@ void validateKernelOutputs(
     Fusion* fusion,
     const std::vector<at::Tensor>& outputs,
     const c10::Device& device) {
-  FUSER_PERF_SCOPE("executor_utils::ValidateKernelOutputs");
+  FUSER_PERF_SCOPE("validateKernelOutputs");
 
   TORCH_INTERNAL_ASSERT(
       fusion->outputs().size() != 0,
@@ -317,278 +217,24 @@ void validateKernelOutputs(
       !mismatch, "Found one or more invalid arguments: ", msg.str());
 }
 
-bool canVectorize(const IValue& aten_val, int word_size) {
-  if (!aten_val.isTensor()) {
-    return false;
-  }
-
-  const auto& aten_tensor = aten_val.toTensor();
-
-  if (reinterpret_cast<size_t>(aten_tensor.data_ptr()) %
-          (word_size * aten_tensor.dtype().itemsize()) !=
-      0) {
-    return false;
-  }
-
-  for (size_t i = aten_tensor.ndimension(); i > 0; i--) {
-    if (aten_tensor.size(i - 1) != 1) {
-      if (aten_tensor.size(aten_tensor.ndimension() - 1) % word_size != 0 ||
-          aten_tensor.stride(aten_tensor.ndimension() - 1) != 1) {
-        return false;
-      }
-      break;
-    }
-  }
-
-  for (auto stride : aten_tensor.strides()) {
-    if (stride != 1 && stride % word_size != 0) {
-      return false;
-    }
-  }
-
-  return true;
-}
-
-bool canVectorize(
-    TensorView* fusion_tv,
-    int word_size,
-    GpuLower& lower,
-    kir::ExpressionEvaluator& expr_eval) {
-  IterDomain* last_root_dim = nullptr;
-  // TODO: Should this be rfactor instead of root??
-  for (size_t i = fusion_tv->getRootDomain().size(); i > 0; i--) {
-    auto r_id = fusion_tv->getRootDomain()[i - 1];
-    if (r_id->isReduction() || r_id->isBroadcast()) {
-      continue;
-    }
-    last_root_dim = r_id;
-    break;
-  }
-
-  if (last_root_dim == nullptr) {
-    return false;
-  }
-
-  auto last_dim_size =
-      expr_eval.evaluate(lower.lowerValue(last_root_dim->extent()));
-
-  if (!last_dim_size.has_value()) {
-    return false;
-  }
-
-  if (last_dim_size.value() % word_size != 0) {
-    return false;
-  }
-
-  return true;
-}
-
-// Misaligned vectorization check. Currently misaligned vectorization is limited
-// to global-register and register-global load/store patterns. However, this
-// could be improved to include shared memory.
-void validateVectorizedTensors(
-    Fusion* fusion,
-    const at::ArrayRef<IValue>& inputs,
-    const std::vector<at::Tensor>& outputs,
-    GpuLower& lower,
-    kir::ExpressionEvaluator& expr_eval) {
-  std::unordered_set<TensorView*> global_inp_misaligned_tv;
-  std::unordered_set<TensorView*> global_out_misaligned_tv;
-  std::unordered_map<TensorView*, int> tv_to_vector_word_size;
-  // Find all vectorized tensors and their word size
-  for (auto expr : fusion->exprs()) {
-    if (!expr->isA<UnaryOp>() ||
-        expr->as<UnaryOp>()->getUnaryOpType() != UnaryOpType::Set) {
-      continue;
-    }
-    auto uop = expr->as<UnaryOp>();
-    if (!uop->out()->isA<TensorView>() || !uop->in()->isA<TensorView>()) {
-      continue;
-    }
-    auto out_tv = uop->out()->as<TensorView>();
-    auto in_tv = uop->in()->as<TensorView>();
-    IterDomain* vector_dim = nullptr;
-    for (auto id : out_tv->domain()->domain()) {
-      if (id->getParallelType() == ParallelType::Vectorize ||
-          id->getParallelType() == ParallelType::MisalignedVectorize) {
-        TORCH_INTERNAL_ASSERT(
-            vector_dim == nullptr,
-            "Found multiple vectorized dimensions on tensor ",
-            out_tv);
-        vector_dim = id;
-      }
-    }
-    if (vector_dim == nullptr) {
-      continue;
-    }
-    auto vector_word_size =
-        expr_eval.evaluate(lower.lowerValue(vector_dim->extent()));
-    TORCH_INTERNAL_ASSERT(
-        vector_word_size.has_value(),
-        "Non constant vector dimension found in ",
-        out_tv);
-    tv_to_vector_word_size[out_tv] = vector_word_size.value();
-    tv_to_vector_word_size[in_tv] = vector_word_size.value();
-
-    if (vector_dim->getParallelType() == ParallelType::MisalignedVectorize) {
-      if (out_tv->getMemoryType() == MemoryType::Global &&
-          in_tv->getMemoryType() == MemoryType::Local) {
-        global_out_misaligned_tv.insert(out_tv);
-      } else if (
-          in_tv->getMemoryType() == MemoryType::Global &&
-          out_tv->getMemoryType() == MemoryType::Local) {
-        global_inp_misaligned_tv.insert(in_tv);
-      } else {
-        TORCH_INTERNAL_ASSERT(
-            false,
-            "Unsupported memory configuration for misaligned vectorization.");
-      }
-    }
-  }
-
-  // Check striding information on input and outputs as well as size information
-  // of all
-  std::vector<c10::IValue> inp_misaligned_tensors;
-  std::vector<c10::IValue> out_misaligned_tensors;
-  for (auto entry : tv_to_vector_word_size) {
-    auto tv = entry.first;
-    auto word_size = entry.second;
-    if (tv->isFusionInput()) {
-      auto inp_it =
-          std::find(fusion->inputs().begin(), fusion->inputs().end(), tv);
-      TORCH_INTERNAL_ASSERT(
-          inp_it != fusion->inputs().end(),
-          "Could not find ",
-          tv,
-          " in fusion inputs.");
-      auto inp_pos = std::distance(fusion->inputs().begin(), inp_it);
-      auto aten_inp = inputs[inp_pos];
-
-      if (global_inp_misaligned_tv.find(tv) != global_inp_misaligned_tv.end()) {
-        inp_misaligned_tensors.emplace_back(aten_inp);
-      } else {
-        TORCH_INTERNAL_ASSERT(
-            canVectorize(aten_inp, word_size),
-            "Error vectorizing, ",
-            tv,
-            " as input provided does not allowed vectorization by word size, ",
-            word_size);
-      }
-    } else if (tv->isFusionOutput() && outputs.size() > 0) {
-      auto out_it =
-          std::find(fusion->outputs().begin(), fusion->outputs().end(), tv);
-      TORCH_INTERNAL_ASSERT(
-          out_it != fusion->outputs().end(),
-          "Could not find ",
-          tv,
-          " in provided fusion outputs.");
-      auto out_pos = std::distance(fusion->outputs().begin(), out_it);
-      auto aten_out = outputs[out_pos];
-
-      if (global_out_misaligned_tv.find(tv) != global_out_misaligned_tv.end()) {
-        out_misaligned_tensors.emplace_back(aten_out);
-      } else {
-        TORCH_INTERNAL_ASSERT(
-            canVectorize(aten_out, word_size),
-            "Error vectorizing, ",
-            tv,
-            " as output provided does not allowed vectorization by word size, ",
-            word_size);
-      }
-    } else {
-      if (!tv_to_vector_word_size.count(tv)) {
-        TORCH_INTERNAL_ASSERT(
-            canVectorize(tv, word_size, lower, expr_eval),
-            "Could not vectorize ",
-            tv,
-            " it's inner most dim is not a multiple of ",
-            word_size);
-      }
-    }
-  }
-
-  // If input stride is non-contiguous + no outputs, return false
-  TORCH_INTERNAL_ASSERT(
-      checkValidMisalignedTensors(
-          global_inp_misaligned_tv,
-          global_out_misaligned_tv,
-          inp_misaligned_tensors,
-          out_misaligned_tensors),
-      "All global tensors must have the same stride for misaligned vectorization.");
-}
-
-kir::ExpressionEvaluator bindKernelInputs(
-    const at::ArrayRef<IValue>& aten_inputs,
-    kir::Kernel* kernel) {
-  FUSER_PERF_SCOPE("executor_utils::BindKernelInputs");
-
-  TORCH_INTERNAL_ASSERT(
-      kernel->inputs().size() == aten_inputs.size(),
-      "Something went wrong configuring launch. Inputs no longer match.");
-
-  kir::ExpressionEvaluator expr_eval;
-  const auto& inputs = kernel->inputs();
-
-  for (size_t i = 0; i < inputs.size(); i++) {
-    const auto input = inputs[i];
-
-    if (auto tensor_input = dynamic_cast<kir::TensorView*>(input)) {
-      TORCH_INTERNAL_ASSERT(
-          aten_inputs[i].isTensor(),
-          "Something went wrong configuring launch. Inputs no longer match.");
-
-      const auto aten_tensor = aten_inputs[i].toTensor();
-      const auto root_domain =
-          kir::TensorDomain::noReductions(tensor_input->domain()->rootDomain());
-      TORCH_INTERNAL_ASSERT(
-          aten_tensor.ndimension() == static_cast<int>(root_domain.size()),
-          "Something went wrong configuring launch. Inputs no longer match.");
-
-      for (size_t dim = 0; dim < root_domain.size(); dim++) {
-        const auto extent = root_domain[dim]->extent();
-        const auto value = aten_tensor.sizes()[dim];
-        const auto prev_value = expr_eval.evaluate(extent);
-        if (prev_value.has_value()) {
-          TORCH_CHECK(
-              *prev_value == value,
-              "Attempting to bind ",
-              kir::toString(extent),
-              " to ",
-              value,
-              "but it's already set to ",
-              *prev_value);
-        } else {
-          expr_eval.bind(extent, value);
-        }
-      }
-      // NOLINTNEXTLINE: https://bugs.llvm.org/show_bug.cgi?id=48525
-    } else if (input->isScalar() && input->dtype() == DataType::Int) {
-      TORCH_INTERNAL_ASSERT(
-          aten_inputs[i].type()->kind() == c10::TypeKind::IntType);
-      expr_eval.bind(input, aten_inputs[i].toInt());
-    }
-  }
-
-  return expr_eval;
-}
-
-ExpressionEvaluator bindFusionInputs(
+StatefulExpressionEvaluator statefulBindInputs(
     const at::ArrayRef<IValue>& aten_inputs,
-    Fusion* fusion) {
-  FUSER_PERF_SCOPE("executor_utils::BindFusionInputs");
+    Fusion* fusion,
+    GpuLower* lower) {
+  FUSER_PERF_SCOPE("statefulBindInputs");
 
   TORCH_INTERNAL_ASSERT(
       fusion->inputs().size() == aten_inputs.size(),
       "Something went wrong configuring launch. Inputs no longer match.");
 
-  ExpressionEvaluator evaluator(fusion);
-  auto inputs = fusion->inputs();
+  auto fusion_inputs = fusion->inputs();
+  StatefulExpressionEvaluator evaluator(fusion);
 
   // This should probably move to EvaluationContext as we may want to bind
   // input values frequently. Bind fusion input values to runtime values.
   for (const auto i : c10::irange(fusion->inputs().size())) {
-    if (inputs[i]->getValType() == ValType::TensorView) {
-      TensorView* cg_tensor = inputs[i]->as<TensorView>();
+    if (fusion->inputs()[i]->getValType() == ValType::TensorView) {
+      TensorView* cg_tensor = fusion->inputs()[i]->as<TensorView>();
 
       TORCH_INTERNAL_ASSERT(
           aten_inputs[i].isTensor(),
@@ -601,28 +247,15 @@ ExpressionEvaluator bindFusionInputs(
           "Something went wrong configuring launch. Inputs no longer match.");
 
       for (const auto dim : c10::irange(root_dom.size())) {
-        const auto extent = root_dom[dim]->extent();
-        const auto value = aten_tensor.sizes()[dim];
-        const auto prev_value = evaluator.evaluate(extent);
-        if (prev_value.has_value()) {
-          TORCH_CHECK(
-              *prev_value == value,
-              "Attempting to bind ",
-              extent,
-              " to ",
-              value,
-              "but it's already set to ",
-              *prev_value);
-        } else {
-          evaluator.bind(extent, value);
-        }
+        evaluator.safeBind(
+            root_dom[dim]->extent(), aten_tensor.sizes()[dim], lower);
       }
     } else if (
-        inputs[i]->getValType().value() == ValType::Scalar &&
-        inputs[i]->getDataType().value() == DataType::Int) {
+        fusion->inputs()[i]->getValType().value() == ValType::Scalar &&
+        fusion->inputs()[i]->getDataType().value() == DataType::Int) {
       TORCH_INTERNAL_ASSERT(
           aten_inputs[i].type()->kind() == c10::TypeKind::IntType);
-      evaluator.bind(inputs[i], aten_inputs[i].toInt());
+      evaluator.safeBind(fusion->inputs()[i], aten_inputs[i].toInt(), lower);
     }
   }
   return evaluator;
@@ -631,9 +264,8 @@ ExpressionEvaluator bindFusionInputs(
 NvrtcFunction nvrtcCompile(
     const std::string& code,
     const std::string& func_name,
-    int id,
-    c10::optional<int> opt_block_size) {
-  FUSER_PERF_SCOPE("executor_utils::NVRTC");
+    int id) {
+  FUSER_PERF_SCOPE("NVRTC");
 
   // lazily construct context if non-existing yet;
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@@ -654,13 +286,13 @@ NvrtcFunction nvrtcCompile(
   nvrtcProgram program; // NOLINT(cppcoreguidelines-init-variables)
 
   {
-    FUSER_PERF_SCOPE("executor_utils::NvrtcCreateProgram");
+    FUSER_PERF_SCOPE("nvrtcCreateProgram");
     AT_CUDA_NVRTC_CHECK(at::globalContext().getNVRTC().nvrtcCreateProgram(
         &program, code.c_str(), nullptr, 0, nullptr, nullptr));
   }
 
   ResourceGuard holdProgram([&] {
-    FUSER_PERF_SCOPE("executor_utils::NvrtcDestroyProgram");
+    FUSER_PERF_SCOPE("nvrtcDestroyProgram");
     AT_CUDA_NVRTC_CHECK(
         at::globalContext().getNVRTC().nvrtcDestroyProgram(&program));
   });
@@ -671,34 +303,26 @@ NvrtcFunction nvrtcCompile(
   args.push_back("-hip-pch");
 #endif
 #else
-#if CUDA_VERSION < 11010
-  // compile to sass is not allowed prior to CUDA 11.1
-  compile_to_sass = false;
-#endif
-  // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
-  // which gives better backwards compatibility to work on older driver,
-  // (since older driver doesn't necessrily recognize PTX emitted by new
-  // toolkit);
-  // Meanwhile, for forward compatibility (future device with
-  // `unsupported_arch==True`), since SASS are not necessarily compatible,
-  // we fallback to PTX instead.
   const std::string compute = std::string("--gpu-architecture=") +
-      (compile_to_sass ? "sm_" : "compute_") + std::to_string(major) +
-      std::to_string(minor);
+#if CUDA_VERSION >= 11010
+      // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
+      // which gives better backwards compatibility to work on older driver,
+      // (since older driver doesn't necessrily recognize PTX emitted by new
+      // toolkit);
+      // Meanwhile, for forward compatibility (future device with
+      // `unsupported_arch==True`), since SASS are not necessarily compatible,
+      // we fallback to PTX instead.
+      (compile_to_sass ? "sm_" : "compute_") +
+#else
+      "compute_" +
+#endif
+      std::to_string(major) + std::to_string(minor);
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
   std::vector<const char*> args = {
       "--std=c++14", compute.c_str(), "-default-device"};
 #endif
 
-  const char* disable_fastmath = getenv("PYTORCH_NVFUSER_DISABLE_FASTMATH");
-  if (!disable_fastmath || (atoi(disable_fastmath) == 0)) {
-    args.push_back("--use_fast_math");
-  } else {
-    TORCH_WARN_ONCE(
-        "fast math disabled in nvfuser, try set `PYTORCH_NVFUSER_DISABLE_FASTMATH=0`");
-  }
-
-  const char* disable_fma = getenv("PYTORCH_NVFUSER_DISABLE_FMA");
+  const char* disable_fma = getenv("PYTORCH_CUDA_FUSER_DISABLE_FMA");
   // int disable_fma_flag = disable_fma ? atoi(disable_fma) : 0;
   if (disable_fma && atoi(disable_fma)) {
 #ifdef __HIP_PLATFORM_HCC__
@@ -709,103 +333,34 @@ NvrtcFunction nvrtcCompile(
 #endif
   }
 
-#ifndef NDEBUG
-  // Add line info to generated kernels
-  args.push_back("-lineinfo");
-#else
-  // Avoid excessive register usage from assertion
-  args.push_back("-DNDEBUG");
-#endif
-
-  const char* ptxas_opt_level = getenv("PYTORCH_NVFUSER_JIT_OPT_LEVEL");
-  std::string jit_opt_level = "-O";
+  const char* ptxas_opt_level = getenv("PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL");
+  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+  uint32_t jit_opt_level;
 
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
   std::vector<CUjit_option> options;
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
   std::vector<void*> option_vals;
-  std::vector<char> info_log;
-  unsigned int log_size = 8196;
-
-  if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) {
-    // show register usage in compilation log
-    if (compile_to_sass) {
-      args.push_back("--ptxas-options");
-      args.push_back("--verbose");
-    } else {
-      options.push_back(CU_JIT_LOG_VERBOSE);
-      option_vals.push_back((void*)1);
-      info_log.reserve(log_size);
-
-      options.push_back(CU_JIT_INFO_LOG_BUFFER);
-      option_vals.push_back((void*)info_log.data());
-
-      options.push_back(CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES);
-      option_vals.push_back((void*)(long)log_size);
-    }
-  }
 
   if (ptxas_opt_level) {
     int val = atoi(ptxas_opt_level);
     if (val <= 4 && val >= 0) {
-      if (compile_to_sass) {
-        jit_opt_level += std::to_string(val);
-        args.push_back("--ptxas-options");
-        args.push_back(jit_opt_level.c_str());
-      } else {
-        options.push_back(CU_JIT_OPTIMIZATION_LEVEL);
-        option_vals.push_back((void*)(intptr_t)val);
-      }
+      jit_opt_level = static_cast<uint32_t>(val);
+      options.push_back(CU_JIT_OPTIMIZATION_LEVEL);
+      option_vals.emplace_back(&jit_opt_level);
     } else {
       TORCH_WARN_ONCE(
-          "acceptable range for PYTORCH_NVFUSER_JIT_OPT_LEVEL is between 0 and 4, but received ",
-          val,
+          "acceptable range for PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL is between 0 and 4, but received ",
+          jit_opt_level,
           ", ignoring the option");
     }
   }
 
-#ifndef USE_ROCM
-  // keeping the string outside the loop for lifetime
-  std::string max_register_usage = "--maxrregcount=";
-  uint32_t max_register = 0;
-  if (opt_block_size.has_value() && opt_block_size.value() > 0) {
-    int num_partition = 0;
-    int reg_allocation_granularity = 0;
-    int max_regs_per_thread = 0;
-    cudaOccDeviceProp occ_prop(*prop);
-    cudaOccSubPartitionsPerMultiprocessor(&num_partition, &occ_prop);
-    cudaOccRegAllocationGranularity(&reg_allocation_granularity, &occ_prop);
-    cudaOccRegAllocationMaxPerThread(&max_regs_per_thread, &occ_prop);
-    int warp_size = prop->warpSize;
-    int num_warps = ceilDiv(opt_block_size.value(), warp_size);
-
-    // warps could be distributed unevenly across partition
-    int max_warps_per_sm_partition = ceilDiv(num_warps, num_partition);
-    // registers are evenly distributed across partitions, partition with most
-    // wraps determins the maximum register available per warp
-    int max_reg_per_warp =
-        prop->regsPerBlock / num_partition / max_warps_per_sm_partition;
-    // clamp down to register allocation granularity at warp level
-    int effective_max_reg_per_warp = max_reg_per_warp /
-        reg_allocation_granularity * reg_allocation_granularity;
-    max_register = static_cast<uint32_t>(
-        std::min(effective_max_reg_per_warp / warp_size, max_regs_per_thread));
-
-    if (compile_to_sass) {
-      max_register_usage += std::to_string(max_register);
-      args.push_back(max_register_usage.c_str());
-    } else {
-      options.push_back(CU_JIT_MAX_REGISTERS);
-      option_vals.push_back((void*)(intptr_t)max_register);
-    }
-  }
-#endif
-
   at::globalContext().getNVRTC().nvrtcAddNameExpression(
       program, func_name.c_str());
 
   {
-    FUSER_PERF_SCOPE("executor_utils::Nvrtc::CompileProgram");
+    FUSER_PERF_SCOPE("nvrtcCompileProgram");
 
     const auto result = at::globalContext().getNVRTC().nvrtcCompileProgram(
         program, args.size(), args.data());
@@ -820,14 +375,6 @@ NvrtcFunction nvrtcCompile(
 
       TORCH_INTERNAL_ASSERT(
           false, code.c_str(), "\nCUDA NVRTC compile error: ", log.data());
-    } else if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) {
-      // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-      size_t logsize;
-      at::globalContext().getNVRTC().nvrtcGetProgramLogSize(program, &logsize);
-      std::vector<char> log(logsize);
-      at::globalContext().getNVRTC().nvrtcGetProgramLog(program, log.data());
-
-      std::cout << log.data() << std::endl;
     }
 
     AT_CUDA_NVRTC_CHECK(result);
@@ -842,7 +389,7 @@ NvrtcFunction nvrtcCompile(
   std::vector<char> ptx;
 
   {
-    FUSER_PERF_SCOPE("executor_utils::Nvrtc::GetPTX");
+    FUSER_PERF_SCOPE("get PTX");
 #if CUDA_VERSION >= 11010
     // compile_to_sass determines whether we are generating SASS or PTX, hence
     // the different API.
@@ -865,86 +412,64 @@ NvrtcFunction nvrtcCompile(
 
   // TODO: We do go through different code path, should investigate whether this
   // has an impact on generated binary.
+  const char* prefix_env = getenv("PYTORCH_CUDA_FUSER_CUBIN");
 #ifndef __HIP_PLATFORM_HCC__
-  const char* prefix_env = getenv("PYTORCH_NVFUSER_CUBIN");
   if (prefix_env) {
-    FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadCUBIN");
+#if CUDA_VERSION >= 11010
+    TORCH_CHECK(
+        !compile_to_sass,
+        "PYTORCH_NVFUSER_CUBIN cannot be used when compile direct to SASS. Please set PYTORCH_NVFUSER_CUBIN to empty");
+#endif
+    FUSER_PERF_SCOPE("load CUBIN");
 
     // Output ptx file
-    std::stringstream output_file_name;
-    output_file_name << prefix_env << "_" << id
-                     << (compile_to_sass ? ".cubin" : ".ptx");
-    std::ofstream outputFile(output_file_name.str().c_str(), std::ios::out);
-    if (outputFile.is_open()) {
-      outputFile.write(ptx.data(), ptx.size());
-      outputFile.close();
+    std::stringstream ptx_file_name;
+    ptx_file_name << prefix_env << "_" << id << ".ptx";
+    std::ofstream myPtxFile(ptx_file_name.str().c_str(), std::ios::out);
+    if (myPtxFile.is_open()) {
+      myPtxFile.write(ptx.data(), ptx.size());
+      myPtxFile.close();
     }
 
-    if (compile_to_sass) {
-      FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadPTX");
+    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+    CUlinkState linkState;
 
-      // load sass directly
-      AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx(
-          &(compiled_kernel_.module),
-          ptx.data(),
-          options.size(),
-          options.data(),
-          option_vals.data()));
-    } else {
-      // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-      CUlinkState linkState;
-
-      AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkCreate(
-          // 0, nullptr, nullptr, &linkState));
-          options.size(),
-          options.data(),
-          option_vals.data(),
-          &linkState));
-
-      AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkAddData(
-          linkState,
-          CU_JIT_INPUT_PTX,
-          ptx.data(),
-          ptx_size,
-          "compiling PTX",
-          0,
-          nullptr,
-          nullptr));
-
-      if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) {
-        std::cout << info_log.data() << std::endl;
-      }
+    AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkCreate(
+        0, nullptr, nullptr, &linkState));
 
-      // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-      size_t cubinSize;
-      // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-      void* cubin;
-      AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkComplete(
-          linkState, &cubin, &cubinSize));
+    AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkAddData(
+        linkState,
+        CU_JIT_INPUT_PTX,
+        ptx.data(),
+        ptx_size,
+        "compiling PTX",
+        options.size(),
+        options.data(),
+        option_vals.data()));
 
-      // Output binary file
-      std::stringstream cubin_file_name;
-      cubin_file_name << prefix_env << "_" << id << ".cubin";
+    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+    size_t cubinSize;
+    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+    void* cubin;
+    AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkComplete(
+        linkState, &cubin, &cubinSize));
 
-      std::ofstream myCubinFile(
-          cubin_file_name.str().c_str(), std::ios::out | std::ios::binary);
+    // Output binary file
+    std::stringstream cubin_file_name;
+    cubin_file_name << prefix_env << "_" << id << ".cubin";
 
-      if (myCubinFile.is_open()) {
-        myCubinFile.write(static_cast<const char*>(cubin), cubinSize);
-        myCubinFile.close();
-      }
-      // load compiled cubin
-      // AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData(
-      //     &(compiled_kernel_.module), cubin));
-      AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx(
-          &(compiled_kernel_.module),
-          cubin,
-          options.size(),
-          options.data(),
-          option_vals.data()));
+    std::ofstream myCubinFile(
+        cubin_file_name.str().c_str(), std::ios::out | std::ios::binary);
+
+    if (myCubinFile.is_open()) {
+      myCubinFile.write(static_cast<const char*>(cubin), cubinSize);
+      myCubinFile.close();
     }
+    // load compiled cubin
+    AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData(
+        &(compiled_kernel_.module), cubin));
   } else {
-    FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadPTX");
+    FUSER_PERF_SCOPE("load PTX");
 
     // load ptx directly
     AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx(
@@ -953,11 +478,6 @@ NvrtcFunction nvrtcCompile(
         options.size(),
         options.data(),
         option_vals.data()));
-
-    if (!compile_to_sass &&
-        isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) {
-      std::cout << info_log.data() << std::endl;
-    }
   }
 #else
   // load ptx directly
index 299c5ff..28a702b 100644 (file)
@@ -5,15 +5,13 @@
 #include <c10/core/DeviceType.h>
 #include <c10/util/Exception.h>
 
-#include <ATen/cuda/CUDAContext.h>
+#include <cuda.h>
 
 #include <torch/csrc/jit/ir/ir.h>
 
 #include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
 
 #include <string>
@@ -28,44 +26,20 @@ namespace executor_utils {
 // Include all the functions we might need in generated code
 std::string kernelPreamble();
 
-// TODO(kir): rewrite in terms of Kernel inputs
 void validateKernelInputs(
     Fusion* fusion,
     const at::ArrayRef<IValue>& inputs,
     const c10::Device& device);
 
-// TODO(kir): rewrite in terms of Kernel outputs
 void validateKernelOutputs(
     Fusion* fusion,
     const std::vector<at::Tensor>& outputs,
     const c10::Device& device);
 
-// Returns if vectorizing the aten value by word size is possible
-bool canVectorize(const IValue& aten_val, int word_size);
-
-// Returns if vectorizing the aten value by word size is possible
-bool canVectorize(
-    TensorView* fusion_tv,
-    int word_size,
-    GpuLower& lower,
-    kir::ExpressionEvaluator& expr_eval);
-
-// TODO(kir): rewrite in terms of Kernel tensors
-void validateVectorizedTensors(
-    Fusion* fusion,
-    const at::ArrayRef<IValue>& inputs,
-    const std::vector<at::Tensor>& outputs,
-    GpuLower& lower,
-    kir::ExpressionEvaluator& expr_eval);
-
-//! Bind kernel input values to runtime values
-kir::ExpressionEvaluator bindKernelInputs(
+StatefulExpressionEvaluator statefulBindInputs(
     const at::ArrayRef<IValue>& aten_inputs,
-    kir::Kernel* kernel);
-
-//! Bind fusion input values to runtime values
-TORCH_CUDA_CU_API ExpressionEvaluator
-bindFusionInputs(const at::ArrayRef<IValue>& aten_inputs, Fusion* fusion);
+    Fusion* fusion,
+    GpuLower* lower = nullptr);
 
 struct NvrtcFunction {
   CUmodule module = CUmodule();
@@ -75,8 +49,7 @@ struct NvrtcFunction {
 NvrtcFunction nvrtcCompile(
     const std::string& code,
     const std::string& func_name,
-    int id,
-    c10::optional<int> opt_block_size = c10::nullopt);
+    int id);
 
 } // namespace executor_utils
 } // namespace cuda
index 1c00da6..21e018e 100644 (file)
@@ -1,4 +1,3 @@
-
 #include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
@@ -12,67 +11,173 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-void ExpressionEvaluator::bind(Val* value, Int::ScalarType concrete_value) {
-  TORCH_CHECK(value->isAnInt());
-  auto val = value->getInt();
-  if (val.has_value() && val.value() == concrete_value) {
-    return;
+void StatefulExpressionEvaluator::safeBind(
+    Val* value,
+    Int::ScalarType concrete_value,
+    GpuLower* lower) {
+  auto already_concrete_val = getValue(value);
+
+  if (already_concrete_val.has_value()) {
+    TORCH_INTERNAL_ASSERT(
+        concrete_value == already_concrete_val.value(),
+        "Tried to bind ",
+        value,
+        " to ",
+        " concrete value, but it's already set to ",
+        already_concrete_val.value());
+  } else {
+    TORCH_INTERNAL_ASSERT(
+        value->getOrigin() == nullptr,
+        "Tried to bind to a value that is computed in the fusion IR. ",
+        "Can only bind to symbolic values to the fusion that do not have an origin expr.");
+
+    bindings_[value] = concrete_value;
   }
-  TORCH_CHECK(!value->isConstScalar(), "Tried to bind to a constant value");
-  TORCH_CHECK(
-      value->definition() == nullptr,
-      "Tried to bind to a value that is computed in the fusion IR");
-  known_values_[value] = concrete_value;
-}
 
-c10::optional<Int::ScalarType> ExpressionEvaluator::evaluate(Val* value) {
-  FUSER_PERF_SCOPE("ExpressionEvaluator::evaluate");
-  auto maybe_concrete_value = getValue(value);
-  if (!maybe_concrete_value.has_value()) {
-    if (value->definition() != nullptr) {
-      OptOutDispatch::handle(value->definition());
-      maybe_concrete_value = getValue(value);
+  if (lower != nullptr) {
+    // TODO(kir): we should not need to lower (or mutate the IR in any way)
+    //  during expression evaluation
+    auto lowered_val = lower->getLowerValue(value);
+    already_concrete_val = getValue(lowered_val);
+
+    if (already_concrete_val.has_value()) {
+      TORCH_INTERNAL_ASSERT(
+          concrete_value == already_concrete_val.value(),
+          "Tried to bind ",
+          lowered_val,
+          " to ",
+          " concrete value, but it's already set to ",
+          already_concrete_val.value());
+    } else {
+      TORCH_INTERNAL_ASSERT(
+          lowered_val->getOrigin() == nullptr,
+          "Tried to bind to a value that is computed in the fusion IR. ",
+          "Can only bind to symbolic values to the fusion that do not have an origin expr.");
+
+      bindings_[lowered_val] = concrete_value;
     }
   }
-  return maybe_concrete_value;
 }
 
-void ExpressionEvaluator::print() const {
+c10::optional<Int::ScalarType> StatefulExpressionEvaluator::inferValue(
+    Val* value) {
+  FUSER_PERF_SCOPE("inferValue");
+  return maybeHandle(value);
+}
+
+void StatefulExpressionEvaluator::print() const {
   std::cout << "\nEvaluation context\n";
   std::cout << "--------------------\n";
-  for (const auto& kv : known_values_) {
-    TORCH_INTERNAL_ASSERT(!kv.first->isConstScalar());
-    std::cout << kv.first << " = " << kv.second << " ; "
-              << *kv.first->getValType() << "\n";
+  for (const auto& kv : bindings_) {
+    std::cout << kv.first << " = " << kv.second;
+    if (kv.first->isConstScalar()) {
+      std::cout << " ; original value = "
+                << kv.first->as<Int>()->value().value();
+    }
+    std::cout << " ; " << *kv.first->getValType() << "\n";
   }
   std::cout << "--------------------\n\n";
 }
 
-c10::optional<Int::ScalarType> ExpressionEvaluator::getValue(Val* value) {
+c10::optional<Int::ScalarType> StatefulExpressionEvaluator::getValue(
+    Val* value) {
   TORCH_INTERNAL_ASSERT(
       value->isAnInt(),
       "Expression Evaluation does not support values other than integers at this time.");
 
-  if (value->getValType().value() == ValType::Scalar) {
-    if (value->as<Int>()->value().has_value()) {
-      return value->as<Int>()->value();
+  switch (value->getValType().value()) {
+    case ValType::Scalar:
+      if (value->as<Int>()->value().has_value()) {
+        return value->as<Int>()->value();
+      }
+      break;
+    case ValType::KirScalar:
+      if (value->as<kir::Int>()->value().has_value()) {
+        return value->as<kir::Int>()->value();
+      }
+      break;
+    default:
+      break;
+  }
+
+  const auto it = bindings_.find(value);
+  return it != bindings_.end() ? c10::optional<Int::ScalarType>(it->second)
+                               : c10::nullopt;
+}
+
+c10::optional<Int::ScalarType> StatefulExpressionEvaluator::maybeHandle(
+    Val* val) {
+  auto maybe_concrete_value = getValue(val);
+  if (!maybe_concrete_value.has_value()) {
+    auto origin = val->getOrigin();
+    if (origin != nullptr) {
+      handle(origin);
+      maybe_concrete_value = getValue(val);
     }
   }
+  return maybe_concrete_value;
+}
 
-  const auto it = known_values_.find(value);
-  return it != known_values_.end() ? c10::optional<Int::ScalarType>(it->second)
-                                   : c10::nullopt;
+void StatefulExpressionEvaluator::handle(UnaryOp* uop) {
+  const auto in = maybeHandle(uop->in());
+  if (in.has_value()) {
+    switch (uop->getUnaryOpType()) {
+      case UnaryOpType::Neg:
+        bindings_[uop->out()] = -*in;
+        break;
+      case UnaryOpType::Cast:
+        bindings_[uop->out()] = *in;
+        break;
+      default:
+        TORCH_CHECK(!"Unexpected operator type");
+    }
+  }
+}
+
+void StatefulExpressionEvaluator::handle(BinaryOp* bop) {
+  const auto lhs = maybeHandle(bop->lhs());
+  const auto rhs = maybeHandle(bop->rhs());
+  if (lhs.has_value() && rhs.has_value()) {
+    switch (bop->getBinaryOpType()) {
+      case BinaryOpType::Add:
+        bindings_[bop->out()] = *lhs + *rhs;
+        break;
+      case BinaryOpType::Sub:
+        bindings_[bop->out()] = *lhs - *rhs;
+        break;
+      case BinaryOpType::Mul:
+        bindings_[bop->out()] = *lhs * *rhs;
+        break;
+      case BinaryOpType::Div:
+        TORCH_CHECK(*rhs != 0);
+        bindings_[bop->out()] = *lhs / *rhs;
+        break;
+      case BinaryOpType::Mod:
+        TORCH_CHECK(*rhs != 0);
+        bindings_[bop->out()] = *lhs % *rhs;
+        break;
+      case BinaryOpType::CeilDiv:
+        TORCH_CHECK(*rhs != 0);
+        bindings_[bop->out()] = (*lhs + *rhs - 1) / *rhs;
+        break;
+      case BinaryOpType::And:
+        bindings_[bop->out()] = Int::ScalarType(*lhs && *rhs);
+        break;
+      default:
+        TORCH_CHECK(!"Unexpected operator type");
+    }
+  }
 }
 
-void ExpressionEvaluator::handle(UnaryOp* uop) {
-  const auto in = evaluate(uop->in());
+void StatefulExpressionEvaluator::handle(kir::UnaryOp* uop) {
+  const auto in = maybeHandle(uop->in());
   if (in.has_value()) {
     switch (uop->getUnaryOpType()) {
       case UnaryOpType::Neg:
-        known_values_[uop->out()] = -*in;
+        bindings_[uop->out()] = -*in;
         break;
       case UnaryOpType::Cast:
-        known_values_[uop->out()] = *in;
+        bindings_[uop->out()] = *in;
         break;
       default:
         TORCH_CHECK(!"Unexpected operator type");
@@ -80,34 +185,34 @@ void ExpressionEvaluator::handle(UnaryOp* uop) {
   }
 }
 
-void ExpressionEvaluator::handle(BinaryOp* bop) {
-  const auto lhs = evaluate(bop->lhs());
-  const auto rhs = evaluate(bop->rhs());
+void StatefulExpressionEvaluator::handle(kir::BinaryOp* bop) {
+  const auto lhs = maybeHandle(bop->lhs());
+  const auto rhs = maybeHandle(bop->rhs());
   if (lhs.has_value() && rhs.has_value()) {
     switch (bop->getBinaryOpType()) {
       case BinaryOpType::Add:
-        known_values_[bop->out()] = *lhs + *rhs;
+        bindings_[bop->out()] = *lhs + *rhs;
         break;
       case BinaryOpType::Sub:
-        known_values_[bop->out()] = *lhs - *rhs;
+        bindings_[bop->out()] = *lhs - *rhs;
         break;
       case BinaryOpType::Mul:
-        known_values_[bop->out()] = *lhs * *rhs;
+        bindings_[bop->out()] = *lhs * *rhs;
         break;
       case BinaryOpType::Div:
         TORCH_CHECK(*rhs != 0);
-        known_values_[bop->out()] = *lhs / *rhs;
+        bindings_[bop->out()] = *lhs / *rhs;
         break;
       case BinaryOpType::Mod:
         TORCH_CHECK(*rhs != 0);
-        known_values_[bop->out()] = *lhs % *rhs;
+        bindings_[bop->out()] = *lhs % *rhs;
         break;
       case BinaryOpType::CeilDiv:
         TORCH_CHECK(*rhs != 0);
-        known_values_[bop->out()] = (*lhs + *rhs - 1) / *rhs;
+        bindings_[bop->out()] = (*lhs + *rhs - 1) / *rhs;
         break;
       case BinaryOpType::And:
-        known_values_[bop->out()] = Int::ScalarType(*lhs && *rhs);
+        bindings_[bop->out()] = Int::ScalarType(*lhs && *rhs);
         break;
       default:
         TORCH_CHECK(!"Unexpected operator type");
index 8632e1c..a316aa7 100644 (file)
@@ -3,6 +3,7 @@
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
+#include <torch/csrc/jit/codegen/cuda/lower2device.h>
 
 #include <c10/util/Optional.h>
 
@@ -13,34 +14,67 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-//! Calculate Fusion IR expressions
-class TORCH_CUDA_CU_API ExpressionEvaluator : private OptOutDispatch {
+class TORCH_CUDA_CU_API StatefulExpressionEvaluator : private OptOutDispatch {
  public:
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
-  explicit ExpressionEvaluator(Fusion* fusion) : fusion_(fusion) {}
+  explicit StatefulExpressionEvaluator(Fusion* fusion) : fusion_(fusion) {}
 
-  //! Returns the associated fusion object
   Fusion* fusion() const {
     return fusion_;
   }
 
-  //! Bind a concrete value to an IR variable
-  void bind(Val* value, Int::ScalarType concrete_value);
+  void safeBind(
+      Val* value,
+      Int::ScalarType concrete_value,
+      GpuLower* lower = nullptr);
 
-  //! Try to evaluate a Fusion IR value
-  c10::optional<Int::ScalarType> evaluate(Val* value);
+  // Returns value if found in mapping, otherwise returns c10::nullopt
+  c10::optional<Int::ScalarType> getValue(Val* value);
+
+  // Checks if value is already infered, returns infered value if so, otherwise
+  // runs traversal on value. Warning: should not be called in traversal.
+  c10::optional<Int::ScalarType> inferValue(Val* value);
 
-  //! Debugging helper, prints all the currently known values
+  // Debugging helper, prints all the currently set values
   void print() const;
 
  private:
-  c10::optional<Int::ScalarType> getValue(Val* value);
+  using OptOutDispatch::handle;
+
+  void handle(Expr* expr) override {
+    switch (expr->getExprType().value()) {
+      case ExprType::UnaryOp:
+        handle(expr->as<UnaryOp>());
+        break;
+      case ExprType::BinaryOp:
+        handle(expr->as<BinaryOp>());
+        break;
+      case ExprType::KirUnaryOp:
+        handle(expr->as<kir::UnaryOp>());
+        break;
+      case ExprType::KirBinaryOp:
+        handle(expr->as<kir::BinaryOp>());
+        break;
+      default:
+        TORCH_INTERNAL_ASSERT(
+            false,
+            "Cannot handle Expr type: ",
+            expr->getExprType().value(),
+            " in stateful expression evaluator.");
+    }
+  }
+
+  void handle(UnaryOp*) override;
+  void handle(BinaryOp*) override;
+
+  // TODO(kir): remove this
+  void handle(kir::UnaryOp*) override;
+  void handle(kir::BinaryOp*) override;
 
-  void handle(UnaryOp*) override final;
-  void handle(BinaryOp*) override final;
+  c10::optional<Int::ScalarType> maybeHandle(Val*);
 
  private:
-  std::unordered_map<const Val*, Int::ScalarType> known_values_;
+  std::unordered_map<const Val*, Int::ScalarType> bindings_;
   Fusion* fusion_ = nullptr;
 };
 
index 7ff2453..4a6fc58 100644 (file)
@@ -1,15 +1,15 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/codegen.h>
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
+
+#include <torch/csrc/jit/codegen/cuda/codegen.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
 #include <torch/csrc/jit/codegen/cuda/ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
 #include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
+#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
 
+// TODO(kir): only needed until we can fix Fusion::origin()
 #include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
 
 namespace torch {
@@ -17,7 +17,7 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-static thread_local Fusion* ACTIVE_FUSION = nullptr; // NOLINT
+static thread_local Fusion* ACTIVE_FUSION = nullptr;
 
 FusionGuard::FusionGuard(Fusion* fusion) {
   prev_fusion = ACTIVE_FUSION;
@@ -32,7 +32,7 @@ Fusion* FusionGuard::getCurFusion() {
   return ACTIVE_FUSION;
 }
 
-TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept {
+void swap(Fusion& a, Fusion& b) noexcept {
   FUSER_PERF_SCOPE("Fusion swap");
 
   using std::swap;
@@ -45,11 +45,12 @@ TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept {
   swap(a.val_type_name_map_, b.val_type_name_map_);
   swap(a.expr_name_counter_, b.expr_name_counter_);
 
+  swap(a.origin_, b.origin_);
+  swap(a.uses_, b.uses_);
+
   swap(a.inputs_, b.inputs_);
   swap(a.outputs_, b.outputs_);
 
-  swap(a.io_alias_, b.io_alias_);
-
   // Fixup the Statement::fusion_ links for a
   for (auto val : a.val_set_) {
     val->fusion_ = &a;
@@ -65,54 +66,63 @@ TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept {
   for (auto expr : b.expr_set_) {
     expr->fusion_ = &b;
   }
-}
 
-Fusion::Fusion(const Fusion& other) {
-  FUSER_PERF_SCOPE("Fusion copy");
-  Fusion::copy(&other, this);
-}
+  // Lowered IR nodes
+  swap(a.lowered_val_set_, b.lowered_val_set_);
+  swap(a.lowered_expr_set_, b.lowered_expr_set_);
+  swap(a.lowered_origin_, b.lowered_origin_);
 
-std::unique_ptr<SegmentedFusion> Fusion::segment(
-    const at::ArrayRef<IValue>& inputs) {
-  FUSER_PERF_SCOPE("Segment Fusion");
-  return SegmentCandidateFinder::segment(this, inputs);
+  for (auto val : a.lowered_val_set_) {
+    val->fusion_ = &a;
+  }
+  for (auto expr : a.lowered_expr_set_) {
+    expr->fusion_ = &a;
+  }
+  for (auto val : b.lowered_val_set_) {
+    val->fusion_ = &b;
+  }
+  for (auto expr : b.lowered_expr_set_) {
+    expr->fusion_ = &b;
+  }
 }
 
-IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
-  to->clear();
-  IrCloner ir_cloner(to);
+Fusion::Fusion(const Fusion& other) {
+  FUSER_PERF_SCOPE("Fusion copy");
 
-  for (auto val : from->val_set_) {
-    to->val_set_.insert(ir_cloner.clone(val));
-  }
+  IrCloner ir_cloner(this);
 
-  for (auto expr : from->expr_set_) {
-    to->expr_set_.insert(ir_cloner.clone(expr));
+  for (auto val : other.val_set_) {
+    val_set_.insert(ir_cloner.clone(val));
   }
 
-  for (auto val : from->val_deque_) {
-    to->val_deque_.push_back(ir_cloner.clone(val));
+  for (auto expr : other.expr_set_) {
+    expr_set_.insert(ir_cloner.clone(expr));
   }
 
-  for (auto val : from->val_set_) {
-    ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
-    ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
+  for (auto val : other.val_deque_) {
+    val_deque_.push_back(ir_cloner.clone(val));
   }
 
-  to->val_type_name_map_ = from->val_type_name_map_;
-  to->expr_name_counter_ = from->expr_name_counter_;
+  val_type_name_map_ = other.val_type_name_map_;
+  expr_name_counter_ = other.expr_name_counter_;
 
-  to->inputs_ = ir_cloner.clone(from->inputs_);
-  to->outputs_ = ir_cloner.clone(from->outputs_);
+  for (const auto& kv : other.origin_) {
+    auto val = ir_cloner.clone(kv.first);
+    auto expr = ir_cloner.clone(kv.second);
+    origin_.insert({val, expr});
+  }
 
-  // TODO: put this into ir_cloner instead
-  for (const auto& entry : from->io_alias_) {
-    Val* copied_output = ir_cloner.clone(entry.first);
-    Val* copied_input = ir_cloner.clone(entry.second);
-    to->io_alias_[copied_output] = copied_input;
+  for (const auto& kv : other.uses_) {
+    auto val = ir_cloner.clone(kv.first);
+    std::unordered_set<Expr*> val_uses;
+    for (auto expr : kv.second) {
+      val_uses.insert(ir_cloner.clone(expr));
+    }
+    uses_.insert({val, std::move(val_uses)});
   }
 
-  return ir_cloner;
+  inputs_ = ir_cloner.clone(other.inputs_);
+  outputs_ = ir_cloner.clone(other.outputs_);
 }
 
 Fusion::Fusion(Fusion&& other) noexcept {
@@ -162,10 +172,22 @@ void Fusion::clear() noexcept {
 
   expr_name_counter_ = 0;
 
+  origin_.clear();
+  uses_.clear();
+
   inputs_.clear();
   outputs_.clear();
 
-  io_alias_.clear();
+  // Lowered IR nodes
+  for (auto ptr : lowered_val_set_) {
+    delete ptr;
+  }
+  for (auto ptr : lowered_expr_set_) {
+    delete ptr;
+  }
+  lowered_val_set_.clear();
+  lowered_expr_set_.clear();
+  lowered_origin_.clear();
 }
 
 void Fusion::removeExpr(Expr* expr) {
@@ -174,16 +196,16 @@ void Fusion::removeExpr(Expr* expr) {
   // that removing something that doesn't exist simply does nothing. For now,
   // we're going with the strictest model which errors.
 
-  for (auto out : expr->outputs()) {
-    out->setDefinition(nullptr);
-  }
+  for (auto out : expr->outputs())
+    if (origin_.find(out) != origin_.end())
+      if (origin_.find(out)->second == expr)
+        origin_.erase(out);
 
   for (auto inp : expr->inputs()) {
-    auto uses_copy = inp->uses();
-    auto it = std::find(uses_copy.begin(), uses_copy.end(), expr);
-    if (it != uses_copy.end()) {
-      uses_copy.erase(it);
-      inp->setUses(uses_copy);
+    if (uses_.find(inp) != uses_.end()) {
+      if (uses_.find(inp)->second.find(expr) != uses_.find(inp)->second.end()) {
+        uses_.find(inp)->second.erase(expr);
+      }
     }
   }
 
@@ -195,16 +217,17 @@ void Fusion::removeExpr(Expr* expr) {
 void Fusion::removeVal(Val* val) {
   assertInFusion(val, "Cannot remove val ");
 
-  TORCH_CHECK(
-      !val->isFusionInput(),
-      "Cannot remove val as it is an input of the fusion.");
-  TORCH_CHECK(
-      !val->isFusionOutput(),
-      "Cannot remove val as it is an output of the fusion.");
+  for (Val* inp : inputs())
+    if (val->sameAs(inp))
+      TORCH_CHECK(false, "Cannot remove val as it is an input of the fusion.");
+
+  for (Val* out : outputs())
+    if (val->sameAs(out))
+      TORCH_CHECK(false, "Cannot remove val as it is an output of the fusion.");
 
-  Expr* orig = val->definition();
+  Expr* orig = origin(val);
   if (orig != nullptr)
-    removeExpr(val->definition());
+    removeExpr(origin(val));
 
   for (Expr* use : unordered_uses(val))
     removeExpr(use);
@@ -225,13 +248,22 @@ void Fusion::addInput(Val* input) {
 
   if (input->getValType().value() == ValType::TensorView) {
     auto tv = input->as<TensorView>();
+    if (tv->hasReduction()) {
+      TORCH_WARN_ONCE(
+          "Registered input ",
+          input,
+          " has a reduction axis, but this does nothing in the fusion.");
+    }
     tv->setMemoryType(MemoryType::Global);
   }
 
+  TORCH_INTERNAL_ASSERT(
+      input->getOrigin() == nullptr,
+      input,
+      " cannot be registered as an input as it is used as an output of an expression (",
+      input->getOrigin(),
+      ").");
   inputs_.push_back(input);
-  input->setIsFusionInput(true);
-
-  all_tv_uses_valid_ = false;
 }
 
 void Fusion::addOutput(Val* output) {
@@ -241,66 +273,33 @@ void Fusion::addOutput(Val* output) {
     tv->setMemoryType(MemoryType::Global);
   }
   outputs_.push_back(output);
-  output->setIsFusionOutput(true);
-
-  all_tv_uses_valid_ = false;
 }
 
-void Fusion::addOutput(WelfordResult& wr) {
-  // Want to always make sure the avg gets added last
-  //  since avg will be the out() value of welfordOp,
-  //  and want to make it the top of the computeAt chain
-  addOutput(wr.var_sum);
-  addOutput(wr.n);
-  addOutput(wr.avg);
-}
+bool Fusion::inFusion(const Statement* stmt) const {
+  bool in_fusion = stmt->fusion() == this;
+  Statement* nonconst_stmt = const_cast<Statement*>(stmt); // NOLINT
 
-void Fusion::removeInput(Val* input) {
-  auto find_input = std::find(inputs_.begin(), inputs_.end(), input);
-  if (find_input != inputs_.end()) {
-    inputs_.erase(find_input);
+  if (stmt->isExpr()) {
+    in_fusion &= expr_set_.find(nonconst_stmt->as<Expr>()) != expr_set_.end();
   }
-  input->setIsFusionInput(false);
-  all_tv_uses_valid_ = false;
-}
-
-void Fusion::removeOutput(Val* output) {
-  auto find_output = std::find(outputs_.begin(), outputs_.end(), output);
-  if (find_output != outputs_.end()) {
-    outputs_.erase(find_output);
+  if (stmt->isVal()) {
+    in_fusion &= val_set_.find(nonconst_stmt->as<Val>()) != val_set_.end();
   }
-  output->setIsFusionOutput(false);
-  all_tv_uses_valid_ = false;
-}
-
-void Fusion::replaceOutput(Val* output, Val* replacement) {
-  auto find_output = std::find(outputs_.begin(), outputs_.end(), output);
-  TORCH_CHECK(find_output != outputs_.end(), "Unable to find output in Fusion");
 
-  if (find_output != outputs_.end()) {
-    *find_output = replacement;
-
-    if (replacement->getValType().value() == ValType::TensorView) {
-      replacement->setIsFusionOutput(true);
-      replacement->as<TensorView>()->setMemoryType(MemoryType::Global);
-    }
-    if (output->getValType().value() == ValType::TensorView) {
-      output->setIsFusionOutput(false);
-      output->as<TensorView>()->setMemoryType(MemoryType::Local);
-    }
-    resetTvUses();
-  }
+  return in_fusion;
 }
 
-bool Fusion::inFusion(const Statement* stmt) const {
+bool Fusion::inKernelIr(const Statement* stmt) const {
   bool in_fusion = stmt->fusion() == this;
   Statement* nonconst_stmt = const_cast<Statement*>(stmt); // NOLINT
 
   if (stmt->isExpr()) {
-    in_fusion &= expr_set_.find(nonconst_stmt->as<Expr>()) != expr_set_.end();
+    in_fusion &= lowered_expr_set_.find(nonconst_stmt->as<Expr>()) !=
+        lowered_expr_set_.end();
   }
   if (stmt->isVal()) {
-    in_fusion &= val_set_.find(nonconst_stmt->as<Val>()) != val_set_.end();
+    in_fusion &= lowered_val_set_.find(nonconst_stmt->as<Val>()) !=
+        lowered_val_set_.end();
   }
 
   return in_fusion;
@@ -308,14 +307,20 @@ bool Fusion::inFusion(const Statement* stmt) const {
 
 void Fusion::assertInFusion(const Statement* stmt, const std::string& msg)
     const {
-  TORCH_CHECK(inFusion(stmt), msg, " it was not found in the active fusion.");
+  if (inFusion(stmt)) {
+    return;
+  }
+  if (inKernelIr(stmt)) {
+    return;
+  }
+  TORCH_CHECK(false, msg, " it was not found in the active fusion.");
 }
 
-std::vector<Expr*> Fusion::exprs() {
-  return ExprSort::getExprs(this);
+std::vector<Expr*> Fusion::exprs(bool from_outputs_only) {
+  return ExprSort::getExprs(this, from_outputs_only);
 }
 
-std::vector<Val*> Fusion::inputsOf(Val* val) {
+std::unordered_set<Val*> Fusion::inputsOf(Val* val) {
   return InputsOf::output(this, val);
 }
 
@@ -327,13 +332,12 @@ void Fusion::validateInputs() {
     }
   }
   for (Val* input : all_inputs) {
-    if (!input->isConstScalar()) {
+    if (!input->isConstScalar())
       TORCH_CHECK(
-          hasInput(input) || inFusion(input),
+          hasInput(input),
           "Could not figure out how ",
           input,
           " is generated, however it was not specified as an input.");
-    }
   }
 }
 
@@ -341,13 +345,12 @@ void Fusion::print() {
   FUSER_PERF_SCOPE("Fusion::print");
 
   FusionGuard fg(this);
-  std::cout << "\n%kernel {\n";
+  std::cout << "%kernel {\n";
   IrMathPrinter op_exprs(std::cout);
   op_exprs.handle(this);
-  std::cout << "\nTransformPrinter : \n";
   IrTransformPrinter t_exprs(std::cout);
   t_exprs.handle(this);
-  std::cout << "}\n\n";
+  std::cout << "}\n";
 }
 
 void Fusion::printKernel() {
@@ -355,38 +358,12 @@ void Fusion::printKernel() {
   std::cout << codegen::generateCudaKernel(GpuLower(this).kernel());
 }
 
-void Fusion::printMath(bool from_outputs_only) {
+void Fusion::printMath() {
   FUSER_PERF_SCOPE("Fusion::printMath");
 
   FusionGuard fg(this);
-  auto exprs_for_print = exprs();
-  std::cout << "Inputs:" << std::endl;
-  for (auto inp : inputs()) {
-    std::cout << "  " << inp << ", " << inp->getDataType().value() << std::endl;
-  }
-
-  std::cout << "Outputs:" << std::endl;
-  for (auto out : outputs()) {
-    std::cout << "  " << out << ", " << out->getDataType().value() << std::endl;
-  }
-
-  // If we want everything in the fusion, grab all values without uses to
-  // traverse from.
-  if (!from_outputs_only) {
-    std::vector<Val*> leaf_vals;
-    for (auto val : deterministic_vals()) {
-      if (val->uses().empty()) {
-        leaf_vals.push_back(val);
-      }
-    }
-    exprs_for_print = ExprSort::getExprs(this, leaf_vals);
-  }
-
-  std::cout << "\n%kernel_math {\n";
-  for (auto expr : exprs_for_print) {
+  for (auto expr : exprs(true))
     std::cout << expr;
-  }
-  std::cout << "}\n\n";
 }
 
 void Fusion::printTransforms() {
@@ -398,6 +375,8 @@ void Fusion::printTransforms() {
 }
 
 StmtNameType Fusion::registerVal(Val* val) {
+  TORCH_CHECK(!inKernelIr(val));
+
   if (val->fusion()) {
     if (val->fusion() != this) {
       TORCH_CHECK(false, val, " was not found in the active fusion.");
@@ -413,6 +392,8 @@ StmtNameType Fusion::registerVal(Val* val) {
 }
 
 StmtNameType Fusion::registerExpr(Expr* expr) {
+  TORCH_CHECK(!inKernelIr(expr));
+
   if (expr->fusion()) {
     if (expr->fusion() != this) {
       TORCH_CHECK(false, expr, " was not found in the active fusion.");
@@ -424,25 +405,26 @@ StmtNameType Fusion::registerExpr(Expr* expr) {
 
   for (Val* input : expr->inputs()) {
     assertInFusion(input, "Input to expr is invalid, ");
-    auto uses_copy = input->uses();
-    if (std::find(uses_copy.begin(), uses_copy.end(), expr) ==
-        uses_copy.end()) {
-      uses_copy.push_back(expr);
-      input->setUses(uses_copy);
+    TORCH_CHECK(!inKernelIr(input));
+    if (uses_.find(input) == uses_.end()) {
+      uses_[input] = {expr};
+    } else {
+      uses_.find(input)->second.emplace(expr);
     }
   }
 
   for (Val* output : expr->outputs()) {
     assertInFusion(output, "Output to expr is invalid, ");
-    if (output->definition() != nullptr) {
-      removeExpr(output->definition());
+    TORCH_CHECK(!inKernelIr(output));
+    auto it = origin_.find(output);
+    if (it != origin_.end()) {
+      removeExpr(it->second); // will also remove origin entry
     }
-    output->setDefinition(expr);
+
+    origin_[output] = expr;
   }
 
   expr_set_.emplace(expr);
-
-  resetTvUses();
   return getExprName();
 }
 
@@ -459,37 +441,39 @@ StmtNameType Fusion::registerStatement(Statement* stmt) {
   TORCH_INTERNAL_ASSERT(
       false,
       "Could not register statement as Fusion could not recognize its type.");
-  return kInvalidStmName;
+  return UNINITIALIZED_STMTNAMETYPE;
 }
 
-void Fusion::resetTvUses() {
-  FUSER_PERF_SCOPE("Fusion::resetTvUses");
-  is_during_update_uses_ = true;
+StmtNameType Fusion::registerLoweredVal(Val* val) {
+  TORCH_INTERNAL_ASSERT(val->fusion() == this);
+  TORCH_INTERNAL_ASSERT(!inFusion(val));
+  TORCH_INTERNAL_ASSERT(!inKernelIr(val));
+  lowered_val_set_.insert(val);
+  return getValName(*val->getValType());
+}
 
-  // getExprs only uses definition, so even if we've modified uses already to
-  // remove dead exprs, this could reinsert them. getExprs is also boundeds by
-  // inputs as registered inputs will return nullptr as their definition.
-  const auto all_tvs = ir_utils::filterByType<TensorView>(val_set_);
-  const auto used_exprs = ExprSort::getExprs(this);
+StmtNameType Fusion::registerLoweredExpr(Expr* expr) {
+  TORCH_INTERNAL_ASSERT(expr->fusion() == this);
+  TORCH_INTERNAL_ASSERT(!inFusion(expr));
+  TORCH_INTERNAL_ASSERT(!inKernelIr(expr));
 
-  for (auto tv : all_tvs) {
-    tv->setUses({});
+  for (Val* input : expr->inputs()) {
+    TORCH_CHECK(inKernelIr(input));
   }
 
-  // Same as in register expr
-  for (auto expr : used_exprs) {
-    for (Val* input : expr->inputs()) {
-      auto uses_copy = input->uses();
-      if (std::find(uses_copy.begin(), uses_copy.end(), expr) ==
-          uses_copy.end()) {
-        uses_copy.push_back(expr);
-        input->setUses(uses_copy);
-      }
-    }
+  for (Val* output : expr->outputs()) {
+    TORCH_CHECK(inKernelIr(output));
+    TORCH_CHECK(lowered_origin_.insert({output, expr}).second);
   }
 
-  all_tv_uses_valid_ = true;
-  is_during_update_uses_ = false;
+  lowered_expr_set_.insert(expr);
+  return getExprName();
+}
+
+bool Fusion::used(Val* val) const {
+  assertInFusion(val, "Cannot detect if val was used, ");
+  return (uses_.find(val) != uses_.end()) &&
+      (uses_.find(val)->second.size() > 0);
 }
 
 const std::unordered_set<Val*>& Fusion::vals() const noexcept {
@@ -500,65 +484,46 @@ const std::deque<Val*>& Fusion::deterministic_vals() const noexcept {
   return val_deque_;
 }
 
-std::vector<Val*> Fusion::usedMathVals() {
-  // Note that using fusion->inputs() as the argument for the first
-  // parameter of getAllValsBetween does not grab all used vals as
-  // there can be vals that are created inside a fusion without using
-  // anything from inputs. See, for example, tv0 in the
-  // FusionOuterSplit test.
-  const auto inputs = InputsOf::outputs(this, outputs());
-  auto used_math_vals = DependencyCheck::getAllValsBetween(
-      {inputs.begin(), inputs.end()}, outputs());
-  // When an expre has multiple outputs and only some of them are
-  // used, the rest aren't included in used_math_vals as they are not
-  // used. However, we want them to be included as they must show up
-  // in the fusion.
-  std::vector<Val*> vals_to_add;
-  std::unordered_set<Val*> added_vals;
-
-  for (auto val : used_math_vals) {
-    auto def = val->definition();
-    if (def == nullptr || def->outputs().size() < 2) {
-      continue;
-    }
-    for (auto out : def->outputs()) {
-      if (std::find(used_math_vals.begin(), used_math_vals.end(), out) ==
-          used_math_vals.end()) {
-        if (!added_vals.count(out)) {
-          vals_to_add.push_back(out);
-          added_vals.insert(out);
-        }
-      }
-    }
-  }
-
-  used_math_vals.insert(
-      used_math_vals.end(), vals_to_add.begin(), vals_to_add.end());
-
-  return used_math_vals;
-}
-
 const std::unordered_set<Expr*>& Fusion::unordered_exprs() const noexcept {
   return expr_set_;
 }
 
 std::unordered_set<Expr*> Fusion::unordered_uses(Val* val) const {
-  return std::unordered_set<Expr*>(val->uses().begin(), val->uses().end());
+  assertInFusion(val, "Cannot detect where val was used, ");
+  if (uses_.find(val) != uses_.end()) {
+    auto ret = uses_.find(val)->second;
+    return ret;
+  }
+  return std::unordered_set<Expr*>();
 }
 
-Expr* Fusion::definition(const Val* val) const {
-  assertInFusion(val, "Cannot detect the definition of val, ");
-  return val->definition();
+Expr* Fusion::origin(const Val* val) const {
+  // TODO(kir): remove the lowered branch
+  if (kir::isLoweredVal(val)) {
+    TORCH_INTERNAL_ASSERT(inKernelIr(val));
+    auto it = lowered_origin_.find(val);
+    return it != lowered_origin_.end() ? it->second : nullptr;
+  } else {
+    assertInFusion(val, "Cannot detect the origin of val, ");
+    auto it = origin_.find(val);
+    return it != origin_.end() ? it->second : nullptr;
+  }
 }
 
 bool Fusion::hasInput(const Val* val) const {
-  assertInFusion(val, "Cannot check if val is an input, ");
-  return val->isFusionInput();
+  return std::find(inputs_.begin(), inputs_.end(), val) != inputs_.end();
 }
 
 bool Fusion::hasOutput(const Val* val) const {
-  assertInFusion(val, "Cannot check if val is an output, ");
-  return val->isFusionOutput();
+  return std::find(outputs_.begin(), outputs_.end(), val) != outputs_.end();
+}
+
+void Fusion::replaceInput(Val* replace, Val* with) {
+  std::replace(inputs_.begin(), inputs_.end(), replace, with);
+}
+
+void Fusion::replaceOutput(Val* replace, Val* with) {
+  std::replace(outputs_.begin(), outputs_.end(), replace, with);
 }
 
 StmtNameType Fusion::getValName(ValType vtype) {
@@ -571,7 +536,7 @@ StmtNameType Fusion::getExprName() {
 
 // Indicate to kernel to set itself up to generate random numbers
 bool Fusion::isStochastic() {
-  for (auto expr : exprs())
+  for (auto expr : exprs(true))
     if (expr->getExprType() == ExprType::UnaryOp)
       if (expr->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::RandLike)
         return true;
@@ -581,7 +546,7 @@ bool Fusion::isStochastic() {
 bool Fusion::hasReduction() {
   FUSER_PERF_SCOPE("Fusion::hasReduction");
 
-  for (auto expr : exprs())
+  for (auto expr : exprs(true))
     for (auto out : expr->outputs())
       if (out->getValType() == ValType::TensorView)
         if (out->as<TensorView>()->hasReduction())
@@ -590,129 +555,75 @@ bool Fusion::hasReduction() {
   return false;
 }
 
-bool Fusion::hasWelford() {
-  FUSER_PERF_SCOPE("Fusion::hasWelford");
-  for (auto expr : exprs()) {
-    if (expr->isA<WelfordOp>()) {
-      return true;
-    }
-  }
+bool Fusion::hasBlockReduction() {
+  FUSER_PERF_SCOPE("Fusion::hasBlockReduction");
+
+  for (auto expr : exprs(true))
+    for (auto out : expr->outputs())
+      if (out->getValType() == ValType::TensorView)
+        if (out->as<TensorView>()->hasBlockReduction())
+          return true;
+
   return false;
 }
 
-std::vector<Val*> Fusion::getTerminatingOutputs() {
-  FUSER_PERF_SCOPE("getTerminatingOutputs");
+bool Fusion::hasGridReduction() {
+  FUSER_PERF_SCOPE("Fusion::hasGridReduction");
 
-  auto is_reachable_to_output = [](Val* val) {
-    // traverse to consumers of val and see if there is an output
-    std::deque<Val*> consumers;
-    for (auto use : val->uses()) {
-      for (auto consumer : use->outputs()) {
-        consumers.push_back(consumer);
-      }
-    }
-    while (!consumers.empty()) {
-      auto consumer = consumers.back();
-      consumers.pop_back();
-      if (consumer->isFusionOutput()) {
-        return true;
-      }
-      // consumer is not an output; proceed to its consumers
-      for (auto use : consumer->uses()) {
-        for (auto consumer_of_consumer : use->outputs()) {
-          consumers.push_back(consumer_of_consumer);
-        }
-      }
-    }
-    return false;
-  };
+  for (auto expr : exprs(true))
+    for (auto out : expr->outputs())
+      if (out->getValType() == ValType::TensorView)
+        if (out->as<TensorView>()->hasGridReduction())
+          return true;
 
-  std::vector<Val*> terminating_outputs;
+  return false;
+}
 
-  for (auto out : outputs()) {
-    // If there is another output reachable from this output, it's not
-    // terminating.
-    if (is_reachable_to_output(out)) {
-      continue;
+bool Fusion::hasBlockBroadcast() {
+  for (auto expr : exprs(true)) {
+    for (auto out : expr->outputs()) {
+      if (out->getValType() == ValType::TensorView) {
+        if (out->as<TensorView>()->hasBlockBroadcast()) {
+          return true;
+        }
+      }
     }
-    terminating_outputs.push_back(out);
   }
-
-  return terminating_outputs;
+  return false;
 }
 
-bool Fusion::isAliasCompatible(Val* left, Val* right) {
-  // Nullptr check
-  if (left == nullptr || right == nullptr) {
-    return false;
-  }
-
-  // DataType check
-  if (!left->getDataType().has_value() || !right->getDataType().has_value() ||
-      left->getDataType().value() != right->getDataType().value()) {
-    return false;
-  }
-
-  // ValType check
-  if (!left->getValType().has_value() || !right->getValType().has_value() ||
-      left->getValType().value() != right->getValType().value()) {
-    return false;
-  }
+bool Fusion::hasBroadcast() {
+  for (auto expr : exprs(true))
+    for (auto out : expr->outputs())
+      if (out->getValType() == ValType::TensorView)
+        if (out->as<TensorView>()->hasBroadcast())
+          return true;
 
-  // Check same number of dimensions if both values are TensorViews
-  if (ir_utils::isTV(left) && ir_utils::isTV(right)) {
-    return left->as<TensorView>()->nDims() == right->as<TensorView>()->nDims();
-  }
   return false;
 }
 
-void Fusion::aliasOutputToInput(Val* output, Val* input) {
-  TORCH_INTERNAL_ASSERT(
-      isAliasCompatible(input, output),
-      "The input and output values are not alias-compatible.");
-  io_alias_[output] = input;
-}
+std::vector<Val*> Fusion::getTerminatingOutputs() {
+  FUSER_PERF_SCOPE("getTerminatingOutputs");
 
-std::unordered_set<int> Fusion::getOutputAliasIndices() const {
-  if (io_alias_.empty()) {
-    return {};
-  }
+  FusionGuard fg(this);
 
-  std::unordered_set<int> alias_indices;
+  std::unordered_set<Val*> used_vals;
 
-  for (size_t i = 0; i < outputs_.size(); i++) {
-    if (io_alias_.count(outputs_[i]) != 0) {
-      alias_indices.insert(i);
-    }
-  }
-  return alias_indices;
-}
+  const auto exprs = ExprSort::getExprs(
+      this, std::vector<Val*>(outputs().begin(), outputs().end()));
 
-std::vector<std::pair<int, int>> Fusion::getInputAliasIndices() const {
-  if (io_alias_.empty()) {
-    return {};
+  for (auto expr : exprs) {
+    for (auto inp : expr->inputs())
+      used_vals.emplace(inp);
   }
 
-  std::vector<std::pair<int, int>> alias_indices;
-  for (size_t i = 0; i < outputs_.size(); i++) {
-    if (io_alias_.count(outputs_[i]) != 0) {
-      bool found = false;
-      for (size_t j = 0; j < inputs_.size(); j++) {
-        if (io_alias_.at(outputs_[i]) == inputs_[j]) {
-          alias_indices.emplace_back(i, j);
-          found = true;
-          break;
-        }
-      }
-      TORCH_INTERNAL_ASSERT(
-          found,
-          "io_alias_ mapping failure, alias output is not present in inputs");
-    }
+  std::vector<Val*> terminating_outputs;
+  for (auto out : outputs()) {
+    if (used_vals.find(out) != used_vals.end())
+      continue;
+    terminating_outputs.push_back(out);
   }
-  // can't assert here, we could have segmented fusion where not all alias
-  // outputs are present
-
-  return alias_indices;
+  return terminating_outputs;
 }
 
 } // namespace cuda
index f858c93..cf27ccb 100644 (file)
@@ -1,11 +1,9 @@
 #pragma once
 
-#include <ATen/core/ivalue.h>
 #include <c10/util/Exception.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
 
 #include <unordered_map>
 #include <unordered_set>
@@ -16,50 +14,42 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-//! Usage: FusionGuard and Fusion are required user interfaces for any operation
-//! underlying the code generator. In order to create values, expressions, and
-//! generate code a Fusion instance must be active. It is the responsibility of
-//! the user to create a Fusion instance and register it with the fusion guard.
-//! The simplest example of this is:
-//!
-//!     Fusion fusion;
-//!     FusionGuard fg(&fusion);
-//!
-//! Once a fusion is active all values and operations will be registered with
-//! it.
-//!
-//! FusionGuard and Fusion are critical to the lifetime model of the IR system.
-//! FusionGuard is a convenient way to set what base container instance holds
-//! the defined IR. Statements that are defined are registered through the
-//! FusionGuard with a particular Fusion. FusionGuard provides convenient
-//! methods to access the active fusion so it doesn't need to be passed around
-//! constantly. Any IR node derived classes from Statement must register with
-//! Fusion to avoid memory leaks.
-//!
-//! Fusion is generally thought of as a translated fusion group from the JIT. It
-//! is likely a single kernel, although, we don't have to stick to this in the
-//! future and could in theory generate multiple kernels with an executor to run
-//! them.
-//!
-//! Fusion also allows users to set input/output values that will allow us to
-//! figure out how to hook up runtime data to and from the JIT as well as
-//! provide us mechanisms for dependency analysis and DCE including safety
-//! checks.
+/*
+ * Usage: FusionGuard and Fusion are required user interfaces for any operation
+ * underlying the code generator. In order to create values, expressions, and
+ * generate code a Fusion instance must be active. It is the responsibility of
+ * the user to create a Fusion instance and register it with the fusion guard.
+ * The simplest example of this is: Fusion fusion; FusionGuard fg(&fusion); Once
+ * a fusion is active all values and operations will be registered with it.
+ *
+ * FusionGuard and Fusion are critical to the lifetime model of the IR system.
+ * FusionGuard is a convenient way to set what base container instance holds the
+ * defined IR. Statements that are defined are registered through the
+ * FusionGuard with a particular Fusion. FusionGuard provides convenient methods
+ * to access the active fusion so it doesn't need to be passed around
+ * constantly. Any IR node derived classes from Statement must register with
+ * Fusion to avoid memory leaks.
+ *
+ * Fusion is generally thought of as a translated fusion group from the JIT. It
+ * is likely a single kernel, although, we don't have to stick to this in the
+ * future and could in theory generate multiple kernels with an executor to run
+ * them.
+ *
+ * Fusion also allows users to set input/output values that will allow us to
+ * figure out how to hook up runtime data to and from the JIT as well as provide
+ * us mechanisms for dependency analysis and DCE including safety checks.
+ */
 
 class Fusion;
 class TensorView;
-class WelfordResult;
 
-class SegmentCandidateFinder;
-class SegmentedFusion;
-
-//! Fusion Guard is our "context manager". It holds the actrive fusion and
-//! allows it to be accessed anywhere through FusionGuard::getCurFusion()
+// Fusion Guard is our "context manager". It holds the actrive fusion and allows
+// it to be accessed anywhere through FusionGuard::getCurFusion().
 class TORCH_CUDA_CU_API FusionGuard {
  public:
   Fusion* prev_fusion;
 
-  //! Set the active fusion so it can be manipulated.
+  // Set the active fusion so it can be manipulated.
   explicit FusionGuard(Fusion* fusion);
 
   ~FusionGuard();
@@ -67,14 +57,15 @@ class TORCH_CUDA_CU_API FusionGuard {
   static Fusion* getCurFusion();
 };
 
-//! Fusion is mutable but unique. Nodes cannot be copied in any way from one
-//! Fusion to another. If anything like that is desired, it would require
-//! duplicating all associated values and exprs. Fusion is considered to SSA,
-//! though this could also change in the future if there is a good reason to do
-//! so.
-//!
-//! The Fusion owns the whole IR graph (Vals and Exprs)
-//!
+/*
+ * Fusion is mutable but unique. Nodes cannot be copied in any way from one
+ * Fusion to another. If anything like that is desired, it would require
+ * duplicating all associated values and exprs. Fusion is considered to SSA,
+ * though this could also change in the future if there is a good reason to do
+ * so.
+ *
+ * The Fusion owns the whole IR graph (Vals and Exprs)
+ */
 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
 class TORCH_CUDA_CU_API Fusion final {
  public:
@@ -92,118 +83,104 @@ class TORCH_CUDA_CU_API Fusion final {
 
   void clear() noexcept;
 
-  //! Break dependency chains associated with Expr, remove references to expr
-  //! delete expr
+  // Break dependency chains associated with Expr, remove references to expr
+  // delete expr.
   void removeExpr(Expr* expr);
 
-  //! Completely remove val from the fusion, break all dependencies associated
-  //! with it
+  // Completely remove val from the fusion, break all dependencies associated
+  // with it.
   void removeVal(Val* val);
 
-  //! Register input as an input of the fusion
-  // TODO: Rename to register
+  // Register input as an input of the fusion
   void addInput(Val* input);
 
-  //! Register output as an output of the fusion
-  // TODO: Rename to register
+  // Register output as an output of the fusion
   void addOutput(Val* output);
 
-  //! Register output as an output of the fusion
-  // TODO: Rename to register
-  void addOutput(WelfordResult& output);
-
-  //! Deregister input as an input of the fusion
-  // TODO: Rename to register
-  void removeInput(Val* input);
-
-  //! Deregister output as an output of the fusion
-  // TODO: Rename to register
-  void removeOutput(Val* output);
-
-  //! Replace output with another value
-  void replaceOutput(Val* output, Val* replacement);
-
-  //! Clear Expr's from TV uses that are not required to produce outputs from
-  //! inputs
-  void resetTvUses();
-
-  //! Check if stmt is properly registered with this fusion
+  // Check if stmt is properly registered with this fusion
   bool inFusion(const Statement* stmt) const;
 
-  //! Throw an error if stmt is not in this fusion
+  // Throw an error if stmt is not in this fusion. Message will be:
+  // msg + " it was not found in the active fusion."
   void assertInFusion(const Statement* stmt, const std::string& msg = "") const;
 
-  //! Assert that all leaves found from outputs are registered as an input
+  /*
+   * Return a list of topologically sorted expressions. We can start
+   * by only traversing back from registered outputs, or from all terminating
+   * Vals.
+   *
+   * from_outputs_only:
+   *   True - Sort from DAG associated with registered outputs
+   *   False - Sort from all terminating Vals.
+   */
+  std::vector<Expr*> exprs(bool from_outputs_only = false);
+
+  // Return a vector of fusion inputs that feed this Val
+  std::unordered_set<Val*> inputsOf(Val* val);
+
+  // Assert that all leaves found from outputs are registered as an input.
   void validateInputs();
 
-  //! Print this fusion to the console
+  // Print this fusion to cout.
   void print();
 
-  //! Print Arith exprs
-  //! \param from_outputs_only Only print exprs reachable from outputs
-  void printMath(bool from_outputs_only = true);
+  // Print Arith exprs used in outputs
+  void printMath();
 
-  //! Print transformations used in fusion (can be very verbose)
+  // Print transformations used in fusion (can be very verbose)
   void printTransforms();
 
-  //! Lower the fusion and print a kernel
+  // Lower the fusion and print a kernel
   void printKernel();
 
-  //! Register the Val with this fusion
+  // Register the Val with this fusion
   StmtNameType registerVal(Val* val);
 
-  //! Register expr with this fusion.
-  //! When we register an expression, we want to update the dependency tracking
-  //! of Vals. We add expr to our general expr_set_,
+  // Register expr with this fusion.
+  // When we register an expression, we want to update the dependency tracking
+  // of Vals. We add expr to our general expr_set_, we add use tracking for
+  // inputs and origin tracking for outputs.
   StmtNameType registerExpr(Expr* expr);
 
-  //! Register stmt with this fusion
+  // Register stmt with this fusion.
   StmtNameType registerStatement(Statement* stmt);
 
-  //! Return a list of topologically sorted expressions. This only includes
-  //! exprs required to genereate registered outputs.
-  std::vector<Expr*> exprs();
+  // Lowered nodes
+  // TODO(kir): to be removed
+  StmtNameType registerLoweredVal(Val* val);
+  StmtNameType registerLoweredExpr(Expr* expr);
 
-  //! Return a vector of fusion inputs that feed this Val
-  std::vector<Val*> inputsOf(Val* val);
+  // Lowered counterpart to inFusion()
+  // TODO(kir): to be removed
+  bool inKernelIr(const Statement* stmt) const;
 
-  //! Return the set of Vals registered with this fusion
-  const std::unordered_set<Val*>& vals() const noexcept;
+  // Check if val is used in this fusion. Not equivelent to DCE
+  bool used(Val* val) const;
 
-  //! Return in insertion order
+  // Return the set of Vals registered with this fusion
+  const std::unordered_set<Val*>& vals() const noexcept;
+  // Return in insertion order
   const std::deque<Val*>& deterministic_vals() const noexcept;
 
-  //! Return all Vals in math expressions that cannot be eliminated.
-  //!
-  //! It is generally equivalent to vals that are used to generate
-  //! outputs, however, when a multi-output expression exists, and only
-  //! some of the outputs are used, the remaining unused outputs are
-  //! also included as they must show up in the final code.
-  std::vector<Val*> usedMathVals();
-
-  //! Return the set of Exprs registered with this fusion. Warning: This will
-  //! return exprs outside inputs/outputs, so can be unsafe for use with
-  //! segmented fusions.
+  // Return the set of Exprs registered with this fusion
   const std::unordered_set<Expr*>& unordered_exprs() const noexcept;
 
-  //! Return all Exprs that use val
+  // Return all Exprs that use val
   std::unordered_set<Expr*> unordered_uses(Val* val) const;
 
-  //! Return the Expr that produces val
-  Expr* definition(const Val* val) const;
+  // Return the Expr that produces val
+  Expr* origin(const Val* val) const;
 
-  //! Indicate to kernel to set itself up to generate random numbers
+  // Indicate to kernel to set itself up to generate random numbers
   bool isStochastic();
 
-  //! Indicate that the fusion contains reduction operations
+  // TODO(kir): revisit to see how many of these are still needed
   bool hasReduction();
-
-  //! Indicate that the fusion contains welford operations
-  bool hasWelford();
-
-  //! Run fusion segmentation algorithm to create a segmented fusion
-  std::unique_ptr<SegmentedFusion> segment(
-      const at::ArrayRef<at::IValue>& inputs);
+  bool hasBlockReduction();
+  bool hasGridReduction();
+  bool hasBlockBroadcast();
+  bool hasBroadcast();
+  size_t gridReductionTempBufferSize();
 
   const auto& inputs() const {
     return inputs_;
@@ -218,31 +195,8 @@ class TORCH_CUDA_CU_API Fusion final {
   bool hasInput(const Val* val) const;
   bool hasOutput(const Val* val) const;
 
-  // Aliasing output to input value, this is a WAR to allow inplace update on
-  // input tensor.
-  // Note: this is not always safe and should be used with extra caution.
-  // Currently the only place it's used is in the running stats update for batch
-  // normalization.
-  // TODO: alias should be made aware to segmentation, so we'll always include
-  // the input tensor to the section where output is produced.
-  void aliasOutputToInput(Val* output, Val* input);
-  std::unordered_set<int> getOutputAliasIndices() const;
-  std::vector<std::pair<int, int>> getInputAliasIndices() const;
-
-  bool isTVUseInfoValid() {
-    return all_tv_uses_valid_;
-  }
-
-  bool isUpdatingTVUseInfo() {
-    return is_during_update_uses_;
-  }
-
- protected:
-  friend SegmentCandidateFinder;
-  friend SegmentedFusion;
-  friend class TranslateApplicableWelford;
-
-  static IrCloner copy(const Fusion* from, Fusion* to);
+  void replaceInput(Val* replace, Val* with);
+  void replaceOutput(Val* replace, Val* with);
 
  private:
   // Return an int that monotonically increases for each val/expr, some are
@@ -250,10 +204,6 @@ class TORCH_CUDA_CU_API Fusion final {
   StmtNameType getValName(ValType vtype);
   StmtNameType getExprName();
 
-  // Determine if the two values are compatible for aliasing
-  // Same DataType, ValType, and number of dimensions
-  bool isAliasCompatible(Val* left, Val* right);
-
  private:
   // Sets of all Vals/Exprs registered with this fusion
   // (val_deque_ is not owning the objects)
@@ -267,17 +217,18 @@ class TORCH_CUDA_CU_API Fusion final {
   // Expression names counter
   StmtNameType expr_name_counter_ = 0;
 
+  // Dependency tracking for Vals. Where did it come from? Where is it used?
+  std::unordered_map<const Val*, Expr*> origin_;
+  std::unordered_map<Val*, std::unordered_set<Expr*>> uses_;
+
   // Fusion inputs and outputs
   std::vector<Val*> inputs_;
   std::vector<Val*> outputs_;
 
-  // io alias pointing from output to input
-  std::unordered_map<Val*, Val*> io_alias_;
-
-  // Records if the current use data in the IR nodes are valid
-  //  the states are either all valid or all invalid
-  bool all_tv_uses_valid_ = false;
-  bool is_during_update_uses_ = false;
+  // Lowered IR
+  std::unordered_set<Val*> lowered_val_set_;
+  std::unordered_set<Expr*> lowered_expr_set_;
+  std::unordered_map<const Val*, Expr*> lowered_origin_;
 };
 
 } // namespace cuda
diff --git a/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp b/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp
deleted file mode 100644 (file)
index d780c72..0000000
+++ /dev/null
@@ -1,3010 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
-#include <torch/csrc/jit/codegen/cuda/ir_graphviz.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-
-#include <sstream>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup::getNeighborGroups() {
-  std::vector<NeighborGroup> neighbors;
-  for (auto inp : producer_edges) {
-    if (inp->val->isFusionOutput()) {
-      // Don't fuse across output nodes, would need to find another path.
-      continue;
-    }
-    neighbors.emplace_back(inp->from, inp);
-  }
-  for (auto out : consumer_edges) {
-    if (out->val->isFusionOutput()) {
-      // Don't fuse across output nodes, would need to find another path.
-      continue;
-    }
-    neighbors.emplace_back(out->to, out);
-  }
-  return neighbors;
-}
-
-std::vector<SegmentedGroup*> SegmentedGroup::getNeighbors() {
-  std::vector<SegmentedGroup*> neighbors;
-  auto neighbors_pair = getNeighborGroups();
-
-  std::transform(
-      neighbors_pair.begin(),
-      neighbors_pair.end(),
-      std::back_inserter(neighbors),
-      [](auto& neighbor_group) { return neighbor_group.group; });
-  return neighbors;
-}
-
-std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup::
-    getMergeCandidates() {
-  // Don't look for candidates if already merged
-  if (merged_) {
-    return {};
-  }
-
-  std::vector<NeighborGroup> neighbors = getNeighborGroups();
-
-  // Can this node be merged with another? Check if neighbors are merged, if
-  // so and merged neighbor is within 1 level or node merged with neighbor is
-  // within 1 level, can't merge this node with anything else.
-  bool can_merge_this = true;
-  for (auto& neighbor : neighbors) {
-    if (!neighbor.group->merged_) {
-      continue;
-    }
-    if (std::abs(neighbor.group->level_ - level_) <= 1) {
-      can_merge_this = false;
-    }
-    if (std::abs(neighbor.group->merge_with_->level_ - level_) <= 1) {
-      can_merge_this = false;
-    }
-  }
-  if (!can_merge_this) {
-    return {};
-  }
-
-  std::vector<bool> can_merge(true, neighbors.size());
-
-  // Find neighbors with a level that is only 1 differant than this groups level
-  for (size_t i = 0; i < neighbors.size(); i++) {
-    if (std::abs(neighbors[i].group->level_ - level_) > 1) {
-      can_merge[i] = false;
-    }
-  }
-
-  // Check neighbor of neighbors we're considering, if any of them are merged
-  // with another node, make sure the resulting edge wouldn't have a level
-  // difference of 1
-  for (size_t i = 0; i < neighbors.size(); i++) {
-    if (!can_merge[i]) {
-      continue;
-    }
-
-    for (auto neighbor_neighbor : neighbors[i].group->getNeighbors()) {
-      // Don't check self
-      if (neighbor_neighbor == neighbors[i].group) {
-        continue;
-      }
-      if (neighbor_neighbor->merged_) {
-        // check neighbor_neighbor level
-        if (std::abs(neighbor_neighbor->level_ - level_) <= 1) {
-          can_merge[i] = false;
-        }
-        if (std::abs(neighbor_neighbor->level_ - neighbors[i].group->level_) <=
-            1) {
-          can_merge[i] = false;
-        }
-
-        // check neighbor_neighber->merged_->level_
-        if (std::abs(neighbor_neighbor->merge_with_->level_ - level_) <= 1) {
-          can_merge[i] = false;
-        }
-        if (std::abs(
-                neighbor_neighbor->merge_with_->level_ -
-                neighbors[i].group->level_) <= 1) {
-          can_merge[i] = false;
-        }
-      }
-    }
-  }
-
-  std::vector<NeighborGroup> merge_candidates;
-  for (size_t i = 0; i < neighbors.size(); i++) {
-    if (can_merge[i]) {
-      merge_candidates.push_back(neighbors[i]);
-    }
-  }
-  return merge_candidates;
-}
-
-void SegmentedGroup::clearTraversalInfo() {
-  level_ = -1;
-  visited_ = false;
-  merge_with_ = nullptr;
-  merge_through_ = nullptr;
-  merged_ = false;
-}
-
-std::vector<Val*> SegmentedGroup::edgesToVals(
-    const std::vector<SegmentedEdge*>& se_v) {
-  std::vector<Val*> ret_v;
-  ret_v.reserve(se_v.size());
-
-  std::transform(
-      se_v.cbegin(),
-      se_v.cend(),
-      std::back_inserter(ret_v),
-      [](SegmentedEdge* se) { return se->val; });
-  return ret_v;
-}
-
-template <typename PREDICATE>
-void insertUniquePredicated(
-    std::vector<Val*>& v,
-    const std::vector<SegmentedEdge*>& e,
-    PREDICATE pred) {
-  std::unordered_set<Val*> to_add;
-  std::transform(
-      e.cbegin(),
-      e.cend(),
-      std::inserter(to_add, to_add.end()),
-      [](SegmentedEdge* se) { return se->val; });
-  std::copy_if(
-      to_add.begin(), to_add.end(), std::back_inserter(v), [pred](Val* val) {
-        return pred(val);
-      });
-}
-
-void SegmentedGroup::finalize() {
-  // Move all the edges to group input/output
-  // Inputs
-  insertUniquePredicated(
-      input_vals, producer_edges, [](Val* v) { return !v->isFusionInput(); });
-
-  std::unordered_set<Val*> input_set(input_vals.begin(), input_vals.end());
-
-  for (auto expr : exprs_) {
-    for (auto i : expr->inputs()) {
-      if (i->isAnInt() && i->definition() == nullptr && !i->isConstScalar() &&
-          !i->isFusionInput() && !input_set.count(i)) {
-        input_set.insert(i);
-        input_vals.push_back(i);
-      }
-    }
-  }
-
-  // Outputs
-  insertUniquePredicated(
-      output_vals, consumer_edges, [](Val* v) { return !v->isFusionOutput(); });
-
-  // alias aware segmentation. we add inputs that are aliased by output
-  // generated in this SegmentedGroup
-  for (auto output : output_vals) {
-    if (auto aliased_input = segmented_fusion_->findAlias(output)) {
-      // aliasing currently only supported as output to input
-      TORCH_INTERNAL_ASSERT(
-          aliased_input->isFusionInput(),
-          "aliased input is not found in the complete fusion");
-      if (!input_set.count(aliased_input)) {
-        input_set.insert(aliased_input);
-        input_vals.push_back(aliased_input);
-      }
-    }
-  }
-}
-
-std::ostream& operator<<(std::ostream& os, const SegmentedGroup* group) {
-  os << "g{";
-  auto expr_to_print = group->exprs();
-  std::sort(
-      expr_to_print.begin(),
-      expr_to_print.end(),
-      [](auto expr_a, auto expr_b) -> bool {
-        return expr_a->name() < expr_b->name();
-      });
-  for (size_t i = 0; i < expr_to_print.size(); i++) {
-    os << expr_to_print[i]->name();
-    if (i + 1 != expr_to_print.size())
-      os << ", ";
-  }
-  os << "}\n";
-  return os;
-}
-
-void SegmentedGroup::print() const {
-  std::cout << this << "\n";
-}
-
-std::string toString(const SegmentedGroup* group) {
-  std::stringstream ss;
-  ss << group;
-  return ss.str();
-}
-
-std::ostream& operator<<(std::ostream& os, const SegmentedEdge* edge) {
-  os << "e{ " << edge->from << " -> " << edge->to << "(";
-  IrPrinter irp(os);
-  irp.handle(edge->val);
-  os << ") }\n";
-  return os;
-}
-
-void SegmentedEdge::print() const {
-  std::cout << this << "\n";
-}
-
-std::string toString(const SegmentedEdge* edge) {
-  std::stringstream ss;
-  ss << edge;
-  return ss.str();
-}
-
-SegmentedFusion::SegmentedFusion(std::unique_ptr<Fusion> fusion)
-    : impl_(this), complete_fusion_(std::move(fusion)) {
-  segmented_fusion_name_ = segmentedFusionName();
-  annotateFP16IntermediateTensors();
-}
-
-SegmentedGroup* SegmentedFusion::Impl::makeGroup() {
-  groups_.emplace_back(std::make_unique<SegmentedGroup>(owning_fusion_));
-  return groups_.back().get();
-}
-
-SegmentedGroup* SegmentedFusion::Impl::makeGroup(Expr* expr) {
-  groups_.emplace_back(std::make_unique<SegmentedGroup>(expr, owning_fusion_));
-  return groups_.back().get();
-}
-
-SegmentedEdge* SegmentedFusion::Impl::makeEdge(
-    SegmentedGroup* from,
-    SegmentedGroup* to,
-    Val* val) {
-  edges_.emplace_back(std::make_unique<SegmentedEdge>(from, to, val));
-  return edges_.back().get();
-}
-
-void SegmentedFusion::Impl::cleanUnused() {
-  std::unordered_set<SegmentedGroup*> g_used(
-      owning_fusion_->groups().begin(), owning_fusion_->groups().end());
-  std::unordered_set<SegmentedEdge*> e_used(
-      owning_fusion_->edges().begin(), owning_fusion_->edges().end());
-
-  groups_.erase(
-      std::remove_if(
-          groups_.begin(),
-          groups_.end(),
-          [&g_used](auto& g) { return g_used.count(g.get()) == 0; }),
-      groups_.end());
-
-  edges_.erase(
-      std::remove_if(
-          edges_.begin(),
-          edges_.end(),
-          [&e_used](auto& e) { return e_used.count(e.get()) == 0; }),
-      edges_.end());
-}
-
-SegmentedGroup* SegmentedFusion::newGroup() {
-  SegmentedGroup* g = impl_.makeGroup();
-  groups_.push_back(g);
-  return g;
-}
-
-SegmentedGroup* SegmentedFusion::newGroup(Expr* expr) {
-  SegmentedGroup* g = impl_.makeGroup(expr);
-  groups_.push_back(g);
-  return g;
-}
-
-SegmentedEdge* SegmentedFusion::newEdge(
-    SegmentedGroup* from,
-    SegmentedGroup* to,
-    Val* val) {
-  SegmentedEdge* e = impl_.makeEdge(from, to, val);
-  edges_.push_back(e);
-  return e;
-}
-
-void SegmentedFusion::draw() {
-  size_t group_index = 0;
-  std::unordered_map<const Expr*, size_t> expr_color_map;
-
-  for (auto group : groups()) {
-    for (auto expr : group->exprs()) {
-      if (ir_utils::isTVOp(expr)) {
-        expr_color_map[expr] = group_index;
-      }
-    }
-    group_index++;
-  }
-
-  std::stringstream sstream;
-  sstream << "segmented_fusion" << segmented_fusion_name_ << ".dot";
-  auto filename = sstream.str();
-
-  IrGraphGenerator::print(
-      completeFusion(),
-      filename.c_str(),
-      IrGraphGenerator::DetailLevel::ComputeOnly,
-      &expr_color_map);
-}
-
-namespace {
-
-std::vector<Val*> uniqueValConcat(
-    const std::vector<std::vector<Val*>>& val_vecs) {
-  std::vector<Val*> unique_vals;
-  std::unordered_set<Val*> added;
-  for (const auto& vec : val_vecs) {
-    for (auto val : vec) {
-      if (added.find(val) == added.end()) {
-        unique_vals.push_back(val);
-        added.emplace(val);
-      }
-    }
-  }
-  return unique_vals;
-}
-
-// Concat's producer edges of sg1 and sg2, but removes any edges from/to sg1/sg2
-std::vector<SegmentedEdge*> getMergedProducerEdges(
-    const SegmentedGroup* sg1,
-    const SegmentedGroup* sg2) {
-  TORCH_INTERNAL_ASSERT(
-      sg1 != nullptr && sg2 != nullptr,
-      "This function doesn't handle trivial.");
-
-  auto producer_edges = sg1->producer_edges;
-
-  producer_edges.insert(
-      producer_edges.end(),
-      sg2->producer_edges.begin(),
-      sg2->producer_edges.end());
-
-  // Register producers into sg2
-  std::unordered_set<Val*> sg2_vals;
-  for (auto se : sg2->producer_edges) {
-    sg2_vals.emplace(se->val);
-  }
-
-  producer_edges.erase(
-      std::remove_if(
-          producer_edges.begin(),
-          producer_edges.end(),
-          [&sg1, &sg2, &sg2_vals](SegmentedEdge* se) {
-            // remove edges in between the groups and common uses
-            return (se->to == sg1 && se->from == sg2) ||
-                (se->to == sg2 && se->from == sg1) ||
-                (se->to == sg1 && sg2_vals.count(se->val));
-          }),
-      producer_edges.end());
-
-  // Remove Duplicate Edges
-
-  return producer_edges;
-}
-
-// Concat's consumer edges of sg1 and sg2, but removes any edges from/to sg1/sg2
-std::vector<SegmentedEdge*> getMergedConsumerEdges(
-    const SegmentedGroup* sg1,
-    const SegmentedGroup* sg2) {
-  TORCH_INTERNAL_ASSERT(
-      sg1 != nullptr && sg2 != nullptr,
-      "This function doesn't handle trivial.");
-
-  auto consumer_edges = sg1->consumer_edges;
-  consumer_edges.insert(
-      consumer_edges.end(),
-      sg2->consumer_edges.begin(),
-      sg2->consumer_edges.end());
-
-  consumer_edges.erase(
-      std::remove_if(
-          consumer_edges.begin(),
-          consumer_edges.end(),
-          [&sg1, &sg2](SegmentedEdge* se) {
-            return (se->to == sg1 && se->from == sg2) ||
-                (se->to == sg2 && se->from == sg1);
-          }),
-      consumer_edges.end());
-
-  return consumer_edges;
-}
-
-// Returns a determinstic, unique set of inputs of the segment group, sg1, or
-// the combined group sg1 + sg2
-std::vector<Val*> getAllInputs(
-    const SegmentedGroup* sg1,
-    const SegmentedGroup* sg2 = nullptr) {
-  std::vector<SegmentedEdge*> merged_producer_edges;
-
-  if (sg1 != nullptr && sg2 != nullptr) {
-    merged_producer_edges = getMergedProducerEdges(sg1, sg2);
-  } else if (sg1 != nullptr) {
-    merged_producer_edges = sg1->producer_edges;
-  } else if (sg2 != nullptr) {
-    merged_producer_edges = sg2->producer_edges;
-  }
-
-  std::vector<Val*> producer_edge_vals;
-
-  std::transform(
-      merged_producer_edges.begin(),
-      merged_producer_edges.end(),
-      std::back_inserter(producer_edge_vals),
-      [](SegmentedEdge* se) { return se->val; });
-
-  return uniqueValConcat(
-      {sg1 == nullptr ? std::vector<Val*>() : sg1->input_vals,
-       sg2 == nullptr ? std::vector<Val*>() : sg2->input_vals,
-       producer_edge_vals});
-}
-
-// Returns a determinstic, unique set of outputs of the segment group, sg1, or
-// the combined group sg1 + sg2
-std::vector<Val*> getAllOutputs(
-    const SegmentedGroup* sg1,
-    const SegmentedGroup* sg2 = nullptr) {
-  std::vector<SegmentedEdge*> merged_consumer_edges;
-
-  if (sg1 != nullptr && sg2 != nullptr) {
-    merged_consumer_edges = getMergedConsumerEdges(sg1, sg2);
-  } else if (sg1 != nullptr) {
-    merged_consumer_edges = sg1->consumer_edges;
-  } else if (sg2 != nullptr) {
-    merged_consumer_edges = sg2->consumer_edges;
-  }
-
-  std::vector<Val*> consumer_edge_vals;
-
-  std::transform(
-      merged_consumer_edges.begin(),
-      merged_consumer_edges.end(),
-      std::back_inserter(consumer_edge_vals),
-      [](SegmentedEdge* se) { return se->val; });
-
-  auto output_vals = uniqueValConcat(
-      {sg1 == nullptr ? std::vector<Val*>() : sg1->output_vals,
-       sg2 == nullptr ? std::vector<Val*>() : sg2->output_vals,
-       consumer_edge_vals});
-
-  return output_vals;
-}
-
-// Set version of getting merged input or output if segmented_groups were
-//  merged
-//  outputs respects order in segmented_groups for deterministic
-//  merge trace
-//  will get input if get_inputs otherwise will get ouputs
-//  TODO: merge with the binary counter parts
-std::vector<Val*> allInputsIfTrueElseOutputs(
-    const std::vector<SegmentedGroup*>& segmented_groups,
-    bool get_inputs = true) {
-  // Helper to distinguish if we are getting inputs or outputs
-  using EdgeVec = std::vector<SegmentedEdge*>;
-  using ValVec = std::vector<Val*>;
-
-  // Get producer edges to get inputs, consumer edges to get outputs
-  auto edges_to_process_from_or_to_group =
-      [get_inputs](SegmentedGroup* group) -> EdgeVec& {
-    return get_inputs ? group->producer_edges : group->consumer_edges;
-  };
-
-  // Get the group that is connected to current group
-  auto global_vals_from_or_to_group =
-      [get_inputs](SegmentedGroup* group) -> ValVec& {
-    return get_inputs ? group->input_vals : group->output_vals;
-  };
-
-  // Get the group that is connected to current group by given edge
-  auto opposite_end_of_edge = [get_inputs](SegmentedEdge* edge) {
-    return get_inputs ? edge->from : edge->to;
-  };
-
-  // Keep track of value and order to ensure deterministic result
-  std::vector<Val*> merged_vals;
-  std::unordered_set<Val*> merged_vals_set;
-
-  // Put groups in a set for quick look up
-  std::unordered_set<SegmentedGroup*> segmented_groups_set(
-      segmented_groups.begin(), segmented_groups.end());
-
-  // Collect vals associated with edges
-  for (auto group : segmented_groups) {
-    for (auto edge : edges_to_process_from_or_to_group(group)) {
-      if (
-          // Need to de-duplicate values so we don't get multiple of any input
-          !merged_vals_set.count(edge->val) &&
-          // One side of this edge will be `group`, if the other end is
-          //  also in segmented_groups, then this is an internal edge
-          //  that we don't want.
-          !segmented_groups_set.count(opposite_end_of_edge(edge))) {
-        merged_vals.push_back(edge->val);
-        merged_vals_set.insert(edge->val);
-      }
-    }
-  }
-
-  // Collect original fusion's inputs/outputs and append at the end
-  for (auto group : segmented_groups) {
-    for (auto global_val : global_vals_from_or_to_group(group)) {
-      // de-duplicate
-      if (!merged_vals_set.count(global_val)) {
-        merged_vals.push_back(global_val);
-        merged_vals_set.insert(global_val);
-      }
-    }
-  }
-
-  return merged_vals;
-}
-
-// A sorting utility used for debug printing only
-//  sorts the given vector of expressions in topological
-//  order, with equal cases respecting the original order
-//  in the vector.
-std::vector<Expr*> groupExprPrintSorting(const std::vector<Expr*>& exprs) {
-  std::vector<Expr*> exprs_to_print(exprs.begin(), exprs.end());
-  std::unordered_set<Expr*> exprs_to_print_set(exprs.begin(), exprs.end());
-  std::unordered_set<Expr*> exprs_visited;
-  std::vector<Expr*> sorted_list;
-  while (sorted_list.size() != exprs_to_print.size()) {
-    bool expr_added_to_sorted_list = false;
-    for (auto expr : exprs_to_print) {
-      if (!exprs_visited.count(expr)) {
-        bool add_this_expr = true;
-        // Check if any of the inputs of current
-        //  expression within the group
-        //  hasn't been visited
-        for (auto input : expr->inputs()) {
-          if (input->definition() &&
-              exprs_to_print_set.count(input->definition()) &&
-              !exprs_visited.count(input->definition())) {
-            add_this_expr = false;
-            break;
-          }
-        }
-
-        // Append the current group to sorted list
-        //  and mark visited
-        if (add_this_expr) {
-          expr_added_to_sorted_list = true;
-          exprs_visited.insert(expr);
-          sorted_list.push_back(expr);
-          break;
-        }
-      }
-    }
-    TORCH_INTERNAL_ASSERT(
-        expr_added_to_sorted_list,
-        "group debug print failed, exprs within given vector not a DAG");
-  }
-  return sorted_list;
-}
-
-// Utility function to list all expressions in a group
-void detailGroupPrint(std::ostream& os, const SegmentedGroup* group) {
-  IrPrinter irp(os);
-
-  auto sort_val_by_name = [](std::vector<Val*> vals_to_sort) {
-    std::sort(vals_to_sort.begin(), vals_to_sort.end(), [](Val* a, Val* b) {
-      return a->name() < b->name();
-    });
-    return vals_to_sort;
-  };
-
-  os << "g{"
-     << "(" << toString(group->heuristic()) << ")\n";
-  os << "inputs: \n";
-  for (auto input : sort_val_by_name(getAllInputs(group))) {
-    os << input << " " << input->getDataType().value() << "\n";
-  }
-  os << "outputs: \n";
-  for (auto output : sort_val_by_name(getAllOutputs(group))) {
-    os << output << " " << output->getDataType().value() << "\n";
-  }
-
-  os << "\n\n";
-
-  auto expr_to_print = groupExprPrintSorting(group->exprs());
-
-  for (size_t i = 0; i < expr_to_print.size(); i++) {
-    irp.handle(expr_to_print[i]);
-  }
-  os << "}\n\n";
-}
-
-//! Insert casts for an intermediate tensorview, i.e. ones
-//!  that are in segmentedEdges. The insertion is done on
-//!  the complete fusion, which should be owned by a segmented
-//!  fusion so that only one segmented fusion will be affected.
-//!  The replacement pattern is:
-//!                 TV0
-//!     replaced as:
-//!       fp16_tv = cast(TV0)
-//!       fp32_tv = cast(fp16_tv)
-//!
-//!  All segmented groups that take TV0 as input will then
-//!   take fp16_tv instead and the cast to fp32 will be
-//!   automatically included in each of the groups.
-TensorView* castIntermediateValueInCompleteFusion(
-    Fusion* fusion,
-    TensorView* original_tv) {
-  FusionGuard fg(fusion);
-
-  // A utility lambda that creates consumer tensordomain of
-  //  the given tv and create a new tensorview around the
-  //  new tensordomain with the given data type.
-  auto make_consumer_tv = [&](TensorView* from, DataType data_type) {
-    // Keep broadcast axes and remove reduction axes
-    size_t i = 0;
-    auto no_reduction_root_domain =
-        TensorDomain::noReductions(original_tv->getRootDomain());
-    std::vector<IterDomain*> new_root_domain(no_reduction_root_domain.size());
-    for (const auto& dom : no_reduction_root_domain) {
-      new_root_domain[i++] = dom->clone();
-    }
-
-    // Create the actual domain and tv.
-    return new TensorView(
-        new TensorDomain(
-            new_root_domain, std::vector<bool>(new_root_domain.size(), true)),
-        data_type);
-  };
-
-  // create the tv's to cast
-  auto fp16_tv = make_consumer_tv(original_tv, DataType::Half);
-  auto fp32_tv = make_consumer_tv(original_tv, DataType::Float);
-
-  // replace uses of original tv with fp32_tv in the complete
-  //  fusion
-  for (auto expr : fusion->unordered_uses(original_tv)) {
-    ir_utils::replaceValInExpr(expr, original_tv, fp32_tv);
-  }
-
-  // Insert the cast ops.
-  new UnaryOp(UnaryOpType::Cast, fp16_tv, original_tv);
-  new UnaryOp(UnaryOpType::Cast, fp32_tv, fp16_tv);
-
-  // Return the new tv to replace original tv with
-  //  on the segmented edges.
-  return fp16_tv;
-}
-
-} // namespace
-
-void SegmentedFusion::finalize() {
-  impl_.cleanUnused();
-
-  // Insert casts for the tensorviews that are on
-  //  segmented edges and also on the force_to_fp16 list
-  //
-  // Note:
-  //  The cast is inserted after the segmenter canSchedule check, which
-  //  shouldn't cause problem short-term. The reason we put the cast here
-  //  is  we don't want to keep making copies of the original fusion
-  //  during segmentation. Could consider making the cast insertion
-  //  reversible if we do have to test canSchedule with the casts inserted
-  //  during segmentation process in the future.
-
-  // Keep track of groups that need to update expr list,
-  //  including both the producer and consumer of the selected tv's that
-  //  we cast to fp16.
-  std::unordered_set<SegmentedGroup*> affected_group_set;
-
-  // A map to keep track of the tv's that have been inserted cast
-  //  and its fp16 version.
-  std::unordered_map<TensorView*, TensorView*> fp32_to_fp16_cast_map;
-
-  // Go through all edges of the segmented fusion.
-  for (auto edge : edges()) {
-    auto edge_tv = edge->val->as<TensorView>();
-    // Only look at ones that need to cast to fp16
-    if (force_fp16_tv_set_.count(edge_tv)) {
-      auto cast_tv_it = fp32_to_fp16_cast_map.find(edge->val->as<TensorView>());
-      TensorView* cast_tv = nullptr;
-      // Insert cast ops for this tv if we haven't done so.
-      if (cast_tv_it == fp32_to_fp16_cast_map.end()) {
-        cast_tv = castIntermediateValueInCompleteFusion(
-            complete_fusion_.get(), edge_tv);
-        fp32_to_fp16_cast_map[edge->val->as<TensorView>()] = cast_tv;
-      } else {
-        cast_tv = cast_tv_it->second;
-      }
-
-      // Update the edge to use the fp16 version
-      edge->val = cast_tv;
-
-      // Mark the groups for update later
-      affected_group_set.insert(edge->from);
-      affected_group_set.insert(edge->to);
-    }
-  }
-
-  // Reset expression lists of all affected groups
-  // TODO : this could have been a general operation that
-  //  the group supports. Could consider moving this into
-  //  segmentedGroup in a follow up.
-  for (auto group : affected_group_set) {
-    auto input_group_vec = getAllInputs(group);
-    std::unordered_set<Val*> input_group_set(
-        input_group_vec.begin(), input_group_vec.end());
-
-    auto expr_set = DependencyCheck::getAllExprsBetween(
-        input_group_set, getAllOutputs(group));
-    group->exprs_ = std::vector<Expr*>(expr_set.begin(), expr_set.end());
-  }
-}
-
-//! An utility class to compute and maintain the "producers of"
-//!   relationship in a segmented graph. Space heavy and should
-//!   avoid use on very large graphs.
-//!
-//!  Currently trying to move as far as possible with only a
-//!   producer map, without transposing it to make a consumer map.
-//!  Making it NonCopyable because we should never need to
-//!   copy an instance of this class.
-//!  TODO: Space efficiency of this class will be important,
-//!        because we need it in the pre-merging of segmentedGroups,
-//!        currently O(n^2). O(nlogn) would be a reasonable
-//!        goal to achieve.
-class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
-  using GroupSet = std::unordered_set<SegmentedGroup*>;
-  using GroupSetOwningPtr = std::unique_ptr<GroupSet>;
-  using DependencyMap = std::unordered_map<SegmentedGroup*, GroupSetOwningPtr>;
-
- public:
-  //! Populate producers of all groups in segmented fusion
-  explicit GroupDependencyAnalysis(const SegmentedFusion* segmented_fusion)
-      : segmented_fusion_(segmented_fusion) {
-    computeAllProducers();
-  }
-
-  //! Checks if group is consumer of any group in groups_to_check
-  //!  TODO: refactor this similar to isConsumerOf
-  bool isConsumerOfAny(
-      SegmentedGroup* group,
-      const std::vector<SegmentedGroup*>& groups_to_check) {
-    auto& producers_of_group = getAllKnownProducersSet(group);
-    for (const auto& potential_producer : groups_to_check) {
-      if (producers_of_group->count(potential_producer)) {
-        return true;
-      }
-    }
-    return false;
-  }
-
-  bool isConsumerOf(SegmentedGroup* a, SegmentedGroup* b) {
-    auto it = known_producers_of_.find(a);
-    if (it == known_producers_of_.end()) {
-      return false;
-    }
-    return it->second->count(b);
-  }
-
-  bool isProducerOf(SegmentedGroup* a, SegmentedGroup* b) {
-    return isConsumerOf(b, a);
-  }
-
-  //! Finds the common producers of given set of groups
-  GroupSet getCommonProducersOf(std::vector<SegmentedGroup*> groups);
-
-  //! Update the map when the given two groups have been merged to create `ab`
-  //! this method is for book keeping and query only, doesn't implicitly check
-  //!  for DAG
-  void mergeGroups(SegmentedGroup* a, SegmentedGroup* b, SegmentedGroup* ab);
-
-  //! Update the map when the given two groups have been merged to create
-  //! `merged` this method is for book keeping and query only, doesn't
-  //! implicitly check
-  //!  for DAG
-  void mergeGroups(const GroupSet& groups, SegmentedGroup* merged);
-
-  //! Populate all values that is on a path from producer to consumer
-  //!  efficiency can be important here. (TODO)
-  GroupSet valuesBetween(SegmentedGroup* producer, SegmentedGroup* consumer) {
-    if (producer == consumer) {
-      return {};
-    }
-
-    GroupSet values_between;
-    auto& all_producers_of_consumer = known_producers_of_.at(consumer);
-    TORCH_INTERNAL_ASSERT(
-        all_producers_of_consumer->count(producer),
-        "Fusion segment: Trying to compute path between two nodes that are not producer-consumer pairs");
-
-    std::copy_if(
-        all_producers_of_consumer->begin(),
-        all_producers_of_consumer->end(),
-        std::inserter(values_between, values_between.end()),
-        [this, producer](SegmentedGroup* producer_of_consumer) {
-          // Checks if producer is on the producer path of this intermediate
-          // node
-          return known_producers_of_.at(producer_of_consumer)->count(producer);
-        });
-
-    return values_between;
-  }
-
-  //! Checks if the segmented fusion this class tracks is still a DAG
-  //!  used for generating assertions after transforms
-  bool isproducerMapDAG() const {
-    for (auto& it : known_producers_of_) {
-      if (it.second->count(it.first)) {
-        return false;
-      }
-    }
-    return true;
-  }
-
- private:
-  //! Collect initial producer info using
-  //!  a work list algorithm through forward traversal
-  //!  a backward DFS would do the same
-  void computeAllProducers();
-
-  //! Add all consumers of `producer` to `to_visit`
-  void addConsumersToWorkList(SegmentedGroup* producer, GroupSet& to_visit) {
-    for (auto e : producer->consumer_edges) {
-      // A consumer wouldn't have been worked before any of its producer
-      to_visit.insert(e->to);
-    }
-  }
-
-  //! Propagate all known producers of `from` into `into`, used to keep track
-  //! of:
-  //!  1. `from` is a producer of `into`
-  //!  2. `from` has been merged with other group to create `into`
-  void mergeAllKnownProducersIntoFrom(
-      SegmentedGroup* into,
-      SegmentedGroup* from) {
-    auto& producer_set_to_merge = *getAllKnownProducersSet(from);
-    for (auto group : producer_set_to_merge) {
-      getAllKnownProducersSet(into)->insert(group);
-    }
-  }
-
-  //! Utility to access known producers of a group so far
-  GroupSetOwningPtr& getAllKnownProducersSet(SegmentedGroup* group) {
-    auto& producer_set_ptr = known_producers_of_[group];
-    if (!producer_set_ptr) {
-      producer_set_ptr = std::make_unique<GroupSet>();
-    }
-    return producer_set_ptr;
-  }
-
-  // utility to compute the set intersection of group sets a,b
-  GroupSet groupSetIntersection(const GroupSet& a, const GroupSet& b) {
-    bool a_is_smaller = a.size() < b.size();
-    const auto& smaller_group_set = a_is_smaller ? a : b;
-    const auto& bigger_group_set = a_is_smaller ? b : a;
-
-    GroupSet intersection;
-    for (auto group : smaller_group_set) {
-      if (bigger_group_set.count(group)) {
-        intersection.insert(group);
-      }
-    }
-    return intersection;
-  }
-
- private:
-  const SegmentedFusion* segmented_fusion_;
-  DependencyMap known_producers_of_;
-};
-
-//! Finds the common producers of given set of groups
-GroupDependencyAnalysis::GroupSet GroupDependencyAnalysis::getCommonProducersOf(
-    std::vector<SegmentedGroup*> groups) {
-  if (groups.empty()) {
-    return {};
-  }
-
-  // Optimization: start with the smallest producer set
-  std::sort(
-      groups.begin(),
-      groups.end(),
-      [this](SegmentedGroup* a, SegmentedGroup* b) {
-        return known_producers_of_.at(a)->size() <
-            known_producers_of_.at(b)->size();
-      });
-
-  // Get intersection of producers
-  GroupSet common_producers = *(known_producers_of_.at(groups[0]));
-  for (size_t i = 1; i < groups.size(); i++) {
-    common_producers = groupSetIntersection(
-        common_producers, *(known_producers_of_.at(groups[i])));
-  }
-
-  return common_producers;
-}
-
-//! Update the map when the given two groups have been merged to create `ab`
-//! this method is for book keeping and query only, doesn't implicitly check
-//!  for DAG
-void GroupDependencyAnalysis::mergeGroups(
-    SegmentedGroup* a,
-    SegmentedGroup* b,
-    SegmentedGroup* ab) {
-  // Access/Create the producer set of ab
-  auto& ab_set = getAllKnownProducersSet(ab);
-
-  // propagate a's and b's known producers into ab
-  mergeAllKnownProducersIntoFrom(ab, a);
-  mergeAllKnownProducersIntoFrom(ab, b);
-
-  // a, b are now merged, so no longer exist
-  ab_set->erase(a);
-  ab_set->erase(b);
-
-  // a, b no longer exist, remove their producer sets
-  known_producers_of_.erase(a);
-  known_producers_of_.erase(b);
-
-  // update producer maps of other groups
-  for (auto& it : known_producers_of_) {
-    // for all groups that are produced by either a or b
-    if (it.second->count(a) || it.second->count(b)) {
-      // insert ab as the new producer
-      it.second->insert(ab);
-      // all producers of both a and b are now producers of `it`
-      mergeAllKnownProducersIntoFrom(it.first, ab);
-    }
-    // a, b no longer exist, remove them from `it`
-    it.second->erase(a);
-    it.second->erase(b);
-  }
-}
-
-//! Update the map when the given two groups have been merged to create
-//! `merged` this method is for book keeping and query only, doesn't
-//! implicitly check
-//!  for DAG
-void GroupDependencyAnalysis::mergeGroups(
-    const GroupSet& groups,
-    SegmentedGroup* merged) {
-  // Access/Create the producer set of merged
-  auto& merged_set = getAllKnownProducersSet(merged);
-
-  // Populate all producers of groups and
-  //  write into producer map of merged
-  std::for_each(
-      groups.begin(), groups.end(), [this, merged](SegmentedGroup* group) {
-        mergeAllKnownProducersIntoFrom(merged, group);
-      });
-
-  // Erase all groups that was merged from producer map
-  std::for_each(
-      groups.begin(), groups.end(), [this, &merged_set](SegmentedGroup* group) {
-        // erase inter dependencies
-        merged_set->erase(group);
-        // erase producer map tracking merged entires
-        known_producers_of_.erase(group);
-      });
-
-  // Update producer relationships with other groups in producer map
-  for (auto& it : known_producers_of_) {
-    auto producer_intersection = groupSetIntersection(*(it.second), groups);
-    // if current node has any producer that was merged
-    if (producer_intersection.size() > 0) {
-      for (auto merged_producer : producer_intersection) {
-        // delete all disappearing producers
-        it.second->erase(merged_producer);
-      }
-      // insert the new group as producer
-      it.second->insert(merged);
-    }
-  }
-}
-
-//! Collect initial producer info using
-//!  a work list algorithm through forward traversal
-//!  a backward DFS would do the same
-void GroupDependencyAnalysis::computeAllProducers() {
-  GroupSet visited;
-  GroupSet to_visit;
-
-  // Collect source nodes, with no producers we are guaranteed
-  //  a source node on a DAG
-  std::copy_if(
-      segmented_fusion_->cgroups().begin(),
-      segmented_fusion_->cgroups().end(),
-      std::inserter(visited, visited.end()),
-      [](SegmentedGroup* group) { return group->producer_edges.empty(); });
-
-  // visited now only contain source nodes
-  //  they can go backward to nowhere
-  for (auto group : visited) {
-    addConsumersToWorkList(group, to_visit);
-  }
-
-  while (!to_visit.empty()) {
-    SegmentedGroup* to_update = nullptr;
-    for (auto visiting_group : to_visit) {
-      if (std::all_of(
-              visiting_group->producer_edges.begin(),
-              visiting_group->producer_edges.end(),
-              [&visited](SegmentedEdge* e) {
-                return visited.count(e->from);
-              })) {
-        // filter multi-edges
-        GroupSet producers_of_visiting_group;
-        for (auto edge : visiting_group->producer_edges) {
-          producers_of_visiting_group.insert(edge->from);
-        }
-
-        // populate all possible paths
-        // from producer backward, including
-        // the producer
-        for (auto producer : producers_of_visiting_group) {
-          getAllKnownProducersSet(visiting_group)->insert(producer);
-          mergeAllKnownProducersIntoFrom(visiting_group, producer);
-        }
-        to_update = visiting_group;
-        break;
-      }
-    }
-    if (to_update) {
-      addConsumersToWorkList(to_update, to_visit);
-      to_visit.erase(to_update);
-      visited.insert(to_update);
-    } else {
-      TORCH_INTERNAL_ASSERT(false, "unreachable, original graph not a DAG");
-    }
-  }
-}
-
-std::ostream& operator<<(
-    std::ostream& os,
-    const SegmentedFusion* segmented_fusion) {
-  // Topologically sort groups
-  GroupDependencyAnalysis dependency(segmented_fusion);
-  std::vector<SegmentedGroup*> groups_to_print(
-      segmented_fusion->cgroups().begin(), segmented_fusion->cgroups().end());
-  std::vector<SegmentedGroup*> sorted_groups_to_print;
-
-  // Sort groups topologically from producer to consumer before printing
-  while (!groups_to_print.empty()) {
-    auto group_it_to_append = groups_to_print.begin();
-    for (auto group_it_to_compare = groups_to_print.begin();
-         group_it_to_compare != groups_to_print.end();
-         group_it_to_compare++) {
-      if (dependency.isProducerOf(*group_it_to_compare, *group_it_to_append)) {
-        group_it_to_append = group_it_to_compare;
-      }
-    }
-    sorted_groups_to_print.push_back(*group_it_to_append);
-    groups_to_print.erase(group_it_to_append);
-  }
-
-  // Do a reverse look up to check the order of sorted groups
-  std::unordered_map<SegmentedGroup*, size_t> group_order;
-  for (size_t i = 0; i < sorted_groups_to_print.size(); i++) {
-    group_order[sorted_groups_to_print[i]] = i;
-  }
-
-  // Sort edges to print
-  std::vector<SegmentedEdge*> sorted_edges_to_print(
-      segmented_fusion->cedges().begin(), segmented_fusion->cedges().end());
-  std::sort(
-      sorted_edges_to_print.begin(),
-      sorted_edges_to_print.end(),
-      [&group_order](SegmentedEdge* edge_a, SegmentedEdge* edge_b) {
-        return group_order.at(edge_a->from) < group_order.at(edge_b->from);
-      });
-
-  os << "Segmented_Fusion{ \n";
-  os << "groups: \n";
-  for (const auto g : sorted_groups_to_print) {
-    os << g << "\n";
-  }
-  os << "edges: \n";
-  for (const auto e : sorted_edges_to_print) {
-    os << e << "\n";
-  }
-  os << "\ngroup details:\n";
-  for (const auto g : sorted_groups_to_print) {
-    detailGroupPrint(os, g);
-  }
-  os << "} //Segmented_Fusion\n";
-  return os;
-}
-
-void SegmentedFusion::print() const {
-  std::cout << this << "\n";
-}
-
-std::string toString(SegmentedFusion* segmented_fusion) {
-  std::stringstream ss;
-  ss << segmented_fusion;
-  return ss.str();
-}
-
-std::unique_ptr<Fusion> SegmentedFusion::makeFusion(SegmentedGroup* sg) {
-  std::unique_ptr<Fusion> fusion_segment = std::make_unique<Fusion>();
-
-  auto complete_to_segment_map =
-      Fusion::copy(completeFusion(), fusion_segment.get());
-
-  std::vector<Val*> input_list(
-      fusion_segment->inputs().begin(), fusion_segment->inputs().end());
-  for (auto inp : input_list) {
-    fusion_segment->removeInput(inp);
-  }
-
-  std::vector<Val*> output_list(
-      fusion_segment->outputs().begin(), fusion_segment->outputs().end());
-  for (auto out : output_list) {
-    fusion_segment->removeOutput(out);
-  }
-
-  for (auto inp : getAllInputs(sg)) {
-    fusion_segment->addInput(complete_to_segment_map.clone(inp));
-  }
-
-  for (auto out : getAllOutputs(sg)) {
-    fusion_segment->addOutput(complete_to_segment_map.clone(out));
-  }
-
-  return fusion_segment;
-}
-
-void SegmentCandidateFinder::resetTraversal() {
-  for (auto group : groups()) {
-    // Start traversal at input groups
-    if (group->producer_edges.empty()) {
-      to_visit_.push_back(group);
-    }
-    group->visited_ = false;
-    group->level_ = 0;
-  }
-}
-
-void SegmentCandidateFinder::resetLevels() {
-  while (!to_visit_.empty()) {
-    auto visit = to_visit_.front();
-    to_visit_.pop_front();
-
-    // All inputs processed?
-    bool ready = true;
-    if (!visit->producer_edges.empty()) {
-      ready = std::all_of(
-          visit->producer_edges.begin(),
-          visit->producer_edges.end(),
-          [&](SegmentedEdge* dep) { return dep->from->visited_; });
-    }
-
-    if (!ready) {
-      // In case traversal doesn't complete because there's an error in the
-      // DAG topology.
-      next_to_visit_.push_back(visit);
-      continue;
-    }
-
-    visit->visited_ = true;
-
-    to_visit_.insert(
-        to_visit_.end(), next_to_visit_.begin(), next_to_visit_.end());
-    next_to_visit_.clear();
-
-    for (auto out : visit->consumer_edges) {
-      to_visit_.push_back(out->to);
-    }
-
-    visit->level_ = 0;
-    for (auto inp : visit->producer_edges) {
-      visit->level_ = std::max(visit->level_, inp->from->level_ + 1);
-    }
-  }
-  TORCH_INTERNAL_ASSERT(
-      next_to_visit_.empty(), "Error in graph, is not a DAG.");
-}
-
-// Disconect group from neighbors, and return edges that were disconnected
-std::unordered_set<SegmentedEdge*> SegmentCandidateFinder::disconnectGroup(
-    SegmentedGroup* group) {
-  std::unordered_set<SegmentedEdge*> removed_edges(
-      group->producer_edges.begin(), group->producer_edges.end());
-
-  for (auto edge : group->producer_edges) {
-    auto from = edge->from;
-    auto& from_edges = from->consumer_edges;
-    auto from_edge_it = std::find(from_edges.begin(), from_edges.end(), edge);
-    TORCH_INTERNAL_ASSERT(
-        from_edge_it != from_edges.end(), "Could not find edge to remove.");
-    from_edges.erase(from_edge_it);
-  }
-
-  for (auto edge : group->consumer_edges) {
-    removed_edges.insert(edge);
-    auto to = edge->to;
-    auto& to_edges = to->producer_edges;
-    auto to_edge_it = std::find(to_edges.begin(), to_edges.end(), edge);
-    TORCH_INTERNAL_ASSERT(
-        to_edge_it != to_edges.end(), "Could not find edge to remove.");
-    to_edges.erase(to_edge_it);
-  }
-
-  group->producer_edges.clear();
-  group->consumer_edges.clear();
-
-  return removed_edges;
-}
-
-void SegmentCandidateFinder::eraseGroups(
-    std::unordered_set<SegmentedGroup*>& groups_to_erase) {
-  std::unordered_set<SegmentedEdge*> edges_to_erase;
-  for (auto group : groups_to_erase) {
-    auto disconnected_edges = disconnectGroup(group);
-    edges_to_erase.insert(disconnected_edges.begin(), disconnected_edges.end());
-  }
-
-  edges().erase(
-      std::remove_if(
-          edges().begin(),
-          edges().end(),
-          [&edges_to_erase](SegmentedEdge* edge) {
-            if (edges_to_erase.find(edge) != edges_to_erase.end()) {
-              return true;
-            };
-            return false;
-          }),
-      edges().end());
-
-  groups().erase(
-      std::remove_if(
-          groups().begin(),
-          groups().end(),
-          [&groups_to_erase](SegmentedGroup* group) {
-            if (groups_to_erase.find(group) != groups_to_erase.end()) {
-              return true;
-            };
-            return false;
-          }),
-      groups().end());
-}
-
-SegmentedGroup* SegmentCandidateFinder::mergeNodes() {
-  SegmentedGroup* last_merged = nullptr;
-  auto it = to_merge_.begin();
-  TORCH_INTERNAL_ASSERT(to_merge_.size() % 2 == 0);
-  while (it != to_merge_.end()) {
-    auto group1 = *it++;
-    auto group2 = *it++;
-
-    clean_up_groups_.emplace(group1);
-    clean_up_groups_.emplace(group2);
-
-    // Make the new joined node
-    auto joined_group = segmented_fusion_->newGroup();
-
-    joined_group->input_vals =
-        uniqueValConcat({group1->input_vals, group2->input_vals});
-
-    joined_group->output_vals =
-        uniqueValConcat({group1->output_vals, group2->output_vals});
-
-    joined_group->exprs_ = group1->exprs_;
-    joined_group->exprs_.insert(
-        joined_group->exprs_.end(),
-        group2->exprs_.begin(),
-        group2->exprs_.end());
-
-    auto producer_edges = getMergedProducerEdges(group1, group2);
-    // Connect joined group to resulting neighbors
-    for (auto edge : producer_edges) {
-      auto from = edge->from;
-      auto val = edge->val;
-
-      auto new_edge = segmented_fusion_->newEdge(from, joined_group, val);
-      joined_group->producer_edges.push_back(new_edge);
-      from->consumer_edges.push_back(new_edge);
-    }
-
-    auto consumer_edges = getMergedConsumerEdges(group1, group2);
-
-    for (auto edge : consumer_edges) {
-      auto to = edge->to;
-      auto val = edge->val;
-
-      auto new_edge = segmented_fusion_->newEdge(joined_group, to, val);
-      joined_group->consumer_edges.push_back(new_edge);
-      edge->to->producer_edges.push_back(new_edge);
-    }
-
-    joined_group->setHeuristic(deriveHeuristic(joined_group));
-    // Need to maintain the group dependency data if it has been intialized
-    //  by previous merging
-    if (group_dependency_) {
-      group_dependency_->as<GroupDependencyAnalysis>()->mergeGroups(
-          group1, group2, joined_group);
-    }
-    last_merged = joined_group;
-  }
-
-  to_merge_.clear();
-  for (auto group : clean_up_groups_) {
-    auto disconnected_edges = disconnectGroup(group);
-    clean_up_edges_.insert(
-        disconnected_edges.begin(), disconnected_edges.end());
-  }
-
-  edges().erase(
-      std::remove_if(
-          edges().begin(),
-          edges().end(),
-          [this](SegmentedEdge* edge) {
-            if (this->clean_up_edges_.find(edge) !=
-                this->clean_up_edges_.end()) {
-              return true;
-            };
-            return false;
-          }),
-      edges().end());
-
-  groups().erase(
-      std::remove_if(
-          groups().begin(),
-          groups().end(),
-          [this](SegmentedGroup* group) {
-            if (this->clean_up_groups_.find(group) !=
-                this->clean_up_groups_.end()) {
-              return true;
-            };
-            return false;
-          }),
-      groups().end());
-
-  clean_up_edges_.clear();
-  clean_up_groups_.clear();
-
-  return last_merged;
-}
-
-// Logic largely parallels mergeNodes, but they are used
-//  in different phases of segmentation. Should consider
-//  a clean up and share the implementations.
-SegmentedGroup* SegmentCandidateFinder::mergeAllGivenGroups(
-    const std::vector<SegmentedGroup*>& groups_to_merge) {
-  TORCH_INTERNAL_ASSERT(
-      !groups_to_merge.empty(),
-      "fusion segment :(mergeAllGivenGroups) tried to merge no groups")
-
-  // Make a set to detect internal edges
-  std::unordered_set<SegmentedGroup*> group_set(
-      groups_to_merge.begin(), groups_to_merge.end());
-
-  // Sets to de-duplicate multiple uses of
-  //  input/edge values and re-computations of exprs
-  std::unordered_set<Val*> used_edge_vals_set;
-  std::unordered_set<Val*> used_input_vals_set;
-  std::unordered_set<Expr*> exprs_set;
-
-  // Create new group
-  auto joined_group = segmented_fusion_->newGroup();
-
-  // Populate edges, exprs, global vals
-  //  from each of the groups
-  for (auto group : groups_to_merge) {
-    // Populate complete fusion inputs to the group
-    for (auto input_val : group->input_vals) {
-      if (!used_input_vals_set.count(input_val)) {
-        used_input_vals_set.insert(input_val);
-        joined_group->input_vals.push_back(input_val);
-      }
-    }
-
-    // Populate complete fusion outputs from the group
-    for (auto output_val : group->output_vals) {
-      joined_group->output_vals.push_back(output_val);
-    }
-
-    // Populate producer edges to the group
-    for (auto edge : group->producer_edges) {
-      if (
-          // Check this is not internal edge
-          !group_set.count(edge->from) &&
-          // Check this val has been added or not
-          !used_edge_vals_set.count(edge->val)) {
-        used_edge_vals_set.insert(edge->val);
-        auto new_producer_edge =
-            segmented_fusion_->newEdge(edge->from, joined_group, edge->val);
-        joined_group->producer_edges.push_back(new_producer_edge);
-        edge->from->consumer_edges.push_back(new_producer_edge);
-      }
-    }
-
-    // Populate consumer edges from the group
-    for (auto edge : group->consumer_edges) {
-      if (
-          // Check this is not internal edge
-          !group_set.count(edge->to)) {
-        auto new_consumer_edge =
-            segmented_fusion_->newEdge(joined_group, edge->to, edge->val);
-        joined_group->consumer_edges.push_back(new_consumer_edge);
-        edge->to->producer_edges.push_back(new_consumer_edge);
-      }
-    }
-
-    // Populate exprs
-    for (auto expr : group->exprs_) {
-      if (!exprs_set.count(expr)) {
-        joined_group->exprs_.push_back(expr);
-        exprs_set.insert(expr);
-      }
-    }
-  }
-
-  // Clean up original groups from segmented fusion
-  for (auto group : groups_to_merge) {
-    auto disconnected_edges = disconnectGroup(group);
-    clean_up_edges_.insert(
-        disconnected_edges.begin(), disconnected_edges.end());
-  }
-
-  edges().erase(
-      std::remove_if(
-          edges().begin(),
-          edges().end(),
-          [this](SegmentedEdge* edge) { return clean_up_edges_.count(edge); }),
-      edges().end());
-
-  groups().erase(
-      std::remove_if(
-          groups().begin(),
-          groups().end(),
-          [&group_set](SegmentedGroup* group) -> bool {
-            return group_set.count(group);
-          }),
-      groups().end());
-
-  clean_up_edges_.clear();
-
-  joined_group->setHeuristic(deriveHeuristic(joined_group));
-  return joined_group;
-}
-namespace {
-
-// Guard to temporarily change the inputs and outputs of a fusion. On
-// destruction will return fusion to original state.
-// Not used temporarily but will be useful when adding more mergin heuristics
-class FusionSegmentGuard : public NonCopyable {
- public:
-  FusionSegmentGuard() = delete;
-
-  FusionSegmentGuard(
-      Fusion* fusion,
-      std::vector<Val*> inputs,
-      std::vector<Val*> outputs)
-      : fusion_(fusion),
-        old_inputs_(fusion->inputs()),
-        old_outputs_(fusion->outputs()),
-        new_inputs_(std::move(inputs)),
-        new_outputs_(std::move(outputs)) {
-    TORCH_INTERNAL_ASSERT(fusion_ != nullptr);
-    for (auto old_inp : old_inputs_) {
-      fusion_->removeInput(old_inp);
-    }
-
-    for (auto old_out : old_outputs_) {
-      fusion_->removeOutput(old_out);
-    }
-
-    for (auto new_inp : new_inputs_) {
-      fusion_->addInput(new_inp);
-    }
-
-    for (auto new_out : new_outputs_) {
-      fusion_->addOutput(new_out);
-    }
-  }
-
-  ~FusionSegmentGuard() {
-    FUSER_PERF_SCOPE("~Segmenter::FusionSegmentGuard");
-
-    if (fusion_ == nullptr) {
-      return;
-    }
-    for (auto new_inp : new_inputs_) {
-      fusion_->removeInput(new_inp);
-    }
-
-    for (auto new_out : new_outputs_) {
-      fusion_->removeOutput(new_out);
-    }
-
-    for (auto old_inp : old_inputs_) {
-      fusion_->addInput(old_inp);
-    }
-
-    for (auto old_out : old_outputs_) {
-      fusion_->addOutput(old_out);
-    }
-  }
-
- private:
-  Fusion* const fusion_ = nullptr;
-  const std::vector<Val*> old_inputs_;
-  const std::vector<Val*> old_outputs_;
-  const std::vector<Val*> new_inputs_;
-  const std::vector<Val*> new_outputs_;
-};
-
-c10::optional<ScheduleHeuristic> tryMerge(
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    SegmentedGroup* a,
-    SegmentedGroup* b = nullptr) {
-  FusionSegmentGuard fsg(fusion, getAllInputs(a, b), getAllOutputs(a, b));
-
-  return SchedulerEntry::proposeHeuristics(fusion, runtime_info);
-}
-
-c10::optional<ScheduleHeuristic> tryMerge(
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    const std::vector<SegmentedGroup*>& segmented_groups) {
-  FusionSegmentGuard fsg(
-      fusion,
-      allInputsIfTrueElseOutputs(segmented_groups, true),
-      allInputsIfTrueElseOutputs(segmented_groups, false));
-  return SchedulerEntry::proposeHeuristics(fusion, runtime_info);
-}
-
-// This function is for cleanup and
-//  easier debugging. It shouldn't affect functionality
-//  since segmented fusions are compiled with fusion
-//  guard on the edges instead of actually looking
-//  at the exprs.
-void deDuplicateScalarExprs(std::vector<Expr*>& exprs) {
-  // Exprs in SegmentedGroup are not ordered
-  // so it is ok to insert them from unordered
-  // set
-  std::unordered_set<Expr*> scalar_expr_set;
-
-  std::copy_if(
-      exprs.begin(),
-      exprs.end(),
-      std::inserter(scalar_expr_set, scalar_expr_set.end()),
-      [](Expr* expr) { return ir_utils::isScalarOp(expr); });
-
-  if (!scalar_expr_set.empty()) {
-    exprs.erase(
-        std::remove_if(
-            exprs.begin(),
-            exprs.end(),
-            [&scalar_expr_set](Expr* expr) {
-              return scalar_expr_set.count(expr);
-            }),
-        exprs.end());
-    exprs.insert(exprs.end(), scalar_expr_set.begin(), scalar_expr_set.end());
-  }
-}
-
-} // namespace
-
-c10::optional<std::unique_ptr<SchedulerEntry>> SegmentedGroup::
-    getMaybeSchedulerEntry(SchedulerRuntimeInfo& runtime_info) {
-  FUSER_PERF_SCOPE("SegmentedGroup::getMaybeSchedulerEntry");
-  auto fusion = segmented_fusion_->completeFusion();
-  auto data_cache = segmented_fusion_->getCachedHeuristicDataFor(this);
-  FusionSegmentGuard fsg(fusion, getAllInputs(this), getAllOutputs(this));
-  if (!SchedulerEntry::canSchedule(
-          heuristic(), fusion, runtime_info, data_cache)) {
-    return c10::nullopt;
-  }
-  return SchedulerEntry::makeEntry(
-      heuristic(), fusion, runtime_info, data_cache);
-}
-
-// Custom merge node passes:
-//  These passes are added at the beginning or the end of
-//  the node merging process to direct the heuristics of
-//  node merging process
-//
-//  Should consider generalization and make a proper interface
-//   if we have more merge node heuristics like this
-
-//! Translate Welford
-//!
-//! This pass can be inserted at any stages of segmentation,
-//!  and it tries to replace welford ops with persistent
-//!  mean and var ops.
-//!
-//! The checking of feasibility of persistent kernels
-//!  is through normalization schedulers. The general idea
-//!  is to first try to translate on a copy, and see if
-//!  normalization scheduler is willing to produce a
-//!  persistent kernel.
-//!
-//! For complete fusion this pass checks if all the
-//!  welford ops can be translated simultaneously to
-//!  produce a persistent normalization kernel and
-//!  will perform translation if checks pass.
-//!
-//! For segmented fusion, same check is performed within
-//!  each segmented group to collect applicable welford ops,
-//!  and actual translations are performed on the complete
-//!  fusion after all the checks are done.
-class TranslateApplicableWelford {
- public:
-  //! Try translation on each segmented group of
-  //!  given segmented fusion
-  //!  returns true if any welford has been translated
-  static bool run(
-      SegmentedFusion* segmented_fusion,
-      const at::ArrayRef<IValue>& runtime_inputs) {
-    TranslateApplicableWelford translate_welford(
-        segmented_fusion, runtime_inputs);
-    return translate_welford.translated_any_welford_;
-  }
-
-  //! Try translation on complete fusion,
-  //!  returns true if any welford has been translated
-  static bool run(Fusion* fusion, const at::ArrayRef<IValue>& runtime_inputs) {
-    TranslateApplicableWelford translate_welford(fusion, runtime_inputs);
-    return translate_welford.translated_any_welford_;
-  }
-
- private:
-  explicit TranslateApplicableWelford(
-      SegmentedFusion* segmented_fusion,
-      const at::ArrayRef<IValue>& runtime_inputs);
-
-  explicit TranslateApplicableWelford(
-      Fusion* fusion,
-      const at::ArrayRef<IValue>& runtime_inputs);
-
-  //! Given vector of welford ops from the same fusion,
-  //!  checks if translating all of them result in a
-  //!  persistent normalization kernel by try-runs on
-  //!  a test copy of the original fusion.
-  //!
-  //! Supported use cases are either un-segmented fusion,
-  //!  or all the given welfords are within the same
-  //!  segmented group. In the latter case, the segmented
-  //!  group containing all the welford ops needs to be
-  //!  provided.
-  bool wouldTranslateToPersistent(
-      const std::vector<WelfordOp*>& orignal_welfords,
-      SegmentedGroup* group = nullptr);
-
-  //! Translate the given welford op into separate
-  //! average and standard deviation calculation.
-  void translateSingleWelford(WelfordOp* welford);
-
-  //! Utility to test if a translated fusion
-  //!  gives a persistent kernel. Uses normalization
-  //!  scheduler to do the test.
-  bool isValidPersistentFusion(
-      Fusion* translated_fusion,
-      SchedulerRuntimeInfo& runtime_info);
-
-  //! Update expression list of groups containing
-  //!  welford ops that have been translated.
-  void updateGroupExprs(SegmentedGroup* group);
-
- private:
-  //! Indicates any translation happened.
-  bool translated_any_welford_ = false;
-
-  //! a reference to global fusion runtime inputs
-  const at::ArrayRef<IValue>& runtime_inputs_;
-
-  //! For translation within group only,
-  //!  group boundary at test copy
-  //! (see wouldTranslateToPersistent implementation )
-  std::vector<Val*> test_group_inputs_;
-  std::vector<Val*> test_group_outputs_;
-};
-
-TranslateApplicableWelford::TranslateApplicableWelford(
-    Fusion* fusion,
-    const at::ArrayRef<IValue>& runtime_inputs)
-    : runtime_inputs_(runtime_inputs) {
-  std::vector<WelfordOp*> orignal_welfords(
-      ir_utils::filterByType<WelfordOp>(fusion->unordered_exprs()).begin(),
-      ir_utils::filterByType<WelfordOp>(fusion->unordered_exprs()).end());
-
-  if (wouldTranslateToPersistent(orignal_welfords)) {
-    for (auto welford : orignal_welfords) {
-      translateSingleWelford(welford);
-    }
-    translated_any_welford_ = true;
-  }
-}
-
-TranslateApplicableWelford::TranslateApplicableWelford(
-    SegmentedFusion* segmented_fusion,
-    const at::ArrayRef<IValue>& runtime_inputs)
-    : runtime_inputs_(runtime_inputs) {
-  std::vector<SegmentedGroup*> translated_groups;
-  std::vector<WelfordOp*> welford_to_translate;
-  // Find welfords that can be translated in each group
-  for (auto group : segmented_fusion->groups()) {
-    std::vector<WelfordOp*> welford_in_group(
-        ir_utils::filterByType<WelfordOp>(group->exprs()).begin(),
-        ir_utils::filterByType<WelfordOp>(group->exprs()).end());
-
-    if (wouldTranslateToPersistent(welford_in_group, group)) {
-      translated_groups.push_back(group);
-      welford_to_translate.insert(
-          welford_to_translate.end(),
-          welford_in_group.begin(),
-          welford_in_group.end());
-    }
-  }
-
-  // Actually translate the welford ops
-  // and record all the vals that have been
-  // replaced by the translation.
-  for (auto welford : welford_to_translate) {
-    translateSingleWelford(welford);
-  }
-
-  for (auto translated_group : translated_groups) {
-    // Update heuristics and expr list of translated groups
-    translated_group->heuristic_ = ScheduleHeuristic::Normalization;
-    updateGroupExprs(translated_group);
-  }
-}
-
-bool TranslateApplicableWelford::isValidPersistentFusion(
-    Fusion* translated_fusion,
-    SchedulerRuntimeInfo& runtime_info) {
-  if (!SchedulerEntry::canSchedule(
-          ScheduleHeuristic::Normalization, translated_fusion, runtime_info)) {
-    return false;
-  }
-
-  auto scheduler = SchedulerEntry::makeEntry(
-      ScheduleHeuristic::Normalization, translated_fusion, runtime_info);
-
-  return scheduler->reductionParams().persistent_kernel;
-}
-
-bool TranslateApplicableWelford::wouldTranslateToPersistent(
-    const std::vector<WelfordOp*>& orignal_welfords,
-    SegmentedGroup* group) {
-  if (orignal_welfords.empty()) {
-    return false;
-  }
-
-  // Make sure all welford ops come from the same complete fusion
-  auto fusion = orignal_welfords[0]->fusion();
-  TORCH_INTERNAL_ASSERT(
-      std::all_of(
-          orignal_welfords.begin(),
-          orignal_welfords.end(),
-          [fusion](WelfordOp* welford) { return welford->fusion() == fusion; }),
-      "Welfords in given vector not in the same fusion");
-
-  // Make initial `in-progress copy`
-  auto test_copy = std::make_unique<Fusion>();
-  auto original_to_test_map = Fusion::copy(fusion, test_copy.get());
-
-  std::vector<WelfordOp*> copied_welfords;
-  std::transform(
-      orignal_welfords.begin(),
-      orignal_welfords.end(),
-      std::back_inserter(copied_welfords),
-      [&original_to_test_map](auto welford) {
-        return original_to_test_map.clone(welford);
-      });
-
-  // Translate the welford ops
-  for (auto welford_to_translate : copied_welfords) {
-    translateSingleWelford(welford_to_translate);
-  }
-
-  SchedulerRuntimeInfo runtime_info(test_copy.get(), runtime_inputs_, true);
-  // If we are looking at a segment of fusion,
-  //  we maintain the segmented group boundary,
-  //  one set for in_progress copy and one set
-  //  for `test copy`
-  if (group != nullptr) {
-    auto original_inputs = getAllInputs(group);
-    auto original_outputs = getAllOutputs(group);
-    test_group_inputs_.clear();
-    test_group_outputs_.clear();
-    std::transform(
-        original_inputs.begin(),
-        original_inputs.end(),
-        std::back_inserter(test_group_inputs_),
-        [&original_to_test_map](Val* in) {
-          return original_to_test_map.clone(in);
-        });
-    std::transform(
-        original_outputs.begin(),
-        original_outputs.end(),
-        std::back_inserter(test_group_outputs_),
-        [&original_to_test_map](Val* out) {
-          return original_to_test_map.clone(out);
-        });
-
-    // Temporarily localize test copy around
-    //  the group boundary
-    FusionSegmentGuard fsg(
-        test_copy.get(), test_group_inputs_, test_group_outputs_);
-
-    // Test if the translated copy is persistent
-    return isValidPersistentFusion(test_copy.get(), runtime_info);
-  }
-  // In the case where we work on un-segmented
-  //  fusion, no group boundary logic, just
-  //  translate and test.
-  return isValidPersistentFusion(test_copy.get(), runtime_info);
-}
-
-void TranslateApplicableWelford::translateSingleWelford(WelfordOp* welford) {
-  auto fusion = welford->fusion();
-  FusionGuard fg(fusion);
-  // Only support translation of welford ops that
-  // doesn't take inputs that are already statistics,
-  // i.e. an r-factor product.
-  // This translation works on un-scheduled fusions so
-  //  shouldn't expect to see this.
-  TORCH_INTERNAL_ASSERT(welford->inN()->isOneInt());
-
-  // Grab the inputs and outputs of the welford
-  auto in_val = welford->in()->as<TensorView>();
-  auto out_avg = welford->outAvg()->as<TensorView>();
-  auto out_var = welford->outVar()->as<TensorView>();
-  auto out_N = welford->outN()->as<TensorView>();
-
-  fusion->removeExpr(welford);
-
-  // Create normalization based welford graph
-  //  largely taken from batchnorm cpp benchmark
-  auto& in_root = in_val->getRootDomain();
-  auto& out_root = out_avg->getRootDomain();
-  std::vector<int> red_axes;
-
-  // Create scalar version of the feature element
-  //  counting.
-  Val* num_features = new Double(1);
-  std::vector<bool> broadcast_mask(in_root.size(), false);
-  for (size_t i = 0; i < in_root.size(); i++) {
-    if (out_root[i]->isReduction()) {
-      red_axes.push_back(i);
-      broadcast_mask[i] = true;
-      num_features = mul(num_features, out_root[i]->extent());
-    }
-  }
-
-  // Build a normalization expression group that is
-  //  equivalent to a welford operation.
-  auto x_sum = sum(in_val, red_axes);
-  new BinaryOp(BinaryOpType::Div, out_avg, x_sum, num_features);
-  // welford.avg may be broadcast. Reuse it if found.
-  TensorView* x_avg_bcast = nullptr;
-  for (auto& use_expr : out_avg->uses()) {
-    if (auto bcast = dynamic_cast<BroadcastOp*>(use_expr)) {
-      if (bcast->getBroadcastDimFlags() == broadcast_mask) {
-        // Same broadcast found.
-        x_avg_bcast = bcast->out()->as<TensorView>();
-        break;
-      }
-    }
-  }
-
-  // x_mean_sub may already exist. Reuse it if found.
-  TensorView* x_mean_sub = nullptr;
-  if (x_avg_bcast != nullptr) {
-    for (auto& use_expr : x_avg_bcast->uses()) {
-      if (auto bop = dynamic_cast<BinaryOp*>(use_expr)) {
-        if (bop->getBinaryOpType() == BinaryOpType::Sub) {
-          if (bop->lhs() == in_val && bop->rhs() == x_avg_bcast) {
-            x_mean_sub = bop->out()->as<TensorView>();
-          }
-        }
-      }
-    }
-  }
-
-  if (x_avg_bcast == nullptr) {
-    x_avg_bcast = broadcast(out_avg, broadcast_mask);
-  }
-
-  if (x_mean_sub == nullptr) {
-    x_mean_sub = sub(in_val, x_avg_bcast);
-  }
-
-  auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub);
-  new ReductionOp(BinaryOpType::Add, new Double(0.0), out_var, x_mean_sub_pow);
-  new UnaryOp(UnaryOpType::Set, out_N, num_features);
-
-  // out_avg, out_N are now outputs of a pointwise ops and we
-  //  need to clear out its reduction domains.
-  out_avg->clearReductionIterDomains();
-  out_N->clearReductionIterDomains();
-}
-
-void TranslateApplicableWelford::updateGroupExprs(SegmentedGroup* group) {
-  // Re-evaluate expression list of the translated group
-  auto input_vec = getAllInputs(group);
-  auto output_vec = getAllOutputs(group);
-
-  if (input_vec.empty() || output_vec.empty()) {
-    return;
-  }
-
-  std::unordered_set<Val*> input_set(input_vec.begin(), input_vec.end());
-  auto expr_set = DependencyCheck::getAllExprsBetween(input_set, output_vec);
-  group->exprs_ = std::vector<Expr*>(expr_set.begin(), expr_set.end());
-}
-
-bool SegmentCandidateFinder::TranslateWelfordInFusion(
-    Fusion* fusion,
-    const at::ArrayRef<IValue>& runtime_inputs) {
-  return TranslateApplicableWelford::run(fusion, runtime_inputs);
-}
-
-//! CombineReductions:
-//!  This pass works before the main merge node process
-//!    It identifies reduction operations that can be combined
-//!    together to form a normalization kernel.
-//!  Two reductions are considered the same type if they have
-//!   the same root domain length, and the reduction axis are the same.
-//!   This pass tries to merge nodes with the same reduction type based
-//!   on the graph structure.
-class CombineReductions {
-  using GroupSet = std::unordered_set<SegmentedGroup*>;
-  using GroupVec = std::vector<SegmentedGroup*>;
-  class ReductionSignature;
-
- public:
-  static void run(SegmentCandidateFinder* segment_candidate_finder) {
-    CombineReductions combine_reductions(segment_candidate_finder);
-  }
-  static bool shouldRun(SegmentCandidateFinder* segment_candidate_finder);
-
- private:
-  CombineReductions(SegmentCandidateFinder* segment_candidate_finder)
-      : segment_candidate_finder_(segment_candidate_finder) {
-    // Run pass over the segments
-
-    // Collect segmented groups with reductions in them,
-    //  Assuming running before any merge happened, so
-    //  should see exactly one non-trivial reduction in each group
-    for (auto group : segment_candidate_finder_->groups()) {
-      if (auto rop_signature =
-              ReductionSignature::makeReductionSignature(group)) {
-        // Ignore pure squeeze operations in this analysis
-        if (!rop_signature->hasNonTrivialReduction()) {
-          continue;
-        }
-
-        groups_with_reductions_.push_back(group);
-        // Check if this reduction signature is one that we have seen before
-        auto signature_match_it = std::find_if(
-            known_reduction_signatures_.begin(),
-            known_reduction_signatures_.end(),
-            [&rop_signature](auto& know_signature) {
-              return know_signature->sameAs(rop_signature.get());
-            });
-        // Unmatched: Create a new signature entry if not known
-        if (signature_match_it == known_reduction_signatures_.end()) {
-          group_reduction_signature_map_[group] = rop_signature.get();
-          known_reduction_signatures_.emplace_back(std::move(rop_signature));
-        } else {
-          // Matched known signature: Mark that this groups belongs to know
-          // signature
-          group_reduction_signature_map_[group] = signature_match_it->get();
-        }
-      }
-    }
-
-    // Keep trying to merge groups with compatible reductions and compatible
-    // paths
-    //  until no more merge opportunity can be identified
-    bool merged_groups = true;
-    while (merged_groups) {
-      merged_groups = false;
-
-      // Merge one pair of reduction groups at a time, and need
-      //  the pass to update dependency info along the way to avoid cycles
-      for (size_t first_group_index = 0;
-           first_group_index < groups_with_reductions_.size();
-           first_group_index++) {
-        if (merged_groups) {
-          // Need to break and re-enter this loop because
-          // groups_with_reductions_ will be updated
-          break;
-        }
-
-        // Select one of the group to merge and get its reduction signature
-        auto first_group = groups_with_reductions_[first_group_index];
-        auto first_group_signature =
-            group_reduction_signature_map_.at(first_group);
-
-        for (size_t second_group_index = first_group_index + 1;
-             second_group_index < groups_with_reductions_.size();
-             second_group_index++) {
-          if (merged_groups) {
-            // Need to break and re-enter this loop because
-            // groups_with_reductions_ will be updated
-            break;
-          }
-          auto second_group = groups_with_reductions_[second_group_index];
-          auto second_group_signature =
-              group_reduction_signature_map_.at(second_group);
-
-          // Cannot merge if their signatures are not the same
-          if (!first_group_signature->sameAs(second_group_signature)) {
-            continue;
-          }
-
-          // first try a vertical merge
-          merged_groups =
-              verticalReductionMerge(first_group, second_group) != nullptr;
-          if (!merged_groups) {
-            // vertical merge didn't happen, try a horizontal merge
-            merged_groups =
-                horizontalReductionMerge(first_group, second_group) != nullptr;
-          }
-        }
-      }
-    }
-  }
-
-  //! Merge a vertical pair of producers and consumers,
-  //!  the resulting group will include all nodes that are
-  //!  also consumers of producer and producers of consumer,
-  //!  i.e. values between the given producer-consumer pair.
-  //!  Can be proven that:
-  //!   1. Including all of these nodes will be cycle-free
-  //!   2. These nodes are the minimal set of nodes to include if
-  //!  for producer-consumer pair to be in the same group cycle-free
-  //!
-  //!  Returns nullptr if such merge cannot be achieved.
-  //!  Reasons for not merging will include:
-  //!   1. Given groups do not form producer-consumer pair
-  //!   2. Merge will create cycle on the graph
-  //!   3. The merged joined group cannot be scheduled
-  SegmentedGroup* verticalReductionMerge(
-      SegmentedGroup* first_group,
-      SegmentedGroup* second_group) {
-    // This is part of ReductionCombine pass, and we should only call this
-    // function on a pair of
-    //  reduction/normalization groups
-    TORCH_INTERNAL_ASSERT(
-        group_reduction_signature_map_.at(first_group)
-            ->sameAs(group_reduction_signature_map_.at(second_group)));
-    TORCH_INTERNAL_ASSERT(first_group != second_group);
-    // Get the group dependency data from segment finder
-    auto dependency_analysis = segment_candidate_finder_->getGroupDependency();
-
-    // Check producer-consumer relationship
-    SegmentedGroup* producer = nullptr;
-    SegmentedGroup* consumer = nullptr;
-    if (dependency_analysis->isConsumerOf(first_group, second_group)) {
-      producer = second_group;
-      consumer = first_group;
-    } else if (dependency_analysis->isProducerOf(first_group, second_group)) {
-      producer = first_group;
-      consumer = second_group;
-    } else {
-      // Given groups aren't producer-consumer pair, won't merge
-      return nullptr;
-    }
-
-    // Collect all groups that we need to merge along with the producer and
-    // consumer
-    auto all_groups_to_merge =
-        getValidMinVerticalMergedGroupSet(producer, consumer);
-
-    if (all_groups_to_merge.empty()) {
-      // The vertical paths from producer to consumer have in-compatible
-      // reductions
-      //   so this vertical merge cannot be done.
-      return nullptr;
-    }
-
-    // TODO: this step would not be deterministic, because valuesBetween isn't
-    //       could fix this by a topological order
-    std::vector<SegmentedGroup*> all_groups_to_merge_vec(
-        all_groups_to_merge.begin(), all_groups_to_merge.end());
-
-    // Final sanity check: the merged group can actually be scheduled
-    Fusion* fusion =
-        segment_candidate_finder_->segmented_fusion_->completeFusion();
-    if (!tryMerge(
-            fusion,
-            segment_candidate_finder_->runtimeInfo(),
-            all_groups_to_merge_vec)) {
-      return nullptr;
-    }
-
-    // Merge this group
-    auto joined_group =
-        segment_candidate_finder_->mergeAllGivenGroups(all_groups_to_merge_vec);
-
-    // Update dependency analysis
-    dependency_analysis->mergeGroups(all_groups_to_merge, joined_group);
-
-    // Update the reduction groups that are merged
-    groups_with_reductions_.push_back(joined_group);
-    group_reduction_signature_map_[joined_group] =
-        group_reduction_signature_map_.at(first_group);
-    groups_with_reductions_.erase(
-        std::remove_if(
-            groups_with_reductions_.begin(),
-            groups_with_reductions_.end(),
-            [&all_groups_to_merge](SegmentedGroup* group) {
-              return all_groups_to_merge.count(group);
-            }),
-        groups_with_reductions_.end());
-
-    return joined_group;
-  }
-
-  //! Horizontal reduction merging:
-  //!  merge two horizontal groups with reduction expressions to make a joined
-  //!  normalization group. A pair of horizontal groups are ones that are not
-  //!  a producer-consumer pair, and share either a common producer or a common
-  //!  consumer.
-  //!
-  //!  TODO: This implementation looks at common producers only, since common
-  //!  consumers
-  //!          are not computed easily with current dependency analysis.
-  SegmentedGroup* horizontalReductionMerge(
-      SegmentedGroup* first_group,
-      SegmentedGroup* second_group) {
-    // This is part of ReductionCombine pass, and we should only call this
-    // function on a pair of
-    //  reduction/normalization groups
-    TORCH_INTERNAL_ASSERT(
-        group_reduction_signature_map_.at(first_group)
-            ->sameAs(group_reduction_signature_map_.at(second_group)));
-    TORCH_INTERNAL_ASSERT(first_group != second_group);
-
-    auto dependency_analysis = segment_candidate_finder_->getGroupDependency();
-
-    // Check that the two groups are not producer-consumer's
-    if (dependency_analysis->isConsumerOf(first_group, second_group) ||
-        dependency_analysis->isProducerOf(first_group, second_group)) {
-      // This merge pass will not handle producer-consumer pairs
-      return nullptr;
-    }
-
-    // Get common producers of the two group
-    auto common_producers_set =
-        dependency_analysis->getCommonProducersOf({first_group, second_group});
-    if (common_producers_set.empty()) {
-      // The given pair doesn't have a common producer.
-      //  Either they have a common consumer, which we don't handle for now,
-      //  or maybe the two given groups are not connected.
-      return nullptr;
-    }
-
-    // We are looking for a very specific patterns here. The cases that this
-    //  pattern will not capture are ones that reductions of different
-    //  signatures are so interleaved that we cannot find a clear cut as
-    //  explained below, without graph rewriting. Some graph re-writing on the
-    //  segmented groups level could provide extra merging opportunities for
-    //  free, which could be part of next step.
-    //
-    // The specific pattern we look for contains a common producer P with
-    // immediate consumers C1, C2 such that all paths from C1 to first_group and
-    // all paths from C2
-    //  to second_group won't hit a reduction with a different signature.
-
-    // Topologically sort the common producers and start with the topologically
-    // minimal,
-    //  i.e. one that are closest to the two groups. This will cut the search
-    //  space.
-    std::vector<SegmentedGroup*> common_producers(
-        common_producers_set.begin(), common_producers_set.end());
-    std::sort(
-        common_producers.begin(),
-        common_producers.end(),
-        [&dependency_analysis](SegmentedGroup* a, SegmentedGroup* b) {
-          return dependency_analysis->isConsumerOf(a, b);
-        });
-
-    // Use a visited filter to prune search space.
-    GroupSet visited_common_producers;
-
-    // Visit the common producers found, starting from topologically minimum,
-    // i.e. the ones closer to the groups
-    for (auto common_producer : common_producers) {
-      // Visit this common producer
-      // Use a double loop in case the schedulers like some patterns
-      //  better than the other
-      for (auto first_consumer_edge : common_producer->consumer_edges) {
-        auto producer_of_first_group = first_consumer_edge->to;
-        if (visited_common_producers.count(producer_of_first_group)) {
-          // We have visited this node as common producer before and it
-          //  had conflicts. It'd hit the same conflict again if we continued
-          //  to pursue this edge.
-          continue;
-        }
-        auto to_merge_with_first_group = getValidMinVerticalMergedGroupSet(
-            producer_of_first_group, first_group);
-        if (to_merge_with_first_group.empty()) {
-          // There's no valid merge path from this consumer of common producer,
-          //  either due to a conflicting reduction signature, or simply there's
-          //  no path to first group
-          continue;
-        }
-        for (auto second_consumer_edge : common_producer->consumer_edges) {
-          auto producer_of_second_group = second_consumer_edge->to;
-          if (visited_common_producers.count(producer_of_second_group)) {
-            // We have visited this node as common producer before and it
-            //  had conflicts. It'd hit the same conflict again if we continued
-            //  to pursue this edge.
-            continue;
-          }
-          auto to_merge_with_second_group = getValidMinVerticalMergedGroupSet(
-              producer_of_second_group, second_group);
-          if (to_merge_with_second_group.empty()) {
-            // There's no valid merge path from this consumer of common
-            // producer,
-            //  either due to a conflicting reduction signature, or simply
-            //  there's no path to second group
-            continue;
-          }
-
-          // At this point we should have a pair of valid candidates,final check
-          // is to see if the combined group
-          //  can be scheduled by schedulers
-          // merge the two paths and de-duplicate,
-          //  re-using container here with to_merge_with_second_group
-          auto& groups_to_merge_set = to_merge_with_second_group;
-          groups_to_merge_set.insert(
-              to_merge_with_first_group.begin(),
-              to_merge_with_first_group.end());
-          std::vector<SegmentedGroup*> groups_to_merge_vec(
-              groups_to_merge_set.begin(), groups_to_merge_set.end());
-          Fusion* fusion =
-              segment_candidate_finder_->segmented_fusion_->completeFusion();
-          if (tryMerge(
-                  fusion,
-                  segment_candidate_finder_->runtimeInfo(),
-                  groups_to_merge_vec)) {
-            // Found a valid horizontal merge, want to proceed with merging here
-            auto joined_group = segment_candidate_finder_->mergeAllGivenGroups(
-                groups_to_merge_vec);
-            dependency_analysis->mergeGroups(groups_to_merge_set, joined_group);
-
-            groups_with_reductions_.push_back(joined_group);
-            group_reduction_signature_map_[joined_group] =
-                group_reduction_signature_map_.at(first_group);
-            groups_with_reductions_.erase(
-                std::remove_if(
-                    groups_with_reductions_.begin(),
-                    groups_with_reductions_.end(),
-                    [&groups_to_merge_set](SegmentedGroup* group) {
-                      return groups_to_merge_set.count(group);
-                    }),
-                groups_with_reductions_.end());
-
-            return joined_group;
-          }
-        }
-      }
-      // Here we should have searched all consumer edges of this common producer
-      // and
-      //  found no valid pattern. Should just add it to the visted list.
-      visited_common_producers.insert(common_producer);
-    }
-
-    // Searched all possibilities and there is no valid horizontal merge pattern
-    //  found.
-    return nullptr;
-  }
-
-  //! This is a utility method that is used in both vertical merging and
-  //! horizontal merging.
-  //!  It is used to identify the smallest set of groups to merge vertically
-  //!  involving the
-  //!   two given nodes.
-  //!  Given a pair of nodes this utility distinguishes 3 cases:
-  //!   1. if maybe_producer is the same as maybe_consumer, then returns
-  //!   {maybe_producer}
-  //!   2. if maybe_producer is actually a producer of consumer, returns a set
-  //!   containing
-  //!     the smallest merged group that would contain producer and consumer and
-  //!     would not introduce a cycle. Returns empty set if such group has
-  //!     a conflicting reduction signature.
-  //!   3. returns empty set if neither conditions above apply.
-  GroupSet getValidMinVerticalMergedGroupSet(
-      SegmentedGroup* maybe_producer,
-      SegmentedGroup* maybe_consumer) {
-    auto dependency_analysis = segment_candidate_finder_->getGroupDependency();
-    if (maybe_consumer == maybe_producer) {
-      // maybe producer is the same as maybe_consumer
-      return {maybe_consumer};
-    } else if (dependency_analysis->isConsumerOf(
-                   maybe_consumer, maybe_producer)) {
-      auto groups_to_check =
-          dependency_analysis->valuesBetween(maybe_producer, maybe_consumer);
-      groups_to_check.insert(maybe_producer);
-      groups_to_check.insert(maybe_consumer);
-
-      // Check that either no group has a reduction or all groups have the same
-      // reduction signature
-      ReductionSignature* reduction_signature = nullptr;
-
-      // Iterate through the minimal group set to see if any conflicts
-      for (auto group : groups_to_check) {
-        // Check that this group does not involve a output edge contraction
-        //  This pass is intended to be a pre-merging pass. Since contracting an
-        //   output edge does not generate much saving of global memory access
-        //   we want to postpone merging these edges till the very final pass
-        for (auto producer_edge_of_group : group->producer_edges) {
-          if (groups_to_check.count(producer_edge_of_group->from) &&
-              producer_edge_of_group->val->isFusionOutput()) {
-            return {};
-          }
-        }
-        for (auto consumer_edge_of_group : group->consumer_edges) {
-          if (groups_to_check.count(consumer_edge_of_group->to) &&
-              consumer_edge_of_group->val->isFusionOutput()) {
-            return {};
-          }
-        }
-
-        // Check that this group does not have a conflicting reduction signature
-        if (group_reduction_signature_map_.count(group)) {
-          if (reduction_signature != nullptr) {
-            if (!group_reduction_signature_map_.at(group)->sameAs(
-                    reduction_signature)) {
-              // Found a conflict in reduction signature, cannot do a vertical
-              // merge
-              return {};
-            }
-          } else {
-            reduction_signature = group_reduction_signature_map_.at(group);
-          }
-        }
-      }
-      return groups_to_check;
-    }
-    // maybe producer is not a producer of maybe consumer
-    return {};
-  }
-
- private:
-  SegmentCandidateFinder* segment_candidate_finder_;
-
-  // Wrapper class for reduction type
-  //  Assuming there wouldn't be too many of them
-  //  so won't need to create a hash
-  // TODO:
-  //   Want to reconsider this for transpose operations,
-  //   need refactoring to handle reduction fusions across a transpose operation
-  class ReductionSignature {
-   public:
-    bool sameAs(const ReductionSignature* reduction_signature) {
-      if (reduction_signature == this) {
-        return true;
-      }
-
-      if (root_domain_size_ != reduction_signature->root_domain_size_ ||
-          has_nontrivial_reduction_ !=
-              reduction_signature->has_nontrivial_reduction_ ||
-          reduction_axes_.size() !=
-              reduction_signature->reduction_axes_.size()) {
-        return false;
-      }
-
-      for (size_t i = 0; i < reduction_axes_.size(); i++) {
-        if (reduction_axes_[i] != reduction_signature->reduction_axes_[i]) {
-          return false;
-        }
-      }
-
-      return true;
-    }
-
-    bool sameAs(const ReductionSignature& reduction_signature) {
-      return sameAs(&reduction_signature);
-    }
-
-    bool hasNonTrivialReduction() const {
-      return has_nontrivial_reduction_;
-    }
-
-    static std::unique_ptr<ReductionSignature> makeReductionSignature(
-        SegmentedGroup* group) {
-      std::unique_ptr<ReductionSignature> signature = nullptr;
-
-      for (auto expr : group->exprs()) {
-        std::unique_ptr<ReductionSignature> new_signature = nullptr;
-
-        if (auto rop = dynamic_cast<ReductionOp*>(expr)) {
-          new_signature = std::make_unique<ReductionSignature>(rop);
-        }
-        if (auto wop = dynamic_cast<WelfordOp*>(expr)) {
-          new_signature = std::make_unique<ReductionSignature>(wop);
-        }
-
-        if (new_signature != nullptr) {
-          TORCH_INTERNAL_ASSERT(
-              signature == nullptr || !signature->has_nontrivial_reduction_ ||
-                  !new_signature->has_nontrivial_reduction_ ||
-                  signature->sameAs(new_signature.get()),
-              "Conflicting signature found in this group");
-          signature = std::move(new_signature);
-        }
-      }
-      return signature;
-    }
-
-    template <typename REDUCTION = ReductionOp>
-    ReductionSignature(REDUCTION* rop) {
-      auto out_tv = rop->out()->template as<TensorView>();
-      has_nontrivial_reduction_ = out_tv->hasReduction();
-      TORCH_INTERNAL_ASSERT(out_tv != nullptr);
-      auto& root_domain = out_tv->getRootDomain();
-      root_domain_size_ = root_domain.size();
-
-      // Trivial reduction i.e. squeeze is tricky here:
-      //  this pass doesn't want to touch any pure squeeze, i.e.:
-      //    T0 [R(1), I(i0), I(i1)]
-      //  meanwhile, for two reductions having
-      //  squeezes, we do require they have squeeze at the
-      //  same position so that they can be easily root domain mapped
-      //  So T0 and T1 are the same signature,
-      //    T0 [R(1), R(i0), I(i1)]
-      //    T1 [R(1), R(i0), I(i1)]
-      //  but T2 and T3 below are not
-      //    T0 [R(1), R(1), R(i0), I(i1)]
-      //    T1 [R(1), R(i0), I(i1)]
-      for (size_t i = 0; i < root_domain_size_; i++) {
-        if (root_domain[i]->isReduction()) {
-          reduction_axes_.push_back(i);
-        }
-        if (!root_domain[i]->isTrivialReduction()) {
-          has_nontrivial_reduction_ = true;
-        }
-      }
-    }
-
-   private:
-    size_t root_domain_size_ = 0;
-    std::vector<int> reduction_axes_;
-    bool has_nontrivial_reduction_ = false;
-  };
-
-  //! Keeps track of groups with reduction expressions,
-  //!  using a vector here to maintain a deterministic ordering
-  GroupVec groups_with_reductions_;
-
-  //! Maps groups to their corresponding signature type
-  std::unordered_map<SegmentedGroup*, ReductionSignature*>
-      group_reduction_signature_map_;
-
-  //! Maintains all reduction signatures seen in the segmented fusion
-  std::vector<std::unique_ptr<ReductionSignature>> known_reduction_signatures_;
-};
-
-//! This is to be checked
-bool CombineReductions::shouldRun(
-    SegmentCandidateFinder* segment_candidate_finder) {
-  std::vector<std::unique_ptr<ReductionSignature>> known_reductions;
-  // Iterate over group segments we have before segment candidate finder
-  //  tries to merge any groups
-  for (auto group : segment_candidate_finder->groups()) {
-    if (auto reduction_signature =
-            ReductionSignature::makeReductionSignature(group)) {
-      if (reduction_signature->hasNonTrivialReduction() &&
-          std::any_of(
-              known_reductions.begin(),
-              known_reductions.end(),
-              [&reduction_signature](auto& know_signature) {
-                return know_signature->sameAs(reduction_signature.get());
-              })) {
-        // Found two reductions with the same signature, run pass
-        return true;
-      }
-      known_reductions.emplace_back(std::move(reduction_signature));
-    }
-  }
-  return false;
-}
-
-bool SegmentCandidateFinder::codeGenSupportedMerge(SegmentedEdge* edge) {
-  Fusion* fusion = segmented_fusion_->completeFusion();
-  auto h = tryMerge(fusion, runtime_info_, edge->from, edge->to);
-  return h.has_value();
-}
-
-// TODO: consider caching the heuristics value so tryMerge doesn't have to be
-//       called twice
-ScheduleHeuristic SegmentCandidateFinder::deriveHeuristic(
-    SegmentedGroup* group) {
-  Fusion* fusion = segmented_fusion_->completeFusion();
-  auto h = tryMerge(fusion, runtime_info_, group);
-  TORCH_INTERNAL_ASSERT(h.has_value());
-  return h.value();
-}
-
-SegmentCandidateFinder::SegmentCandidateFinder(
-    std::unique_ptr<Fusion> fusion,
-    const at::ArrayRef<IValue>& inputs,
-    SegmentCandidateFinderOptions options)
-    : options_(options),
-      runtime_info_(fusion.get(), inputs, true),
-      runtime_inputs_(inputs) {
-  segmented_fusion_ = std::make_unique<SegmentedFusion>(std::move(fusion));
-  findSegments();
-}
-
-void SegmentCandidateFinder::findSegments() {
-  FUSER_PERF_SCOPE("Finding valid fusion segment solutions");
-  // TODO: Make traversal items local to this function.
-
-  // Need this for initialization of the DAG that is process
-  std::unordered_map<Expr*, SegmentedGroup*> expr2group;
-
-  // Keep track of complete fusion input use
-  std::unordered_map<Val*, SegmentedGroup*> input2group;
-
-  // Initialize DAG, convert each expr to a segment group
-  auto exprs = completeFusion()->exprs();
-  for (auto expr : exprs) {
-    if (!ir_utils::isScalarOp(expr)) {
-      auto new_group = segmented_fusion_->newGroup(expr);
-      expr2group.insert(std::make_pair(expr, new_group));
-    }
-  }
-
-  // Insert auxiliary groups to use group dependency on inputs as well
-  // TODO: these groups should never merged into any other groups, but are
-  //       just there to support the dependency analysis. Later re-factor should
-  //       avoid introducing them explicitly on the segmented fusion.
-  for (auto input : completeFusion()->inputs()) {
-    // These groups are used to represent input as a common
-    //  producer in horizontal merges, and should never be
-    //  seen as a candidate for vertical merge
-    auto new_group = segmented_fusion_->newGroup();
-    input2group.insert({input, new_group});
-  }
-
-  // Create edges between the Exprs. Mark inputs and outputs of the fusion.
-  for (auto expr : exprs) {
-    // No group created for scalar ops
-    if (ir_utils::isScalarOp(expr)) {
-      continue;
-    }
-
-    auto expr_group = expr2group.at(expr);
-    for (auto inp : expr->inputs()) {
-      if (inp->isFusionInput()) {
-        expr_group->input_vals.push_back(inp);
-        auto aux_group = input2group.at(inp);
-        auto new_edge = segmented_fusion_->newEdge(aux_group, expr_group, inp);
-        expr_group->producer_edges.push_back(new_edge);
-        aux_group->consumer_edges.push_back(new_edge);
-        continue;
-      }
-
-      // Could be something like a constant scalar, definition is nullptr, but
-      // isn't an "input" to the fusion. At least not one provided by an
-      // external source.
-      if (inp->definition() == nullptr) {
-        continue;
-      }
-
-      // No group created for scalar ops since they may need to be duplicated
-      //  to avoid scalar edges. They are handled in resolveScalarsInGroup
-      if (inp->isScalar()) {
-        continue;
-      }
-
-      auto def_group = expr2group.at(inp->definition());
-      auto new_edge = segmented_fusion_->newEdge(def_group, expr_group, inp);
-      expr_group->producer_edges.push_back(new_edge);
-      def_group->consumer_edges.push_back(new_edge);
-    }
-    for (auto out : expr->outputs()) {
-      if (out->isFusionOutput()) {
-        expr_group->output_vals.push_back(out);
-      }
-    }
-  }
-
-  if (options_.run_translate_welford &&
-      segmented_fusion_->completeFusion()->hasWelford()) {
-    TranslateApplicableWelford::run(segmented_fusion_.get(), runtime_inputs_);
-  }
-
-  for (auto group : groups()) {
-    // Set heuristics in case single reduction kernels were left out
-    group->setHeuristic(deriveHeuristic(group));
-  }
-
-  // Remove all scalar edges since they do not represent actual
-  //  dependency among segmented groups.
-  removeScalarEdges();
-
-  // Run pre-merge heuristics
-  if (options_.run_combine_reductions && CombineReductions::shouldRun(this)) {
-    CombineReductions::run(this);
-  }
-
-  // All merges will be vertical beyond this point for now, so
-  //  we can remove the input auxiliary groups. Should make the vertical
-  //  merges avoid auxiliary group once we start general horizontal merges
-  std::unordered_set<SegmentedGroup*> input_groups;
-  for (auto input : completeFusion()->inputs()) {
-    input_groups.insert(input2group.at(input));
-  }
-  eraseGroups(input_groups);
-
-  if (options_.run_herrmann_merge) {
-    bool merged_nodes = true;
-    // Initial merge iteration
-    while (merged_nodes) {
-      // Reset stateful traversal details in SegmentedGroups
-      resetTraversal();
-
-      resetLevels();
-
-      for (auto& group : groups()) {
-        if (group->merged_) {
-          continue;
-        }
-        auto candidates = group->getMergeCandidates();
-        if (candidates.empty()) {
-          continue;
-        }
-
-        auto candidate_it = candidates.begin();
-        while (candidate_it != candidates.end() &&
-               !codeGenSupportedMerge(candidate_it->edge)) {
-          candidate_it++;
-        }
-        if (candidate_it == candidates.end()) {
-          continue;
-        }
-
-        to_merge_.emplace_back(group);
-        to_merge_.emplace_back(candidate_it->group);
-
-        group->merged_ = true;
-        group->merge_with_ = candidate_it->group;
-        group->merge_through_ = candidate_it->edge;
-
-        candidate_it->group->merged_ = true;
-        candidate_it->group->merge_with_ = group;
-        candidate_it->group->merge_through_ = candidate_it->edge;
-      }
-
-      if (to_merge_.empty()) {
-        merged_nodes = false;
-      }
-
-      mergeNodes();
-    }
-  }
-
-  if (options_.run_final_merge) {
-    // TODO: consider interleaving herrmman merge and bruteforce merge, as
-    // bruteforce merge can introduce
-    //  opportunities for more herrmann merge
-    finalMerge();
-  }
-
-  finalize();
-  if (isDebugDumpEnabled(DebugDumpOption::FusionSegmentsDrawing)) {
-    segmented_fusion_->draw();
-  }
-}
-
-void SegmentCandidateFinder::finalMerge() {
-  auto producer_check = getGroupDependency();
-
-  bool merged_nodes = true;
-  while (merged_nodes) {
-    // Iterate all groups and check if a group
-    //  can merge with one of its consumers
-    for (auto producer_group : groups()) {
-      // Populate consumers and their corresponding consumer edges
-      std::unordered_map<SegmentedGroup*, SegmentedEdge*> consumer_edge_map;
-      std::vector<SegmentedGroup*> all_consumers_of_producer_group;
-      for (auto consumer : producer_group->consumer_edges) {
-        // Since this is the last fusion pass, we can enable fusion through
-        // outputs. Priority of this was decreased because if the only
-        // connection between groups is an output node, best case scenario we
-        // can save a single pass in memory. Where if it wasn't an output it
-        // would be two passes.
-        consumer_edge_map.insert({consumer->to, consumer});
-      }
-      // Populate all consumers from the map to avoid duplicate
-      std::transform(
-          consumer_edge_map.begin(),
-          consumer_edge_map.end(),
-          std::back_inserter(all_consumers_of_producer_group),
-          [](auto& it) { return it.first; });
-
-      for (auto consumer : all_consumers_of_producer_group) {
-        if (!producer_check->isConsumerOfAny(
-                consumer, all_consumers_of_producer_group) &&
-            codeGenSupportedMerge(consumer_edge_map.at(consumer))) {
-          to_merge_.emplace_back(producer_group);
-          to_merge_.emplace_back(consumer);
-          producer_group->merged_ = true;
-          producer_group->merge_with_ = consumer;
-          producer_group->merge_through_ = consumer_edge_map.at(consumer);
-          consumer->merged_ = true;
-          consumer->merge_with_ = producer_group;
-          consumer->merge_through_ = producer_group->merge_through_;
-          break;
-        }
-      }
-
-      // Only want to merge one pair at a time so break if found any
-      if (!to_merge_.empty()) {
-        break;
-      }
-    }
-
-    if (to_merge_.empty()) {
-      merged_nodes = false;
-    } else {
-      TORCH_INTERNAL_ASSERT(
-          to_merge_.size() == 2, "merging more than 2 nodes in final iter");
-      mergeNodes();
-    }
-  }
-}
-
-void SegmentCandidateFinder::resolveScalarsInGroup(SegmentedGroup* group) {
-  std::vector<Val*> to_visit;
-  std::unordered_set<Val*> visited;
-
-  // Collect all scalar uses in the group
-  for (auto expr : group->exprs()) {
-    for (auto input : expr->inputs()) {
-      if (input->isScalar()) {
-        to_visit.push_back(input);
-      }
-    }
-  }
-
-  // Keep track of composite fusion inputs used in this group
-  std::unordered_set<Val*> input_set(
-      group->input_vals.begin(), group->input_vals.end());
-
-  // Record and append all missing scalar exprs at the end.
-  std::vector<Expr*> exprs_to_add;
-
-  // Do a stack based traversal of the scalar ops to avoid
-  //  combinatorial duplication of exprs.
-  while (!to_visit.empty()) {
-    auto stack_top_val = to_visit.back();
-    if (visited.count(stack_top_val)) {
-      to_visit.pop_back();
-    } else if (stack_top_val->definition() == nullptr) {
-      // A scalar without def can be a scalar, a tensor dim,
-      //  or a composite fusion input
-      // The first two cases are handled in finalize(),
-      //  the last case needs to add new input_val to this group.
-      visited.insert(stack_top_val);
-      // If this is a composite fusion scalar input, make sure this group has it
-      if (stack_top_val->isFusionInput() && !input_set.count(stack_top_val)) {
-        group->input_vals.push_back(stack_top_val);
-        input_set.insert(stack_top_val);
-      }
-      to_visit.pop_back();
-    } else {
-      // A scalar with an actual definition
-      auto definition_expr = stack_top_val->definition();
-      bool all_inputs_visited = true;
-      // If any of the inputs are not visited, visit them first
-      for (auto input : definition_expr->inputs()) {
-        if (!visited.count(input)) {
-          all_inputs_visited = false;
-          to_visit.push_back(input);
-        }
-      }
-      // This node is ready to be visited
-      if (all_inputs_visited) {
-        // Collect the defining expr to insert into group
-        exprs_to_add.push_back(definition_expr);
-        visited.insert(stack_top_val);
-        to_visit.pop_back();
-      }
-    }
-  }
-
-  // Add all the defining expr to the group
-  for (auto expr : exprs_to_add) {
-    group->exprs_.push_back(expr);
-  }
-}
-
-void SegmentCandidateFinder::removeScalarEdges() {
-  // Remove all scalar edges between groups
-  //  They may have been created by welford
-  //   translation.
-  //  we will not need them after scalar
-  //  resolution
-  auto remove_scalar_edges_from_vec = [](std::vector<SegmentedEdge*>& edges) {
-    edges.erase(
-        std::remove_if(
-            edges.begin(),
-            edges.end(),
-            [](SegmentedEdge* segmented_edge) {
-              return segmented_edge->val->isScalar();
-            }),
-        edges.end());
-  };
-
-  remove_scalar_edges_from_vec(edges());
-  for (auto group : groups()) {
-    remove_scalar_edges_from_vec(group->producer_edges);
-    remove_scalar_edges_from_vec(group->consumer_edges);
-  }
-}
-
-void SegmentCandidateFinder::finalize() {
-  // Remove unconnected groups
-  groups().erase(
-      std::remove_if(
-          groups().begin(),
-          groups().end(),
-          [](SegmentedGroup* sg) { return !sg->isConnected(); }),
-      groups().end());
-
-  // Add group labeling
-  int i = 0;
-  for (auto it = groups().begin(); it != groups().end(); it++, i++) {
-    deDuplicateScalarExprs((*it)->exprs_);
-    (*it)->setID(i);
-  }
-
-  // TODO: too many things are currently abstracted under the term
-  //  finalize. Need to re-structure in a follow up.
-
-  // Finalize connections between segmented groups
-  segmented_fusion_->finalize();
-
-  // Resolve all the scalar expressions needed in each group
-  for (auto group : segmented_fusion_->groups()) {
-    resolveScalarsInGroup(group);
-  }
-
-  // Finalize each group, fill in the missing inputs, i.e. tensor dims.
-  for (auto g : groups()) {
-    g->finalize();
-  }
-}
-
-GroupDependencyAnalysis* SegmentCandidateFinder::getGroupDependency() {
-  if (!group_dependency_) {
-    group_dependency_ =
-        std::make_unique<GroupDependencyAnalysis>(segmented_fusion_.get());
-  }
-  return group_dependency_->as<GroupDependencyAnalysis>();
-}
-
-FusionKernelRuntime::SchedulerEntryPtr SegmentedFusion::
-    makeInitialSchedulerEntry(
-        SegmentedGroup* sg,
-        SchedulerRuntimeInfo& runtime_info) {
-  auto local_fusion = completeFusion();
-  FusionSegmentGuard fsg(local_fusion, getAllInputs(sg), getAllOutputs(sg));
-  // This will be the first time each group is scheduled. So we'd want to
-  //  construct the cache data here.
-  auto data_cache_ptr = std::make_unique<HeuristicSummary>(
-      local_fusion, sg->heuristic(), runtime_info);
-  auto data_cache = data_cache_ptr.get();
-  setCachedHeuristicDataFor(sg, std::move(data_cache_ptr));
-  return SchedulerEntry::makeEntry(
-      sg->heuristic(), local_fusion, runtime_info, data_cache);
-}
-
-std::unique_ptr<FusionHeuristics> SegmentedFusion::makeInitialHeuristics(
-    const at::ArrayRef<IValue>& inputs) {
-  auto ret = std::make_unique<FusionHeuristics>();
-  SchedulerRuntimeInfo runtime_info(completeFusion(), inputs, true);
-  for (auto g : groups()) {
-    ret->emplaceBack(makeInitialSchedulerEntry(g, runtime_info));
-  }
-  return ret;
-}
-
-HeuristicSummary* SegmentedFusion::getCachedHeuristicDataFor(
-    SegmentedGroup* group) {
-  auto data_it = heuristic_summary_cache_.find(group);
-  if (data_it == heuristic_summary_cache_.end()) {
-    return nullptr;
-  }
-  return data_it->second.get();
-}
-
-void SegmentedFusion::setCachedHeuristicDataFor(
-    SegmentedGroup* group,
-    std::unique_ptr<HeuristicSummary> data) {
-  TORCH_INTERNAL_ASSERT(!heuristic_summary_cache_.count(group));
-  heuristic_summary_cache_[group] = std::move(data);
-}
-
-namespace {
-
-//! A thin traversal class that collects all the tensorviews
-//!  that could cast to fp16 if they were segmented edges.
-//!  The selected values are currently defined as all the
-//!  tensorviews that
-//!     1. are not complete fusion input/output,
-//!     2. have a use chain that ends with a fp16
-//!         complete fusion output
-//!     3. are fp32 datatype
-class ForceFP16Annotation : public IterVisitor {
- public:
-  static std::unordered_set<TensorView*> getAnnotatedSet(Fusion* fusion) {
-    ForceFP16Annotation annotation;
-    std::vector<Val*> fp16_outputs;
-
-    std::copy_if(
-        fusion->outputs().begin(),
-        fusion->outputs().end(),
-        std::back_inserter(fp16_outputs),
-        [](auto* val) {
-          return val->template isA<TensorView>() &&
-              val->getDataType().has_value() &&
-              val->getDataType().value() == DataType::Half;
-        });
-
-    annotation.traverseFrom(fusion, fp16_outputs);
-    return annotation.force_fp16_tv_set_;
-  }
-
- private:
-  using IterVisitor::handle;
-
-  void handle(TensorView* tv) override {
-    auto dtype = tv->getDataType();
-    if (dtype.has_value() && dtype.value() == DataType::Float &&
-        !tv->isFusionOutput() && !tv->isFusionInput()) {
-      force_fp16_tv_set_.insert(tv);
-    }
-  }
-
-  std::unordered_set<TensorView*> force_fp16_tv_set_;
-};
-
-} // namespace
-
-void SegmentedFusion::annotateFP16IntermediateTensors() {
-  force_fp16_tv_set_ =
-      ForceFP16Annotation::getAnnotatedSet(complete_fusion_.get());
-}
-
-TORCH_CUDA_CU_API std::string toString(
-    const SegmentCandidateFinderOptions& segment_options) {
-  std::stringstream ss;
-  ss << "segmentation phases {\n";
-  if (segment_options.run_combine_reductions) {
-    ss << "combine reductions\n";
-  }
-  if (segment_options.run_herrmann_merge) {
-    ss << "herrmann merging\n";
-  }
-  if (segment_options.run_final_merge) {
-    ss << "final merging\n";
-  }
-  ss << "\n}\n";
-  return ss.str();
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/fusion_segmenter.h b/torch/csrc/jit/codegen/cuda/fusion_segmenter.h
deleted file mode 100644 (file)
index ae11d38..0000000
+++ /dev/null
@@ -1,598 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
-
-#include <deque>
-#include <list>
-#include <unordered_set>
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class SegmentedGroup;
-class SegmentCandidateFinder;
-
-// A directed edge on DAG,
-// Wrapper for values, edges between segmented groups which are made up
-// of Exprs. Multiple edges can exist between segmented groups.
-struct SegmentedEdge {
-  SegmentedEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val)
-      : from(from), to(to), val(val) {}
-
-  SegmentedGroup* from;
-  SegmentedGroup* to;
-  Val* val;
-
-  void print() const;
-};
-
-std::ostream& operator<<(std::ostream& os, const SegmentedEdge* edge);
-
-//! Groups together expressions which create a segmented group
-//! Can be used to produce fusions
-class TORCH_CUDA_CU_API SegmentedGroup {
- public:
-  SegmentedGroup(SegmentedFusion* segmented_fusion)
-      : segmented_fusion_(segmented_fusion) {}
-
-  SegmentedGroup(Expr* expr, SegmentedFusion* segmented_fusion)
-      : segmented_fusion_(segmented_fusion) {
-    exprs_.push_back(expr);
-  }
-
-  //! Checks if this group takes original fusion's input
-  bool isInputGroup() {
-    return !input_vals.empty();
-  };
-
-  //! Checks if this group is used any where in the segmented fusion
-  bool isConnected() const {
-    return !producer_edges.empty() || !consumer_edges.empty() ||
-        !output_vals.empty();
-  }
-
-  //! returns the id assigned by segment pass
-  int groupId() const {
-    return group_id_;
-  }
-
-  //! Returns inputs that this group shares with the original fusion
-  const auto& inputs() const {
-    return input_vals;
-  }
-
-  //! Returns outputs that this group shares with the original fusion
-  const auto& outputs() const {
-    return output_vals;
-  }
-
-  //! Returns the schedule heuristic associated with this group
-  ScheduleHeuristic heuristic() const {
-    return heuristic_;
-  }
-
-  //! Returns the exprs that make up this group
-  const auto& exprs() const {
-    return exprs_;
-  }
-
-  //! Debug print function
-  void print() const;
-
-  //! Returns the segmented fusion that this group is in
-  SegmentedFusion* segmentedFusion() const {
-    return segmented_fusion_;
-  }
-
-  //! Try to get a scheduler entry for this group with
-  //!  the given runtime info.
-  //! Returns a new scheduler with the same heuristics
-  //!  for this group if possible.
-  //!  Note that the schedule params can be different.
-  //! Returns a nullopt if this group cannot be scheduled
-  //!  with the same heuristics.
-  c10::optional<std::unique_ptr<SchedulerEntry>> getMaybeSchedulerEntry(
-      SchedulerRuntimeInfo& runtime_info);
-
- public:
-  //! "Ancestor nodes", towards inputs of segmentedDAG
-  std::vector<SegmentedEdge*> producer_edges;
-
-  //! "Descendent nodes", towards outputs of segmentedDAG
-  std::vector<SegmentedEdge*> consumer_edges;
-
-  //! Composite Fusion inputs in this group
-  std::vector<Val*> input_vals;
-
-  //! Composite Fusion outputs in this group
-  std::vector<Val*> output_vals;
-
- private:
-  friend class SegmentCandidateFinder;
-  friend class SegmentedFusion;
-  friend class FusionKernelRuntime;
-  friend class TranslateApplicableWelford;
-
-  //! unique identifier of group in the segmented fusion
-  int group_id_ = -1;
-
-  //! The scheduler to use for compiling this group
-  ScheduleHeuristic heuristic_ = ScheduleHeuristic::PointWise;
-
-  //! Exprs that make up the group
-  std::vector<Expr*> exprs_;
-
-  //! Maximum path distance from an input segmented group required for
-  //! Theorem 4.2
-  int level_ = -1;
-
-  //! traversal marker, has this node already been processed
-  bool visited_ = false;
-
-  //! Did we select another group to merge with
-  SegmentedGroup* merge_with_ = nullptr;
-
-  //! if we selected another group to merge, which edge is to be contracted
-  SegmentedEdge* merge_through_ = nullptr;
-
-  //! Has this node been merged?
-  bool merged_ = false;
-
- private:
-  //! Utility to convert edge vector to value vector
-  std::vector<Val*> edgesToVals(const std::vector<SegmentedEdge*>& se_v);
-
-  //! Reset method to call at begining of each
-  //!  merge node iteration
-  void clearTraversalInfo();
-
-  //! To be called at the very end of segment fusion
-  //!  no more segment merging should be done beyond
-  void finalize();
-
-  //! Return all segmented groups connected with *this
-  std::vector<SegmentedGroup*> getNeighbors();
-
-  //! Utility struct to represent a group connection
-  //!  both the group to connect with and the edge
-  //!  to connect through
-  struct NeighborGroup {
-    NeighborGroup(SegmentedGroup* g, SegmentedEdge* e) : group(g), edge(e) {}
-    SegmentedGroup* group;
-    SegmentedEdge* edge;
-  };
-
-  //! TODO: May want to sort this based on size of connections between this and
-  //! neighbors as well as if the connection is an output of the fusion (has to
-  //! be saved to gmem anyways)
-  std::vector<NeighborGroup> getNeighborGroups();
-
-  //! Look at all neighbors of this and return who this could merge with based
-  //! on level values of this, neighbors, and merged neighbors of neighbors
-  std::vector<NeighborGroup> getMergeCandidates();
-
-  //! Assign schedule heuristic to this group
-  void setHeuristic(ScheduleHeuristic sh) {
-    heuristic_ = sh;
-  }
-
-  //! Assign Id for this group
-  void setID(int id) {
-    TORCH_INTERNAL_ASSERT(group_id_ == -1);
-    group_id_ = id;
-  }
-
-  //! SegmentedFusion this group belongs to
-  SegmentedFusion* segmented_fusion_;
-};
-
-std::ostream& operator<<(std::ostream& os, const SegmentedGroup* group);
-
-//! Auxiliary class for storing heuristics. The managed data is either
-//!  a single scheduler entry for complete fusion,
-//!  or a vector of schedulers, one for each segment, for segmented fusion.
-class TORCH_CUDA_CU_API FusionHeuristics {
-  using SchedulerEntryOwningPtr = std::unique_ptr<SchedulerEntry>;
-
- public:
-  //! Constructor for segmented fusion case. Created with empty list and
-  //!  uses emplaceBack for inserting heuristics in order
-  explicit FusionHeuristics() = default;
-
-  //! Constructor for complete fusion case, generates the scheduler entry
-  //!  for the fusion owning the given expression
-  explicit FusionHeuristics(
-      ScheduleHeuristic schedule_heuristic,
-      SchedulerRuntimeInfo& runtime_info,
-      HeuristicSummary* data_cache = nullptr) {
-    heuristics_.emplace_back(SchedulerEntry::makeEntry(
-        schedule_heuristic, runtime_info.fusion(), runtime_info, data_cache));
-    is_segmented_ = false;
-  }
-
-  //! Place a scheduler entry on the list. Applies to segmented fusion only.
-  void emplaceBack(SchedulerEntryOwningPtr&& pt) {
-    TORCH_INTERNAL_ASSERT(is_segmented_);
-    heuristics_.emplace_back(std::move(pt));
-  }
-
-  //! Returns list of schedulers for a segmneted fusion.
-  const std::vector<SchedulerEntryOwningPtr>& heuristicsList() const {
-    return heuristics_;
-  }
-
-  //! Returns the single scheduler for a complete fusion.
-  SchedulerEntry* singleKernelHeuristics() {
-    TORCH_INTERNAL_ASSERT(!is_segmented_);
-    return heuristics_.begin()->get();
-  }
-
- private:
-  std::vector<SchedulerEntryOwningPtr> heuristics_;
-  bool is_segmented_ = true;
-};
-
-//! Exported Interface for representing segmented fusion graph
-//!   this class owns the segmented groups
-class TORCH_CUDA_CU_API SegmentedFusion {
- public:
-  explicit SegmentedFusion(std::unique_ptr<Fusion> fusion);
-
-  //! Is the fusion segmented?
-  bool isSegmented() const {
-    return !groups_.empty();
-  }
-
-  std::vector<SegmentedGroup*>& groups() {
-    return groups_;
-  }
-
-  std::vector<SegmentedEdge*>& edges() {
-    return edges_;
-  }
-
-  const std::vector<SegmentedGroup*>& cgroups() const {
-    return groups_;
-  }
-
-  const std::vector<SegmentedEdge*>& cedges() const {
-    return edges_;
-  }
-
-  //! Returns the original un-segmented fusion
-  Fusion* completeFusion() {
-    return complete_fusion_.get();
-  }
-
-  const auto& inputs() const {
-    return complete_fusion_->inputs();
-  }
-
-  const auto& outputs() const {
-    return complete_fusion_->outputs();
-  }
-
-  Val* findAlias(Val* val) const {
-    Val* alias_val = nullptr;
-    if (complete_fusion_->io_alias_.count(val) != 0) {
-      alias_val = complete_fusion_->io_alias_[val];
-    }
-    return alias_val;
-  }
-
-  //! Make a clone of the group and convert to fusion
-  std::unique_ptr<Fusion> makeFusion(SegmentedGroup* sg);
-
-  //! Make heuristics for all groups in this segmented fusion
-  std::unique_ptr<FusionHeuristics> makeInitialHeuristics(
-      const at::ArrayRef<IValue>& inputs);
-
-  //! Inline Debug print for segmented fusion
-  std::string toString(int verbosity) const;
-
-  //! Debug drawing for graphviz
-  void draw();
-
-  //! Debug print for segmented fusions
-  void print() const;
-
-  //! API for adding groups
-  SegmentedGroup* newGroup();
-
-  //! API shortcut for adding a singleton group
-  SegmentedGroup* newGroup(Expr* expr);
-
-  //! API for adding edges
-  SegmentedEdge* newEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val);
-
-  //! Returns the set of potential intermediate tensors that
-  //!  will be cast to fp16 when written to global mem.
-  //!  These are not actual intermediate tensors,
-  //!  just the ones that will need to cast to fp16 if
-  //!  they end up being an intermediate tensor between
-  //!  segmented groups.
-  const auto& getForceToFP16Set() {
-    return force_fp16_tv_set_;
-  }
-
-  HeuristicSummary* getCachedHeuristicDataFor(SegmentedGroup* group);
-
- private:
-  //! Unique name for segmented fusion
-  int segmented_fusion_name_;
-
-  //! States representing segmentation
-  std::vector<SegmentedEdge*> edges_;
-  std::vector<SegmentedGroup*> groups_;
-
-  //! Owning object to explicitly manage groups and edges
-  class Impl {
-   public:
-    explicit Impl(SegmentedFusion* sf) : owning_fusion_(sf) {}
-
-    SegmentedGroup* makeGroup();
-    SegmentedGroup* makeGroup(Expr*);
-    SegmentedEdge* makeEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val);
-    void cleanUnused();
-
-   private:
-    using GroupPtr = std::unique_ptr<SegmentedGroup>;
-    using EdgePtr = std::unique_ptr<SegmentedEdge>;
-    std::vector<GroupPtr> groups_;
-    std::vector<EdgePtr> edges_;
-    SegmentedFusion* owning_fusion_;
-  };
-  Impl impl_;
-
-  //! A Copy of original full fusion
-  std::unique_ptr<Fusion> complete_fusion_;
-
-  //! A set of intermediate tensors that need to be cast to fp16
-  std::unordered_set<TensorView*> force_fp16_tv_set_;
-
-  //! Static traversal information to be used for fast heuristics lookup
-  std::unordered_map<SegmentedGroup*, std::unique_ptr<HeuristicSummary>>
-      heuristic_summary_cache_;
-
-  // TODO: this class needs cleanup
- protected:
-  friend class SegmentCandidateFinder;
-  //! Make a heuristics entry for a group and parameters
-  std::unique_ptr<SchedulerEntry> makeInitialSchedulerEntry(
-      SegmentedGroup* sg,
-      SchedulerRuntimeInfo& runtime_info);
-
-  //! Cleanup function to be call at the end of fusion
-  //!  segment pass
-  void finalize();
-
-  //! Collect all the intermediate tensors between segmented
-  //!  groups that will cast to fp16
-  void annotateFP16IntermediateTensors();
-
-  //! Keep heuristic checking intermediate data
-  void setCachedHeuristicDataFor(
-      SegmentedGroup* group,
-      std::unique_ptr<HeuristicSummary> data);
-
-  //! Utility to give unique name for each segmented fusion
-  static size_t segmentedFusionName() {
-    static size_t counter = 0;
-    return counter++;
-  }
-};
-
-//! This is a base class for segmenter analysis
-//!  provides the minimal implementation on header so that
-//!  a unique_ptr can use this base class
-//!  actual implementations of analyses are in the .cpp files
-//! TODO: In the next refactor PR, should put segment candidate
-//!  finder in .cpp file completely since API doesn't require these
-//!  details
-class SegmenterAnalysis : public PolymorphicBase {};
-class GroupDependencyAnalysis;
-
-// Manual node merging passes
-class CombineReductions;
-
-//! Options to configure/debug candidate finder
-struct TORCH_CUDA_CU_API SegmentCandidateFinderOptions {
-  bool run_translate_welford = true;
-  bool run_combine_reductions = true;
-  bool run_herrmann_merge = true;
-  bool run_final_merge = true;
-};
-
-//!  SegmentCandidateFinder
-//!    Responsible for going through DAG and proposing things we could try to
-//!    fuse together, calls "canGenerateCode" on these proposed segments to see
-//!    if they are valid and we can generate code for them.
-//!  FusionSegment
-//!    A group of exprs that are segmented together
-//!  FusionSegmentConnections
-//!    Holds vals and what they connect. In other words it's a val that is an
-//!    output of a FusionSegment "from" and an input of FusionSegment "to".
-//!    There's nothing preventing from a val being between segments twice.
-//!    TODO: make sure there's nothing wrong with segmentation on nodes that
-//!    have the same value input twice. i.e. (B = A*A)
-//! Selecting segments to propose is based on the theorem 4.2 in the paper which
-//! makes sure when segment the segmented graph will be a DAG (assumes Fusion is
-//! already a DAG). The segmentation code relies on assumptions of DAG-ness
-//! during segmentation, meaning proposed merging of groups must maintain the
-//! DAG property of the graph.
-//!
-//! Julien Herrmann, Yusuf Ã–zkaya, Bora Uçar, Kamer Kaya, Umit Catalyurek.
-//! Multilevel Algorithms for Acyclic Partitioning of Directed Acyclic Graphs.
-//! SIAM Journal on Scientific Computing, Society for Industrial and Applied
-//! Mathematics, 2019, 41 (4), pp.A2117-A2145. ff10.1137/18M1176865ff.
-//! ffhal02306566f
-class TORCH_CUDA_CU_API SegmentCandidateFinder {
- public:
-  // Perform segmentation on a copy of the given fusion
-  static std::unique_ptr<SegmentedFusion> segment(
-      const Fusion* fusion,
-      const at::ArrayRef<IValue>& inputs,
-      SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) {
-    auto fusion_copy = std::make_unique<Fusion>(*fusion);
-    SegmentCandidateFinder scf(std::move(fusion_copy), inputs, options);
-    return std::move(scf.segmented_fusion_);
-  }
-
-  // Perform segmentation on and take ownership of the given fusion
-  static std::unique_ptr<SegmentedFusion> segment(
-      std::unique_ptr<Fusion> fusion,
-      const at::ArrayRef<IValue>& inputs,
-      SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) {
-    SegmentCandidateFinder scf(std::move(fusion), inputs, options);
-    return std::move(scf.segmented_fusion_);
-  }
-
-  static bool TranslateWelfordInFusion(
-      Fusion* fusion,
-      const at::ArrayRef<IValue>& runtime_inputs);
-
- private:
-  // Perform segmentation on and take ownership of the given fusion
-  SegmentCandidateFinder(
-      std::unique_ptr<Fusion> fusion,
-      const at::ArrayRef<IValue>& inputs,
-      SegmentCandidateFinderOptions options);
-
-  void resetTraversal();
-
-  void resetLevels();
-
-  SegmentedGroup* mergeNodes();
-
-  bool codeGenSupportedMerge(SegmentedEdge* edge);
-
-  void findSegments();
-
-  std::unordered_set<SegmentedEdge*> disconnectGroup(SegmentedGroup* group);
-
-  std::vector<SegmentedGroup*>& groups() {
-    TORCH_INTERNAL_ASSERT(
-        segmented_fusion_ != nullptr, "Segment finder not owinging any fusion");
-    return segmented_fusion_->groups();
-  }
-
-  std::vector<SegmentedEdge*>& edges() {
-    TORCH_INTERNAL_ASSERT(
-        segmented_fusion_ != nullptr, "Segment finder not owinging any fusion");
-    return segmented_fusion_->edges();
-  }
-
-  Fusion* completeFusion() {
-    TORCH_INTERNAL_ASSERT(
-        segmented_fusion_ != nullptr, "Segment finder not owinging any fusion");
-    return segmented_fusion_->completeFusion();
-  }
-
-  SchedulerRuntimeInfo& runtimeInfo() {
-    return runtime_info_;
-  }
-
-  ExpressionEvaluator& expressionEvaluator() {
-    return runtime_info_.expressionEvaluator();
-  }
-
-  //! Additional merging iteration, clean up the rest of
-  //!  the merging opportunities
-  //!  Herrmann et al. is a fast and safe algorithm for finding merge candidates
-  //!  but can become too conservative in our use cases because we place
-  //!  additional qualifiers on valid merges other than having to generate DAGs,
-  //!  i.e. canSchedule. So we need a bruteforce final merging iteration as a
-  //!  clean up pass. Cost isn't expected to be high since the graph at this
-  //!  stage is already quite merged. Example cf. test_gpu.cpp:
-  //!  FusionDAGMerging_CUDA
-  //!
-  //!  This merging algorithm is based on Theorem 4.1 of Herrmann et al.,
-  //!   to check if a producer-consumer pair can be merged into one group,
-  //!   it's enough to check if any other consumer of the producer also
-  //!   produces the consumer.
-  void finalMerge();
-
-  //! Duplicate and add all exprs producing the used
-  //!  scalar values in group
-  void resolveScalarsInGroup(SegmentedGroup* group);
-
-  //! Remove all scalar edges in group
-  //!  (TODO: need structure better so we don't have to do this)
-  void removeScalarEdges();
-
-  //! Utility function to merge a vector of groups in one step,
-  //!  need to check for DAG condition before using this method
-  SegmentedGroup* mergeAllGivenGroups(
-      const std::vector<SegmentedGroup*>& groups);
-
-  //! Utility to remove a group and corresponding edges
-  //!  TODO: remove inline versions of this as much as possible
-  void eraseGroups(std::unordered_set<SegmentedGroup*>& groups_to_erase);
-
-  void finalize();
-
-  //! Return the resulting heuristic corresponding to the merged
-  //!  group built by merging the two groups connected by edge
-  ScheduleHeuristic deriveHeuristic(SegmentedGroup* edge);
-
-  GroupDependencyAnalysis* getGroupDependency();
-
- protected:
-  //! These are the merge node heuristic passes, should
-  //!  eventually should have a dedicated interface
-  //!  instead of keeping adding friends
-  friend class CombineReductions;
-
-  //! options to configure and debug the segment process
-  SegmentCandidateFinderOptions options_;
-
-  std::deque<SegmentedGroup*> to_visit_;
-  std::vector<SegmentedGroup*> next_to_visit_;
-
-  std::unordered_set<SegmentedGroup*> clean_up_groups_;
-  std::unordered_set<SegmentedEdge*> clean_up_edges_;
-
-  std::vector<SegmentedGroup*> to_merge_;
-
-  std::unique_ptr<SegmentedFusion> segmented_fusion_;
-
-  std::unique_ptr<SegmenterAnalysis> group_dependency_;
-
-  SchedulerRuntimeInfo runtime_info_;
-
-  //! Note:
-  //!  Segmenter should eventually rely only on runtime_info_ for
-  //!  safe caching. runtime_inputs_ is only used in translateWelford
-  //!  to initialize expression evaluators on copies of the original
-  //!  fusion, which doesn't use any un-cached info and is safe.
-  //!
-  //!  Directly using runtime_inputs_ in other cases is in general
-  //!   risky.
-  //!
-  //!  To get rid of runtime_inputs_ we need mechanisms
-  //!  to copy expression evaluator values from fusion
-  //!  to a copy, or even better to a copy of a
-  //!  sub-graph of original fusion.
-  //! TODO:
-  //!  implement the expression evaluator transfer and
-  //!  remove runtime_inputs_ in a follow up.
-  const at::ArrayRef<IValue>& runtime_inputs_;
-};
-
-TORCH_CUDA_CU_API std::string toString(const SegmentedGroup* group);
-TORCH_CUDA_CU_API std::string toString(const SegmentedEdge* edge);
-TORCH_CUDA_CU_API std::string toString(const SegmentedFusion* segmented_fusion);
-TORCH_CUDA_CU_API std::string toString(
-    const SegmentCandidateFinderOptions& segment_options);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
index 48ea9b1..ebe3ef0 100644 (file)
@@ -7,12 +7,10 @@
 #include <torch/csrc/jit/codegen/cuda/partition.h>
 #include <torch/csrc/jit/frontend/ir_emitter.h>
 #include <torch/csrc/jit/ir/alias_analysis.h>
-#include <torch/csrc/jit/jit_log.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
 #include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/pass_manager.h>
-#include <torch/csrc/jit/passes/remove_inplace_ops.h>
 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
 #include <torch/csrc/jit/runtime/autodiff.h>
 #include <torch/csrc/jit/runtime/custom_operator.h>
@@ -32,16 +30,6 @@ constexpr size_t NVRTC_KERNEL_ARG_LIMIT = 128;
 
 namespace {
 
-bool usedOnlyInDtype(Value* v) {
-  const auto& uses = v->uses();
-  if (uses.empty()) {
-    return false;
-  }
-  return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
-    return u.user->matches("prim::dtype(Tensor a) -> int");
-  });
-}
-
 Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
   AT_ASSERT(!sizes.empty());
   Graph* graph = sizes[0]->owningGraph();
@@ -51,44 +39,6 @@ Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
   return broadcast_n->output();
 }
 
-Value* createConditionalConstant(Node* profile_ivalue) {
-  TORCH_INTERNAL_ASSERT(profile_ivalue->kind() == prim::profile_ivalue);
-
-  auto graph = profile_ivalue->owningGraph();
-
-  IValue val; // default to None
-  if (profile_ivalue->hasAttribute(Symbol::attr("profiled_int_list"))) {
-    // int[]
-    val = IValue(profile_ivalue->is(Symbol::attr("profiled_int_list")));
-  } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_bool_list"))) {
-    // bool[]
-    auto int_list = profile_ivalue->is(Symbol::attr("profiled_bool_list"));
-    std::vector<bool> bool_list(int_list.begin(), int_list.end());
-    val = IValue(bool_list);
-  } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_size"))) {
-    // int[]
-    val = IValue(profile_ivalue->is(Symbol::attr("profiled_size")));
-  } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_bool"))) {
-    // bool
-    val = IValue(
-        static_cast<bool>(profile_ivalue->i(Symbol::attr("profiled_bool"))));
-  } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_int"))) {
-    // int
-    val = IValue(
-        static_cast<int>(profile_ivalue->i(Symbol::attr("profiled_int"))));
-  } else {
-    GRAPH_DEBUG("profile_ivalue: ", *profile_ivalue);
-    TORCH_WARN(
-        __func__,
-        " profile_node ",
-        *profile_ivalue,
-        " does not have profile information");
-    return nullptr;
-  }
-
-  return graph->insertConstant(val);
-}
-
 struct CudaGraphFuser {
   using FusionCallback = std::function<bool(Node*)>;
 
@@ -208,6 +158,7 @@ struct CudaGraphFuser {
     std::unordered_map<Value*, Value*> inputs_map;
     size_t i = 0;
     size_t tensor_insert_idx = 0;
+    AT_ASSERT(group->inputs().size() == subgraph.inputs().size());
     for (auto input : group->inputs()) {
       inputs_map[input] = subgraph.inputs()[i++];
       if (input->type()->isSubtypeOf(TensorType::get()))
@@ -231,7 +182,9 @@ struct CudaGraphFuser {
         } else if (
             // TODO: extend the supporting inputs here.
             (input->type()->isSubtypeOf(FloatType::get()) &&
-             input->node()->kind() != prim::Constant)) {
+             input->node()->kind() != prim::Constant) ||
+            (n->kind() == aten::_grad_sum_to_size &&
+             input->type()->isSubtypeOf(ListType::ofInts()))) {
           auto in_group = subgraph.addInput();
           in_group->setType(input->type());
           inputs_map[input] = in_group;
@@ -239,20 +192,8 @@ struct CudaGraphFuser {
         } else if (input->node()->kind() == prim::Constant) {
           // inline the constants directly in the body of the fused group.
           Node* in_const =
-              subgraph.createClone(input->node(), [&](Value* v) -> Value* {
-                if (v->node()->kind() != prim::profile_ivalue) {
-                  throw std::runtime_error(
-                      std::string(
-                          "merging constant with unexpected input from node") +
-                      v->node()->kind().toDisplayString());
-                }
-                group->addInput(v->node()->output());
-
-                // we are doing this just to keep alias_analysis silent with
-                // their checks
-                auto in_group = subgraph.addInput();
-                in_group->setType(v->type());
-                return in_group;
+              subgraph.createClone(input->node(), [](Value*) -> Value* {
+                throw std::runtime_error("unexpected input");
               });
           subgraph.insertNode(in_const);
           inputs_map[input] = in_const->output();
@@ -298,11 +239,9 @@ struct CudaGraphFuser {
     // have a valid mapping
     group->insertBefore(n);
     Node* mergedNode = mergeNodeIntoGroup(group, n);
-    for (size_t i = 0; i < n->outputs().size(); i++) {
-      getSubgraph(group).registerOutput(mergedNode->output(i));
-      auto sel = group->addOutput();
-      sel->copyMetadata(n->output(i));
-    }
+    getSubgraph(group).registerOutput(mergedNode->output());
+    auto sel = group->addOutput();
+    sel->copyMetadata(n->output());
     n->replaceAllUsesWith(group);
     n->destroy();
     return group;
@@ -317,7 +256,7 @@ struct CudaGraphFuser {
     // but this requires better handling of merging fusion groups so it is not
     // done now
     bool shouldFuse =
-        fuser::cuda::isFusibleCudaFusionGroup(consumer, producer->node()) &&
+        fuser::cuda::isFusableCudaFusionGroup(consumer, producer->node()) &&
         // Rearrange nodes such that all uses of producer are after the
         // consumer. Fusion will rewrite those later uses to use the version of
         // producer generated by the fused blob. In this case, producer becomes
@@ -343,21 +282,17 @@ struct CudaGraphFuser {
       mergeFusionGroups(group, producer->node());
       return group;
     }
+    AT_ASSERT(producer->node()->outputs().size() == 1);
     Node* merged = mergeNodeIntoGroup(group, producer->node());
     // remaining uses of this producer can occur because we allow
     // fusion in cases where uses remain after the consumer
     // if these exist, re-route them to the version of producer
     // created in FusionGroup
-
-    // We need to apply this to all outputs from producer->node();
-    auto producer_outputs = producer->node()->outputs();
-    for (size_t i = 0; i < producer_outputs.size(); i++) {
-      if (producer_outputs[i]->uses().size() != 0) {
-        getSubgraph(group).registerOutput(merged->outputs()[i]);
-        Value* new_producer = group->addOutput();
-        new_producer->copyMetadata(producer_outputs[i]);
-        producer_outputs[i]->replaceAllUsesWith(new_producer);
-      }
+    if (producer->uses().size() != 0) {
+      getSubgraph(group).registerOutput(merged->output());
+      Value* new_producer = group->addOutput();
+      new_producer->copyMetadata(producer);
+      producer->replaceAllUsesWith(new_producer);
     }
     producer->node()->destroy();
     return group;
@@ -545,7 +480,7 @@ struct CudaGraphFuser {
         chunk->inputs().begin(),
         chunk->inputs().end(),
         [&](Value* producer_for_chunk) {
-          return fuser::cuda::isFusibleCudaFusionGroup(
+          return fuser::cuda::isFusableCudaFusionGroup(
                      consumer, producer_for_chunk->node()) &&
               allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
         });
@@ -573,7 +508,6 @@ struct CudaGraphFuser {
       bchunk = promoteChunkToBroadcastingChunk(chunk);
     }
     size_t nchunks = bchunk->i(attr::chunks);
-    TORCH_INTERNAL_ASSERT(nchunks > 0, "number of chunks cannot be zero");
     WithInsertPoint guard(bchunk->next());
 
     std::vector<Value*> producer_chunk_outputs;
@@ -587,10 +521,6 @@ struct CudaGraphFuser {
     //  = Node* for chunk_output_idx'th output of the chunk(inputs[input_nr])
     std::vector<std::vector<Value*>> chunked_inputs;
 
-    // We have asserted single output earlier
-    auto producer_output_sizes =
-        producer_for_chunk_node->output()->type()->cast<TensorType>()->sizes();
-
     for (auto input : producer_for_chunk_node->inputs()) {
       // XXX: we only work with pointwise ops in here, so we know it is valid to
       // push the concat only through tensor arguments (and all other args can
@@ -619,61 +549,9 @@ struct CudaGraphFuser {
       // distinct from Node.
       bchunk->addInput(input);
       chunked_inputs.emplace_back(); // alas, to not be C++17
-
-      // properly compute strides for BroadcastingChunk
-      //
-      // We copy stride of each dimension from input to output for
-      // BroadcastingChunk. A note is that Chunk should not alter strides,
-      // However, broadcasted dimension should have a stride 0. We could have
-      // broadcasting happening on existing dimensions in input (case1), as well
-      // as extended dimension that does not exist in input (case2).
-      // e.g.
-      // If we look at an input tensor t0 with shape [3, 1] broadcasted to
-      // output tensor t1 with shape [4, 1, 3, 3],
-      // We set stride to zero in case of broadcast, which could happen in:
-      //   case1: t1.dim[3] (broadcasted as in the description above)
-      //   case2: t1.dim[0] (broadcasted implicitly)
-      std::vector<int64_t> strides;
-      auto input_type = input->type()->cast<TensorType>();
-      auto input_sizes = input_type->sizes();
-      auto input_strides = input_type->strides();
-      if (producer_output_sizes.isComplete() && input_sizes.isComplete() &&
-          input_strides.isComplete()) {
-        auto input_c_sizes = input_sizes.concrete_sizes().value();
-        auto input_c_strides = input_strides.concrete_sizes().value();
-        auto output_c_sizes = producer_output_sizes.concrete_sizes().value();
-        int output_index = int(output_c_sizes.size()) - 1;
-        strides.resize(output_index);
-        AT_ASSERT(output_index >= int(input_c_sizes.size()) - 1);
-        for (int input_index = int(input_c_sizes.size()) - 1; input_index >= 0;
-             input_index--, output_index--) {
-          // in braodcast case 1, we set stride to 0;
-          // otherwise, stride remain the same.
-          if (input_c_sizes[input_index] == 1 &&
-              output_c_sizes[output_index] != 1) {
-            strides[output_index] = 0;
-          } else {
-            strides[output_index] = input_c_strides[input_index];
-          }
-        }
-
-        // continue on expanding dimensions to set stride to 0 for case2
-        while (output_index >= 0) {
-          strides[output_index] =
-              output_c_sizes[output_index] == 1 ? strides[output_index + 1] : 0;
-          output_index--;
-        }
-      }
-
       for (auto chunk_sel : producer_chunk_outputs) {
         Value* input_chunk_sel = bchunk->addOutput();
-        auto chunk_sel_type = chunk_sel->type()->cast<TensorType>();
-        if (strides.empty() || !chunk_sel_type->sizes().isComplete()) {
-          input_chunk_sel->setType(chunk_sel_type);
-        } else {
-          input_chunk_sel->setType(chunk_sel_type->withSizesStrides(
-              chunk_sel_type->sizes().concrete_sizes().value(), strides));
-        }
+        input_chunk_sel->setType(chunk_sel->type());
         chunked_inputs.back().push_back(input_chunk_sel);
       }
     }
@@ -705,7 +583,6 @@ struct CudaGraphFuser {
     bchunk->removeInput(producer_index);
     // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
     for (const auto i : c10::irange(nchunks)) {
-      (void)i; // Suppress unused variable warning
       bchunk->eraseOutput(nchunks * producer_index);
     }
 
@@ -739,7 +616,7 @@ struct CudaGraphFuser {
 
   // returns where to continue scanning, and whether any fusion was made
   std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
-    if (fuser::cuda::isFusibleCudaFusionGroup(consumer)) {
+    if (fuser::cuda::isFusableCudaFusionGroup(consumer)) {
       // handle inputs in reverse topological order as well...
       // otherwise in f(a,a+b) it will appear a is used twice if we consider
       // the f-a fusion before the f-(a+b) fusion first.
@@ -795,18 +672,10 @@ struct CudaGraphFuser {
     }
   }
 
-  bool usedInDtype(Value* v) {
-    const auto& uses = v->uses();
-    return std::any_of(uses.begin(), uses.end(), [](const Use& u) {
-      return u.user->matches("prim::dtype(Tensor a) -> int");
-    });
-  }
-
-  bool usedOnlyInDtypeAndSize(Value* v) {
+  bool usedOnlyInSize(Value* v) {
     const auto& uses = v->uses();
     return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
-      return u.user->matches("prim::dtype(Tensor a) -> int") ||
-          u.user->matches("aten::size(Tensor self) -> int[]");
+      return u.user->matches("aten::size(Tensor self) -> int[]");
     });
   }
 
@@ -836,12 +705,10 @@ struct CudaGraphFuser {
     auto outputs = fusion_group->outputs();
     auto soutputs = subgraph->outputs();
     AT_ASSERT(outputs.size() == soutputs.size());
-    for (size_t i = 0; i < outputs.size(); ++i) {
-      if (usedOnlyInDtypeAndSize(outputs[i]))
+    for (const auto i : c10::irange(outputs.size())) {
+      if (usedOnlyInSize(outputs[i]))
         continue;
-      if (soutputs[i]->type()->isSubtypeOf(TensorType::get())) {
-        shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]});
-      }
+      shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]});
     }
 
     for (Node* n : subgraph->nodes()) {
@@ -858,9 +725,6 @@ struct CudaGraphFuser {
         continue;
       }
       if (n->kind() == prim::ConstantChunk) {
-        TORCH_INTERNAL_ASSERT(
-            shape_of.count(n->input()) > 0,
-            "buildShapeExpressions failed at accessing input shapes");
         Node* sizes_node = graph->insertNode(
             graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
         sizes_node->i_(attr::dim, n->i(attr::dim));
@@ -888,28 +752,17 @@ struct CudaGraphFuser {
             "only supports reduction axes and keepdim being constant");
 
         // hmmm, do I need to setInsertPoint...
-        const auto map_inputs = [&](Value* v) -> Value* {
-          // if constant ever has an input, it has to come from
-          // profile_ivalue dependency
-          if (v->node()->kind() == prim::Param &&
-              fusion_group->input(v->offset())->node()->kind() ==
-                  prim::profile_ivalue) {
-            // we need to map it along profile_ivalue dependency
-            return fusion_group->input(v->offset());
-          } else {
-            throw std::runtime_error(
-                std::string("unexpected input from node") +
-                v->node()->kind().toDisplayString());
-          }
-        };
-        Node* in1_const = graph->createClone(n->input(1)->node(), map_inputs);
+        Node* in1_const =
+            graph->createClone(n->input(1)->node(), [](Value*) -> Value* {
+              throw std::runtime_error("unexpected input");
+            });
         graph->insertNode(in1_const);
-        Node* in2_const = graph->createClone(n->input(2)->node(), map_inputs);
+        Node* in2_const =
+            graph->createClone(n->input(2)->node(), [](Value*) -> Value* {
+              throw std::runtime_error("unexpected input");
+            });
         graph->insertNode(in2_const);
 
-        TORCH_INTERNAL_ASSERT(
-            shape_of.count(n->input(0)) > 0,
-            "buildShapeExpressions failed at accessing input shapes");
         std::vector<Value*> inputs = {
             shape_of.at(n->input(0)), in1_const->output(), in2_const->output()};
         Node* size_node =
@@ -919,62 +772,14 @@ struct CudaGraphFuser {
         shape_of.emplace(n->output(), size);
         continue;
       }
-      // TODO: output(1) & output(2) should also be marked
-      if (n->kind() == aten::native_layer_norm) {
-        TORCH_INTERNAL_ASSERT(
-            shape_of.count(n->input(0)) > 0,
-            "buildShapeExpressions failed at accessing input shapes");
-        shape_of.emplace(n->output(0), shape_of.at(n->input(0)));
-        continue;
-      }
-      // TODO: output(1) & output(2) should also be marked
-      if (n->kind() == aten::native_layer_norm_backward) {
-        TORCH_INTERNAL_ASSERT(
-            shape_of.count(n->input(0)) > 0,
-            "buildShapeExpressions failed at accessing input shapes");
-        shape_of.emplace(n->output(0), shape_of.at(n->input(0)));
-        if (shape_of.count(n->input(5)) > 0) {
-          shape_of.emplace(n->output(1), shape_of.at(n->input(5)));
-        }
-        if (shape_of.count(n->input(6)) > 0) {
-          shape_of.emplace(n->output(2), shape_of.at(n->input(6)));
-        }
-        continue;
-      }
-      // TODO: output(1) & output(2) should also be marked
-      if (n->kind() == aten::native_batch_norm) {
-        TORCH_INTERNAL_ASSERT(
-            shape_of.count(n->input(0)) > 0,
-            "buildShapeExpressions failed at accessing input shapes");
-        shape_of.emplace(n->output(0), shape_of.at(n->input(0)));
-        continue;
-      }
-      // TODO: output(1) & output(2) should also be marked
-      if (n->kind() == aten::native_batch_norm_backward) {
-        TORCH_INTERNAL_ASSERT(
-            shape_of.count(n->input(0)) > 0,
-            "buildShapeExpressions failed at accessing input shapes");
-        shape_of.emplace(n->output(0), shape_of.at(n->input(0)));
-        if (shape_of.count(n->input(2)) > 0) {
-          shape_of.emplace(n->output(1), shape_of.at(n->input(2)));
-          // use shape of weight here for grad_bias
-          shape_of.emplace(n->output(2), shape_of.at(n->input(2)));
-        }
-        continue;
-      }
       auto tensor_inputs = filter(n->inputs(), [](Value* v) {
         return v->type()->isSubtypeOf(TensorType::get());
       });
-      auto shapes = fmap(tensor_inputs, [&](Value* v) {
-        TORCH_INTERNAL_ASSERT(
-            shape_of.count(v) > 0,
-            "buildShapeExpressions failed at accessing input shapes");
-        return shape_of.at(v);
-      });
+      auto shapes =
+          fmap(tensor_inputs, [&](Value* v) { return shape_of.at(v); });
       AT_ASSERT(!shapes.empty());
       shape_of.emplace(
-          n->output(0),
-          shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
+          n->output(), shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
     }
     return shape_of;
   }
@@ -986,9 +791,7 @@ struct CudaGraphFuser {
 
     // TODO: failure in buildShapeExpressions should not break fusion execution,
     // we can add a try/catch here to bailout from removeOutputsUsedOnlyInSize.
-    GRAPH_DEBUG("before build shape expression: ", *graph_);
     auto shape_of = buildShapeExpressions(fusion_group);
-    GRAPH_DEBUG("after build shape expression: ", *graph_);
     auto outputs = fusion_group->outputs().vec();
     auto soutputs = subgraph->outputs().vec();
     // XXX: Iterating in this order is not only good for performance reasons!
@@ -997,30 +800,17 @@ struct CudaGraphFuser {
     for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
       auto output = outputs[i];
       auto soutput = soutputs[i];
-      if (usedOnlyInDtypeAndSize(output) && shape_of.count(soutput) > 0) {
-        bool has_dtype = usedInDtype(output);
+      if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
         auto uses = output->uses();
         for (Use u : uses) {
-          if (u.user->matches("aten::size(Tensor self) -> int[]")) {
-            u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
-            u.user->destroy();
-          } else if (u.user->matches("prim::dtype(Tensor a) -> int")) {
-            continue;
-          } else {
-            AT_ASSERT(
-                false,
-                "unrecognized consumer should not trigger removeOutputsUsedOnlyInSize");
-          }
-        }
-        // We only wipe the output when there's no more dtype consumer.
-        // This is to be removed by `removeOutputUsedOnlyInDtype`
-        if (!has_dtype) {
-          fusion_group->eraseOutput(i);
-          subgraph->eraseOutput(i);
+          AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]"));
+          u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
+          u.user->destroy();
         }
+        fusion_group->eraseOutput(i);
+        subgraph->eraseOutput(i);
       }
     }
-    GRAPH_DEBUG("after build shape expression and re-wiring: ", *graph_);
   }
 
   void refreshAliasDb() {
@@ -1061,13 +851,12 @@ struct CudaGraphFuser {
       any_changed = false;
       refreshAliasDb();
       for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
-        bool changed = false;
+        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+        bool changed;
         std::tie(it, changed) = scanNode(*it);
         any_changed |= changed;
       }
     }
-
-    GRAPH_DEBUG("after scan and merge", *graph_);
     refreshAliasDb();
 
     // fuseConcats();
@@ -1083,12 +872,10 @@ struct CudaGraphFuser {
     //  it = scanNodeForChunks(*it);
     //}
 
-    GRAPH_DEBUG("before removeOutputsUsedOnlyInSize", *graph_);
     // Remove outputs that have been added only because we need their size
     for (Node* n : block_->nodes()) {
       removeOutputsUsedOnlyInSize(n);
     }
-    GRAPH_DEBUG("after removeOutputsUsedOnlyInSize", *graph_);
 
     for (Node* node : block_->nodes()) {
       for (Block* sub_block : node->blocks()) {
@@ -1098,75 +885,6 @@ struct CudaGraphFuser {
   }
 };
 
-void removeCudaFusionPathForGuardNode(Node* n) {
-  auto uses = n->output()->uses();
-  TORCH_INTERNAL_ASSERT(
-      uses.size() == 1,
-      "CudaFusionGuard should only be used by a single prim::If");
-  Node* if_node = uses[0].user;
-  TORCH_INTERNAL_ASSERT(
-      if_node->kind() == prim::If,
-      "CudaFusionGuard should only be used by prim::If");
-  auto fall_back_graph = if_node->blocks()[1];
-  Node* fallback_node = nullptr;
-  for (auto fb_n : fall_back_graph->nodes()) {
-    TORCH_INTERNAL_ASSERT(
-        fb_n->kind() == prim::FallbackGraph,
-        "CudaFusionGuard fallback path should only have single fallback node");
-    TORCH_INTERNAL_ASSERT(
-        fallback_node == nullptr,
-        "CudaFusionGuard fallback path should only have single fallback node");
-    fallback_node = fb_n;
-  }
-
-  TORCH_INTERNAL_ASSERT(
-      fallback_node != nullptr,
-      "CudaFusionGuard fallback path found no fallback node");
-  fallback_node->moveBefore(n);
-
-  TORCH_INTERNAL_ASSERT(
-      fallback_node->outputs().size() == if_node->outputs().size(),
-      "CudaFusionGuard fallback should have same number of outputs as with nesting if block");
-
-  if_node->replaceAllUsesWith(fallback_node);
-  if_node->destroy();
-  n->destroy();
-}
-
-bool missingCompleteTypes(const std::vector<TypePtr>& types) {
-  for (const auto& type : types) {
-    if (auto tensor_type = type->cast<TensorType>()) {
-      // if we found one missing value, we know that we are not going to able to
-      // generate a kernel, so we bail out;
-      if (!tensor_type->device().has_value() ||
-          !tensor_type->dim().has_value() ||
-          !tensor_type->scalarType().has_value()) {
-        return true;
-      }
-    }
-  }
-  return false;
-}
-
-void removeFusionWithMissingProfilingInformation(Block* block) {
-  FUSER_PERF_SCOPE("compileFusionRecursive");
-  std::vector<Node*> removeCudaFusionNodes;
-
-  for (auto node : block->nodes()) {
-    if (node->kind() == prim::CudaFusionGuard &&
-        missingCompleteTypes(node->tys(attr::types))) {
-      removeCudaFusionNodes.push_back(node);
-    }
-    for (auto sub_block : node->blocks()) {
-      removeFusionWithMissingProfilingInformation(sub_block);
-    }
-  }
-
-  for (auto node : removeCudaFusionNodes) {
-    removeCudaFusionPathForGuardNode(node);
-  }
-}
-
 void compileFusionRecursive(Block* block) {
   FUSER_PERF_SCOPE("compileFusionRecursive");
 
@@ -1275,35 +993,41 @@ void PeepholeOptimizeShapeExpressions(Block* block) {
 void guardFusionGroup(Node* fusion) {
   // Fixup types of the subgraph inputs
   std::vector<TypePtr> guard_types;
-  std::vector<Value*> tensor_inputs_to_check;
-  std::set<size_t> profiled_ivalue_indices;
-
-  for (size_t index = 0; index < fusion->inputs().size(); index++) {
-    Value* input = fusion->inputs()[index];
-    if (input->type()->cast<TensorType>()) {
-      // We only check inputs of the fusion group and expect NNC to infer
-      // intermediates and outputs shapes
-
-      // note: modified from original implementation, we are guarding fusion
-      //       outputs
-      if (input->node()->kind() == prim::Constant) {
-        continue;
-      }
-      tensor_inputs_to_check.push_back(input);
-      guard_types.push_back(input->type());
-    } else if (input->node()->kind() == prim::profile_ivalue) {
-      // Conditional constant from profiled_ivalue, should be guarded
-      profiled_ivalue_indices.insert(index);
+  std::vector<Value*> inputs_to_check;
+  for (Value* input : fusion->inputs()) {
+    // We only check inputs of the fusion group and expect NNC to infer
+    // intermediates and outputs shapes
+    if (!input->type()->cast<TensorType>()) {
+      continue;
     }
+
+    // note: modified from original implementation, we are guarding fusion
+    //       outputs
+    if (input->node()->kind() == prim::Constant) {
+      continue;
+    }
+    inputs_to_check.push_back(input);
+    guard_types.push_back(input->type());
+  }
+  if (!inputs_to_check.size()) {
+    return;
   }
-  // we should assert on non-tensor inputs
-  TORCH_INTERNAL_ASSERT(
-      tensor_inputs_to_check.size(),
-      "CudaFusionGuard expects at least one tensor input");
 
-  // insert the if block first;
+  Node* typecheck_node = fusion->owningGraph()
+                             ->create(prim::CudaFusionGuard, inputs_to_check, 1)
+                             ->insertBefore(fusion);
+  // fix output to BoolType
+  typecheck_node->output()->setType(BoolType::get());
+  Value* typecheck_result = typecheck_node->output();
+  typecheck_node->tys_(attr::types, guard_types);
+
+  std::unordered_map<Value*, Value*> typechecked_inputs;
+
+  // Insert if block
   auto versioning_if =
-      fusion->owningGraph()->create(prim::If, fusion->outputs().size());
+      fusion->owningGraph()
+          ->create(prim::If, {typecheck_result}, fusion->outputs().size())
+          ->insertAfter(typecheck_node);
   for (size_t idx = 0; idx < fusion->outputs().size(); ++idx) {
     versioning_if->output(idx)->setType(fusion->output(idx)->type());
     fusion->output(idx)->replaceAllUsesWith(versioning_if->output(idx));
@@ -1311,160 +1035,22 @@ void guardFusionGroup(Node* fusion) {
   auto true_block = versioning_if->addBlock();
   auto false_block = versioning_if->addBlock();
 
-  // insert typecheck_node;
-  Node* typecheck_node =
-      fusion->owningGraph()
-          ->create(prim::CudaFusionGuard, tensor_inputs_to_check, 1)
-          ->insertBefore(fusion);
-  // fix output to BoolType
-  typecheck_node->output()->setType(BoolType::get());
-  Value* typecheck_result = typecheck_node->output();
-  typecheck_node->tys_(attr::types, guard_types);
-
-  versioning_if->insertAfter(typecheck_node);
-
   // Fill in the false block. It should contain the unoptimized
-  // copy of the fused subgraph, unless we have conditional constants from
-  // profiled_ivalue;
-  auto fusion_graph = fusion->g(attr::Subgraph);
-  std::shared_ptr<Graph> fb_graph; // resource holder;
-  // Restore the dependency for constant introduced by profiled_ivalue within
-  // the graph.
-  if (!profiled_ivalue_indices.empty()) {
-    // This is necessary as it cleans up the fallback graph, which was copied
-    // from subgraph, since the two graph would differ as we cannot use
-    // conditional constant in fallback
-
-    // 1. RESTORE conditional constant dependency in fallback group;
-    fb_graph = fusion_graph->copy();
-    GRAPH_DEBUG("re-wiring fallback graph", *fb_graph);
-
-    for (const auto& offset : profiled_ivalue_indices) {
-      auto val = fb_graph->inputs()[offset];
-      auto uses = val->uses();
-      // since we are updating use of val in the loop, we have to copy
-      // val->uses() before hand.
-      for (const auto& use : uses) {
-        // re-wire inputs and remove conditional constant nodes;
-        TORCH_INTERNAL_ASSERT(
-            use.user->kind() == prim::Constant,
-            "profile_ivalue at index: ",
-            offset,
-            " can only be used by conditional constant, instead got: ",
-            use.user->kind().toDisplayString());
-        use.user->output()->replaceAllUsesWith(val);
-        use.user->destroy();
-      }
-    }
-
-    WithInsertPoint guard(false_block->return_node());
-    const auto subgraph_outputs =
-        insertGraph(*fusion->owningGraph(), *fb_graph, fusion->inputs());
-    for (Value* output : subgraph_outputs) {
-      false_block->registerOutput(output);
-    }
-    // types get copied to the fallback graph, so remove specializations before
-    // replacing
-    // TODO: this is not exposed here, I need to remove that before inserting
-    // the graph
-    // removeTensorTypeSpecializations(false_block);
-    replaceBlockWithFallbackGraph(false_block, fusion->inputs());
-
-    // 2. REMOVE conditional constant dependency in fusion group
-    size_t compensation = 0;
-
-    // get a constant false, which is used by `and` pattern later
-    auto const_true = fusion->owningGraph()->insertConstant(IValue(true));
-    const_true->node()->moveBefore(versioning_if);
-
-    for (const auto& original_offset : profiled_ivalue_indices) {
-      size_t offset = original_offset - compensation;
-
-      // step a. handle fusion
-      // remove inputs to fusion, and update check logic for fallback
-      auto profiled_ival = fusion->input(offset)->node()->input();
-      auto const_o = createConditionalConstant(fusion->input(offset)->node());
-      TORCH_INTERNAL_ASSERT(
-          const_o,
-          "profile_ivalue node are expected to have profile information, at node: ",
-          *fusion->input(offset)->node());
-      const_o->node()->moveBefore(versioning_if);
-      Value* ivalue_check = nullptr;
-
-      if (fusion->input(offset)->node()->hasAttribute(
-              Symbol::attr("profiled_bool"))) {
-        // aten::eq doesn't support comparison between two boolean
-        auto xor_n = fusion->owningGraph()
-                         ->create(aten::__xor__, {profiled_ival, const_o}, 1)
-                         ->insertBefore(versioning_if);
-        xor_n->output()->setType(BoolType::get());
-        ivalue_check =
-            fusion->owningGraph()
-                ->create(aten::__xor__, {xor_n->output(), const_true}, 1)
-                ->insertBefore(versioning_if)
-                ->output();
-      } else if (fusion->input(offset)->node()->hasAttribute(
-                     Symbol::attr("profiled_size"))) {
-        // TODO(profile_size): check sizes here with special size comparison op
-        // TORCH_INTERNAL_ASSERT(false, "not implemented yet");
-        ivalue_check =
-            fusion->owningGraph()
-                ->create(
-                    c10::Symbol::fromQualString("prim::CudaFusionSizeEq"),
-                    {profiled_ival, const_o},
-                    1)
-                ->insertBefore(versioning_if)
-                ->output();
-      } else {
-        ivalue_check = fusion->owningGraph()
-                           ->create(aten::eq, {profiled_ival, const_o}, 1)
-                           ->insertBefore(versioning_if)
-                           ->output();
-      }
-      ivalue_check->setType(BoolType::get());
-
-      typecheck_result =
-          fusion->owningGraph()
-              ->create(aten::__and__, {ivalue_check, typecheck_result}, 1)
-              ->insertBefore(versioning_if)
-              ->output();
-      typecheck_result->setType(BoolType::get());
-
-      // remove inputs to fusion;
-      fusion->removeInput(offset);
-
-      // step b. remove the extra dependency inside fusion;
-      for (const auto& use : fusion_graph->inputs()[offset]->uses()) {
-        TORCH_INTERNAL_ASSERT(
-            use.user->kind() == prim::Constant,
-            "profile_ivalue at index: ",
-            offset,
-            " can only be used by conditional constant, instead got: ",
-            use.user->kind().toDisplayString());
-        use.user->removeAllInputs();
-      }
-      fusion_graph->eraseInput(offset);
-      compensation++;
-    }
-    // update graph in fusion node
-    fusion->g_(attr::Subgraph, fusion_graph);
-  } else {
-    WithInsertPoint guard(false_block->return_node());
-    const auto subgraph_outputs =
-        insertGraph(*fusion->owningGraph(), *fusion_graph, fusion->inputs());
-    for (Value* output : subgraph_outputs) {
-      false_block->registerOutput(output);
-    }
-    // types get copied to the fallback graph, so remove specializations before
-    // replacing
-    // TODO: this is not exposed here, I need to remove that before inserting
-    // the graph
-    // removeTensorTypeSpecializations(false_block);
-    replaceBlockWithFallbackGraph(false_block, fusion->inputs());
+  // copy of the fused subgraph.
+  auto& subgraph = *fusion->g(attr::Subgraph);
+  WithInsertPoint guard(false_block->return_node());
+  const auto subgraph_outputs =
+      insertGraph(*fusion->owningGraph(), subgraph, fusion->inputs());
+  for (Value* output : subgraph_outputs) {
+    false_block->registerOutput(output);
   }
 
-  // wiring up if block
-  versioning_if->addInput(typecheck_result);
+  // types get copied to the fallback graph, so remove specializations before
+  // replacing
+  // TODO: this is not exposed here, I need to remove that before inserting the
+  //       graph
+  // removeTensorTypeSpecializations(false_block);
+  replaceBlockWithFallbackGraph(false_block, fusion->inputs());
 
   // Fill in the true block. It has all inputs type-checked and its
   // body should be the fusion group node.
@@ -1485,406 +1071,20 @@ void guardFusionGroups(Block* block) {
     }
   }
   for (Node* fusion : fusions) {
-    // step 1: a. add prim::CudaFusionGuard and fallback logic
-    //         b. insert guard logic of profile_ivalue with if block
-    //         c. restore conditional constant to non-constant for fallback
     guardFusionGroup(fusion);
   }
 }
 
-// rewire const integer index & empty byte-typed reserve space tensor outputs,
-// so `CudaFusionGroup` doesn't have to handle those
-void alterBatchNormImplIndex(Node* node) {
-  std::set<size_t> bn_index_out_indices;
-  std::set<size_t> bn_buffer_out_indices;
-
-  auto subgraph = node->g(attr::Subgraph);
-  for (size_t i = 0; i < subgraph->outputs().size(); i++) {
-    auto val = subgraph->outputs()[i];
-    if (val->node()->kind() == aten::_batch_norm_impl_index &&
-        val->offset() == 4) {
-      bn_index_out_indices.emplace(i);
-    } else if (
-        val->node()->kind() == aten::_batch_norm_impl_index &&
-        val->offset() == 3) {
-      bn_buffer_out_indices.emplace(i);
-    }
-  }
-
-  if (!bn_index_out_indices.empty()) {
-    // we output index to 0 so backwards go through native_batch_norm, which is
-    // what we support;
-    auto const_1 = node->owningGraph()->insertConstant(IValue(0));
-    const_1->node()->moveBefore(node);
-    for (auto i : bn_index_out_indices) {
-      node->outputs()[i]->replaceAllUsesWith(const_1);
-    }
-  }
-
-  if (!bn_buffer_out_indices.empty()) {
-    auto graph = node->owningGraph();
-    std::vector<int64_t> sizes{0}; // empty tensor with no size;
-    // std::vector<int64_t> sizes; // empty tensor with no size;
-    auto const_size_0 = node->owningGraph()->insertConstant(IValue(sizes));
-    const_size_0->node()->moveBefore(node);
-    auto const_0 = node->owningGraph()->insertConstant(IValue(0));
-    const_0->node()->moveBefore(node);
-    auto none_val = node->owningGraph()->insertConstant(IValue());
-    none_val->node()->moveBefore(node);
-    auto device =
-        graph->insertNode(graph->create(prim::device, {node->inputs()[0]}, 1));
-    device->moveBefore(node);
-    device->output()->setType(DeviceObjType::get());
-    auto empty_tensor = graph->insertNode(graph->create(
-        aten::empty,
-        {const_size_0, const_0, none_val, device->output(), none_val, none_val},
-        1));
-    empty_tensor->moveBefore(node);
-    for (auto i : bn_buffer_out_indices) {
-      node->outputs()[i]->replaceAllUsesWith(empty_tensor->output());
-    }
-  }
-
-  bn_index_out_indices.insert(
-      bn_buffer_out_indices.begin(), bn_buffer_out_indices.end());
-  for (auto iter = bn_index_out_indices.crbegin();
-       iter != bn_index_out_indices.crend();
-       ++iter) {
-    subgraph->eraseOutput(*iter);
-    node->eraseOutput(*iter);
-  }
-}
-
-// rewire empty byte-typed reserve space tensor input to an empty float-typed
-// tensor, because `CudaFusionGroup` doesn't support byte-typed tensor, nor does
-// it use reserve space.
-void alterBatchNormImplIndexBackward(Node* node) {
-  std::set<size_t> bn_buffer_in_indices;
-
-  auto subgraph = node->g(attr::Subgraph);
-  for (auto n : subgraph->nodes()) {
-    if (n->kind() == aten::_batch_norm_impl_index_backward) {
-      // 11th inputs are `reserve`, which is not used by codegen kernel and its
-      // type is not supported `Byte`. So we disconnect it here to avoid codegen
-      // error
-      auto byte_input = n->inputs()[11];
-      // TODO: let's check the data type for buffer and skip if it's good
-      // TODO: we can actually support it by adding an extra inputs to the
-      // subgraph
-      // TODO: assert on empty buffer
-      TORCH_INTERNAL_ASSERT(
-          byte_input->node() == subgraph->param_node(),
-          "Assumption that reserve input to aten::_batch_norm_impl_index_backward comes from forward graph is broken");
-      bn_buffer_in_indices.emplace(byte_input->offset());
-    }
-  }
-
-  if (!bn_buffer_in_indices.empty()) {
-    auto graph = node->owningGraph();
-    std::vector<int64_t> sizes{0}; // empty tensor with no size;
-    // std::vector<int64_t> sizes{}; // empty tensor with no size;
-    auto const_size_0 = node->owningGraph()->insertConstant(IValue(sizes));
-    const_size_0->node()->moveBefore(node);
-    auto const_0 = node->owningGraph()->insertConstant(IValue(6));
-    const_0->node()->moveBefore(node);
-    auto none_val = node->owningGraph()->insertConstant(IValue());
-    none_val->node()->moveBefore(node);
-    auto device =
-        graph->insertNode(graph->create(prim::device, {node->inputs()[1]}, 1));
-    device->moveBefore(node);
-    device->output()->setType(DeviceObjType::get());
-    auto empty_tensor = graph->insertNode(graph->create(
-        aten::empty,
-        {const_size_0, const_0, none_val, device->output(), none_val, none_val},
-        1));
-    empty_tensor->moveBefore(node);
-
-    for (auto iter = bn_buffer_in_indices.begin();
-         iter != bn_buffer_in_indices.end();
-         ++iter) {
-      subgraph->inputs()[*iter]->setType(
-          node->inputs()[*iter]->type()->cast<TensorType>()->withScalarType(
-              at::ScalarType::Float));
-      node->replaceInput(*iter, empty_tensor->output());
-    }
-  }
-}
-
-void alterBatchNormImpls(Block* block) {
-  std::vector<Node*> fusions;
-  for (Node* n : block->nodes()) {
-    for (Block* b : n->blocks()) {
-      alterBatchNormImpls(b);
-    }
-    if (n->kind() == prim::CudaFusionGroup) {
-      fusions.push_back(n);
-    }
-  }
-  for (Node* fusion : fusions) {
-    // remove index & reserve from outputs;
-    alterBatchNormImplIndex(fusion);
-    // remove reserve from inputs;
-    alterBatchNormImplIndexBackward(fusion);
-  }
-}
-
-// We absorb `prim::dtype` node into CudaFusion structure. The structure below
-//
-// %1 = prim::CudaFusionGuard(...)
-// %2, %3 = prim::If(...)
-//   block0():
-//     %4, %5 = prim::CudaFusionGroup(...)
-//     -> (%4, %5)
-//   block1():
-//     %6, %7 = prim::FallbackGraph(...)
-//     -> (%6, %7)
-// %4 = prim::dtype(%3)
-//   ... (uses %2, %4, but never reference to %3 any more)
-//
-// is updated to:
-//
-// %1 = prim::CudaFusionGuard(...)
-// %2, %3 = prim::If(...)
-//   block0():
-//     %4 = prim::CudaFusionGroup(...)  # %5 is also removed from subgraph
-//     %8 = prim::Constant[value=...]()
-//     -> (%4, %8)
-//   block1():
-//     %6, %7 = prim::FallbackGraph(...)
-//     %9 = prim::dtype(%7)
-//     -> (%6, %9)
-// # %4 = prim::dtype(%3) is removed. All reference to %4 is replaced with %3
-//   ... (uses %2, %4, but never reference to %3 any more)
-void removeOutputUsedOnlyInDtype(Node* fusion_node) {
-  auto fusion_block = fusion_node->owningBlock();
-  TORCH_INTERNAL_ASSERT(
-      fusion_block->owningNode() &&
-          fusion_block->owningNode()->kind() == prim::If,
-      "CudaFusionGroup should be inside `prim::CudaFusionGuard` / `prim::If`");
-
-  auto if_node = fusion_block->owningNode();
-  auto fusion_node_graph = fusion_node->g(attr::Subgraph);
-  auto fallback_block = if_node->blocks()[1];
-
-  bool updated = false;
-  // Iterating in this order is crucial for correctness (i has to reflect the
-  // current true index of outputs[i])!
-  for (int64_t i = static_cast<int64_t>(if_node->outputs().size()) - 1; i >= 0;
-       --i) {
-    auto output = if_node->outputs()[i];
-    // output only used in dtype, we eliminate the output and rely on
-    // profiled/static scalar type inference to save on memory IO.
-    if (usedOnlyInDtype(output)) {
-      updated = true;
-      {
-        // update fusion_block to output profiled scalar type
-        auto fusion_output = fusion_block->outputs()[i];
-        auto tensor_type = fusion_output->type()->cast<TensorType>();
-        TORCH_INTERNAL_ASSERT(
-            tensor_type, "non tensor fed to dtype is not supported");
-        auto scalar_type = tensor_type->scalarType();
-        TORCH_INTERNAL_ASSERT(
-            scalar_type.has_value(),
-            "ScalarType should be static for Tensors in fusion for amp optimization");
-        auto type_const =
-            fusion_block->owningGraph()->insertConstant(IValue(scalar_type));
-        type_const->setType(IntType::get());
-        type_const->node()->moveBefore(fusion_block->return_node());
-        fusion_block->replaceOutput(i, type_const);
-
-        // remove the dangling output tensor in CudaFusionGroup
-        fusion_node->eraseOutput(i);
-        fusion_node_graph->eraseOutput(i);
-      }
-
-      {
-        // update fallback_block to output dtype instead of tensor
-        auto tensor_output = fallback_block->outputs()[i];
-        auto dtype_node = fallback_block->owningGraph()->create(
-            prim::dtype, tensor_output, 1);
-        dtype_node->output()->setType(IntType::get());
-        fallback_block->appendNode(dtype_node);
-        fallback_block->replaceOutput(i, dtype_node->output());
-      }
-
-      // we just shot-cut the `dtype` node since we are already outputing dtype
-      auto uses = output->uses();
-      for (Use u : uses) {
-        AT_ASSERT(u.user->matches("prim::dtype(Tensor a) -> int"));
-        u.user->output()->replaceAllUsesWith(output);
-        u.user->destroy();
-      }
-      output->setType(IntType::get());
-    }
-  }
-
-  if (updated) {
-    fusion_node->g_(attr::Subgraph, fusion_node_graph);
-  }
-}
-
-// For output tensors in fusion group that is only used by dtype node, with
-// CudaFusionGuard, we can short-cut it with constant dtype directly instead to
-// save IO memory bandwidth.
-// The reason that we do it after we insert the guard, instead of doing it along
-// during graph fusion/partitioning, is that we needed to handle the fallback
-// differently, since fallback is not inside CudaFusionGuard, and hence doesn't
-// have the dtype as a constant.
-void removeOutputUsedOnlyInDtype(Block* block) {
-  std::vector<Node*> fusions;
-  for (Node* n : block->nodes()) {
-    for (Block* b : n->blocks()) {
-      removeOutputUsedOnlyInDtype(b);
-    }
-    if (n->kind() == prim::CudaFusionGroup) {
-      fusions.push_back(n);
-    }
-  }
-  for (Node* fusion : fusions) {
-    // remove index & reserve from outputs;
-    removeOutputUsedOnlyInDtype(fusion);
-  }
-}
-
-void RemoveProfileIValue(Node* profile_ivalue) {
-  for (const auto& use : profile_ivalue->output()->uses()) {
-    if (use.user->kind() == prim::Constant) {
-      use.user->output()->replaceAllUsesWith(profile_ivalue->input());
-      use.user->destroy();
-    }
-  }
-  profile_ivalue->output()->replaceAllUsesWith(profile_ivalue->input());
-  profile_ivalue->destroy();
-}
-
-void ExtractProfileIValue(Node* profile_ivalue) {
-  auto const_o = createConditionalConstant(profile_ivalue);
-  if (const_o) {
-    auto const_n = const_o->node();
-    const_n->moveAfter(profile_ivalue);
-    profile_ivalue->output()->replaceAllUsesAfterNodeWith(const_n, const_o);
-    // special wiring, we add this input to constant simply in order to create
-    // dependency, which we can trace and remove later;
-    const_n->addInput(profile_ivalue->output());
-  } else {
-    // no profile value available, remove profile_ivalue node;
-    RemoveProfileIValue(profile_ivalue);
-  }
-}
-
-void traverseProfileIValues(
-    Block* block,
-    const std::function<void(Node*)>& func) {
-  std::vector<Node*> profile_ivalues;
-  for (Node* n : block->nodes()) {
-    for (Block* b : n->blocks()) {
-      traverseProfileIValues(b, func);
-    }
-    if (n->kind() == prim::profile_ivalue) {
-      profile_ivalues.push_back(n);
-    }
-  }
-  for (Node* profile_ivalue : profile_ivalues) {
-    func(profile_ivalue);
-  }
-}
-
-// break `linear` layer into `matmul` and `add_optional`. This allows us to fuse
-// the binary operation without supporting gemm.
-// Note that we are not breaking `linear` layer without bias.
-void decomposeLinearOps(Block* block) {
-  std::vector<Node*> linear_nodes;
-  for (Node* n : block->nodes()) {
-    for (Block* b : n->blocks()) {
-      decomposeLinearOps(b);
-    }
-    // only decompose `linear` layer with bias.
-    if (n->kind() == aten::linear &&
-        !n->input(2)->type()->isSubtypeOf(
-            static_cast<c10::TypePtr>(NoneType::get()))) {
-      linear_nodes.push_back(n);
-    }
-  }
-
-  auto graph = block->owningGraph();
-  for (Node* n : linear_nodes) {
-    WithInsertPoint guard(n);
-    auto weight_t = graph->insertNode(graph->create(aten::t, {n->input(1)}, 1));
-    auto matmul = graph->insertNode(
-        graph->create(aten::matmul, {n->input(0), weight_t->output()}, 1));
-    auto input_tensor_type = n->input(0)->type()->cast<c10::TensorType>();
-    auto mat0_size = input_tensor_type->sizes().concrete_sizes();
-    auto mat1_size =
-        n->input(1)->type()->cast<c10::TensorType>()->sizes().concrete_sizes();
-
-    // TODO: The assert is not necessary when we can handle matmul, right now we
-    // are splitting the linear between matmul & bias_add. Our fuser can only
-    // take the second half and we would need the size information.
-    TORCH_INTERNAL_ASSERT(
-        mat0_size.has_value() && mat1_size.has_value(),
-        "concrete shape for linear input & weight are required");
-    auto out_size = mat0_size.value();
-    out_size[out_size.size() - 1] = mat1_size.value()[0];
-    matmul->output()->setType(input_tensor_type->withSizes(out_size));
-
-    // TODO: memory stride should be considered here, our inference above is not
-    // safe.
-    auto bias = graph->insertNode(
-        graph->create(prim::add_optional, {matmul->output(0), n->input(2)}, 1));
-    bias->output()->setType(matmul->output(0)->type());
-
-    n->output()->replaceAllUsesWith(bias->output());
-    n->destroy();
-  }
-}
-
 } // anonymous namespace
 
 void CudaFuseGraph(std::shared_ptr<Graph>& graph) {
-  FUSER_PERF_SCOPE("nvFuser::Manager::CudaFuseGraph");
-  GRAPH_DUMP("Before Fusion: ", graph);
-
-  // TODO: extract & guard profile_ivalue; but how do we restore it???
-  // I don't know how to store edge/node in attribute. so let's abuse data flow
-  // dependency and add inputs to conditional constant generated by
-  // aten::profile_ivalue
-  traverseProfileIValues(graph->block(), ExtractProfileIValue);
-  GRAPH_DEBUG("insert conditional constant from profile_ivalue: ", *graph);
-
+  FUSER_PERF_SCOPE("CudaFuseGraph");
   // TODO: we need to properly restore shape information after fusion.
   // shamelessly use tool from NNC.
   RemoveProfileNodesAndSpecializeTypes(graph);
-  GRAPH_DEBUG("After Profiling Nodes Removed: ", *graph);
-
-  // TODO: separate passes into different file;
-  // TODO: restore decomposition after fusion, in case we are decomposing
-  //       operation that can't be fused;
-  decomposeLinearOps(graph->block());
-  GRAPH_DEBUG("decompose operations by nvfuser: ", *graph);
 
   CudaGraphFuser(graph->block(), graph).run();
-  GRAPH_DEBUG("After Fusion: ", *graph);
-
-  // guard input types as well as conditional constants from
-  // aten::profile_ivalue
   guardFusionGroups(graph->block());
-  GRAPH_DEBUG("After Guard Fusion: ", *graph);
-
-  // mutate `aten::_batch_norm_impl_index` and
-  // `aten::_batch_norm_impl_index_backward` node in the fusion group to WAR
-  // the lack of fusion support on integer output as well as byte-typed tensor.
-  alterBatchNormImpls(graph->block());
-  GRAPH_DEBUG("After _batch_norm_impl_index: ", *graph);
-
-  traverseProfileIValues(graph->block(), RemoveProfileIValue);
-
-  GRAPH_DEBUG("Before remove missing profiling: ", *graph);
-  removeFusionWithMissingProfilingInformation(graph->block());
-  GRAPH_DEBUG("After remove missing profiling: ", *graph);
-
-  // optimization targeting AMP
-  removeOutputUsedOnlyInDtype(graph->block());
-  GRAPH_DEBUG("After removeOutputUsedOnlyInDtype: ", *graph);
   // After FuseGraph some common subexpressions may come back
   EliminateCommonSubexpression(graph);
   // We might have emitted a fair amount of useless shape propagating code, so
@@ -1897,7 +1097,6 @@ void CudaFuseGraph(std::shared_ptr<Graph>& graph) {
   // shamelessly use tool from NNC.
   RemoveTensorTypeSpecializations(graph);
 
-  GRAPH_DUMP("Before Compilation: ", graph);
   // Compile CudaFusionGroup
   compileFusionRecursive(graph->block());
 }
index 2411303..ecec6ae 100644 (file)
@@ -3,16 +3,13 @@
 #include <c10/util/Exception.h>
 #include <c10/util/irange.h>
 #include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/index_reference_replay.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
 #include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
 #include <torch/csrc/jit/codegen/cuda/transform_iter.h>
 #include <torch/csrc/jit/codegen/cuda/transform_replay.h>
 
@@ -61,14 +58,11 @@ class ContigIDs : public OptInDispatch {
   void handle(Split*) override {}
 
   void handle(Merge* merge) override {
-    const auto gpu_lower = GpuLower::current();
-
     // If either input is non-contiguous so is output.
-    const auto inner = merge->inner();
-    const auto outer = merge->outer();
-
-    if ((!isContig(gpu_lower->lowerValue(inner)->as<kir::IterDomain>()) ||
-         !isContig(gpu_lower->lowerValue(outer)->as<kir::IterDomain>()))) {
+    auto inner = merge->inner();
+    auto outer = merge->outer();
+    if (!isContig(GpuLower::lowerValue(inner)->as<kir::IterDomain>()) ||
+        !isContig(GpuLower::lowerValue(outer)->as<kir::IterDomain>())) {
       return;
     }
 
@@ -116,9 +110,9 @@ class ContigIDs : public OptInDispatch {
       if (root_copy.front() == ordered_inputs.front()) {
         root_copy.pop_front();
         ordered_inputs.pop_front();
-        // This is no longer causing an error in:
-        // ReductionSchedulerMultiDimNonFastest TODO: test reenablement to make
-        // sure it does what's expected
+        // We probably should be able to make access contiguous through
+        // reduction domains, however, for now it's causing issues in predicate
+        // generation. See test: ReductionSchedulerMultiDimNonFastest
         //  } else if (
         //     root_copy.front()->isReduction() ||
         //     root_copy.front()->isBroadcast()) {
@@ -132,10 +126,10 @@ class ContigIDs : public OptInDispatch {
     // top contig ID, lower ids should be placed in the "within_contig_ids" map
     // of top id.
     auto kir_inner =
-        gpu_lower->lowerValue(merge->inner())->as<kir::IterDomain>();
+        GpuLower::lowerValue(merge->inner())->as<kir::IterDomain>();
     auto kir_outer =
-        gpu_lower->lowerValue(merge->outer())->as<kir::IterDomain>();
-    auto kir_out = gpu_lower->lowerValue(merge->out())->as<kir::IterDomain>();
+        GpuLower::lowerValue(merge->outer())->as<kir::IterDomain>();
+    auto kir_out = GpuLower::lowerValue(merge->out())->as<kir::IterDomain>();
     if (ordered_inputs.empty()) {
       if (contig_ids.find(kir_inner) != contig_ids.end()) {
         contig_ids.erase(kir_inner);
@@ -171,14 +165,12 @@ class ContigIDs : public OptInDispatch {
 
   // Check through thie history of ids whose inputs map to root_domain with
   // contiguity root_contiguity. Return unordered_set of all merges that are
-  // contiguous. Ignore root order is primarily used for predicate generation.
-  // In this case we can linearize indexing of any ID that only consists of
-  // merge operations.
+  // contiguous.
   ContigIDs(
       const std::vector<IterDomain*>& ids,
-      const std::vector<IterDomain*>& root_domain,
-      const std::vector<bool>& root_contiguity)
-      : root_domain_(root_domain), root_contiguity_(root_contiguity) {
+      const std::vector<IterDomain*>& _root_domain,
+      const std::vector<bool>& _root_contiguity)
+      : root_domain_(_root_domain), root_contiguity_(_root_contiguity) {
     if (ids.empty()) {
       return;
     }
@@ -190,19 +182,15 @@ class ContigIDs : public OptInDispatch {
         " != ",
         root_contiguity_.size());
 
-    const auto gpu_lower = GpuLower::current();
-
     for (const auto i : c10::irange(root_domain_.size())) {
       if (root_contiguity_[i]) {
         auto kir_root_domain_i =
-            gpu_lower->lowerValue(root_domain_[i])->as<kir::IterDomain>();
+            GpuLower::lowerValue(root_domain_[i])->as<kir::IterDomain>();
         contig_ids.emplace(kir_root_domain_i);
         within_contig_ids[kir_root_domain_i] =
             std::unordered_set<kir::IterDomain*>();
-        is_contig_root[root_domain_[i]] = true;
-      } else {
-        is_contig_root[root_domain_[i]] = false;
       }
+      is_contig_root[root_domain_[i]] = root_contiguity_[i];
     }
 
     auto exprs = ExprSort::getExprs(ids[0]->fusion(), {ids.begin(), ids.end()});
@@ -223,287 +211,59 @@ class ContigIDs : public OptInDispatch {
   }
 };
 
-// Update the HaloInfo mappings for a reference tensor by propagating
-// the halo information from the consumer tensor.
-void updateHaloInfoForReference(
-    const ReferenceTensor& reference,
-    const TensorView* consumer_tv) {
-  const auto gpu_lower = GpuLower::current();
-
-  auto& halo_info = gpu_lower->haloInfo();
-
-  auto* reference_domain = reference.domain;
-  const auto& reference_concrete_map = reference.concrete_to_id;
-
-  for (auto reference_root_axis : reference_domain->getRootDomain()) {
-    // Set default
-    halo_info.setRootAxisInfo(reference_root_axis, AxisHaloInfo());
-    auto consumer_it = std::find_if(
-        consumer_tv->getRootDomain().begin(),
-        consumer_tv->getRootDomain().end(),
-        [&](IterDomain* consumer_root) {
-          auto concrete_id =
-              gpu_lower->caIndexMap().getConcreteMappedID(consumer_root);
-          auto it = reference_concrete_map.find(concrete_id);
-          return it != reference_concrete_map.end() &&
-              it->second == reference_root_axis;
-        });
-    // When no corresponding ID of the consumer exists, the reference
-    // axis can be ignored
-    if (consumer_it == consumer_tv->getRootDomain().end()) {
-      continue;
-    }
-    auto consumer_root_axis = *consumer_it;
-    auto root_axis_info =
-        gpu_lower->haloInfo().getRootAxisInfo(consumer_root_axis);
-    if (root_axis_info.width() == 0) {
-      continue;
-    }
-    halo_info.setRootAxisInfo(reference_root_axis, root_axis_info);
-  }
-
-  halo_info.build(reference_domain);
-
-  return;
-}
-
-// Get a map of IterDomains to halo-extended extents of corresponding
-// reference IterDomains.
-//
-// ref_map: ref-to-consumer in consumer indexing; ref-to-producer in
-// producer indexing
-std::unordered_map<kir::IterDomain*, kir::Val*> getReferenceHaloExtentMap(
-    const ReferenceTensor& reference,
-    const TensorView* consumer_tv,
-    const std::unordered_map<IterDomain*, IterDomain*>& ref_map,
-    const std::unordered_map<kir::IterDomain*, kir::Val*>& extent_map) {
-  const auto gpu_lower = GpuLower::current();
-
-  // First, update HaloInfo with the reference tensor, which reflects
-  // the halo extents of the consumer tensor.
-  updateHaloInfoForReference(reference, consumer_tv);
-
-  const auto& halo_info = gpu_lower->haloInfo();
-
-  std::unordered_map<kir::IterDomain*, kir::Val*> reference_halo_extent_map;
-
-  // Propagate halo extents of the reference to the consumer or
-  // producer tensor
-  for (auto kv : ref_map) {
-    auto ref_id = gpu_lower->lowerValue(kv.first)->as<kir::IterDomain>();
-    auto producer_or_consumer_id =
-        gpu_lower->lowerValue(kv.second)->as<kir::IterDomain>();
-    auto extent = halo_info.getExtent(ref_id);
-    if (extent == nullptr) {
-      auto extent_it = extent_map.find(ref_id);
-      if (extent_it != extent_map.end()) {
-        extent = extent_it->second;
-      } else {
-        extent = ref_id->extent();
-      }
-    }
-    reference_halo_extent_map[producer_or_consumer_id] = extent;
-  }
-
-  return reference_halo_extent_map;
-}
-
-//! Offset of an index of a producer axis with respect to its
-//! corresponding consumer index
-kir::Val* getProducerHaloOffset(
-    const TensorView* producer_tv,
-    size_t producer_axis,
-    const TensorView* consumer_tv) {
-  auto p2c =
-      PairwiseRootDomainMap(producer_tv, consumer_tv)
-          .mapProducerToConsumer(producer_tv->domain(), consumer_tv->domain());
-
-  auto producer_id = producer_tv->getMaybeRFactorDomain()[producer_axis];
-
-  auto it = p2c.find(producer_id);
-  // p2c should always have a mapping for producer_id. The only case
-  // where no mapping exists for a producer axis is when it is a
-  // reduction axis. Since this function is only used for indexing
-  // producer tensors, where reduction axes are skipped, producer_id
-  // should never be a reduction axis.
-  TORCH_INTERNAL_ASSERT(it != p2c.end());
-  IterDomain* consumer_id = it->second;
-
-  const auto& halo_map = GpuLower::current()->haloInfo();
-  const auto p_pad = halo_map.getRootAxisInfo(producer_id).width(0);
-  const auto c_pad = halo_map.getRootAxisInfo(consumer_id).width(0);
-
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  kir::Val* offset = (p_pad->isConst() && c_pad->isConst())
-      ? ir_builder.create<kir::Int>(
-            p_pad->value().value() - c_pad->value().value())
-      : ir_builder.subExpr(p_pad, c_pad);
-
-  // If the consumer is a result of shifting the producer, adjust the
-  // producer index per the offsets argument of the shift op.
-  if (auto shift_op = dynamic_cast<const ShiftOp*>(consumer_tv->definition())) {
-    offset = ir_builder.subExpr(
-        offset, ir_builder.create<kir::Int>(shift_op->offset(producer_axis)));
-  }
-
-  return offset;
-}
-
-//! Offset producer index when necessary
-kir::Val* getProducerIndexWithHalo(
-    const TensorView* producer_tv,
-    size_t producer_axis,
-    kir::Val* producer_index,
-    const TensorView* consumer_tv) {
-  const auto offset =
-      getProducerHaloOffset(producer_tv, producer_axis, consumer_tv);
-
-  if (offset->isZeroInt()) {
-    return producer_index;
-  }
-
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  producer_index = ir_builder.addExpr(producer_index, offset);
-
-  return producer_index;
-}
-
-//! Offset a producer index of a gather expression
-//!
-//! Given an index of a producer root axis, build a new index
-//! expression that accesses a window position that the current loop
-//! structure refers to.
-kir::Val* getProducerIndexWithGather(
-    size_t producer_root_axis,
-    kir::Val* producer_index,
-    const TensorView* producer_tv,
-    const TensorView* consumer_tv,
-    const std::unordered_map<kir::IterDomain*, kir::Val*>& ref_index_map,
-    const std::unordered_map<IterDomain*, IterDomain*>& ref_concrete_map) {
-  auto gather_op = dynamic_cast<const GatherOp*>(consumer_tv->definition());
-
-  // Just return the producer index as is if this is not a gather
-  if (gather_op == nullptr) {
-    return producer_index;
-  }
-
-  // Consumer axis that corresponds to the producer axis
-  int consumer_axis = -1;
-  for (size_t i = 0; i <= producer_root_axis; ++i) {
-    if (producer_tv->getRootDomain()[i]->isReduction()) {
-      continue;
-    }
-    ++consumer_axis;
-  }
-
-  TORCH_INTERNAL_ASSERT(
-      consumer_axis >= 0 &&
-          consumer_axis < (int)gather_op->windowShape().size(),
-      "Invalid consumer axis",
-      consumer_axis,
-      ", producer_axis: ",
-      producer_root_axis);
-
-  // If the window extent is one, no specific offsetting
-  // is necessary
-  if (gather_op->windowShape()[consumer_axis]->isOneInt()) {
-    return producer_index;
-  }
-
-  // Basically, the goal is to build an expression of producer_index +
-  // window_index, so we first need to locate the index expression
-  // that corresponds to the window axis of this producer axis.
-
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  // Locate the root IterDomain of the reference that corresponds to the gather
-  // axis
-  const auto window_root_axis = gather_op->gatherAxis(consumer_axis);
-  auto concrete_window_id = gpu_lower->caIndexMap().getConcreteMappedID(
-      consumer_tv->getRootDomain().at(window_root_axis));
-  auto ref_concrete_map_it = ref_concrete_map.find(concrete_window_id);
-  TORCH_INTERNAL_ASSERT(ref_concrete_map_it != ref_concrete_map.end());
-  IterDomain* reference_root_of_gather_axis = ref_concrete_map_it->second;
-
-  // Now that reference_root_of_gather_axis is the IterDomain for the
-  // window axis, take its corresponding index from the index map
-  auto window_idx =
-      ref_index_map.at(gpu_lower->lowerValue(reference_root_of_gather_axis)
-                           ->as<kir::IterDomain>());
-
-  // Positive (or negative) padding at offset zero means the indexing
-  // shifted to the negative (or positive) direction.
-  auto pad_width = gather_op->padWidth()[consumer_axis][0];
-
-  // producer_index - padding + window_index
-  auto offset_producer_index = ir_builder.addExpr(
-      ir_builder.subExpr(
-          producer_index, ir_builder.create<kir::Int>(pad_width)),
-      window_idx);
-
-  return offset_producer_index;
-}
-
 } // namespace
 
 void IndexCompute::handle(Split* split) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  auto in_id = gpu_lower->lowerValue(split->in())->as<kir::IterDomain>();
-  auto outer_id = gpu_lower->lowerValue(split->outer())->as<kir::IterDomain>();
-  auto inner_id = gpu_lower->lowerValue(split->inner())->as<kir::IterDomain>();
+  auto in_id = GpuLower::lowerValue(split->in())->as<kir::IterDomain>();
+  auto outer_id = GpuLower::lowerValue(split->outer())->as<kir::IterDomain>();
+  auto inner_id = GpuLower::lowerValue(split->inner())->as<kir::IterDomain>();
 
   auto outer_it = index_map_.find(outer_id);
   auto inner_it = index_map_.find(inner_id);
   if (outer_it == index_map_.end() || inner_it == index_map_.end())
     return;
 
-  const auto outer_ind = outer_it->second;
-  const auto inner_ind = inner_it->second;
-
-  const bool outer_zero = isZero(outer_id);
-  const bool inner_zero = isZero(inner_id);
+  auto outer_ind = outer_it->second;
+  auto inner_ind = inner_it->second;
 
-  // We want to mark as zero merged in if we're working with shared or local
-  // memory, and the dimension we're working with is not part of the allocation,
-  // as we have special propagation rules for that scenario.
+  bool outer_zero = outer_ind->isZeroInt();
+  bool inner_zero = inner_ind->isZeroInt();
 
-  // Maybe clear in_id as it could have been mapped over from another
-  // IndexCompute. Uncertain if this is needed but seems to be safe.
-  bool zero_merged_in = hasZeroMerged(in_id) || hasZeroMerged(inner_id) ||
-      hasZeroMerged(outer_id);
+  bool outer_bcast = outer_id->isBroadcast();
+  bool inner_bcast = inner_id->isBroadcast();
 
-  // If both are zero, the split input is also zero
-  if (inner_zero && outer_zero) {
-    zero_.emplace(in_id);
-  }
-
-  if (zero_merged_in) {
+  // Zero inds because a dim is bcast is part of normal traversal, if it's not
+  // bcast but is zero ind then it's from local or smem. In the latter case we
+  // want to propagate this property.
+  if ((outer_zero && !outer_bcast) || (inner_zero && !inner_bcast) ||
+      hasZeroMerged(inner_id) || hasZeroMerged(outer_id)) {
     zero_merged_in_.emplace(in_id);
+  } else {
+    // Maybe clear in_id as it could have been mapped over from another
+    // IndexCompute. Uncertain if this is needed but seems to be safe.
+    if (hasZeroMerged(in_id)) {
+      zero_merged_in_.erase(in_id);
+    }
   }
 
-  if (isZero(in_id)) {
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+
+  if (outer_zero && inner_zero) {
     index_map_[in_id] = ir_builder.create<kir::Int>(0);
     extent_map_[in_id] = ir_builder.create<kir::Int>(0);
-  } else if (zero_merged_in && outer_zero) {
+  } else if (outer_zero) {
     index_map_[in_id] = inner_ind;
+    zero_merged_in_.emplace(in_id);
     extent_map_[in_id] = getExtent(inner_id);
-  } else if (zero_merged_in && inner_zero) {
+  } else if (inner_zero) {
     index_map_[in_id] = outer_ind;
+    zero_merged_in_.emplace(in_id);
     extent_map_[in_id] = getExtent(outer_id);
   } else {
     index_map_[in_id] = ir_builder.addExpr(
         ir_builder.mulExpr(outer_ind, getExtent(inner_id)), inner_ind);
-    // The extent of a root axis should be only updated when its
-    // allocation is partial, i.e., zero_merged_in is true. See issue
-    // #1016 and the FusionIssue1016 test.
-    if (split->in()->definition() != nullptr || zero_merged_in) {
+    if (extent_map_.find(outer_id) != extent_map_.end() ||
+        extent_map_.find(inner_id) != extent_map_.end()) {
       extent_map_[in_id] =
           ir_builder.mulExpr(getExtent(outer_id), getExtent(inner_id));
     }
@@ -511,33 +271,28 @@ void IndexCompute::handle(Split* split) {
 }
 
 void IndexCompute::handle(Merge* merge) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  auto out_id = gpu_lower->lowerValue(merge->out())->as<kir::IterDomain>();
-  auto outer_id = gpu_lower->lowerValue(merge->outer())->as<kir::IterDomain>();
-  auto inner_id = gpu_lower->lowerValue(merge->inner())->as<kir::IterDomain>();
+  auto out_id = GpuLower::lowerValue(merge->out())->as<kir::IterDomain>();
+  auto outer_id = GpuLower::lowerValue(merge->outer())->as<kir::IterDomain>();
+  auto inner_id = GpuLower::lowerValue(merge->inner())->as<kir::IterDomain>();
 
   auto out_it = index_map_.find(out_id);
-  if (out_it == index_map_.end()) {
+  if (out_it == index_map_.end())
     return;
-  }
+
   auto out_ind = out_it->second;
 
-  auto zero = ir_builder.zeroVal();
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+  auto zero = ir_builder.create<kir::Int>(0);
 
-  if (isZero(out_id)) {
+  if (out_ind->isZeroInt()) {
     index_map_[outer_id] = zero;
     index_map_[inner_id] = zero;
     extent_map_[outer_id] = zero;
     extent_map_[inner_id] = zero;
-    zero_.emplace(outer_id);
-    zero_.emplace(inner_id);
     return;
   }
 
   if (!hasZeroMerged(out_id) && contig_ids.find(out_id) != contig_ids.end()) {
-    // Contiguous indexing path
     auto input_ids = ir_utils::iterDomainInputsOfOrderedAs(
         {merge->out()}, td_->getRootDomain());
 
@@ -545,79 +300,45 @@ void IndexCompute::handle(Merge* merge) {
     TORCH_INTERNAL_ASSERT(!input_ids.empty());
 
     for (auto root_id : input_ids) {
-      index_map_[gpu_lower->lowerValue(root_id)->as<kir::IterDomain>()] = zero;
+      index_map_[GpuLower::lowerValue(root_id)->as<kir::IterDomain>()] = zero;
     }
 
-    index_map_[gpu_lower
-                   ->lowerValue(*(input_ids.end() - 1))
+    index_map_[GpuLower::lowerValue(*(input_ids.end() - 1))
                    // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
                    ->as<kir::IterDomain>()] = out_ind;
     return;
   }
 
-  kir::Val* inner_extent = getExtent(inner_id);
-
-  // When the reference has halo extent for inner_id, that extent needs to
-  // be used to un-merge
-  if (reference_halo_extent_map_.find(inner_id) !=
-      reference_halo_extent_map_.end()) {
-    inner_extent = reference_halo_extent_map_[inner_id];
-  }
-
-  const auto outer_extent = getExtent(outer_id);
+  Val* inner_extent = getExtent(inner_id);
+  Val* outer_extent = getExtent(outer_id);
 
   if (inner_id->isBroadcast() && inner_extent->isOneInt()) {
-    // Propagate away from broadcast dims
     index_map_[outer_id] = out_ind;
     index_map_[inner_id] = zero;
 
     extent_map_[outer_id] = getExtent(out_id);
   } else if (outer_id->isBroadcast() && outer_extent->isOneInt()) {
-    // Propagate away from broadcast dims
     index_map_[outer_id] = zero;
     index_map_[inner_id] = out_ind;
 
     extent_map_[inner_id] = getExtent(out_id);
   } else if (hasZeroMerged(out_id)) {
-    // Don't propagate to inner id if it's comprised of only broadcast root
-    // domains, unless outer is also all broadcast domains. Index shouldn't be
-    // anything but zero if both inner and outer are all broadcast domains, but
-    // didn't add a hard check for this. See FusionAdvancedIndexing5_CUDA
-    if (!inner_id->isBroadcast() && !outer_id->isBroadcast()) {
-      // If neither dimension is a broadcast (should be true for reference
-      // indexing) pick the preferred path or the inner path.
-      if (preferred_paths_.find(outer_id) != preferred_paths_.end() &&
-          preferred_paths_.find(inner_id) == preferred_paths_.end()) {
-        // Marked that we should prop through outer, not inner.
-        index_map_[outer_id] = out_ind;
-        extent_map_[outer_id] = getExtent(out_id);
-        index_map_[inner_id] = zero;
-        extent_map_[inner_id] = zero;
-      } else {
-        // Prop through inner
-        index_map_[inner_id] = out_ind;
-        extent_map_[inner_id] = getExtent(out_id);
-        index_map_[outer_id] = zero;
-        extent_map_[outer_id] = zero;
-      }
-    } else if (inner_id->isBroadcast() && !outer_id->isBroadcast()) {
-      // Inner is broadcast and outer isn't, prop through outer
-      index_map_[outer_id] = out_ind;
-      extent_map_[outer_id] = getExtent(out_id);
-      index_map_[inner_id] = zero;
-      extent_map_[inner_id] = zero;
-    } else {
-      // Default to propagating through inner
-      index_map_[inner_id] = out_ind;
-      extent_map_[inner_id] = getExtent(out_id);
-      index_map_[outer_id] = zero;
-      extent_map_[outer_id] = zero;
-    }
+    index_map_[inner_id] = out_ind;
+    extent_map_[inner_id] = getExtent(out_id);
+
+    index_map_[outer_id] = zero;
+    extent_map_[outer_id] = zero;
+
     zero_merged_in_.emplace(inner_id);
     zero_merged_in_.emplace(outer_id);
   } else {
-    index_map_[outer_id] = ir_builder.divExpr(out_ind, inner_extent);
-    index_map_[inner_id] = ir_builder.modExpr(out_ind, inner_extent);
+    Val* I = inner_extent;
+
+    Val* outer_ind = ir_builder.divExpr(out_ind, I);
+    Val* inner_ind = ir_builder.modExpr(out_ind, I);
+
+    index_map_[outer_id] = outer_ind;
+    index_map_[inner_id] = inner_ind;
   }
 }
 
@@ -637,19 +358,15 @@ void IndexCompute::handle(Expr* e) {
 // using TransformIter::runBackward;
 IndexCompute::IndexCompute(
     const TensorDomain* _td,
-    std::unordered_map<kir::IterDomain*, kir::Val*> initial_index_map,
-    std::unordered_map<kir::IterDomain*, kir::Val*> extent_map,
-    std::unordered_set<kir::IterDomain*> zero_merged_in,
-    const std::vector<bool>& root_contiguity,
-    std::unordered_set<kir::IterDomain*> preferred_paths,
-    std::unordered_map<kir::IterDomain*, kir::Val*> reference_halo_extent_map)
+    std::unordered_map<kir::IterDomain*, Val*> initial_index_map,
+    std::unordered_map<kir::IterDomain*, Val*> _extent_map,
+    std::unordered_set<kir::IterDomain*> _zero_merged_in,
+    const std::vector<bool>& root_contiguity)
     : td_(_td),
       index_map_(std::move(initial_index_map)),
-      extent_map_(std::move(extent_map)),
-      zero_merged_in_(std::move(zero_merged_in)),
-      preferred_paths_(std::move(preferred_paths)),
-      reference_halo_extent_map_(std::move(reference_halo_extent_map)) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute");
+      extent_map_(std::move(_extent_map)),
+      zero_merged_in_(std::move(_zero_merged_in)) {
+  FUSER_PERF_SCOPE("IndexCompute::IndexCompute");
 
   // Make sure we recompute any indices we can that map to a contiguous access
   // in physical memory.
@@ -671,459 +388,427 @@ IndexCompute::IndexCompute(
     }
   }
 
-  // Initialize the zero_ set with domains that do not contibute to
-  // the resulting index. Any domain that is mapped to Int(0), except
-  // for vectorized ones, is included in this set.
-  const auto gpu_lower = GpuLower::current();
-  for (auto dom : td_->domain()) {
-    auto kir_dom = gpu_lower->lowerValue(dom)->as<kir::IterDomain>();
-    auto it = index_map_.find(kir_dom);
-    if (it == index_map_.end()) {
-      continue;
-    }
-    auto idx = it->second;
-    if (idx->isZeroInt() && !isParallelTypeVectorize(dom->getParallelType())) {
-      zero_.emplace(kir_dom);
-    }
-  }
-}
-
-void IndexCompute::run() {
   const std::vector<Val*> domain_vals(
       td_->domain().begin(), td_->domain().end());
 
   traverseFrom(td_->fusion(), domain_vals, false);
 }
 
-kir::Val* IndexCompute::getExtent(kir::IterDomain* id) {
-  if (isParallelTypeThread(id->parallelType())) {
-    auto parallel_dim =
-        GpuLower::current()->parallelDimensionMap().get(id->parallelType());
-    TORCH_INTERNAL_ASSERT(parallel_dim != nullptr);
-    return parallel_dim;
-  } else if (extent_map_.find(id) != extent_map_.end()) {
+Val* IndexCompute::getExtent(kir::IterDomain* id) {
+  if (extent_map_.find(id) != extent_map_.end()) {
     return extent_map_.at(id);
   } else {
     return id->extent();
   }
 }
 
-bool IndexCompute::hasZeroMerged(kir::IterDomain* id) const {
-  return zero_merged_in_.find(id) != zero_merged_in_.end() || isZero(id);
-}
-
-bool IndexCompute::isZero(kir::IterDomain* id) const {
-  return zero_.find(id) != zero_.end();
+bool IndexCompute::hasZeroMerged(kir::IterDomain* id) {
+  return zero_merged_in_.find(id) != zero_merged_in_.end();
 }
 
 IndexCompute IndexCompute::updateIndexCompute(
     const TensorDomain* new_td,
     const std::unordered_map<IterDomain*, IterDomain*>& id_map,
-    const std::vector<bool>& root_contiguity,
-    const std::unordered_map<kir::IterDomain*, kir::Val*>&
-        reference_halo_extent_map) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::updateIndexCompute");
-
-  const auto gpu_lower = GpuLower::current();
+    std::unordered_map<kir::IterDomain*, Val*> new_index_entries,
+    const std::vector<bool>& root_contiguity) {
+  FUSER_PERF_SCOPE("updateIndexCompute");
 
-  std::unordered_map<kir::IterDomain*, kir::Val*> updated_index_map;
-  std::unordered_map<kir::IterDomain*, kir::Val*> updated_extent_map;
+  std::unordered_map<kir::IterDomain*, Val*> updated_index_map =
+      std::move(new_index_entries);
+  std::unordered_map<kir::IterDomain*, Val*> updated_extent_map;
   std::unordered_set<kir::IterDomain*> updated_zero_merged_in;
 
   for (auto id_entry : id_map) {
     kir::IterDomain* prev_id =
-        gpu_lower->lowerValue(id_entry.first)->as<kir::IterDomain>();
+        GpuLower::lowerValue(id_entry.first)->as<kir::IterDomain>();
     kir::IterDomain* new_id =
-        gpu_lower->lowerValue(id_entry.second)->as<kir::IterDomain>();
+        GpuLower::lowerValue(id_entry.second)->as<kir::IterDomain>();
 
     if (index_map_.find(prev_id) != index_map_.end()) {
       updated_index_map[new_id] = index_map_.at(prev_id);
     }
 
-    updated_extent_map[new_id] = getExtent(prev_id);
+    if (extent_map_.find(prev_id) != extent_map_.end()) {
+      updated_extent_map[new_id] = extent_map_.at(prev_id);
+    }
 
     if (zero_merged_in_.find(prev_id) != zero_merged_in_.end()) {
       updated_zero_merged_in.emplace(new_id);
     }
   }
 
-  IndexCompute updated_index_compute(
+  return IndexCompute(
       new_td,
       updated_index_map,
       updated_extent_map,
       updated_zero_merged_in,
-      root_contiguity,
-      {},
-      reference_halo_extent_map);
-  updated_index_compute.run();
+      root_contiguity);
+}
 
-  return updated_index_compute;
+std::vector<bool> IndexCompute::contiguityAnd(
+    const std::vector<bool>& contig1,
+    const std::vector<bool>& contig2) {
+  TORCH_INTERNAL_ASSERT(
+      contig1.size() == contig2.size(),
+      "Called contiguityAnd with mismatched vectors.");
+
+  std::vector<bool> contig_result;
+  std::transform(
+      contig1.begin(),
+      contig1.end(),
+      contig2.begin(),
+      std::back_inserter(contig_result),
+      std::logical_and<>());
+  return contig_result;
 }
 
-namespace {
-// Map indices down to the leaf domains for applying swizzle
-class UpdateLeafIndices : public IterVisitor {
- public:
-  UpdateLeafIndices(
-      const TensorDomain* td,
-      std::unordered_map<kir::IterDomain*, kir::Val*> initial_index_map,
-      std::unordered_map<kir::IterDomain*, kir::Val*> extent_map)
-      : td_(td),
-        index_map_(std::move(initial_index_map)),
-        extent_map_(std::move(extent_map)) {
-    const std::vector<Val*> domain_vals(
-        td_->domain().begin(), td_->domain().end());
-
-    traverseFrom(td_->fusion(), domain_vals, false);
+// TODO: use new mapping functions
+// This mapping might need to go through rfactor, unclear
+std::vector<bool> IndexCompute::contiguityPasC(
+    TensorDomain* producer,
+    TensorDomain* consumer) {
+  FUSER_PERF_SCOPE("contiguityPasC");
+
+  const std::vector<bool>& producer_contiguity = producer->contiguity();
+  std::vector<bool> as_consumer_contiguity;
+
+  auto c_root = consumer->getRootDomain();
+  auto p_root = producer->getRootDomain();
+
+  size_t p_ind = 0;
+  size_t c_ind = 0;
+  while (p_ind < p_root.size()) {
+    if (p_root[p_ind]->isReduction()) {
+      p_ind++;
+    } else if (
+        c_root[c_ind]->isBroadcast() &&
+        p_root[p_ind]->getIterType() != c_root[c_ind]->getIterType()) {
+      c_ind++;
+      as_consumer_contiguity.push_back(false);
+    } else {
+      as_consumer_contiguity.push_back(producer_contiguity[p_ind]);
+      c_ind++;
+      p_ind++;
+    }
   }
 
-  const std::unordered_map<kir::IterDomain*, kir::Val*>& indexMap() const {
-    return index_map_;
+  while (c_ind < c_root.size()) {
+    as_consumer_contiguity.push_back(false);
+    c_ind++;
   }
 
-  const std::unordered_map<kir::IterDomain*, kir::Val*>& extentMap() const {
-    return extent_map_;
+  return as_consumer_contiguity;
+}
+
+namespace {
+
+std::deque<TensorView*> getComputeAtTVStackFrom(TensorView* from_tv) {
+  // What's the computeAt root tensor view in this operation
+  // This tensor is the terminating tensor in the computeAT dag from consumer
+  auto end_tv = from_tv->getComputeAtAxis(0).second;
+
+  // grab all tensor views from producer_tv -> computeAtRoot
+  std::deque<TensorView*> tv_stack;
+
+  // Then immediate consumer
+  auto running_tv = from_tv;
+
+  // Follow computeAt path until we hit end_tv
+  while (running_tv != end_tv) {
+    TORCH_INTERNAL_ASSERT(running_tv->hasComputeAt());
+    tv_stack.push_front(running_tv);
+    running_tv = running_tv->getComputeAtView();
   }
 
- private:
-  using IterVisitor::handle;
+  tv_stack.push_front(end_tv);
+
+  return tv_stack;
+}
+
+std::pair<
+    std::unordered_map<kir::IterDomain*, Val*>,
+    std::unordered_map<kir::IterDomain*, Val*>>
+generateIndexAndExtentMap(
+    std::deque<TensorView*> c2p_tv_stack,
+    std::deque<kir::ForLoop*> loops,
+    const std::unordered_map<kir::ForLoop*, Val*>& loop_to_ind_map,
+    const std::vector<bool>& last_tv_root_contiguity) {
+  if (c2p_tv_stack.empty())
+    return std::make_pair(
+        std::unordered_map<kir::IterDomain*, Val*>(),
+        std::unordered_map<kir::IterDomain*, Val*>());
+
+  // Go through our stack, and map the intermediate IterDomains from common
+  // transformations from consumer to producer
+  std::deque<std::unordered_map<IterDomain*, IterDomain*>> c2p_ID_maps;
+  std::deque<std::unordered_map<IterDomain*, IterDomain*>> p2c_ID_maps;
+
+  // c2p_tv_stack comes in as consumer -> producer
+  // Realized we may want to actually do a pass from producer->consumer first to
+  // propagate iterators outside the compute at position back into consumers, so
+  // we can repropagate back to producer. The need for this was exposed in
+  // https://github.com/csarofeen/pytorch/issues/286
+
+  for (size_t i = 0; i + 1 < c2p_tv_stack.size(); i++) {
+    auto c_tv = c2p_tv_stack[i];
+    auto p_tv = c2p_tv_stack[i + 1];
+
+    // Map root ID's from consumer to producer
+    auto c2p_root_map =
+        TensorDomain::mapRootCtoP(c_tv->domain(), p_tv->domain());
+
+    // Look for matching ID transformations in producer and consumer...
+    BestEffortReplay replay(
+        p_tv->domain()->domain(), c_tv->domain()->domain(), c2p_root_map);
+
+    // and grab the intermediate IterDomain map.
+    c2p_ID_maps.push_back(replay.getReplay());
+
+    // Something wasn't symmetric when using:
+    //
+    // auto p2c_root_map = TensorDomain::mapRootPtoC(p_tv->domain(),
+    // c_tv->domain());
+    //
+    // replay = BestEffortReplay(
+    //     c_tv->domain()->domain(), p_tv->domain()->domain(), p2c_root_map,
+    //     true);
 
-  void handle(Split* split) override {
-    const auto gpu_lower = GpuLower::current();
+    BestEffortReplay replay_p2c(
+        p_tv->domain()->domain(), c_tv->domain()->domain(), c2p_root_map, true);
 
-    auto in_id = gpu_lower->lowerValue(split->in())->as<kir::IterDomain>();
-    auto outer_id =
-        gpu_lower->lowerValue(split->outer())->as<kir::IterDomain>();
-    auto inner_id =
-        gpu_lower->lowerValue(split->inner())->as<kir::IterDomain>();
+    std::unordered_map<IterDomain*, IterDomain*> p2c_id_map;
 
-    // Nothing need to be done when mappings for the output axes
-    // already exist.
-    if (index_map_.find(outer_id) != index_map_.end()) {
-      TORCH_INTERNAL_ASSERT(
-          index_map_.find(inner_id) != index_map_.end(),
-          "Outer exists but inner not found");
-      return;
+    for (auto ent : replay_p2c.getReplay()) {
+      p2c_id_map[ent.second] = ent.first;
     }
 
-    kir::IrBuilder ir_builder(gpu_lower->kernel());
-    auto factor = gpu_lower->lowerValue(split->factor());
-    index_map_[inner_id] = ir_builder.modExpr(index_map_[in_id], factor);
-    extent_map_[inner_id] = factor;
-    index_map_[outer_id] = ir_builder.divExpr(index_map_[in_id], factor);
-    extent_map_[outer_id] = ir_builder.ceilDivExpr(getExtent(in_id), factor);
+    // and grab the intermediate IterDomain map.
+    p2c_ID_maps.push_front(p2c_id_map);
   }
 
-  void handle(Merge* merge) override {
-    const auto gpu_lower = GpuLower::current();
+  // Maps to be used in the c2p propagation
+  std::unordered_map<TensorView*, std::unordered_map<kir::IterDomain*, Val*>>
+      p2c_index_maps;
 
-    auto out_id = gpu_lower->lowerValue(merge->out())->as<kir::IterDomain>();
-    auto outer_id =
-        gpu_lower->lowerValue(merge->outer())->as<kir::IterDomain>();
-    auto inner_id =
-        gpu_lower->lowerValue(merge->inner())->as<kir::IterDomain>();
+  // PROPAGATE PRODUCER -> CONSUMER START
 
-    // Nothing need to be done when mappings for the output axes
-    // already exist.
-    if (index_map_.find(out_id) != index_map_.end()) {
-      return;
-    }
+  std::deque<TensorView*> p2c_tv_stack(
+      c2p_tv_stack.rbegin(), c2p_tv_stack.rend());
 
+  // Setup initial IndexCompute:
+  auto tv = p2c_tv_stack.front();
+  p2c_tv_stack.pop_front();
+  auto td = tv->domain()->domain();
+
+  std::vector<kir::IterDomain*> kir_td;
+
+  std::transform(
+      td.begin(), td.end(), std::back_inserter(kir_td), [](IterDomain* id) {
+        return GpuLower::lowerValue(id)->as<kir::IterDomain>();
+      });
+
+  // Map from all IterDomain's to corresponding index as we process each tv in
+  // the stack
+  std::unordered_map<kir::IterDomain*, Val*> initial_index_map;
+
+  // Match loops to this TV if the loop matchis this TV's ID (could reduce
+  // complexity here)
+
+  while (
+      !loops.empty() &&
+      std::find(kir_td.rbegin(), kir_td.rend(), loops.back()->iter_domain()) !=
+          kir_td.rend()) {
     TORCH_INTERNAL_ASSERT(
-        index_map_.find(outer_id) != index_map_.end(), "Outer ID not found");
-    TORCH_INTERNAL_ASSERT(
-        index_map_.find(inner_id) != index_map_.end(), "Inner ID not found");
+        loop_to_ind_map.find(loops.back()) != loop_to_ind_map.end());
+    initial_index_map[loops.back()->iter_domain()] =
+        loop_to_ind_map.at(loops.back());
+    loops.pop_back();
+  }
+
+  IndexCompute index_compute(
+      tv->domain(),
+      initial_index_map,
+      std::unordered_map<kir::IterDomain*, Val*>(),
+      std::unordered_set<kir::IterDomain*>(),
+      std::vector<bool>(tv->getRootDomain().size(), false));
+
+  p2c_index_maps[tv] = index_compute.indexMap();
+
+  // Go through the tv entire stack
+  while (!p2c_tv_stack.empty()) {
+    // Grab the TV
+    tv = p2c_tv_stack.front();
+    p2c_tv_stack.pop_front();
+    td = tv->domain()->domain();
+    kir_td.clear();
+    std::transform(
+        td.begin(), td.end(), std::back_inserter(kir_td), [](IterDomain* id) {
+          return GpuLower::lowerValue(id)->as<kir::IterDomain>();
+        });
 
-    kir::IrBuilder ir_builder(gpu_lower->kernel());
-    index_map_[out_id] = ir_builder.mulExpr(
-        index_map_[inner_id],
-        ir_builder.mulExpr(index_map_[outer_id], getExtent(inner_id)));
+    // Match loops to this TV if the loop matchis this TV's ID (could reduce
+    // complexity here)
 
-    extent_map_[out_id] =
-        ir_builder.mulExpr(getExtent(outer_id), getExtent(inner_id));
-  }
+    // Map from all IterDomain's to corresponding index as we process each tv in
+    // the stack
+    std::unordered_map<kir::IterDomain*, Val*> new_indices;
 
-  // return extent_map_[id] if exists, else return id->extent()
-  kir::Val* getExtent(kir::IterDomain* id) {
-    if (extent_map_.find(id) != extent_map_.end()) {
-      return extent_map_.at(id);
-    } else {
-      return id->extent();
+    while (!loops.empty() &&
+           std::find(
+               kir_td.rbegin(), kir_td.rend(), loops.back()->iter_domain()) !=
+               kir_td.rend()) {
+      TORCH_INTERNAL_ASSERT(
+          loop_to_ind_map.find(loops.back()) != loop_to_ind_map.end());
+      new_indices[loops.back()->iter_domain()] =
+          loop_to_ind_map.at(loops.back());
+      loops.pop_back();
     }
-  }
 
- private:
-  const TensorDomain* td_;
-  std::unordered_map<kir::IterDomain*, kir::Val*> index_map_;
-  std::unordered_map<kir::IterDomain*, kir::Val*> extent_map_;
-};
+    if (!p2c_ID_maps.empty()) {
+      index_compute = index_compute.updateIndexCompute(
+          tv->domain(),
+          p2c_ID_maps.front(),
+          new_indices,
+          std::vector<bool>(tv->getRootDomain().size(), false));
 
-// Returns halo-extended extent if id has halo. Otherwise, just
-// returns id->extent.
-kir::Val* getHaloExtentOfRootAxis(
-    IterDomain* id,
-    kir::Val* normal_extent = nullptr) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
+      p2c_index_maps[tv] = index_compute.indexMap();
 
-  if (normal_extent == nullptr) {
-    normal_extent = gpu_lower->lowerValue(id->extent());
+      p2c_ID_maps.pop_front();
+    }
   }
 
-  const auto& halo = gpu_lower->haloInfo().getRootAxisInfo(id);
-  if (halo.hasHalo()) {
-    auto halo_extent = ir_builder.addExpr(normal_extent, halo.width());
-    return halo_extent;
-  } else {
-    return normal_extent;
+  // PROPAGATE PRODUCER -> CONSUMER END
+
+  // PROPAGATE CONSUMER -> PRODUCER START
+
+  // Setup initial IndexCompute:
+  tv = c2p_tv_stack.front();
+  c2p_tv_stack.pop_front();
+
+  // Map from all IterDomain's to corresponding index as we process each tv in
+  // the stack
+  initial_index_map = p2c_index_maps.at(tv);
+
+  std::unordered_map<kir::IterDomain*, Val*> initial_extent_map;
+  if (!c2p_ID_maps.empty()) {
+    auto first_id_map = c2p_ID_maps.front();
+    for (auto id_entry : first_id_map) {
+      kir::IterDomain* this_id =
+          GpuLower::lowerValue(id_entry.first)->as<kir::IterDomain>();
+      if (initial_extent_map.find(this_id) == initial_extent_map.end()) {
+        initial_extent_map[this_id] = this_id->extent();
+      }
+    }
   }
-}
 
-} // namespace
+  index_compute = IndexCompute(
+      tv->domain(),
+      initial_index_map,
+      initial_extent_map,
+      std::unordered_set<kir::IterDomain*>(),
+      c2p_tv_stack.empty()
+          ? last_tv_root_contiguity
+          : std::vector<bool>(tv->getRootDomain().size(), false));
 
-IndexSwizzle::IndexSwizzle(
-    const TensorView* tv,
-    std::unordered_map<kir::IterDomain*, kir::Val*> initial_index_map,
-    std::unordered_map<kir::IterDomain*, kir::Val*> extent_map,
-    std::unordered_set<kir::IterDomain*> zero_merged_in)
-    : IndexCompute(
+  // Go through the tv entire stack
+  while (!c2p_tv_stack.empty()) {
+    // Grab the TV
+    tv = c2p_tv_stack.front();
+    c2p_tv_stack.pop_front();
+
+    if (!c2p_ID_maps.empty()) {
+      index_compute = index_compute.updateIndexCompute(
           tv->domain(),
-          std::move(initial_index_map),
-          std::move(extent_map),
-          std::move(zero_merged_in),
-          std::vector<bool>(tv->getRootDomain().size(), false)),
-      tv_(tv),
-      swizzle_type_(tv->swizzleType()),
-      ids_to_swizzle_(tv->axesToSwizzle()) {}
-
-void IndexSwizzle::run() {
-  TORCH_INTERNAL_ASSERT(
-      swizzle_type_ == SwizzleType::NoSwizzle ||
-          swizzle_type_ == SwizzleType::Transpose,
-      "Invalid swizzle type");
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-  if (swizzle_type_ == SwizzleType::Transpose) {
-    // Shifts the second axis by the first axis as ((idx_1 + idx_2) %
-    // ext). Alternatively, ((idx_1 - idx_2) & (ext - 1)) would also
-    // work if ext is a power of two. Practically, ext should be 32 if
-    // the data type of the tensor is float, so the latter approach
-    // should also be fine.
-    TORCH_INTERNAL_ASSERT(tv_->getMemoryType() == MemoryType::Shared);
-    TORCH_INTERNAL_ASSERT(tv_->axesToSwizzle().size() == 2);
-
-    UpdateLeafIndices update_leaves(td_, indexMap(), extentMap());
-    index_map_ = update_leaves.indexMap();
-    extent_map_ = update_leaves.extentMap();
-
-    IterDomain* id_to_swizzle_i = ids_to_swizzle_.at(0);
-    IterDomain* id_to_swizzle_j = ids_to_swizzle_.at(1);
-    kir::IterDomain* id_to_swizzle_i_kir =
-        gpu_lower->lowerValue(id_to_swizzle_i)->as<kir::IterDomain>();
-    kir::IterDomain* id_to_swizzle_j_kir =
-        gpu_lower->lowerValue(id_to_swizzle_j)->as<kir::IterDomain>();
-
-    if (indexMap().find(id_to_swizzle_i_kir) != indexMap().end() &&
-        indexMap().find(id_to_swizzle_j_kir) != indexMap().end()) {
-      auto idx_to_swizzle_i = indexMap().at(id_to_swizzle_i_kir);
-      auto idx_to_swizzle_j = indexMap().at(id_to_swizzle_j_kir);
-
-      auto swizzled_idx = ir_builder.modExpr(
-          ir_builder.addExpr(idx_to_swizzle_i, idx_to_swizzle_j),
-          id_to_swizzle_j_kir->extent());
-      index_map_[id_to_swizzle_j_kir] = swizzled_idx;
-      swizzled_ids_.insert(id_to_swizzle_j);
-      IndexCompute::run();
+          c2p_ID_maps.front(),
+          p2c_index_maps.at(tv),
+          c2p_tv_stack.empty()
+              ? last_tv_root_contiguity
+              : std::vector<bool>(tv->getRootDomain().size(), false));
+
+      c2p_ID_maps.pop_front();
     }
   }
-}
 
-void IndexSwizzle::handle(Expr* e) {
-  auto out_ids = ir_utils::filterByType<IterDomain>(e->outputs());
-  bool needs_update =
-      std::any_of(out_ids.begin(), out_ids.end(), [this](IterDomain* id) {
-        return swizzled_ids_.find(id) != swizzled_ids_.end();
-      });
-  if (!needs_update) {
-    return;
-  }
+  // PROPAGATE CONSUMER -> PRODUCER END
+
+  // Fill in extent map as some mapped indices may not have their extent filled
+  // in it, but consumers of this function expect it to be there
 
-  IndexCompute::handle(e);
-  for (auto input : ir_utils::filterByType<IterDomain>(e->inputs())) {
-    swizzled_ids_.insert(input);
+  std::unordered_map<kir::IterDomain*, Val*> extent_map(
+      index_compute.extentMap());
+  for (auto ind_entry : index_compute.indexMap()) {
+    auto id = ind_entry.first;
+    if (extent_map.find(id) == extent_map.end()) {
+      extent_map[id] = id->extent();
+    }
   }
+
+  return std::make_pair(index_compute.indexMap(), extent_map);
 }
 
-std::vector<kir::Val*> Index::getGlobalProducerStridedIndices(
+} // namespace
+
+kir::TensorIndex* Index::getGlobalProducerIndex(
     TensorView* producer_tv,
-    const TensorView* consumer_tv,
+    TensorView* consumer_tv,
     const std::vector<kir::ForLoop*>& loops) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalProducerIndex");
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
+  FUSER_PERF_SCOPE("getGlobalProducerIndex");
 
-  // Get a reference tensor replayed as existing loop structure
-  auto reference = IndexReferenceReplay::getReference(loops);
-  auto reference_domain = reference.domain;
-  auto reference_id_map = reference.concrete_to_id;
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
 
   // Replay producer to look like consumer so we can index on producer since our
   // loop nests look like consumer
-  auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv);
-  auto producerAsC =
-      TransformReplay::replayPasC(producer_tv, consumer_tv, -1, pairwiseMap)
-          .first;
+  auto producerAsC = TransformReplay::replayPasC(
+                         producer_tv->domain(), consumer_tv->domain(), -1)
+                         .first;
 
-  // Make the producer_tv look like consumer while performing indexing math
+  // Make the actual producer_tv look like consumer while we do the indexing
+  // math in this function
   ir_utils::TVDomainGuard domain_guard(producer_tv, producerAsC);
 
-  // Map reference tensor to producer
-  std::unordered_map<IterDomain*, IterDomain*> root_ref_to_producer;
-  for (auto p_root : producer_tv->getMaybeRFactorDomain()) {
-    auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(p_root);
-    auto ref_id_it = reference_id_map.find(concrete_id);
-    if (ref_id_it != reference_id_map.end()) {
-      root_ref_to_producer[ref_id_it->second] = p_root;
-    }
-  }
-
-  // Index into the reference tensor. Reference indexing will handle vectorized
-  // dims where index should be set to 0
-  auto ref_compute = getReferenceIndexing(loops, reference_domain);
-
-  // Replay producer as reference to get reference to producer ID map
-  BestEffortReplay replay_producer_as_ref(
-      producer_tv->domain()->domain(),
-      reference_domain->domain(),
-      root_ref_to_producer);
-
-  const auto& ref_2_producer = replay_producer_as_ref.getReplay();
-
-  // Forward vectorized IDs to index into producer correctly
-  // We want p_id to be vectorized like consumer just for the indexing, then we
-  // need to switch it back later. Store previous state here when changing. We
-  // need to do this as replaying producer as consumer can use replay best
-  // effort which means some domains may be the originals.
-  std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup;
-  for (auto entry : ref_2_producer) {
-    auto ref_id = entry.first;
-    auto p_id = entry.second;
-    if (ref_id->getParallelType() == ParallelType::Vectorize) {
-      p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType()));
-      p_id->parallelize(ParallelType::Vectorize);
-    } else if (ref_id->getParallelType() == ParallelType::MisalignedVectorize) {
-      p_id->parallelize(ParallelType::MisalignedVectorize);
-    }
-  }
-
-  const auto reference_halo_extent_map = getReferenceHaloExtentMap(
-      reference, consumer_tv, ref_2_producer, ref_compute.extentMap());
+  // grab all tensor views from producer_tv <- computeAtRoot
+  std::deque<TensorView*> tv_stack = getComputeAtTVStackFrom(consumer_tv);
+  tv_stack.push_back(producer_tv);
 
-  // Index into producer using reference indexing
-  auto producer_indexing = ref_compute.updateIndexCompute(
-      producer_tv->domain(),
-      ref_2_producer,
-      producer_tv->domain()->contiguity(),
-      reference_halo_extent_map);
+  std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
+  std::transform(
+      loops.begin(),
+      loops.end(),
+      std::inserter(loop_to_ind_map, loop_to_ind_map.begin()),
+      [](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); });
 
-  // Revert p_ids
-  for (auto entry : p_id_backup) {
-    entry.first->parallelize(entry.second);
-  }
+  auto index_map = generateIndexAndExtentMap(
+                       tv_stack,
+                       std::deque<kir::ForLoop*>(loops.begin(), loops.end()),
+                       loop_to_ind_map,
+                       producer_tv->domain()->contiguity())
+                       .first;
 
   // Indices should now be mapped onto IterDomains in producer, so just grab
   // and use them.
   auto root_dom = producer_tv->getMaybeRFactorDomain();
 
-  // TODO: Abstract stride logic to reuse with consumer indexing
-  auto zero = ir_builder.create<kir::Int>(0);
-  std::vector<kir::Val*> strides(root_dom.size(), nullptr);
-  {
-    int stride_i = 0;
-    for (size_t i = 0; i < root_dom.size(); i++) {
-      if (root_dom[i]->isReduction() ||
-          root_dom[i]->getIterType() == IterType::BroadcastWithoutStride) {
-        strides[i] = zero;
-        continue;
-      }
-      std::stringstream ss;
-      ss << "T" << producer_tv->name() << ".stride[" << stride_i++ << "]";
-      strides[i] = ir_builder.create<kir::NamedScalar>(ss.str(), DataType::Int);
-    }
-  }
-
-  kir::Val* cur_contig_stride = ir_builder.create<kir::Int>(1);
-  // if we have rfactor we can't simplify the indexing like this, we would need
-  // to fix contiguity size to be rfactor size not root size
-  if (root_dom.size() == producer_tv->domain()->contiguity().size()) {
-    for (size_t i = 0; i < root_dom.size(); i++) {
-      auto dim = root_dom.size() - i - 1;
-      if (root_dom[dim]->isReduction()) {
-        continue;
-      }
-      if (root_dom[dim]->getIterType() == IterType::BroadcastWithoutStride) {
-        continue;
-      }
-
-      kir::Val* root_ind = nullptr;
-      auto kir_root_dom =
-          gpu_lower->lowerValue(root_dom[dim])->as<kir::IterDomain>();
-      if (producer_indexing.indexMap().find(kir_root_dom) !=
-          producer_indexing.indexMap().end()) {
-        root_ind = producer_indexing.indexMap().at(kir_root_dom);
-      } else if (
-          root_dom[dim]->getIterType() == IterType::BroadcastWithStride) {
-        root_ind = zero;
-      }
-
-      TORCH_INTERNAL_ASSERT(
-          root_ind != nullptr,
-          "Couldn't find root mapping for TV",
-          producer_tv->name(),
-          " dim: ",
-          i,
-          " id: ",
-          root_dom[dim]);
-
-      if (producer_tv->domain()->contiguity()[dim]) {
-        // If contig, used the stored stride which may be the previous
-        // dimensions stride * previous dimensions size
-        strides[dim] = cur_contig_stride;
-        // Prepare for the next dimension which may also be contiguous, multiply
-        // by extent of this dimension
-        auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]);
-        cur_contig_stride =
-            ir_builder.mulExpr(cur_contig_stride, root_dim_extent);
-      } else {
-        // If non contiguous dimension, keep local stride information, set cur
-        // stride to local stride * local raw extent
-        auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]);
-        cur_contig_stride = ir_builder.mulExpr(strides[dim], root_dim_extent);
-      }
-    }
-  }
-
-  auto vectorize_shift =
-      loops.empty() ? nullptr : loops.back()->vectorize_shift();
+  bool inner_most_dim_contig =
+      root_dom[root_dom.size() - 1]->getIterType() == IterType::Iteration &&
+      producer_tv->domain()->contiguity()[root_dom.size() - 1];
 
   // Global striding
-  std::vector<kir::Val*> strided_inds(root_dom.size(), ir_builder.zeroVal());
+  int64_t stride_i = 0;
+  std::vector<Val*> strided_inds;
   for (const auto i : c10::irange(root_dom.size())) {
-    // If the domain is derived from a trivial reduction, no indexing
-    // to create.
     if (root_dom[i]->isReduction() ||
-        root_dom[i]->getIterType() == IterType::BroadcastWithoutStride ||
-        root_dom[i]->getIterType() == IterType::BroadcastWithStride ||
-        gpu_lower->trivialReductionInfo().isDerived(root_dom[i])) {
+        root_dom[i]->getIterType() == IterType::BroadcastWithoutStride) {
+      continue;
+    } else if (root_dom[i]->getIterType() == IterType::BroadcastWithStride) {
+      stride_i++;
       continue;
     }
 
     auto kir_root_dom_i =
-        gpu_lower->lowerValue(root_dom[i])->as<kir::IterDomain>();
+        GpuLower::lowerValue(root_dom[i])->as<kir::IterDomain>();
 
     TORCH_INTERNAL_ASSERT(
-        producer_indexing.indexMap().find(kir_root_dom_i) !=
-            producer_indexing.indexMap().end(),
+        index_map.find(kir_root_dom_i) != index_map.end(),
         "Couldn't find root mapping for TV",
         producer_tv->name(),
         " dim: ",
@@ -1131,43 +816,34 @@ std::vector<kir::Val*> Index::getGlobalProducerStridedIndices(
         " id: ",
         kir::toString(kir_root_dom_i));
 
-    auto root_ind = producer_indexing.indexMap().at(kir_root_dom_i);
-
-    root_ind = getProducerIndexWithHalo(producer_tv, i, root_ind, consumer_tv);
+    auto root_ind = index_map.at(kir_root_dom_i);
+    TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(root_ind));
 
-    root_ind = getProducerIndexWithGather(
-        i,
-        root_ind,
-        producer_tv,
-        consumer_tv,
-        ref_compute.indexMap(),
-        reference_id_map);
-
-    if (root_ind->isZeroInt()) {
-      continue;
+    if (i == root_dom.size() - 1 && inner_most_dim_contig) {
+      strided_inds.push_back(root_ind);
+    } else if (root_ind->isZeroInt()) {
+      stride_i++;
     } else {
-      auto strided_ind = ir_builder.mulExpr(root_ind, strides[i]);
-      if (i == root_dom.size() - 1 && vectorize_shift != nullptr) {
-        strided_inds[i] = ir_builder.addExpr(strided_ind, vectorize_shift);
-      } else {
-        strided_inds[i] = strided_ind;
-      }
+      std::stringstream ss;
+      ss << "T" << producer_tv->name() << ".stride[" << stride_i++ << "]";
+      strided_inds.push_back(ir_builder.mulExpr(
+          root_ind,
+          ir_builder.create<kir::NamedScalar>(ss.str(), DataType::Int)));
     }
   }
 
-  return strided_inds;
+  if (strided_inds.size() == 0)
+    strided_inds.push_back(ir_builder.create<kir::Int>(0));
+
+  return ir_builder.create<kir::TensorIndex>(producer_tv, strided_inds);
 }
 
 namespace {
 
-// Used for local and shared index mapping
-std::unordered_map<kir::ForLoop*, kir::Val*> indexMapFromTV(
-    const TensorView* tv,
-    const std::vector<kir::ForLoop*>& loops,
-    const std::pair<kir::ForLoop*, int64_t>& alloc_point) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
+std::unordered_map<kir::ForLoop*, Val*> indexMapFromTV(
+    TensorView* tv,
+    const std::vector<kir::ForLoop*>& loops) {
+  auto alloc_point = loop_utils::getAllocPoint(tv, loops);
   auto alloc_loop = alloc_point.first;
 
   bool within_alloc = false;
@@ -1175,35 +851,26 @@ std::unordered_map<kir::ForLoop*, kir::Val*> indexMapFromTV(
     within_alloc = true;
   }
 
-  const auto zero = ir_builder.create<kir::Int>(0);
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+  Val* zero = ir_builder.create<kir::Int>(0);
 
-  const bool is_global = tv->getMemoryType() == MemoryType::Global;
-  const bool is_shared = tv->getMemoryType() == MemoryType::Shared;
-  const bool is_local = tv->getMemoryType() == MemoryType::Local;
+  bool is_shared = tv->getMemoryType() == MemoryType::Shared;
+  bool is_local = tv->getMemoryType() == MemoryType::Local;
 
-  std::unordered_map<kir::ForLoop*, kir::Val*> loop_to_ind_map;
+  std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
 
   for (auto loop : loops) {
-    kir::Val* idx = nullptr;
-    // See also LoopNestGenerator::pushAlloc.
     // NOLINTNEXTLINE(bugprone-branch-clone)
     if (!within_alloc) {
-      if ((loop->iter_domain()->isThreadDim() && is_shared) ||
-          (loop->iter_domain()->isThread() && is_global)) {
-        idx = loop->index();
-      } else {
-        idx = zero;
-      }
-    } else if (
-        (loop->iter_domain()->isBlockDim() && is_shared) ||
-        (loop->iter_domain()->isThread() && is_local) || loop->vectorize()) {
-      idx = zero;
+      loop_to_ind_map[loop] = zero;
+    } else if (loop->iter_domain()->isBlockDim() && is_shared) {
+      loop_to_ind_map[loop] = zero;
+    } else if (loop->iter_domain()->isThread() && is_local) {
+      loop_to_ind_map[loop] = zero;
     } else {
-      idx = loop->index();
+      loop_to_ind_map[loop] = loop->index();
     }
 
-    loop_to_ind_map[loop] = idx;
-
     if (!within_alloc && loop == alloc_loop) {
       within_alloc = true;
     }
@@ -1215,189 +882,50 @@ std::unordered_map<kir::ForLoop*, kir::Val*> indexMapFromTV(
 } // namespace
 
 // Producer index for either shared or local memory
-std::vector<kir::Val*> Index::getNonGlobalProducerStridedIndices(
+kir::TensorIndex* Index::getProducerIndex_impl(
     TensorView* producer_tv,
-    const TensorView* consumer_tv,
+    TensorView* consumer_tv,
     const std::vector<kir::ForLoop*>& loops) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  // Get a reference tensor replayed as existing loop structure
-  auto reference = IndexReferenceReplay::getReference(loops);
-  auto reference_domain = reference.domain;
-  auto reference_id_map = reference.concrete_to_id;
-
-  // Replay producer to look like consumer so we can index on producer since our
-  // loop nests look like consumer
-  auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv);
-  auto producer_replayed_as_consumer =
-      TransformReplay::replayPasC(producer_tv, consumer_tv, -1, pairwise_map)
-          .first;
-
-  ir_utils::TVDomainGuard domain_guard(
-      producer_tv, producer_replayed_as_consumer);
-
-  //  We want to play producer as consumer instead of the other way around since
-  //  consumer may have some broadcasted axes producer doesn't have merged into
-  //  loops producer may use. If we did consumer as producer we wouldn't have
-  //  this information in the mapping.
-  auto replay_PasC =
-      BestEffortReplay::replayPasC(producer_tv, consumer_tv, -1, pairwise_map);
-
-  auto c2p_map = replay_PasC.getReplay();
-
-  // Grab consumer domain entries and reverse replay map. TODO: Maybe
-  // TransformReplay::replayPasC could return this map
-  decltype(c2p_map) p2c_map;
-  for (auto id : consumer_tv->domain()->domain()) {
-    auto c2p_it = c2p_map.find(id);
-    if (c2p_it != c2p_map.end()) {
-      auto c_id = c2p_it->first;
-      auto p_id = c2p_it->second;
-      p2c_map[p_id] = c_id;
-    }
-  }
-
-  // Find allocation point of producer relative to loop nests. P2C map is
-  // required because producer was replayed as consumer, so we can't use the
-  // regular compute at maps to line up its iter domains with the for loops.
-  auto alloc_point =
-      loop_utils::getAllocPoint(producer_tv, loops, p2c_map, true);
-  std::unordered_map<kir::ForLoop*, kir::Val*> loop_to_ind_map =
-      indexMapFromTV(producer_tv, loops, alloc_point);
-
-  // Map loop nests to indicies, zeroing out those not used due to locality of
-  // memory
-  std::unordered_map<kir::IterDomain*, kir::Val*> ref_id_to_ind_map;
-
-  // Due to rfactor/initialization reference_domain may be bigger than loop nest
-  // structure, ignore IterDomains that aren't present in the loop nest when
-  // indexing reference.
-  TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims());
-  for (size_t loop_i = 0; loop_i < loops.size(); loop_i++) {
-    auto ref_axis = gpu_lower->lowerValue(reference_domain->axis(loop_i))
-                        ->as<kir::IterDomain>();
-    ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]];
-  }
-
-  // Map reference tensor to producer
-  std::unordered_map<IterDomain*, IterDomain*> root_ref_to_producer;
-  for (auto p_root : producer_tv->getMaybeRFactorDomain()) {
-    auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(p_root);
-    auto ref_id_it = reference_id_map.find(concrete_id);
-    if (ref_id_it != reference_id_map.end()) {
-      root_ref_to_producer[ref_id_it->second] = p_root;
-    }
-  }
-
-  // Grab roots that map into producer and save them into the preferred roots
-  // set for references indexing
-  std::unordered_set<IterDomain*> preferred_roots;
-  for (auto entry : root_ref_to_producer) {
-    if (entry.second->isBroadcast() || entry.second->isReduction()) {
-      continue;
-    }
-    preferred_roots.emplace(entry.first);
-  }
-
-  // Make sure propagation of indexing while mixing with 0 indicies we propagate
-  // in a way that the producer will be able to see what's going on (propagating
-  // into common roots of reference and producer).
-  auto preferred_paths = buildPreferredPaths(reference_domain, preferred_roots);
-
-  // Index into the reference tensor
-  auto ref_compute = getReferenceIndexing(
-      loops, reference_domain, ref_id_to_ind_map, preferred_paths);
-
-  // Directly replay the producer as the reference to get the mapping of
-  // reference to producer we will use to map the indexing into producer
-  BestEffortReplay replay_producer_as_ref(
-      producer_tv->domain()->domain(),
-      reference_domain->domain(),
-      root_ref_to_producer);
-
-  const auto& ref_2_producer = replay_producer_as_ref.getReplay();
-
-  // Forward vectorized IDs to index into producer correctly
-  // We want p_id to be vectorized like consumer just for the indexing, then we
-  // need to switch it back later. Store previous state here when changing. We
-  // need to do this as replaying producer as consumer can use replay best
-  // effort which means some domains may be the originals.
-  std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup;
-  for (auto entry : ref_2_producer) {
-    auto ref_id = entry.first;
-    auto p_id = entry.second;
-    if (ref_id->getParallelType() == ParallelType::Vectorize) {
-      p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType()));
-      p_id->parallelize(ParallelType::Vectorize);
-    } else if (ref_id->getParallelType() == ParallelType::MisalignedVectorize) {
-      p_id->parallelize(ParallelType::MisalignedVectorize);
-    }
-  }
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
 
-  // Index into producer using reference indexing
+  // producer_tv->domain() is not replayed as the loop strucutre we were
+  // provided, so replay it to match consumer_tv which is.
+  auto producerAsC = TransformReplay::replayPasC(
+                         producer_tv->domain(), consumer_tv->domain(), -1)
+                         .first;
 
-  const auto reference_halo_extent_map = getReferenceHaloExtentMap(
-      reference, consumer_tv, ref_2_producer, ref_compute.extentMap());
-
-  auto producer_indexing = ref_compute.updateIndexCompute(
-      producer_tv->domain(),
-      ref_2_producer,
-      producer_tv->domain()->contiguity(),
-      reference_halo_extent_map);
-
-  // Revert p_ids
-  for (auto entry : p_id_backup) {
-    entry.first->parallelize(entry.second);
-  }
+  // Set producer_tv with the domain replayed as consumer to grab the right
+  // indices. The guard will reset the domain when this scope ends.
+  ir_utils::TVDomainGuard domain_guard(producer_tv, producerAsC);
 
-  IndexSwizzle index_swizzle(
-      producer_tv,
-      producer_indexing.indexMap(),
-      producer_indexing.extentMap(),
-      producer_indexing.zeroMergedIn());
+  // grab all tensor views from producer_tv <- computeAtRoot
+  std::deque<TensorView*> tv_stack = getComputeAtTVStackFrom(consumer_tv);
+  tv_stack.push_back(producer_tv);
 
-  index_swizzle.run();
+  std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map =
+      indexMapFromTV(producer_tv, loops);
 
-  auto index_map = index_swizzle.indexMap();
-  auto extent_map = producer_indexing.extentMap();
+  auto index_and_extent_map = generateIndexAndExtentMap(
+      tv_stack,
+      std::deque<kir::ForLoop*>(loops.begin(), loops.end()),
+      loop_to_ind_map,
+      std::vector<bool>(producer_tv->getRootDomain().size(), false));
+  auto index_map = index_and_extent_map.first;
+  auto extent_map = index_and_extent_map.second;
 
   // Indices should now be mapped onto IterDomains in producer, so just grab
   // and use them.
   auto root_dom = producer_tv->getMaybeRFactorDomain();
 
-  // Figure out which root axes we don't need to index
-  std::unordered_set<IterDomain*> skip_indexing;
-
-  for (auto root_id : root_dom) {
-    // Already taken care of because we can detect no indexing required
-    if (root_id->isBroadcast() || root_id->isReduction() ||
-        gpu_lower->trivialReductionInfo().isDerived(root_id)) {
-      skip_indexing.insert(root_id);
-      continue;
-    }
-
-    // Already an entry for this root domain, continue
-    if (index_map.find(gpu_lower->lowerValue(root_id)->as<kir::IterDomain>()) !=
-        index_map.end()) {
-      continue;
-    }
-
-    // Maps to consumers trivial reduction, don't index
-    if (p2c_map.find(root_id) != p2c_map.end() &&
-        gpu_lower->trivialReductionInfo().isDerived(p2c_map.at(root_id))) {
-      skip_indexing.emplace(root_id);
-    }
-  }
+  std::vector<Val*> strided_inds;
 
-  std::vector<kir::Val*> strided_inds(root_dom.size(), ir_builder.zeroVal());
   for (const auto i : c10::irange(root_dom.size())) {
-    if (skip_indexing.count(root_dom[i])) {
+    if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast()) {
       continue;
     }
 
     auto kir_root_dom_i =
-        gpu_lower->lowerValue(root_dom[i])->as<kir::IterDomain>();
+        GpuLower::lowerValue(root_dom[i])->as<kir::IterDomain>();
 
     TORCH_INTERNAL_ASSERT(
         index_map.find(kir_root_dom_i) != index_map.end(),
@@ -1409,34 +937,25 @@ std::vector<kir::Val*> Index::getNonGlobalProducerStridedIndices(
         kir::toString(kir_root_dom_i));
 
     auto root_ind_i = index_map.at(kir_root_dom_i);
-
-    root_ind_i =
-        getProducerIndexWithHalo(producer_tv, i, root_ind_i, consumer_tv);
-
-    root_ind_i = getProducerIndexWithGather(
-        i,
-        root_ind_i,
-        producer_tv,
-        consumer_tv,
-        ref_compute.indexMap(),
-        reference_id_map);
+    TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(root_ind_i));
 
     if (root_ind_i->isZeroInt()) {
       continue;
     }
 
     // Compute striding for this index.
-    kir::Val* stride = nullptr;
+    Val* stride = nullptr;
     for (size_t j = i + 1; j < root_dom.size(); j++) {
-      if (skip_indexing.count(root_dom[j])) {
+      if (root_dom[j]->isBroadcast() || root_dom[j]->isReduction()) {
         continue;
       }
 
       auto kir_root_dom_j =
-          gpu_lower->lowerValue(root_dom[j])->as<kir::IterDomain>();
+          GpuLower::lowerValue(root_dom[j])->as<kir::IterDomain>();
 
       TORCH_INTERNAL_ASSERT(
-          index_map.find(kir_root_dom_j) != index_map.end(),
+          index_map.find(kir_root_dom_j) != index_map.end() &&
+              extent_map.find(kir_root_dom_j) != extent_map.end(),
           "Couldn't find root mapping for TV",
           consumer_tv->name(),
           " dim: ",
@@ -1445,11 +964,9 @@ std::vector<kir::Val*> Index::getNonGlobalProducerStridedIndices(
           root_dom[i]);
 
       auto root_ind_j = index_map.at(kir_root_dom_j);
-      auto root_ext_j = extent_map.find(kir_root_dom_j) == extent_map.end()
-          ? kir_root_dom_j->extent()
-          : extent_map.at(kir_root_dom_j);
+      auto root_ext_j = extent_map.at(kir_root_dom_j);
 
-      root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j);
+      TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(root_ext_j));
 
       if (!root_ind_j->isZeroInt()) {
         if (stride == nullptr) {
@@ -1461,273 +978,125 @@ std::vector<kir::Val*> Index::getNonGlobalProducerStridedIndices(
     }
 
     if (stride != nullptr) {
-      strided_inds[i] = ir_builder.mulExpr(root_ind_i, stride);
+      strided_inds.push_back(ir_builder.mulExpr(root_ind_i, stride));
     } else {
-      strided_inds[i] = root_ind_i;
+      strided_inds.push_back(root_ind_i);
     }
   }
 
-  return strided_inds;
+  if (strided_inds.size() == 0)
+    strided_inds.push_back(ir_builder.create<kir::Int>(0));
+
+  return ir_builder.create<kir::TensorIndex>(producer_tv, strided_inds);
 }
 
-std::vector<kir::Val*> Index::getGlobalConsumerStridedIndices(
-    const TensorView* consumer_tv,
+kir::TensorIndex* Index::getGlobalConsumerIndex(
+    TensorView* consumer_tv,
     const std::vector<kir::ForLoop*>& loops) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex");
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  // Get a reference tensor replayed as existing loop structure
-  auto reference = IndexReferenceReplay::getReference(loops);
-  auto reference_domain = reference.domain;
-  auto reference_id_map = reference.concrete_to_id;
-
-  // Map reference tensor to consumer
-  std::unordered_map<IterDomain*, IterDomain*> root_ref_to_consumer;
-  for (auto c_root : consumer_tv->getMaybeRFactorDomain()) {
-    auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(c_root);
-    auto ref_id_it = reference_id_map.find(concrete_id);
-    if (ref_id_it != reference_id_map.end()) {
-      root_ref_to_consumer[ref_id_it->second] = c_root;
-    }
-  }
-
-  BestEffortReplay replay_consumer_as_ref(
-      consumer_tv->domain()->domain(),
-      reference_domain->domain(),
-      root_ref_to_consumer);
+  FUSER_PERF_SCOPE("getGlobalConsumerIndex");
 
-  const auto& ref_2_consumer = replay_consumer_as_ref.getReplay();
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
 
-  // Index into the reference tensor. Reference indexing will handle vectorized
-  // dims where index should be set to 0
-  auto ref_compute = getReferenceIndexing(loops, reference_domain);
+  // grab all tensor views from producer_tv <- computeAtRoot
+  std::deque<TensorView*> tv_stack = getComputeAtTVStackFrom(consumer_tv);
 
-  // Index into consumer using reference indexing
-
-  const auto reference_halo_extent_map = getReferenceHaloExtentMap(
-      reference, consumer_tv, ref_2_consumer, ref_compute.extentMap());
+  std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
+  std::transform(
+      loops.begin(),
+      loops.end(),
+      std::inserter(loop_to_ind_map, loop_to_ind_map.begin()),
+      [](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); });
 
-  auto consumer_indexing = ref_compute.updateIndexCompute(
-      consumer_tv->domain(),
-      ref_2_consumer,
-      consumer_tv->domain()->contiguity(),
-      reference_halo_extent_map);
+  auto index_map = generateIndexAndExtentMap(
+                       tv_stack,
+                       std::deque<kir::ForLoop*>(loops.begin(), loops.end()),
+                       loop_to_ind_map,
+                       consumer_tv->domain()->contiguity())
+                       .first;
 
   // Indices should now be mapped onto IterDomains in consumer, so just grab
   // and use them.
   auto root_dom = consumer_tv->getMaybeRFactorDomain();
 
-  // TODO: Abstract stride logic to reuse with producer indexing
-  auto zero = ir_builder.zeroVal();
-  std::vector<kir::Val*> strides(root_dom.size(), zero);
-  {
-    int stride_i = 0;
-    for (size_t i = 0; i < root_dom.size(); i++) {
-      if (root_dom[i]->isReduction() ||
-          root_dom[i]->getIterType() == IterType::BroadcastWithoutStride) {
-        strides[i] = zero;
-        continue;
-      }
-      std::stringstream ss;
-      ss << "T" << consumer_tv->name() << ".stride[" << stride_i++ << "]";
-      strides[i] = ir_builder.create<kir::NamedScalar>(ss.str(), DataType::Int);
-    }
-  }
-
-  kir::Val* cur_contig_stride = ir_builder.oneVal();
-  // if we have rfactor we can't simplify the indexing like this, we would need
-  // to fix contiguity size to be rfactor size not root size
-  if (root_dom.size() == consumer_tv->domain()->contiguity().size()) {
-    for (size_t i = 0; i < root_dom.size(); i++) {
-      auto dim = root_dom.size() - i - 1;
-      if (root_dom[dim]->isReduction()) {
-        continue;
-      }
-      if (root_dom[dim]->getIterType() == IterType::BroadcastWithoutStride) {
-        continue;
-      }
-
-      kir::Val* root_ind = nullptr;
-      auto kir_root_dom =
-          gpu_lower->lowerValue(root_dom[dim])->as<kir::IterDomain>();
-      if (consumer_indexing.indexMap().find(kir_root_dom) !=
-          consumer_indexing.indexMap().end()) {
-        root_ind = consumer_indexing.indexMap().at(kir_root_dom);
-      } else if (
-          root_dom[dim]->getIterType() == IterType::BroadcastWithStride) {
-        root_ind = zero;
-      }
+  bool inner_most_dim_contig =
+      root_dom[root_dom.size() - 1]->getIterType() == IterType::Iteration &&
+      consumer_tv->domain()->contiguity()[root_dom.size() - 1];
 
-      TORCH_INTERNAL_ASSERT(
-          root_ind != nullptr,
-          "Couldn't find root mapping for TV",
-          consumer_tv->name(),
-          " dim: ",
-          i,
-          " id: ",
-          root_dom[dim]);
-
-      if (consumer_tv->domain()->contiguity()[dim]) {
-        // If contig, used the stored stride which may be the previous
-        // dimensions stride * previous dimensions size
-        strides[dim] = cur_contig_stride;
-        // Prepare for the next dimension which may also be contiguous, multiply
-        // by extent of this dimension
-        auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]);
-        cur_contig_stride =
-            ir_builder.mulExpr(cur_contig_stride, root_dim_extent);
-      } else {
-        // If non contiguous dimension, keep local stride information, set cur
-        // stride to local stride * local raw extent
-        cur_contig_stride = ir_builder.mulExpr(
-            strides[dim], getHaloExtentOfRootAxis(root_dom[dim]));
-      }
-    }
-  }
-
-  auto vectorize_shift =
-      loops.empty() ? nullptr : loops.back()->vectorize_shift();
-
-  // Global striding
-  std::vector<kir::Val*> strided_inds(root_dom.size(), ir_builder.zeroVal());
+  int64_t stride_i = 0;
+  std::vector<Val*> strided_inds;
   for (const auto i : c10::irange(root_dom.size())) {
-    // See a comment in indexing to root domains in getGlobalProducerIndex.
     if (root_dom[i]->isReduction() ||
-        root_dom[i]->getIterType() == IterType::BroadcastWithoutStride ||
-        root_dom[i]->getIterType() == IterType::BroadcastWithStride ||
-        gpu_lower->trivialReductionInfo().isDerived(root_dom[i])) {
+        root_dom[i]->getIterType() == IterType::BroadcastWithoutStride) {
+      continue;
+    } else if (root_dom[i]->getIterType() == IterType::BroadcastWithStride) {
+      stride_i++;
       continue;
     }
 
     auto kir_root_dom_i =
-        gpu_lower->lowerValue(root_dom[i])->as<kir::IterDomain>();
+        GpuLower::lowerValue(root_dom[i])->as<kir::IterDomain>();
 
     TORCH_INTERNAL_ASSERT(
-        consumer_indexing.indexMap().find(kir_root_dom_i) !=
-            consumer_indexing.indexMap().end(),
+        index_map.find(kir_root_dom_i) != index_map.end(),
         "Couldn't find root mapping for TV",
         consumer_tv->name(),
         " dim: ",
         i,
         " id: ",
         kir::toString(kir_root_dom_i));
+    auto ind = index_map.at(kir_root_dom_i);
 
-    auto root_ind = consumer_indexing.indexMap().at(kir_root_dom_i);
-
-    if (root_ind->isZeroInt()) {
-      continue;
+    if (i == root_dom.size() - 1 && inner_most_dim_contig) {
+      strided_inds.push_back(ind);
+    } else if (ind->isZeroInt()) {
+      stride_i++;
     } else {
-      auto strided_ind = ir_builder.mulExpr(root_ind, strides[i]);
-      if (i == root_dom.size() - 1 && vectorize_shift != nullptr) {
-        strided_inds[i] = ir_builder.addExpr(strided_ind, vectorize_shift);
-      } else {
-        strided_inds[i] = strided_ind;
-      }
+      std::stringstream ss;
+      ss << "T" << consumer_tv->name() << ".stride[" << stride_i++ << "]";
+      strided_inds.push_back(ir_builder.mulExpr(
+          ind, ir_builder.create<kir::NamedScalar>(ss.str(), DataType::Int)));
     }
   }
 
-  return strided_inds;
+  if (strided_inds.size() == 0)
+    strided_inds.push_back(ir_builder.create<kir::Int>(0));
+
+  return ir_builder.create<kir::TensorIndex>(consumer_tv, strided_inds);
 }
 
 // Consumer index for either shared or local memory
-std::vector<kir::Val*> Index::getNonGlobalConsumerStridedIndices(
-    const TensorView* consumer_tv,
+kir::TensorIndex* Index::getConsumerIndex_impl(
+    TensorView* consumer_tv,
     const std::vector<kir::ForLoop*>& loops) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  // Get a reference tensor replayed as existing loop structure
-  auto reference = IndexReferenceReplay::getReference(loops);
-  auto reference_domain = reference.domain;
-  auto reference_id_map = reference.concrete_to_id;
-
-  auto alloc_point = loop_utils::getAllocPoint(consumer_tv, loops);
-  std::unordered_map<kir::ForLoop*, kir::Val*> loop_to_ind_map =
-      indexMapFromTV(consumer_tv, loops, alloc_point);
-
-  // Map loop nests to indicies, zeroing out those not used due to locality of
-  // memory
-  std::unordered_map<kir::IterDomain*, kir::Val*> ref_id_to_ind_map;
-
-  // Due to rfactor/initialization reference_domain may be bigger than loop nest
-  // structure, ignore IterDomains that aren't present in the loop nest when
-  // indexing reference.
-  TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims());
-  for (size_t loop_i = 0; loop_i < loops.size(); loop_i++) {
-    auto ref_axis = gpu_lower->lowerValue(reference_domain->axis(loop_i))
-                        ->as<kir::IterDomain>();
-    ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]];
-  }
-
-  // Map reference tensor to consumer
-  std::unordered_map<IterDomain*, IterDomain*> root_ref_to_consumer;
-  for (auto c_root : consumer_tv->getMaybeRFactorDomain()) {
-    auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(c_root);
-    auto ref_id_it = reference_id_map.find(concrete_id);
-    if (ref_id_it != reference_id_map.end()) {
-      root_ref_to_consumer[ref_id_it->second] = c_root;
-    }
-  }
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
 
-  // Grab roots that map into consumer and save them into the preferred roots
-  // set for references indexing
-  std::unordered_set<IterDomain*> preferred_roots;
-  for (auto entry : root_ref_to_consumer) {
-    if (entry.second->isBroadcast() || entry.second->isReduction()) {
-      continue;
-    }
-    preferred_roots.emplace(entry.first);
-  }
-
-  // Make sure propagation of indexing while mixing with 0 indicies we propagate
-  // in a way that consumer will be able to see what's going on.
-  auto preferred_paths = buildPreferredPaths(reference_domain, preferred_roots);
-
-  // Index into the reference tensor
-  auto ref_compute = getReferenceIndexing(
-      loops, reference_domain, ref_id_to_ind_map, preferred_paths);
-
-  BestEffortReplay replay_consumer_as_ref(
-      consumer_tv->domain()->domain(),
-      reference_domain->domain(),
-      root_ref_to_consumer);
-
-  const auto& ref_2_consumer = replay_consumer_as_ref.getReplay();
+  // grab all tensor views from consumer_tv <- computeAtRoot
+  std::deque<TensorView*> tv_stack = getComputeAtTVStackFrom(consumer_tv);
 
-  const auto reference_halo_extent_map = getReferenceHaloExtentMap(
-      reference, consumer_tv, ref_2_consumer, ref_compute.extentMap());
+  std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map =
+      indexMapFromTV(consumer_tv, loops);
 
-  // Index into consumer using reference indexing
-  auto consumer_indexing = ref_compute.updateIndexCompute(
-      consumer_tv->domain(),
-      ref_2_consumer,
-      consumer_tv->domain()->contiguity(),
-      reference_halo_extent_map);
+  auto index_and_extent_map = generateIndexAndExtentMap(
+      tv_stack,
+      std::deque<kir::ForLoop*>(loops.begin(), loops.end()),
+      loop_to_ind_map,
+      std::vector<bool>(consumer_tv->getRootDomain().size(), false));
 
-  IndexSwizzle index_swizzle(
-      consumer_tv,
-      consumer_indexing.indexMap(),
-      consumer_indexing.extentMap(),
-      consumer_indexing.zeroMergedIn());
-
-  index_swizzle.run();
-
-  auto index_map = index_swizzle.indexMap();
-  auto extent_map = consumer_indexing.extentMap();
+  auto index_map = index_and_extent_map.first;
+  auto extent_map = index_and_extent_map.second;
 
   // Indices should now be mapped onto IterDomains in consumer, so just grab
   // and use them.
   auto root_dom = consumer_tv->getMaybeRFactorDomain();
-  std::vector<kir::Val*> strided_inds(root_dom.size(), ir_builder.zeroVal());
+
+  std::vector<Val*> strided_inds;
   for (const auto i : c10::irange(root_dom.size())) {
-    if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast() ||
-        gpu_lower->trivialReductionInfo().isDerived(root_dom[i])) {
+    if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast()) {
       continue;
     }
 
     auto kir_root_dom_i =
-        gpu_lower->lowerValue(root_dom[i])->as<kir::IterDomain>();
+        GpuLower::lowerValue(root_dom[i])->as<kir::IterDomain>();
 
     TORCH_INTERNAL_ASSERT(
         index_map.find(kir_root_dom_i) != index_map.end(),
@@ -1737,25 +1106,26 @@ std::vector<kir::Val*> Index::getNonGlobalConsumerStridedIndices(
         i,
         " id: ",
         kir::toString(kir_root_dom_i));
+    auto root_ind_i = index_map.at(kir_root_dom_i);
+    TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(root_ind_i));
 
-    const auto root_ind_i = index_map.at(kir_root_dom_i);
     if (root_ind_i->isZeroInt()) {
       continue;
     }
 
     // Compute striding for this index.
-    kir::Val* stride = nullptr;
+    Val* stride = nullptr;
     for (size_t j = i + 1; j < root_dom.size(); j++) {
-      if (root_dom[j]->isBroadcast() || root_dom[j]->isReduction() ||
-          gpu_lower->trivialReductionInfo().isDerived(root_dom[j])) {
+      if (root_dom[j]->isBroadcast() || root_dom[j]->isReduction()) {
         continue;
       }
 
       auto kir_root_dom_j =
-          gpu_lower->lowerValue(root_dom[j])->as<kir::IterDomain>();
+          GpuLower::lowerValue(root_dom[j])->as<kir::IterDomain>();
 
       TORCH_INTERNAL_ASSERT(
-          index_map.find(kir_root_dom_j) != index_map.end(),
+          index_map.find(kir_root_dom_j) != index_map.end() &&
+              extent_map.find(kir_root_dom_j) != extent_map.end(),
           "Couldn't find root mapping for TV",
           consumer_tv->name(),
           " dim: ",
@@ -1764,12 +1134,8 @@ std::vector<kir::Val*> Index::getNonGlobalConsumerStridedIndices(
           root_dom[i]);
 
       auto root_ind_j = index_map.at(kir_root_dom_j);
-      auto root_ext_j = extent_map.find(kir_root_dom_j) == extent_map.end()
-          ? kir_root_dom_j->extent()
-          : extent_map.at(kir_root_dom_j);
-
-      root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j);
-
+      auto root_ext_j = extent_map.at(kir_root_dom_j);
+      TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(root_ext_j));
       if (!root_ind_j->isZeroInt()) {
         if (stride == nullptr) {
           stride = root_ext_j;
@@ -1780,131 +1146,72 @@ std::vector<kir::Val*> Index::getNonGlobalConsumerStridedIndices(
     }
 
     if (stride != nullptr) {
-      strided_inds[i] = ir_builder.mulExpr(root_ind_i, stride);
+      strided_inds.push_back(ir_builder.mulExpr(root_ind_i, stride));
     } else {
-      strided_inds[i] = root_ind_i;
+      strided_inds.push_back(root_ind_i);
     }
   }
 
-  return strided_inds;
+  if (strided_inds.size() == 0)
+    strided_inds.push_back(ir_builder.create<kir::Int>(0));
+
+  return ir_builder.create<kir::TensorIndex>(consumer_tv, strided_inds);
 }
 
-std::vector<kir::Val*> Index::getProducerStridedIndices(
+// Producer is the inputs of an expression
+kir::TensorIndex* Index::getProducerIndex(
     TensorView* producer,
-    const TensorView* consumer,
+    TensorView* consumer,
     const std::vector<kir::ForLoop*>& loops) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices");
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
+  FUSER_PERF_SCOPE("Index::getProducerIndex");
+
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
 
   if (producer->domain()->noReductions().size() == 0) {
-    return std::vector<kir::Val*>(
-        producer->getMaybeRFactorDomain().size(), ir_builder.zeroVal());
+    return ir_builder.create<kir::TensorIndex>(producer, std::vector<Val*>{});
   }
 
-  std::vector<kir::Val*> strided_indices;
   if (producer->getMemoryType() == MemoryType::Global) {
-    strided_indices =
-        getGlobalProducerStridedIndices(producer, consumer, loops);
-  } else {
-    strided_indices =
-        getNonGlobalProducerStridedIndices(producer, consumer, loops);
+    return getGlobalProducerIndex(producer, consumer, loops);
   }
 
-  TORCH_INTERNAL_ASSERT(
-      strided_indices.size() == producer->getMaybeRFactorDomain().size());
-
-  return strided_indices;
+  return getProducerIndex_impl(producer, consumer, loops);
 }
 
-// Producer is the inputs of an expression
-kir::TensorIndex* Index::getProducerIndex(
-    TensorView* producer,
-    const TensorView* consumer,
+// Consumer is the output of an expression
+kir::TensorIndex* Index::getConsumerIndex(
+    TensorView* consumer,
     const std::vector<kir::ForLoop*>& loops) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  auto strided_indices = getProducerStridedIndices(producer, consumer, loops);
-  return ir_builder.create<kir::TensorIndex>(producer, strided_indices);
-}
+  FUSER_PERF_SCOPE("Index::getConsumerIndex");
 
-std::vector<kir::Val*> Index::getConsumerStridedIndices(
-    const TensorView* consumer,
-    const std::vector<kir::ForLoop*>& loops) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::Index::getConsumerStridedIndices");
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
 
   if (consumer->domain()->noReductions().size() == 0) {
-    return std::vector<kir::Val*>(
-        consumer->getMaybeRFactorDomain().size(), ir_builder.zeroVal());
+    return ir_builder.create<kir::TensorIndex>(consumer, std::vector<Val*>{});
   }
 
-  std::vector<kir::Val*> strided_indices;
   if (consumer->getMemoryType() == MemoryType::Global) {
-    strided_indices = getGlobalConsumerStridedIndices(consumer, loops);
-  } else {
-    strided_indices = getNonGlobalConsumerStridedIndices(consumer, loops);
+    return getGlobalConsumerIndex(consumer, loops);
   }
 
-  TORCH_INTERNAL_ASSERT(
-      strided_indices.size() == consumer->getMaybeRFactorDomain().size());
-
-  return strided_indices;
-}
-
-// Consumer is the output of an expression
-kir::TensorIndex* Index::getConsumerIndex(
-    const TensorView* consumer,
-    const std::vector<kir::ForLoop*>& loops) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  auto strided_indices = getConsumerStridedIndices(consumer, loops);
-  return ir_builder.create<kir::TensorIndex>(consumer, strided_indices);
+  return getConsumerIndex_impl(consumer, loops);
 }
 
 // Basically just copy getGlobalConsumerIndex, just don't do the striding and
 // return std::vector of Vals
-//
-// TODO(kir): replace pair with struct
-//
-std::pair<std::vector<kir::Val*>, bool> Index::getConsumerRootPredIndices(
-    const kir::TensorView* kir_consumer_tv,
+std::pair<std::vector<Val*>, bool> Index::getConsumerRootPredIndices(
+    TensorView* consumer_tv,
     const std::vector<kir::ForLoop*>& loops,
     const std::vector<bool>& root_contiguity,
-    bool unswitch) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::Index::getConsumerRootPredIndices");
-
-  auto consumer_tv = kir_consumer_tv->fuserTv();
-
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  // Get a reference tensor replayed as existing loop structure
-  ReferenceTensor reference = IndexReferenceReplay::getReference(loops);
-  auto reference_domain = reference.domain;
-  auto reference_id_map = reference.concrete_to_id;
-
-  // Map reference tensor to consumer
-  std::unordered_map<IterDomain*, IterDomain*> root_ref_to_consumer;
-  for (auto c_root : consumer_tv->getMaybeRFactorDomain()) {
-    auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(c_root);
-    auto ref_id_it = reference_id_map.find(concrete_id);
-    if (ref_id_it != reference_id_map.end()) {
-      root_ref_to_consumer[ref_id_it->second] = c_root;
-    }
-  }
+    bool unroll) {
+  FUSER_PERF_SCOPE("Index::getConsumerRootPredIndices");
 
-  BestEffortReplay replay_consumer_as_ref(
-      consumer_tv->domain()->domain(),
-      reference_domain->domain(),
-      root_ref_to_consumer);
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
 
-  const auto& ref_2_consumer = replay_consumer_as_ref.getReplay();
+  // grab all tensor views from producer_tv <- computeAtRoot
+  std::deque<TensorView*> tv_stack = getComputeAtTVStackFrom(consumer_tv);
 
-  std::unordered_map<kir::ForLoop*, kir::Val*> loop_to_ind_map;
+  std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
 
   std::transform(
       loops.begin(),
@@ -1912,397 +1219,70 @@ std::pair<std::vector<kir::Val*>, bool> Index::getConsumerRootPredIndices(
       std::inserter(loop_to_ind_map, loop_to_ind_map.begin()),
       [](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); });
 
-  if (unswitch) {
-    bool within_unswitch = false;
-    const auto one = ir_builder.create<kir::Int>(1);
+  if (unroll) {
+    bool within_unroll = false;
+    Val* one = ir_builder.create<kir::Int>(1);
     for (auto loop : loops) {
-      if (loop->iter_domain()->parallelType() == ParallelType::Unroll ||
-          loop->iter_domain()->parallelType() == ParallelType::Unswitch ||
-          loop->iter_domain()->parallelType() == ParallelType::Vectorize) {
-        within_unswitch = true;
+      if (loop->iter_domain()->getParallelType() == ParallelType::Unroll) {
+        within_unroll = true;
       }
 
-      if (within_unswitch) {
-        if (loop->iter_domain()->isThread()) {
-          loop_to_ind_map[loop] = loop->start();
-        } else {
-          loop_to_ind_map[loop] = ir_builder.subExpr(loop->stop(), one);
-        }
+      if (within_unroll && !loop->iter_domain()->isThread()) {
+        loop_to_ind_map[loop] =
+            ir_builder.subExpr(loop->iter_domain()->extent(), one);
       }
     }
   }
 
-  std::unordered_map<kir::IterDomain*, kir::Val*> ref_id_to_ind_map;
-  // Due to rfactor/initialization reference_domain may be bigger than loop nest
-  // structure
-  TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims());
-  for (size_t loop_i = 0; loop_i < loops.size(); loop_i++) {
-    auto ref_axis = gpu_lower->lowerValue(reference_domain->axis(loop_i))
-                        ->as<kir::IterDomain>();
-    ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]];
-  }
-
-  // Index into the reference tensor
-  auto ref_compute =
-      getReferenceIndexing(loops, reference_domain, ref_id_to_ind_map, {});
-
-  const auto reference_halo_extent_map = getReferenceHaloExtentMap(
-      reference, consumer_tv, ref_2_consumer, ref_compute.extentMap());
-
-  // Index into consumer using reference indexing
-  auto consumer_indexing = ref_compute.updateIndexCompute(
-      consumer_tv->domain(),
-      ref_2_consumer,
-      root_contiguity,
-      reference_halo_extent_map);
+  // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
+  auto index_map = generateIndexAndExtentMap(
+                       tv_stack,
+                       std::deque<kir::ForLoop*>(loops.begin(), loops.end()),
+                       loop_to_ind_map,
+                       root_contiguity)
+                       .first;
 
   // Indices should now be mapped onto IterDomains in consumer, so just grab
   // and use them.
 
-  // If we are generating a predicate for initialization, we should use
-  // rfactor instead of root_dom. If we are generating a predicate for
-  // actual reduction expr, reduction axes should have their indices
-  // mapped to non-zero symbolic vals.
-  bool buffer_init = false;
-  for (auto consumer_id : kir_consumer_tv->domain()->domain()) {
-    if (consumer_id->isReduction()) {
-      if (consumer_indexing.indexMap().find(consumer_id) !=
-          consumer_indexing.indexMap().end()) {
-        if (!consumer_indexing.indexMap().at(consumer_id)->isZeroInt()) {
-          buffer_init = false;
-          break;
+  // If we are generating a predicate for initialization check if we should use
+  // rfactor instead of root_dom
+  bool use_rfactor = true;
+  if (consumer_tv->hasRFactor()) {
+    auto rfactor_dom = consumer_tv->getMaybeRFactorDomain();
+    for (auto rfactor_id : rfactor_dom) {
+      if (rfactor_id->isReduction()) {
+        auto kir_rfactor_id =
+            GpuLower::lowerValue(rfactor_id)->as<kir::IterDomain>();
+        if (index_map.find(kir_rfactor_id) != index_map.end()) {
+          if (!index_map.at(kir_rfactor_id)->isZeroInt()) {
+            use_rfactor = false;
+            break;
+          }
         }
       }
-      buffer_init = true;
     }
   }
 
-  // If we are initializing a reduction buffer and the tensor has a
-  // rfactor root, the predicate should be based on the rfactor root.
-  const auto root_domain =
-      (buffer_init && kir_consumer_tv->domain()->hasRFactor())
-      ? kir_consumer_tv->domain()->rfactorDomain()
-      : kir_consumer_tv->domain()->rootDomain();
-
-  const auto zero = ir_builder.create<kir::Int>(0);
-  std::vector<kir::Val*> root_inds(root_domain.size(), zero);
+  auto root_dom = use_rfactor ? consumer_tv->getMaybeRFactorDomain()
+                              : consumer_tv->getRootDomain();
 
-  for (const auto i : c10::irange(root_domain.size())) {
-    if (root_domain[i]->isBroadcast() ||
-        gpu_lower->trivialReductionInfo().isDerived(root_domain[i])) {
-      continue;
-    }
-    const auto it = consumer_indexing.indexMap().find(root_domain[i]);
-    if (it != consumer_indexing.indexMap().end()) {
-      root_inds[i] = it->second;
-    }
-  }
-
-  return {root_inds, buffer_init};
-}
-
-namespace {
-struct PredicateContigInfo {
- public:
-  // Iteration domain that is only comprised of merge transformations
-  IterDomain* contig_id;
-  // The set of root iteration domains that make up the contig_id
-  std::unordered_set<IterDomain*> root_ids;
-};
-
-// Find iteration domains in the history of reference comprised only of
-// merge operations. Only return iteration domains that are subsequently fed
-// into a split, or are in the provided domain. In other words, we don't want to
-// return every IterDomain that's contiguous, just the one closest to the
-// leaves. Predicates are not associated with physical memory so we can treat
-// all of them as contiguous merges.
-std::vector<PredicateContigInfo> getPredicateContigIds(
-    std::vector<IterDomain*> reference_domain) {
-  auto root_vals = IterVisitor::getInputsTo(
-      {reference_domain.begin(), reference_domain.end()});
-  auto root_ids = ir_utils::filterByType<IterDomain>(root_vals);
-
-  // Mark all roots as being originally "contiguous"
-  std::vector<IterDomain*> contiguous_ids(root_ids.begin(), root_ids.end());
-
-  // Dereference root_vals.begin below, so make sure there's at least one entry
-  if (root_vals.empty()) {
-    return std::vector<PredicateContigInfo>();
-  }
-
-  // Run through iteration domain history
-  auto exprs = ExprSort::getExprs(
-      (*root_vals.begin())->fusion(),
-      {reference_domain.begin(), reference_domain.end()});
-
-  for (auto expr : exprs) {
-    // If not a merge, output is not contiguous
-    if (expr->isA<Merge>()) {
-      auto merge = expr->as<Merge>();
-      auto inner_contig_it = std::find(
-          contiguous_ids.begin(), contiguous_ids.end(), merge->inner());
-      auto outer_contig_it = std::find(
-          contiguous_ids.begin(), contiguous_ids.end(), merge->outer());
-
-      if (inner_contig_it != contiguous_ids.end() &&
-          outer_contig_it != contiguous_ids.end()) {
-        // If inner and outer are contiguous, out must be contiguous. Remove
-        // inner and outer, and add out.
-        contiguous_ids.erase(outer_contig_it);
-        contiguous_ids.erase(std::find(
-            contiguous_ids.begin(), contiguous_ids.end(), merge->inner()));
-        contiguous_ids.emplace_back(merge->out());
-      }
-    }
-  }
-
-  std::vector<PredicateContigInfo> contig_id_infos;
-
-  // Create entries and return them
-  for (auto contig_id : contiguous_ids) {
-    auto contig_root_vals = IterVisitor::getInputsTo({contig_id});
-    auto contig_root_ids = ir_utils::filterByType<IterDomain>(contig_root_vals);
-    PredicateContigInfo contig_id_info;
-    contig_id_info.contig_id = contig_id;
-    contig_id_info.root_ids = std::unordered_set<IterDomain*>(
-        contig_root_ids.begin(), contig_root_ids.end());
-    contig_id_infos.push_back(contig_id_info);
-  }
-  return contig_id_infos;
-}
-
-} // namespace
-
-// Returns predicates and the concrete (by loop map) root domains they cover
-std::pair<std::vector<kir::Bool*>, std::vector<std::unordered_set<IterDomain*>>>
-Index::getReferenceRootPredicates(
-    const kir::TensorView* kir_consumer_tv,
-    const std::vector<kir::ForLoop*>& loops,
-    bool unswitch) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::Index::getReferenceRootPredicates");
-
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  // Get a reference tensor replayed as existing loop structure
-  ReferenceTensor reference = IndexReferenceReplay::getReference(loops);
-  auto reference_domain = reference.domain;
-  auto reference_id_map = reference.concrete_to_id;
-
-  std::unordered_map<kir::ForLoop*, kir::Val*> loop_to_ind_map;
-
-  std::transform(
-      loops.begin(),
-      loops.end(),
-      std::inserter(loop_to_ind_map, loop_to_ind_map.begin()),
-      [](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); });
-
-  // If unswitch don't directly use indices from for loop, use for loop extent
-  // minus 1
-  if (unswitch) {
-    TORCH_INTERNAL_ASSERT(
-        loops.size() <= reference_domain->nDims(),
-        "Invalid reference generated.");
-    bool within_unswitch = false;
-    const auto one = ir_builder.create<kir::Int>(1);
-    for (size_t loop_i = 0; loop_i < loops.size(); loop_i++) {
-      auto loop = loops[loop_i];
-      auto ref_id = reference_domain->axis(loop_i);
-      if (loop->iter_domain()->parallelType() == ParallelType::Unroll ||
-          loop->iter_domain()->parallelType() == ParallelType::Unswitch ||
-          loop->iter_domain()->parallelType() == ParallelType::Vectorize) {
-        within_unswitch = true;
-      }
-
-      if (within_unswitch) {
-        // Rely on the reference to check broadcasting. The for loop could be
-        // broadcasted on a constant value from an unroll split. Since reference
-        // may convert this to an iter domain, that for loop could be valid to
-        // generate predication from.
-        if (ref_id->isBroadcast()) {
-          // Ignore indexing into broadcasted dimensions.
-          continue;
-        } else if (loop->iter_domain()->isThread()) {
-          loop_to_ind_map[loop] = loop->start();
-        } else {
-          loop_to_ind_map[loop] = ir_builder.subExpr(loop->stop(), one);
-        }
-      }
-    }
-  }
-
-  // Add magic zero to a loop pretty far inside in indexing
-  kir::IterDomain* magic_zero_loop = nullptr;
-  std::unordered_map<kir::IterDomain*, kir::Val*> ref_id_to_ind_map;
-  // Due to rfactor/initialization reference_domain may be bigger than loop nest
-  // structure
-  TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims());
-  for (size_t loop_i = 0; loop_i < loops.size(); loop_i++) {
-    auto loop = loops[loop_i];
-    auto ind = loop_to_ind_map[loops[loop_i]];
-    auto ref_axis = reference_domain->axis(loop_i);
-    auto kir_ref_axis = gpu_lower->lowerValue(ref_axis)->as<kir::IterDomain>();
-
-    if (Index::protectWithMagicZero(loop, ref_axis, ind)) {
-      magic_zero_loop = kir_ref_axis;
-    }
-
-    ref_id_to_ind_map[kir_ref_axis] = loop_to_ind_map[loop];
-  }
-
-  if (ref_id_to_ind_map.count(magic_zero_loop)) {
-    ref_id_to_ind_map[magic_zero_loop] = ir_builder.addExpr(
-        ref_id_to_ind_map[magic_zero_loop], ir_builder.magicZeroVal());
-  }
-
-  auto consumer_tv = kir_consumer_tv->fuserTv();
-
-  // Map reference tensor to consumer
-  std::unordered_map<IterDomain*, IterDomain*> root_ref_to_consumer;
-  for (auto c_root : consumer_tv->getMaybeRFactorDomain()) {
-    auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(c_root);
-    auto ref_id_it = reference_id_map.find(concrete_id);
-    if (ref_id_it != reference_id_map.end()) {
-      root_ref_to_consumer[ref_id_it->second] = c_root;
-    }
-  }
-
-  BestEffortReplay replay_consumer_as_ref(
-      consumer_tv->domain()->domain(),
-      reference_domain->domain(),
-      root_ref_to_consumer);
-
-  const auto& ref_2_consumer = replay_consumer_as_ref.getReplay();
-
-  // Halo information is not currently used as lower_shift will take care of the
-  // predicate generation and is still using the older function:
-  // getConsumerRootPredIndices
-
-  // Generate halo information for reference.
-  updateHaloInfoForReference(reference, consumer_tv);
-
-  std::unordered_map<kir::IterDomain*, kir::Val*> reference_halo_extent_map;
-
-  const auto& halo_info = gpu_lower->haloInfo();
-
-  // Generate map from reference iter domains to halo extents
-  for (auto entry : ref_2_consumer) {
-    auto ref_id = entry.first;
-    auto extent = halo_info.getExtent(ref_id);
-    if (extent != nullptr) {
-      reference_halo_extent_map[gpu_lower->lowerValue(ref_id)
-                                    ->as<kir::IterDomain>()] = extent;
-    }
-  }
-
-  // Index into the reference tensor
-  auto ref_indexing = getReferenceIndexing(
-      loops,
-      reference_domain,
-      ref_id_to_ind_map,
-      {},
-      reference_halo_extent_map);
-
-  // If we are initializing a reduction buffer and the tensor has a
-  // rfactor root, the predicate should be based on the rfactor root.
-  const auto root_domain = reference_domain->getRootDomain();
-
-  // Get the contiguous ids we need to generate predicates for
-  auto contig_id_infos = getPredicateContigIds(reference_domain->domain());
-
-  // Roots in contiguous processing is based on reference roots, want to convert
-  // these to concrete roots, flip reference's concrete_to_id map as reference
-  // ids are not part of compute at maps.
-  decltype(reference_id_map) ref_id_to_concrete;
-  std::transform(
-      reference_id_map.begin(),
-      reference_id_map.end(),
-      std::inserter(ref_id_to_concrete, ref_id_to_concrete.begin()),
-      [](auto entry) { return std::make_pair(entry.second, entry.first); });
-
-  // Track which roots have been handled by the generated predicates
-  std::vector<std::unordered_set<IterDomain*>> handeled_roots;
-
-  std::vector<kir::Bool*> predicates;
-
-  for (auto contig_id_entry : contig_id_infos) {
-    auto contig_id = contig_id_entry.contig_id;
-    // No predicates needed for braodcasted indices.
-    if (contig_id->isBroadcast() ||
-        gpu_lower->trivialReductionInfo().isDerived(contig_id)) {
-      continue;
-    }
-
-    auto root_ids = contig_id_entry.root_ids;
-    auto kir_contig_id =
-        gpu_lower->lowerValue(contig_id)->as<kir::IterDomain>();
-
-    const auto it = ref_indexing.indexMap().find(kir_contig_id);
-
-    // First condition below is due to broadcasts in consumers of consumer that
-    // are not in consumer there can be unresolved indexing in the reference
-    // tensor. This can happen when we have something like: TV3[i1o*i2, i1i] and
-    // TV1[i2] where tv3 and tv1 share their outer dimension. i1 will be part of
-    // reference tensors root domain, but when indexing into TV1 there aren't
-    // enough indices to resolve it.
-    //
-    // The condition also happens with Misaligned predicates, where
-    // inner-most vectorized loops are not included in the loops
-    // parameter. Predicates involving vectorized loops are separately
-    // generated in lower_misaligned_vectorization.
-    //
-    // Second condition is simply to avoid predication on broadcasting axes as
-    // it's not required.
-    if (it == ref_indexing.indexMap().end() || it->second->isZeroInt()) {
+  std::vector<Val*> root_inds(root_dom.size(), ir_builder.create<kir::Int>(0));
+  for (const auto i : c10::irange(root_dom.size())) {
+    if (root_dom[i]->isBroadcast()) {
       continue;
     }
 
-    // Use the iteration domains extent unless there's a halo extent
-    auto extent = kir_contig_id->extent();
-
-    auto halo_extent_it = reference_halo_extent_map.find(kir_contig_id);
-    if (halo_extent_it != reference_halo_extent_map.end()) {
-      extent = halo_extent_it->second;
-    }
-
-    // If the index definition is "simple" and the extent is "simple" then our
-    // for loop goes exactly across the iteration domain extent so no predicate
-    // needed.
-    if (it->second->definition() == nullptr &&
-        extent->definition() == nullptr) {
-      continue;
+    auto kir_root_dom_i =
+        GpuLower::lowerValue(root_dom[i])->as<kir::IterDomain>();
+    if (index_map.find(kir_root_dom_i) != index_map.end()) {
+      auto ind = index_map.at(kir_root_dom_i);
+      TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(ind))
+      root_inds[i] = ind;
     }
-
-    predicates.push_back(
-        ir_builder.ltExpr(it->second, extent)->as<kir::Bool>());
-
-    // Transform roots from reference to concrete roots (based on loop compute
-    // at map)
-    std::unordered_set<IterDomain*> concrete_root_ids;
-    std::transform(
-        contig_id_entry.root_ids.begin(),
-        contig_id_entry.root_ids.end(),
-        std::inserter(concrete_root_ids, concrete_root_ids.begin()),
-        [&ref_id_to_concrete](IterDomain* root_id) {
-          return ref_id_to_concrete.at(root_id);
-        });
-    handeled_roots.push_back(concrete_root_ids);
   }
 
-  return {predicates, handeled_roots};
-}
-
-bool Index::protectWithMagicZero(
-    kir::ForLoop* loop,
-    IterDomain* reference_domain,
-    kir::Val* ind) {
-  bool ref_dom_simple =
-      (reference_domain == nullptr ? true
-                                   : reference_domain->definition() != nullptr);
-  bool ind_simple =
-      (ind == nullptr ? true
-                      : ind->definition() != nullptr && !ind->isZeroInt());
-  return loop->isUnrollable() && (!ref_dom_simple || !ind_simple);
+  return std::make_pair(root_inds, use_rfactor);
 }
 
 } // namespace cuda
index 637362d..7b4b67d 100644 (file)
@@ -1,7 +1,6 @@
 #pragma once
 
 #include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
 
 #include <unordered_map>
 #include <unordered_set>
@@ -60,38 +59,31 @@ namespace fuser {
 namespace cuda {
 
 class IndexCompute : public BackwardVisitor {
- protected:
+ private:
   using BackwardVisitor::handle;
-
   void handle(Split*) override;
   void handle(Merge*) override;
   void handle(Expr*) override;
 
   // return extent_map_[id] if exists, else return id->extent()
-  kir::Val* getExtent(kir::IterDomain* id);
+  Val* getExtent(kir::IterDomain* id);
 
-  //! True if a domain is not used to index
-  bool isZero(kir::IterDomain* id) const;
-  //! True if any dependent of a domain is not used to index
-  bool hasZeroMerged(kir::IterDomain* id) const;
+  bool hasZeroMerged(kir::IterDomain* id);
 
   // Tensor domain we're mapping back to root
-  const TensorDomain* td_; // NOLINT
+  const TensorDomain* td_;
 
   // Map we update as we propagate backward, containing all IDs in the
   // propagation. Initial indices are mapped with this map at tv->domain()
   // and are back propagated to tv->rootDomain(). This index_map_ keeps the
   // indices at intermediate IterDomain's in that back propagation.
-  std::unordered_map<kir::IterDomain*, kir::Val*> index_map_; // NOLINT
+  std::unordered_map<kir::IterDomain*, Val*> index_map_;
 
   // Map from IterDomain to their broadcasted extent. If a TV has I0*I1 but its
   // producer has B0*I1 this map will contain a mapping from the ID{B0*I1} to
   // the extent I0*I1. Also contains updated extents if we merge in a 0 index.
   // See zero_merged_in_.
-  std::unordered_map<kir::IterDomain*, kir::Val*> extent_map_; // NOLINT
-
-  // Keeps track of domains that do not contribute to indexing
-  std::unordered_set<kir::IterDomain*> zero_; // NOLINT
+  std::unordered_map<kir::IterDomain*, Val*> extent_map_;
 
   // This set keeps track of IterDomain's that have had a zero index merged into
   // them. This happens if we do something like tv->axis(0)->split(4) then
@@ -104,71 +96,45 @@ class IndexCompute : public BackwardVisitor {
   // IDs that are a result of contiguous merges
   std::unordered_set<kir::IterDomain*> contig_ids;
 
-  // Mentions if we should propagate an index down a particular IterDomain path
-  // if there's an option
-  std::unordered_set<kir::IterDomain*> preferred_paths_;
-
-  // Map from IterDomains to halo-extended extents in corresponding
-  // reference tensor
-  std::unordered_map<kir::IterDomain*, kir::Val*> reference_halo_extent_map_;
-
  public:
-  const std::unordered_map<kir::IterDomain*, kir::Val*>& indexMap() const {
+  const std::unordered_map<kir::IterDomain*, Val*> indexMap() const {
     return index_map_;
   }
 
-  const std::unordered_map<kir::IterDomain*, kir::Val*>& extentMap() const {
+  const std::unordered_map<kir::IterDomain*, Val*> extentMap() const {
     return extent_map_;
   }
 
-  const std::unordered_set<kir::IterDomain*>& zeroMergedIn() const {
+  std::unordered_set<kir::IterDomain*> zeroMergedIn() const {
     return zero_merged_in_;
   }
 
   // Propagate back from _td using initial_index_map
   IndexCompute(
       const TensorDomain* _td,
-      std::unordered_map<kir::IterDomain*, kir::Val*> initial_index_map,
-      std::unordered_map<kir::IterDomain*, kir::Val*> _extent_map,
+      std::unordered_map<kir::IterDomain*, Val*> initial_index_map,
+      std::unordered_map<kir::IterDomain*, Val*> _extent_map,
       std::unordered_set<kir::IterDomain*> _zero_merged_in,
-      const std::vector<bool>& _root_contiguity,
-      std::unordered_set<kir::IterDomain*> preferred_paths = {},
-      std::unordered_map<kir::IterDomain*, kir::Val*>
-          reference_halo_extent_map = {});
+      const std::vector<bool>& _root_contiguity);
 
   // Updates index_map, extent_map, and zero_merged_in based on id_map and
-  // returns a new IndexCompute ready to be used.
+  // returns a new IndexCompute ready to be used. new_index_entries are not
+  // mapped, but are added to index_map.
   IndexCompute updateIndexCompute(
       const TensorDomain* new_td,
       const std::unordered_map<IterDomain*, IterDomain*>& id_map,
-      const std::vector<bool>& _root_contiguity,
-      const std::unordered_map<kir::IterDomain*, kir::Val*>&
-          reference_halo_extent_map = {});
-
-  virtual void run();
-};
-
-//! Apply swizzle and update root indices accordingly
-class IndexSwizzle : public IndexCompute {
- public:
-  IndexSwizzle(
-      const TensorView* tv,
-      std::unordered_map<kir::IterDomain*, kir::Val*> initial_index_map,
-      std::unordered_map<kir::IterDomain*, kir::Val*> extent_map,
-      std::unordered_set<kir::IterDomain*> zero_merged_in);
-
-  void run() override;
-
- protected:
-  using IndexCompute::handle;
-
-  void handle(Expr* e) override;
-
- private:
-  const TensorView* tv_ = nullptr;
-  SwizzleType swizzle_type_ = SwizzleType::NoSwizzle;
-  std::vector<IterDomain*> ids_to_swizzle_;
-  std::unordered_set<IterDomain*> swizzled_ids_;
+      std::unordered_map<kir::IterDomain*, Val*> new_index_entries,
+      const std::vector<bool>& _root_contiguity);
+
+  // Map producer contiguity information to consumer, if entries don't match
+  // mark as false
+  static std::vector<bool> contiguityPasC(
+      TensorDomain* producer,
+      TensorDomain* consumer);
+
+  static std::vector<bool> contiguityAnd(
+      const std::vector<bool>& contig1,
+      const std::vector<bool>& contig2);
 };
 
 // Simple interface for IndexCompute
@@ -177,25 +143,25 @@ class IndexSwizzle : public IndexCompute {
 class Index {
  private:
   // Producer indexing if it's in shared or local memory
-  static std::vector<kir::Val*> getNonGlobalProducerStridedIndices(
+  static kir::TensorIndex* getProducerIndex_impl(
       TensorView* producer,
-      const TensorView* consumer,
+      TensorView* consumer,
       const std::vector<kir::ForLoop*>& loops);
 
   // Consumer indexing if it's in shared or local memory
-  static std::vector<kir::Val*> getNonGlobalConsumerStridedIndices(
-      const TensorView* consumer,
+  static kir::TensorIndex* getConsumerIndex_impl(
+      TensorView* consumer,
       const std::vector<kir::ForLoop*>& loops);
 
   // Producer if it's in global memory
-  static std::vector<kir::Val*> getGlobalProducerStridedIndices(
+  static kir::TensorIndex* getGlobalProducerIndex(
       TensorView* producer,
-      const TensorView* consumer,
+      TensorView* consumer,
       const std::vector<kir::ForLoop*>& loops);
 
   // Consumer indexing if it's in global memory
-  static std::vector<kir::Val*> getGlobalConsumerStridedIndices(
-      const TensorView* consumer,
+  static kir::TensorIndex* getGlobalConsumerIndex(
+      TensorView* consumer,
       const std::vector<kir::ForLoop*>& loops);
 
  public:
@@ -205,77 +171,22 @@ class Index {
   // Producer indexing dispatch
   static kir::TensorIndex* getProducerIndex(
       TensorView* producer,
-      const TensorView* consumer,
+      TensorView* consumer,
       const std::vector<kir::ForLoop*>& loops);
 
   // Consumer index dispatch
   static kir::TensorIndex* getConsumerIndex(
-      const TensorView* consumer,
-      const std::vector<kir::ForLoop*>& loops);
-
-  //! Returns a vector of strided indices mapped onto the (rfactor)
-  //! root domain of a producer tensor. The size of the returned
-  //! vector is guaranteed to be equal to the number of axes of the
-  //! indexing root domain.
-  static std::vector<kir::Val*> getProducerStridedIndices(
-      TensorView* producer,
-      const TensorView* consumer,
-      const std::vector<kir::ForLoop*>& loops);
-
-  //! Returns a vector of strided indices mapped onto the (rfactor)
-  //! root domain of a consumer tensor. The size of the returned
-  //! vector is guaranteed to be equal to the number of axes of the
-  //! indexing root domain.
-  static std::vector<kir::Val*> getConsumerStridedIndices(
-      const TensorView* consumer,
+      TensorView* consumer,
       const std::vector<kir::ForLoop*>& loops);
 
   // Consumer indices for predicates, keep all indices matching in root domain.
   // Even those not used for physical addressing. Returns pair <root indices, if
   // indices are mapped to rfactor dom>
-  static std::pair<std::vector<kir::Val*>, bool> getConsumerRootPredIndices(
-      const kir::TensorView* consumer,
+  static std::pair<std::vector<Val*>, bool> getConsumerRootPredIndices(
+      TensorView* consumer,
       const std::vector<kir::ForLoop*>& loops,
       const std::vector<bool>& root_contiguity,
-      bool unswitch = false);
-
-  //! Take a consumer tensorview and loop nest and generates predicates
-  //! associated with the concrete roots of the loop nest. Returns a list of
-  //! predicates, and a list of concrete roots they're associated with. It is
-  //! assumed that no predicate is required if index[i] is an index directly
-  //! from a for loop. This will not catch all cases if we actually have static
-  //! size information for example:
-  //!
-  //! TV[I].split(4)
-  //! would produce the code:
-  //! for(i : I/4)
-  //!   for(j : 4)
-  //!     if( i * 4 + j < TV.size(0))
-  //!       TV[i * 4 + j]...
-  //!
-  //! However if we had TV.size[0] = 16 at "compile time" then we wouldn't need
-  //! the predicate. This will be caught by canOmitPredicate in the predicate
-  //! lowering
-  // TODO: Replace pair of vectors with vector of
-  static std::pair<
-      std::vector<kir::Bool*>,
-      std::vector<std::unordered_set<IterDomain*>>>
-  getReferenceRootPredicates(
-      const kir::TensorView* kir_consumer_tv,
-      const std::vector<kir::ForLoop*>& loops,
-      bool unswitch = false);
-
-  // Determine if we may run into over reuse of predicates or registers in the
-  // compiler. If the loop can be unrolled and the index and domain are not
-  // "simple" we likely want the loop protected.
-  //
-  // Magic zero protection should only be done for global memory and predicates.
-  // We should avoid use on registers. Shared memory does not require it, but
-  // likely wouldn't hurt.
-  static bool protectWithMagicZero(
-      kir::ForLoop* loop,
-      IterDomain* reference_domain = nullptr,
-      kir::Val* ind = nullptr);
+      bool unroll = false);
 };
 
 } // namespace cuda
diff --git a/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp b/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp
deleted file mode 100644 (file)
index b3025bd..0000000
+++ /dev/null
@@ -1,424 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/index_reference_replay.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-// We're going to replay this split operation on the corresponding ID
-void IndexReferenceReplay::handle(Split* s) {
-  auto in = s->in();
-
-  auto concrete_in = GpuLower::current()->caIndexMap().getConcreteMappedID(in);
-  auto mapped_in_it = concrete_to_id_.find(concrete_in);
-  if (mapped_in_it == concrete_to_id_.end()) {
-    // If we can't find the concrete IDs in our local map, don't do anything.
-    return;
-  }
-
-  auto mapped_in = mapped_in_it->second;
-
-  if (leaf_ids_.find(mapped_in) == leaf_ids_.end()) {
-    // If ID has already been replayed, don't do anything.
-    return;
-  }
-
-  auto replayed_outs =
-      IterDomain::split(mapped_in, s->factor(), s->innerSplit());
-
-  auto concrete_outer =
-      GpuLower::current()->caIndexMap().getConcreteMappedID(s->outer());
-  auto concrete_inner =
-      GpuLower::current()->caIndexMap().getConcreteMappedID(s->inner());
-
-  // Update leaf id set and concrete id map
-  leaf_ids_.erase(mapped_in);
-  leaf_ids_.emplace(replayed_outs.first);
-  leaf_ids_.emplace(replayed_outs.second);
-  concrete_to_id_[concrete_outer] = replayed_outs.first;
-  concrete_to_id_[concrete_inner] = replayed_outs.second;
-}
-
-// We're going to replay this merge operation on the corresponding IDs
-void IndexReferenceReplay::handle(Merge* m) {
-  auto in_outer = m->outer();
-  auto in_inner = m->inner();
-
-  auto concrete_in_outer =
-      GpuLower::current()->caIndexMap().getConcreteMappedID(in_outer);
-  auto concrete_in_inner =
-      GpuLower::current()->caIndexMap().getConcreteMappedID(in_inner);
-
-  auto mapped_in_outer_it = concrete_to_id_.find(concrete_in_outer);
-  auto mapped_in_inner_it = concrete_to_id_.find(concrete_in_inner);
-
-  if (mapped_in_outer_it == concrete_to_id_.end() ||
-      mapped_in_inner_it == concrete_to_id_.end()) {
-    // If we can't find the concrete IDs in our local map, don't do anything.
-    return;
-  }
-
-  auto mapped_in_outer = mapped_in_outer_it->second;
-  auto mapped_in_inner = mapped_in_inner_it->second;
-
-  if (leaf_ids_.find(mapped_in_outer) == leaf_ids_.end() &&
-      leaf_ids_.find(mapped_in_inner) == leaf_ids_.end()) {
-    // If ID has already been replayed, don't do anything.
-    return;
-  }
-  auto replayed = IterDomain::merge(mapped_in_outer, mapped_in_inner);
-
-  auto concrete_out =
-      GpuLower::current()->caIndexMap().getConcreteMappedID(m->out());
-
-  // Update leaf id set and concrete id map
-  leaf_ids_.erase(mapped_in_outer);
-  leaf_ids_.erase(mapped_in_inner);
-  leaf_ids_.emplace(replayed);
-  concrete_to_id_[concrete_out] = replayed;
-}
-
-TensorDomain* IndexReferenceReplay::computeReplay() {
-  auto gpu_lower = GpuLower::current();
-  // Throw an error when two loops are mapped with each other, which
-  // violates an assumption that unique mappings between concrete
-  // IterDomains and the IterDomains of the loop structure must be
-  // established. It should be a reasonable assumption, but fusions
-  // like below won't work:
-  // tv0 = [I0]
-  // tv1 = broadcast(tv0, {true, false});
-  // tv2 = broadcast(tv0, {false, true});
-  // tv3 = tv1 + tv2
-  // Notice that the two axes of each of tv1, tv2 and tv3 are mapped
-  // with each other. We believe it is unlikely this limitation
-  // becomes a real concern in practice.
-  for (auto it_i = loop_structure_.begin(); it_i != loop_structure_.end();
-       ++it_i) {
-    for (auto it_j = it_i + 1; it_j != loop_structure_.end(); ++it_j) {
-      TORCH_INTERNAL_ASSERT(
-          !gpu_lower->caIndexMap().areMapped(
-              (*it_i)->iter_domain(), (*it_j)->iter_domain()),
-          "Unsupported loop structure. Two loops are mapped together.");
-    }
-  }
-
-  // Grab the iter domain's from the loop structure
-  std::vector<IterDomain*> fusion_loop_structure;
-
-  std::transform(
-      loop_structure_.begin(),
-      loop_structure_.end(),
-      std::back_inserter(fusion_loop_structure),
-      [&](kir::ForLoop* fl) {
-        auto fid = gpu_lower->caIndexMap().toFusion(fl->iter_domain());
-        return fid;
-      });
-
-  // Get any and all inputs that generated the provided loop structure, some
-  // root inputs may be mapped to eachother but not identical
-  auto all_inputs = InputsOf::outputs(
-      FusionGuard::getCurFusion(),
-      std::vector<Val*>(
-          fusion_loop_structure.begin(), fusion_loop_structure.end()));
-
-  // Make sure all inputs are iter domains, ignoring anything like split factor
-  // inputs
-  auto all_iter_inputs = ir_utils::filterByType<IterDomain>(all_inputs);
-
-  // Sort out the inputs as there could be entires that map to eachother, and
-  // they can be a combiantion of iteration, reduction, and broadcast. Order as
-  // iter, reduction, then broadcast for iterating and removing duplicate mapped
-  // entries. Since these are input IterDomains we mainly want to prioritize
-  // non-broadcast "versions" of the iter domain if it shows up more than once.
-  // We could get both if we have a compute at structure where a consumer has a
-  // concrete iter domain but it's producer has a broadcast domain, and the
-  // compute at axis is across a split on this domain. The producer would give a
-  // broadcast input, consumer would have iter domain input.
-  // Additionally, we prefer non-reduction iter domains over reduciton
-  // domains, but this is just optional and not necessary for correctness.
-  std::vector<IterDomain*> sorted_inputs;
-  std::copy_if(
-      all_iter_inputs.begin(),
-      all_iter_inputs.end(),
-      std::back_inserter(sorted_inputs),
-      [](IterDomain* id) { return !id->isBroadcast() && !id->isReduction(); });
-  std::copy_if(
-      all_iter_inputs.begin(),
-      all_iter_inputs.end(),
-      std::back_inserter(sorted_inputs),
-      [](IterDomain* id) { return id->isReduction(); });
-  std::copy_if(
-      all_iter_inputs.begin(),
-      all_iter_inputs.end(),
-      std::back_inserter(sorted_inputs),
-      [](IterDomain* id) { return id->isBroadcast(); });
-
-  // Produce a non repetitive set of inputs. Remove "duplicate" IterDomains that
-  // map to eachother.
-  std::vector<IterDomain*> root_axes;
-  for (auto root_id : sorted_inputs) {
-    auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(root_id);
-    if (concrete_to_id_.find(concrete_id) != concrete_to_id_.end()) {
-      continue;
-    }
-
-    // Make a copy of the root_id for the reference to "own"
-    IterDomain* root_id_copy = root_id->clone();
-
-    // Initialize root axes, concrete map, and leaf map for replay.
-    root_axes.push_back(root_id_copy);
-    concrete_to_id_[concrete_id] = root_id_copy;
-    leaf_ids_.emplace(root_id_copy);
-  }
-
-  // Order is important here, replay expressions from loops outside to inside.
-  auto replay_exprs = ExprSort::getExprs(
-      FusionGuard::getCurFusion(),
-      {fusion_loop_structure.begin(), fusion_loop_structure.end()});
-
-  // Run the reference replay
-  for (auto expr : replay_exprs) {
-    OptInDispatch::handle(expr);
-  }
-
-  // Construct a tensor that's representitive of the replayed loop structure.
-  std::vector<IterDomain*> loops_replayed_domain;
-
-  // Grab a set of concrete leaf ids to make it easier to search which for loop
-  // matches the leaf id from the replay.
-  std::unordered_set<IterDomain*> concrete_leaf_ids;
-  for (auto entry : concrete_to_id_) {
-    if (leaf_ids_.find(entry.second) != leaf_ids_.end()) {
-      concrete_leaf_ids.emplace(entry.first);
-    }
-  }
-
-  // Figure out which ID's that were replayed correspond to the respective loops
-  // that were replayed.
-  std::transform(
-      fusion_loop_structure.begin(),
-      fusion_loop_structure.end(),
-      std::back_inserter(loops_replayed_domain),
-      [&](IterDomain* loop_id) {
-        for (auto id : concrete_leaf_ids) {
-          // Matching has to be done on loop map, though replay was done in ID
-          // map, so we need to manually check that things are mapped in the
-          // loop map. Cannot simply look up concrete IDs to match them as index
-          // map and loop map do not have the same concrete id mapping. We also
-          // allow matching explicitly through the index map. Index map is not
-          // gauranteed to be contained in loop map, therefore if we generate
-          // mappings to conrete id's through the index map, the mapping from
-          // those ID's to the ID's we replay are not gauranteed to be in loop
-          // map. The reverse is also true, so for validation make sure one of
-          // the mappings exist. For reference check the difference between:
-          // AdvancedLowering5 test and AdvancedIndexing1.
-          if (gpu_lower->caLoopMap().areMapped(id, loop_id) ||
-              gpu_lower->caIndexMap().areMapped(id, loop_id)) {
-            concrete_leaf_ids.erase(id);
-            auto replayed_id = concrete_to_id_.at(id);
-            // Propagate parallelization and vectorization. Necessary
-            // for indexing. IndexCompute::getExtent depends on the
-            // propagated parallelization.
-            if (isParallelTypeVectorize(loop_id->getParallelType()) ||
-                isParallelTypeThread(loop_id->getParallelType())) {
-              replayed_id->parallelize(loop_id->getParallelType());
-            }
-            return replayed_id;
-          }
-        }
-
-        TORCH_INTERNAL_ASSERT(
-            false,
-            "Could not find required iter domain in reference replay: ",
-            loop_id);
-      });
-
-  // Add any remaining leaf iter domains, this can happen from rfactor patterns.
-  for (auto entry : concrete_leaf_ids) {
-    loops_replayed_domain.push_back(concrete_to_id_.at(entry));
-  }
-  if (replay_exprs.empty()) {
-    auto domain = new TensorDomain(
-        // If there was no replay only return a domain with a root domain.
-        loops_replayed_domain);
-    return domain;
-  } else {
-    auto domain = new TensorDomain(root_axes, loops_replayed_domain);
-    return domain;
-  }
-}
-
-IndexCompute getReferenceIndexing(
-    const std::vector<kir::ForLoop*>& loop_structure,
-    TensorDomain* reference_tensor) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  // Create a simple index mapping from loop iter domains to their local index.
-  // This is only applicable to global memory buffers.
-  std::unordered_map<kir::IterDomain*, kir::Val*> initial_index_map;
-
-  TORCH_INTERNAL_ASSERT(loop_structure.size() <= reference_tensor->nDims());
-  int magic_zero_loop = -1;
-  for (size_t loop_i = 0; loop_i < loop_structure.size(); loop_i++) {
-    auto ref_axis = reference_tensor->axis(loop_i);
-    auto kir_ref_axis = gpu_lower->lowerValue(ref_axis)->as<kir::IterDomain>();
-    auto loop = loop_structure[loop_i];
-    auto ind = loop->index();
-    ;
-
-    initial_index_map[kir_ref_axis] = ind;
-    if (loop->vectorize()) {
-      initial_index_map[kir_ref_axis] = ir_builder.create<kir::Int>(0);
-    }
-
-    if (Index::protectWithMagicZero(loop, ref_axis, ind)) {
-      magic_zero_loop = (int)loop_i;
-    }
-  }
-
-  // Add magic zero to a fairly inner most index
-  if (magic_zero_loop >= 0) {
-    auto ref_id = gpu_lower->lowerValue(reference_tensor->axis(magic_zero_loop))
-                      ->as<kir::IterDomain>();
-    initial_index_map[ref_id] = ir_builder.addExpr(
-        initial_index_map[ref_id], ir_builder.magicZeroVal());
-  }
-
-  // Send to the other version of reference indexing that directly takes the
-  // index map
-  return getReferenceIndexing(
-      loop_structure, reference_tensor, initial_index_map, {});
-}
-
-IndexCompute getReferenceIndexing(
-    const std::vector<kir::ForLoop*>& loop_structure,
-    TensorDomain* reference_tensor,
-    std::unordered_map<kir::IterDomain*, kir::Val*> index_map,
-    std::unordered_set<IterDomain*> preferred_paths,
-    std::unordered_map<kir::IterDomain*, kir::Val*> halo_extent_map) {
-  auto gpu_lower = GpuLower::current();
-
-  // I thought this might be necesasry, but turns out it's not. I think it's
-  // because of the root ordering above, however leaving it in case we find
-  // out it is necessary in some cases. At the time of commiting, cuda-memcheck
-  // passed without this.
-  //
-  // std::unordered_map<kir::IterDomain*,
-  // kir::Val*> reference_extent_map; for (auto loop : loop_structure) {
-  //   // If there's a broadcast merged in the for loop ID we want to track its
-  //   // extent
-  //   auto inputs = InputsOf::outputs(
-  //       FusionGuard::getCurFusion(),
-  //       {gpu_lower->caIndexMap().toFusion(loop->iter_domain())});
-
-  //   auto iter_inputs = ir_utils::filterByType<IterDomain>(inputs);
-
-  //   // If any of the inputs are a broadcast, explicitly mark the loop id's
-  //   // extent
-  //   if (std::any_of(iter_inputs.begin(), iter_inputs.end(), [](IterDomain*
-  //   id) {
-  //         return id->isBroadcast();
-  //       })) {
-  //     reference_extent_map[loop->iter_domain()] =
-  //     loop->iter_domain()->extent();
-  //   }
-  // }
-
-  // Convert to preferred_path to kir::IterDomain for IndexCompute
-  std::unordered_set<kir::IterDomain*> kir_preferred_path;
-  std::transform(
-      preferred_paths.begin(),
-      preferred_paths.end(),
-      std::inserter(kir_preferred_path, kir_preferred_path.begin()),
-      [&gpu_lower](IterDomain* id) {
-        return gpu_lower->lowerValue(id)->as<kir::IterDomain>();
-      });
-
-  IndexCompute compute(
-      reference_tensor,
-      index_map, // NOLINT
-      // reference_extent_map, // Seems this is not necessary, see comment above
-      // in this function
-      {},
-      std::unordered_set<kir::IterDomain*>(),
-      reference_tensor->contiguity(),
-      kir_preferred_path,
-      halo_extent_map);
-
-  compute.run();
-
-  return compute;
-}
-
-namespace {
-
-// Class to track through the reference what path to take for zero merged in
-// indices if we're indexing shared memory or local memory. Use marked root
-// domains and traverse through the replay to mark paths to get to them during a
-// backward replay.
-class PreferredPathCompute : public IterVisitor {
- private:
-  void handle(Expr* e) override {
-    // If an input ID is marked, propagate the marking to outputs of the
-    // expression
-    auto all_iter_inputs = ir_utils::filterByType<IterDomain>(e->inputs());
-    if (std::any_of(
-            all_iter_inputs.begin(),
-            all_iter_inputs.end(),
-            [&](IterDomain* inp_id) {
-              return this->preferred_path.find(inp_id) !=
-                  this->preferred_path.end();
-            })) {
-      auto all_iter_outputs = ir_utils::filterByType<IterDomain>(e->outputs());
-      preferred_path.insert(all_iter_outputs.begin(), all_iter_outputs.end());
-    }
-  }
-
- private:
-  // If making a choice these are the iter domains to prefer when traversing
-  // backward.
-  std::unordered_set<IterDomain*> preferred_path;
-
- public:
-  static std::unordered_set<IterDomain*> compute(
-      TensorDomain* reference_domain,
-      const std::unordered_set<IterDomain*>& preferred_roots) {
-    // TODO: assert all provided preferred roots are in the history of reference
-    // domain.
-
-    PreferredPathCompute compute;
-    // Init preferred path
-    compute.preferred_path = preferred_roots;
-
-    // Propagate
-    compute.traverseFrom(
-        FusionGuard::getCurFusion(),
-        std::vector<Val*>(
-            reference_domain->domain().begin(),
-            reference_domain->domain().end()));
-
-    return compute.preferred_path;
-  }
-};
-} // namespace
-
-// External interface for preferred path propagation.
-std::unordered_set<IterDomain*> buildPreferredPaths(
-    TensorDomain* reference_tensor,
-    const std::unordered_set<IterDomain*>& preferred_roots) {
-  return PreferredPathCompute::compute(reference_tensor, preferred_roots);
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/index_reference_replay.h b/torch/csrc/jit/codegen/cuda/index_reference_replay.h
deleted file mode 100644 (file)
index 45cd65d..0000000
+++ /dev/null
@@ -1,83 +0,0 @@
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-struct ReferenceTensor {
-  TensorDomain* domain = nullptr;
-
-  // Map from concrete iteration domains in ComputeAtMaps to iter domains
-  // including those used to construct domain.
-  std::unordered_map<IterDomain*, IterDomain*> concrete_to_id;
-};
-
-class IndexReferenceReplay : public OptInDispatch {
- private:
-  IndexReferenceReplay(const std::vector<kir::ForLoop*>& loop_structure)
-      : loop_structure_(loop_structure) {}
-
-  // We're going to replay this split operation on the corresponding ID
-  void handle(Split* s) override;
-
-  // We're going to replay this merge operation on the corresponding IDs
-  void handle(Merge* m) override;
-
-  TensorDomain* computeReplay();
-
-  using OptInDispatch::handle;
-
- private:
-  const std::vector<kir::ForLoop*>& loop_structure_;
-
-  // Replay map
-  std::unordered_map<IterDomain*, IterDomain*> concrete_to_id_;
-
-  // Replay map
-  std::unordered_set<IterDomain*> leaf_ids_;
-
- public:
-  static ReferenceTensor getReference(
-      const std::vector<kir::ForLoop*>& loop_structure) {
-    auto replay = IndexReferenceReplay(loop_structure);
-    ReferenceTensor ref;
-    ref.domain = replay.computeReplay();
-    ref.concrete_to_id = replay.concrete_to_id_;
-    return ref;
-  }
-};
-
-// Index into the reference based on the provided index map.
-IndexCompute getReferenceIndexing(
-    const std::vector<kir::ForLoop*>& loop_structure,
-    TensorDomain* reference_domain,
-    std::unordered_map<kir::IterDomain*, kir::Val*> index_map,
-    std::unordered_set<IterDomain*> preferred_path,
-    std::unordered_map<kir::IterDomain*, kir::Val*> halo_extent_map = {});
-
-// Short cut for global TVs. Index into the reference based on all loop indicies
-// in the loop structure.
-IndexCompute getReferenceIndexing(
-    const std::vector<kir::ForLoop*>& loop_structure,
-    TensorDomain* reference_domain);
-
-// When indexing there are sometimes an option to propagate an index down
-// multiple paths. This will return the IterDomains in the history of the
-// reference domain and mark which paths should be taken (if there's a
-// preference) to reach the roots provided in preferred_roots.
-std::unordered_set<IterDomain*> buildPreferredPaths(
-    TensorDomain* reference_domain,
-    const std::unordered_set<IterDomain*>& preferred_roots);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
index 3210624..49196be 100644 (file)
@@ -16,7 +16,7 @@ namespace cuda {
 namespace inst {
 
 Trace::Trace() {
-  const char* trace_filename = getenv("PYTORCH_NVFUSER_TRACE");
+  const char* trace_filename = getenv("PYTORCH_CUDA_FUSER_TRACE");
   if (trace_filename != nullptr) {
     log_file_ = fopen(trace_filename, "w");
     TORCH_CHECK(log_file_ != nullptr, "Can't open trace file");
index d0670e3..06dde89 100644 (file)
@@ -2,10 +2,6 @@
 
 #include <torch/csrc/jit/codegen/cuda/utils.h>
 
-#include <nvToolsExt.h>
-
-// NOLINTNEXTLINE(modernize-deprecated-headers)
-#include <stdio.h>
 #include <chrono>
 #include <cstdio>
 
@@ -20,7 +16,7 @@ namespace inst {
 //! This class is not intended to be used directly. Instead, the operations
 //! to be traced are marked (for example using the FUSER_PERF_SCOPE macro)
 //!
-//! In order to enable tracing, the `PYTORCH_NVFUSER_TRACE` environment
+//! In order to enable tracing, the `PYTORCH_CUDA_FUSER_TRACE` environment
 //! variable is set to point to a trace file (ex `test.trace`). The file name
 //! may be a relative or an absolute path.
 //!
@@ -45,11 +41,9 @@ class Trace : public NonCopyable {
     if (log_file_ != nullptr) {
       logEvent('B', name);
     }
-    nvtxRangePushA(name);
   }
 
   void endEvent(const char* name) {
-    nvtxRangePop();
     if (log_file_ != nullptr) {
       logEvent('E', name);
     }
index 8ef51a1..cf8f378 100644 (file)
@@ -23,34 +23,28 @@ CudaFuserInterface* getFuserInterface() {
 
 void compileFusionGroup(Node* fusion_node) {
   TORCH_CHECK(
-      getFuserInterface()->fn_compile_n != nullptr,
+      getFuserInterface()->fn_compile_n_ != nullptr,
       "Running the CUDA fuser requires a CUDA build.");
-  getFuserInterface()->fn_compile_n(fusion_node);
+  getFuserInterface()->fn_compile_n_(fusion_node);
 }
 
 void runFusionGroup(const Node* fusion_node, Stack& stack) {
   TORCH_CHECK(
-      getFuserInterface()->fn_run_n_s != nullptr,
+      getFuserInterface()->fn_run_n_s_ != nullptr,
       "Running the CUDA fuser requires a CUDA build.");
-  getFuserInterface()->fn_run_n_s(fusion_node, stack);
+  getFuserInterface()->fn_run_n_s_(fusion_node, stack);
 }
 
 void fuseGraph(std::shared_ptr<Graph>& graph) {
   TORCH_CHECK(
-      getFuserInterface()->fn_fuse_graph != nullptr,
+      getFuserInterface()->fn_fuse_graph_ != nullptr,
       "Running the CUDA fuser requires a CUDA build.");
-  getFuserInterface()->fn_fuse_graph(graph);
+  getFuserInterface()->fn_fuse_graph_(graph);
 }
 
 bool canFuseNode(const Node* node) {
-  return getFuserInterface()->fn_can_fuse_n != nullptr &&
-      getFuserInterface()->fn_can_fuse_n(node);
-}
-
-void InsertProfileNodesForCUDAFuser(ProfilingRecord* pr) {
-  if (getFuserInterface()->fn_insert_profile_inodes) {
-    getFuserInterface()->fn_insert_profile_inodes(pr);
-  }
+  return getFuserInterface()->fn_can_fuse_n_ != nullptr &&
+      getFuserInterface()->fn_can_fuse_n_(node);
 }
 
 //! [ Note -- type guard logic in CudaFusionGuard ]
@@ -96,7 +90,8 @@ bool complyWith(
   // check a. if num_dimension check fails or scalar type check fails
   if (*guard_tensor_type->dim() != static_cast<size_t>(tensor.ndimension()) ||
       (guard_tensor_type->scalarType().has_value() &&
-       (guard_tensor_type->scalarType().value() != tensor.scalar_type()))) {
+       (guard_tensor_type->scalarType().value() != tensor.scalar_type())) ||
+      tensor.requires_grad()) {
     return false;
   }
 
@@ -137,8 +132,7 @@ bool complyWith(
         if (j != 0) {
           // we use contiguity to collapse dimension, if size == 1, it is
           // always collapsible
-          // computeStrideProps also default to contiguous when stride == 1
-          if (t_sizes[sorted_index] != 1 && t_strides[sorted_index] != 1) {
+          if (t_sizes[sorted_index] != 1) {
             TORCH_INTERNAL_ASSERT(
                 stride_properties[j - 1]->stride_index_.has_value(),
                 "Counknown index is meaningless");
@@ -184,59 +178,6 @@ bool complyWith(
 
 namespace {
 
-// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
-RegisterOperators size_eq_guard({
-    Operator(
-        //"prim::CudaFusionSizeEq(int[] size, int[] ref) -> bool",
-        "prim::CudaFusionSizeEq(...) -> bool",
-        // prim::CudaFusionGuard returns a fresh Boolean type without aliasing.
-        // if we would ever return refined tensor, which would change aliasing
-        // analysis, we should update aliasdb pass.
-        [](const Node* node) -> Operation {
-          return [](Stack* stack) {
-            at::ArrayRef<IValue> inputs = last(stack, 2);
-            drop(stack, 2);
-
-            if (!fuser::cuda::getCudaFusionGuardMode()) {
-              push(stack, IValue(true));
-              return;
-            }
-
-            // auto inp = inputs[0].toIntList();
-            TORCH_INTERNAL_ASSERT(
-                inputs[1].isIntList(), "reference needs to be of int list");
-            auto ref = inputs[1].toIntList();
-
-            auto ret = true;
-            if (ref.empty()) {
-              ret = inputs[0].isNone();
-            } else {
-              if (inputs[0].isIntList()) {
-                auto inp = inputs[0].toIntList();
-                if (inp.size() != ref.size()) {
-                  push(stack, IValue(false));
-                  return;
-                }
-
-                for (size_t i = 0; i < inp.size(); i++) {
-                  if (((inp[i] == 1) != (ref[i] == 1))) {
-                    ret = false;
-                    break;
-                  }
-                }
-              } else {
-                ret = false;
-              }
-            }
-
-            push(stack, IValue(ret));
-            return;
-          };
-        },
-        aliasAnalysisFromSchema()),
-});
-
-// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 RegisterOperators reg_fusion({
     Operator(
         prim::CudaFusionGroup,
@@ -289,24 +230,6 @@ RegisterOperators reg_guard({
         },
         aliasAnalysisFromSchema()),
 });
-
-// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
-RegisterOperators reg_add_optional({
-    Operator(
-        "prim::add_optional(Tensor(a) input, Tensor? bias) -> Tensor(a)",
-        [](const Node* node) -> Operation {
-          return [](Stack* stack) {
-            IValue input, bias;
-            pop(stack, input, bias);
-            if (bias.isNone()) {
-              push(stack, std::move(input));
-            } else {
-              push(stack, at::add(input.toTensor(), bias.toTensor(), 1.0));
-            }
-          };
-        },
-        aliasAnalysisFromSchema()),
-});
 } // namespace
 
 } // namespace jit
index d7924ed..00d94a9 100644 (file)
@@ -21,11 +21,10 @@ TORCH_API std::atomic<bool>& getCudaFusionGuardMode();
 
 // dummy struct to allow API registration
 struct CudaFuserInterface {
-  void (*fn_compile_n)(Node*) = nullptr;
-  void (*fn_run_n_s)(const Node*, Stack&) = nullptr;
-  void (*fn_fuse_graph)(std::shared_ptr<Graph>&) = nullptr;
-  bool (*fn_can_fuse_n)(const Node*) = nullptr;
-  void (*fn_insert_profile_inodes)(ProfilingRecord* pr) = nullptr;
+  void (*fn_compile_n_)(Node*) = nullptr;
+  void (*fn_run_n_s_)(const Node*, Stack&) = nullptr;
+  void (*fn_fuse_graph_)(std::shared_ptr<Graph>&) = nullptr;
+  bool (*fn_can_fuse_n_)(const Node*) = nullptr;
 };
 
 // Get interface, this is used by registration and user facing API internally
@@ -35,7 +34,6 @@ C10_EXPORT void compileFusionGroup(Node* fusion_node);
 C10_EXPORT void runFusionGroup(const Node* fusion_node, Stack& stack);
 C10_EXPORT void fuseGraph(std::shared_ptr<Graph>&);
 C10_EXPORT bool canFuseNode(const Node* node);
-C10_EXPORT void InsertProfileNodesForCUDAFuser(ProfilingRecord* pr);
 
 C10_EXPORT bool complyWith(
     const at::Tensor& tensor,
index 72d81a8..b010685 100644 (file)
@@ -43,44 +43,61 @@ void Statement::print() const {
 }
 
 // When we create a Val we immediately register them with the active fusion.
-Val::Val(ValType _vtype, DataType _dtype, bool register_val)
+Val::Val(ValType _vtype, DataType _dtype, bool register_val, bool lowered)
     : vtype_(_vtype), dtype_(_dtype) {
   Fusion* fusion = FusionGuard::getCurFusion();
   TORCH_CHECK(
       fusion != nullptr, "No active fusion group found when creating a Val.");
   fusion_ = fusion;
   if (register_val) {
-    name_ = fusion_->registerVal(this);
+    if (lowered) {
+      name_ = fusion_->registerLoweredVal(this);
+    } else {
+      name_ = fusion_->registerVal(this);
+    }
   }
 }
 
-// NOTE: we don't clone the definition_ and uses_ here
-//  since they may introduce cloning cycles. Instead, we copy
-//  the original pointers and we'll fix them up later part of the
-//  Fusion copy. Neither definition_ nor uses_ are copied through
-//  this constructor now leaving them to be resolved by later stages
-//
-Val::Val(const Val* src, IrCloner* ir_cloner)
-    : Statement(src, ir_cloner),
-      vtype_(src->vtype_),
-      dtype_(src->dtype_),
-      is_fusion_input_(src->is_fusion_input_),
-      is_fusion_output_(src->is_fusion_output_) {}
-
-const std::vector<Expr*>& Val::uses() const {
-  if (vtype_ == ValType::TensorView) {
-    if (!fusion()->isTVUseInfoValid() && !fusion()->isUpdatingTVUseInfo()) {
-      fusion()->resetTvUses();
-    }
+namespace {
+
+// TODO(kir): remove this
+ValType lowerValType(ValType vtype) {
+  switch (vtype) {
+    case ValType::Scalar:
+      return ValType::KirScalar;
+    case ValType::NamedScalar:
+      return ValType::KirNamedScalar;
+    case ValType::TensorDomain:
+      return ValType::KirTensorDomain;
+    case ValType::IterDomain:
+      return ValType::KirIterDomain;
+    case ValType::TensorView:
+      return ValType::KirTensorView;
+    default:
+      TORCH_CHECK(false, "Unexpected");
   }
-  return uses_;
 }
 
-namespace {
+} // namespace
+
+// TODO(kir): remove this
+Val::Val(const Val* fusion_ir_node)
+    : vtype_(lowerValType(fusion_ir_node->vtype_)),
+      dtype_(fusion_ir_node->dtype_) {
+  // The lowered nodes preserve the names from the fusion IR counterparts
+  name_ = fusion_ir_node->name_;
+  fusion_ = fusion_ir_node->fusion_;
+  fusion_->registerLoweredVal(this);
+}
 
-// Traverse definition of all values involved in constructing the provided val.
+Val::Val(const Val* src, IrCloner* ir_cloner)
+    : Statement(src, ir_cloner), vtype_(src->vtype_), dtype_(src->dtype_) {}
+
+// Traverse origin of all values involved in constructing the provided val.
 // Check if all values involved are constant values, meaning the provided
 // val is also a constant value.
+namespace {
+
 class ConstCheck : OptOutConstDispatch {
  private:
   bool is_const_ = true;
@@ -89,8 +106,12 @@ class ConstCheck : OptOutConstDispatch {
     is_const_ = is_const_ && b->isConst();
   }
 
-  void handle(const Double* d) override {
-    is_const_ = is_const_ && d->isConst();
+  void handle(const Float* f) override {
+    is_const_ = is_const_ && f->isConst();
+  }
+
+  void handle(const Half* h) override {
+    is_const_ = is_const_ && h->isConst();
   }
 
   void handle(const Int* i) override {
@@ -101,6 +122,26 @@ class ConstCheck : OptOutConstDispatch {
     is_const_ = is_const_ && false;
   }
 
+  void handle(const kir::Bool* b) override {
+    is_const_ = is_const_ && b->isConst();
+  }
+
+  void handle(const kir::Float* f) override {
+    is_const_ = is_const_ && f->isConst();
+  }
+
+  void handle(const kir::Half* h) override {
+    is_const_ = is_const_ && h->isConst();
+  }
+
+  void handle(const kir::Int* i) override {
+    is_const_ = is_const_ && i->isConst();
+  }
+
+  void handle(const kir::NamedScalar* ns) override {
+    is_const_ = is_const_ && false;
+  }
+
   void handle(const Expr* expr) override {
     for (auto inp : expr->inputs()) {
       handle(inp);
@@ -108,11 +149,11 @@ class ConstCheck : OptOutConstDispatch {
   }
 
   void handle(const Val* val) override {
-    if (val->definition() != nullptr) {
-      handle(val->definition());
-    } else {
+    const Expr* orig = FusionGuard::getCurFusion()->origin(val);
+    if (orig != nullptr)
+      handle(orig);
+    else
       OptOutConstDispatch::handle(val);
-    }
   }
 
  public:
@@ -135,6 +176,8 @@ c10::optional<int64_t> Val::getInt() const {
   if (isConstScalar() && isAnInt()) {
     if (this->getValType() == ValType::Scalar) {
       return this->as<Int>()->value();
+    } else if (this->getValType() == ValType::KirScalar) {
+      return this->as<kir::Int>()->value();
     }
   }
   return c10::optional<int64_t>();
@@ -156,26 +199,17 @@ c10::optional<DataType> Val::getDataType() const {
   return dtype_;
 }
 
-bool Val::isProducerOf(const Val* other) const {
-  TORCH_INTERNAL_ASSERT(other != nullptr);
-  TORCH_INTERNAL_ASSERT(fusion() == other->fusion());
-
-  if (definition() == nullptr) {
-    return false;
-  }
-  return std::any_of(
-      definition()->inputs().begin(),
-      definition()->inputs().end(),
-      [other](const Val* input) { return input == other; });
+Expr* Val::getOrigin() {
+  return fusion_->origin(this);
 }
 
-bool Val::isConsumerOf(const Val* other) const {
-  return other->isProducerOf(this);
+const Expr* Val::getOrigin() const {
+  return fusion_->origin(this);
 }
 
 // We don't register with the active fusion in Expr as this needs to be done
 // after inputs and outputs are registered with the Expr
-Expr::Expr(ExprType type) : type_{type} {
+Expr::Expr(ExprType _type) : type_{_type} {
   Fusion* fusion = FusionGuard::getCurFusion();
   if (fusion == nullptr)
     TORCH_CHECK(false, "No active fusion group found when creating an Expr.");
@@ -188,25 +222,15 @@ Expr::Expr(const Expr* src, IrCloner* ir_cloner)
       inputs_(ir_cloner->clone(src->inputs_)),
       outputs_(ir_cloner->clone(src->outputs_)) {}
 
-bool Expr::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<Expr>()) {
+bool Expr::sameAs(const Expr* const other) const {
+  if (getExprType() != other->getExprType())
     return false;
-  }
-  const Expr* other_expr = other->as<Expr>();
-  if (getExprType() != other_expr->getExprType()) {
-    return false;
-  }
-  if (inputs().size() != other_expr->inputs().size() ||
-      outputs().size() != other_expr->outputs().size()) {
+  if (inputs().size() != other->inputs().size() ||
+      outputs().size() != other->outputs().size())
     return false;
-  }
   for (const auto i : c10::irange(inputs().size())) {
-    if (!input(i)->sameAs(other_expr->input(i))) {
+    if (!input(i)->sameAs(other->input(i)))
       return false;
-    }
   }
   return true;
 }
index 2d4cd82..5c8757e 100644 (file)
@@ -9,6 +9,7 @@
 #include <torch/csrc/jit/codegen/cuda/utils.h>
 
 #include <cstdint>
+#include <deque>
 #include <iostream>
 #include <limits>
 #include <memory>
@@ -37,7 +38,7 @@ namespace cuda {
 
 using StmtNameType = unsigned int;
 
-constexpr StmtNameType kInvalidStmName =
+constexpr StmtNameType UNINITIALIZED_STMTNAMETYPE =
     std::numeric_limits<unsigned int>::max();
 
 class Fusion;
@@ -49,16 +50,17 @@ class BinaryOp;
 class IterDomain;
 class IrCloner;
 
-//! Statement is the highest level node representation. Everything that is
-//! considered "IR" will be derived from this class at some point. Both Values
-//! and Expr's are a Statement. If there will ever be any more fundamental
-//! types, they will also derive from Statement.
-//!
-//! We use Statements to pass around nodes of unknown compile type. Therefore it
-//! is also important for the design to have a dispatch system for a Statment.
-//! Basically beinng able to succienctly traverse down the inhereitance stack of
-//! a Statment at runtime. This is currently implemented in dispatch.h
-//!
+/*
+ * Statement is the highest level node representation. Everything that is
+ * considered "IR" will be derived from this class at some point. Both Values
+ * and Expr's are a Statement. If there will ever be any more fundamental types,
+ * they will also derive from Statement.
+ *
+ * We use Statements to pass around nodes of unknown compile type. Therefore it
+ * is also important for the design to have a dispatch system for a Statment.
+ * Basically beinng able to succienctly traverse down the inhereitance stack of
+ * a Statment at runtime. This is currently implemented in dispatch.h
+ */
 class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase {
   friend void swap(Fusion&, Fusion&) noexcept;
 
@@ -123,7 +125,7 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase {
 
   // Return if this statement is the same as another statement
   // TODO: should this run through dispatch on this and other?
-  virtual bool sameAs(const Statement* other) const {
+  bool sameAs(const Statement* const other) const {
     return this == other;
   }
 
@@ -131,42 +133,44 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase {
 
  protected:
   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
-  StmtNameType name_ = kInvalidStmName;
+  StmtNameType name_ = UNINITIALIZED_STMTNAMETYPE;
   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
   Fusion* fusion_ = nullptr;
 };
 
-//! A Val represents a "value." These are objects, like tensors, scalars, and
-//! memory locations, that are inputs and outputs of computations (represented
-//! by Exprs, below)
-//!
-//! Vals are constant and unique and should always be passed
-//! around as a pointer. Val can generally be thought of as representing any
-//! type of data. Some examples: a constant size like convolution filter width a
-//! runtime constant like batch normalizations momentum a "symbolic" tensor like
-//! one passed down from the JIT a memory buffer used in device code
-//!
-//! Adding a Val:
-//! Right now adding a Val is quite involved. Val's can be defined in ir.h or in
-//! their own header file. The following is what is currently needed to add a
-//! new Val:
-//!
-//! 1) Definition inheriting from Val
-//!     - Members must be private or protected
-//!     - Accessor functions for members
-//!     - Must call Val constructor, Val constructor registers with fusion
-//!     - Implementation of bool sameAs(...)
-//!     - Must implement a "cloning" constructor, ex.
-//!        Int::Int(const Int* src, IrCloner* ir_cloner)
-//! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
-//! 3) Default mutator function should be added to mutator.cpp
-//! 4a) Printing functions should be added to ir_iostream.h/.cpp
-//! 4b) Graphviz generation must be added to ir_graphviz.h/.cpp
-//! 5) An enum value must be added to ValType in type.h
-//! 6) A string entry must be added in val_type_string_map
-//!
+/*
+ * A Val represents a "value." These are objects, like tensors, scalars, and
+ * memory locations, that are inputs and outputs of computations (represented
+ * by Exprs, below). Vals are constant and unique and should always be passed
+ * around as a pointer. Val can generally be thought of as representing any type
+ * of data. Some examples: a constant size like convolution filter width a
+ * runtime constant like batch normalizations momentum a "symbolic" tensor like
+ * one passed down from the JIT a memory buffer used in device code
+ *
+ * Adding a Val:
+ * Right now adding a Val is quite involved. Val's can be defined in ir.h or in
+ * their own header file. The following is what is currently needed to add a new
+ * Val:
+ * 1) Definition inheriting from Val
+ *     - Members must be private or protected
+ *     - Accessor functions for members
+ *     - Must call Val constructor, Val constructor registers with fusion
+ *     - Implementation of bool sameAs(...)
+ *     - Must implement a "cloning" constructor, ex.
+ *        Int::Int(const Int* src, IrCloner* ir_cloner)
+ * 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
+ * 3) Default mutator function should be added to mutator.cpp
+ * 4a) Printing functions should be added to ir_iostream.h/.cpp
+ * 4b) Graphviz generation must be added to ir_graphviz.h/.cpp
+ * 5) An enum value must be added to ValType in type.h
+ * 6) A string entry must be added in val_type_string_map
+ */
 class TORCH_CUDA_CU_API Val : public Statement {
  public:
+  ~Val() override = default;
+
+  Val() = delete;
+
   // We may not want to register this value during Val's constructor. The reason
   // for this is that if we register the val, then ina derived constructor try
   // to throw, fusion's destructor will get called, but the pointer to this Val
@@ -175,24 +179,31 @@ class TORCH_CUDA_CU_API Val : public Statement {
   explicit Val(
       ValType _vtype,
       DataType _dtype = DataType::Null,
-      bool register_val = true);
+      bool register_val = true,
+      bool lowered = false);
+
+  // Lowers an existing Fusion IR node into a Kernel IR counterpart
+  explicit Val(const Val* fusion_ir_node);
 
   Val(const Val* src, IrCloner* ir_cloner);
 
-  // TODO: why is this optional?
-  //
+  // TODO: Values are unique and not copyable
+  Val(const Val& other) = delete;
+  Val& operator=(const Val& other) = delete;
+
+  Val(Val&& other) = delete;
+  Val& operator=(Val&& other) = delete;
+
   c10::optional<ValType> getValType() const override {
     return vtype_;
   }
 
   // Throws if no DataType is found. Vals must have a DataType
-  //
-  // TODO: why is this optional?
-  //
   c10::optional<DataType> getDataType() const override;
 
   bool isScalar() const {
-    return vtype_ == ValType::Scalar || vtype_ == ValType::NamedScalar;
+    return vtype_ == ValType::Scalar || vtype_ == ValType::NamedScalar ||
+        vtype_ == ValType::KirScalar || vtype_ == ValType::KirNamedScalar;
   }
 
   bool isConstScalar() const;
@@ -208,28 +219,8 @@ class TORCH_CUDA_CU_API Val : public Statement {
 
   // Returns the Expr that this value is an output of, returns nullptr if none
   // was found
-  Expr* definition() const {
-    if (is_fusion_input_) {
-      return nullptr;
-    }
-    return definition_;
-  }
-
-  const std::vector<Expr*>& uses() const;
-
-  bool isFusionInput() const {
-    return is_fusion_input_;
-  }
-
-  bool isFusionOutput() const {
-    return is_fusion_output_;
-  }
-
-  //! Returns true when other is a producer of this
-  bool isProducerOf(const Val* other) const;
-
-  //! Returns true when other is a consumer of this
-  bool isConsumerOf(const Val* other) const;
+  Expr* getOrigin();
+  const Expr* getOrigin() const;
 
   bool sameType(const Statement* other) override {
     return Statement::sameType(other) &&
@@ -239,7 +230,7 @@ class TORCH_CUDA_CU_API Val : public Statement {
   // TODO: Make this more sophisticated. A value being the same as another value
   // should be evaluated based on the DAG that created it, and that DAGs leaf
   // nodes
-  bool sameAs(const Statement* other) const override {
+  bool sameAs(const Val* const other) const {
     return this == other;
   }
 
@@ -254,83 +245,63 @@ class TORCH_CUDA_CU_API Val : public Statement {
   static Statement* mutatorDispatch(T mutator, Val*);
 
  protected:
-  friend Fusion;
-
   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
   const ValType vtype_;
   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
   const DataType dtype_;
-
-  // Following is managed by Fusion and can change.
-  void setDefinition(Expr* expr) {
-    definition_ = expr;
-  }
-
-  void setIsFusionInput(bool is_fusion_input) {
-    is_fusion_input_ = is_fusion_input;
-  }
-
-  void setIsFusionOutput(bool is_fusion_output) {
-    is_fusion_output_ = is_fusion_output;
-  }
-
-  void setUses(const std::vector<Expr*>& uses) {
-    uses_ = uses;
-  }
-
- private:
-  // Following is managed by Fusion and can change.
-  bool is_fusion_input_ = false;
-  bool is_fusion_output_ = false;
-
-  Expr* definition_ = nullptr;
-  std::vector<Expr*> uses_;
 };
 
-//!  A Expr represents a "computation." These are functions that takes inputs
-//!  and produce outputs, inputs and outputs all being Vals. There are
-//!  specializations of BinaryOp which takes 2 inputs and produces 1 output, and
-//!  UnaryOp which takes 1 input and produces 1 output. Exprs are unique and
-//!  immutable. Conceptually, Exprs could always be manipulated using unique
-//!  pointers, and we could add this later. However, for now Exprs can be
-//!  replaced in a fusion, but they cannot be modified in place.
-//!
-//!  The IR is static single assignment (SSA). Values can only be defined as an
-//!  output of an Expr once. If they are re-defined the original definition is
-//!  deleted from the program, as opposed to an ordered redefinition of the
-//!  value in the program.
-//!
-//!  Note: Registering an Expr with a Fusion is actually 2 parts, one part is
-//!  done in the Expr constructor, so that should be called on anything that
-//!  inherits Expr. The issue with having registration in Expr's constructor, is
-//!  that the constructor of an Expr will set ouputs and inputs. This
-//!  information is important for registration with Fuser, so it can track the
-//!  dependency chain.
-//!
-//!  Adding an Expr:
-//!  Right now adding an Expr is quite involved. Expr's can be defined in ir.h
-//!  or in their own header file. The following is what is currently needed for
-//!  Expr definitions:
-//!
-//! 1) Definition inheriting from Expr.
-//!      - Members must be private or protected
-//!      - Accessor functions for members
-//!      - Constructors need to register with the Fusion after inputs/outputs
-//!         are defined
-//!      - Implementation of bool sameAs(...)
-//!  2) dispatch.h/.cpp must be updated to include dispatch of the new Val
-//!  3) Default mutator function should be added to mutator.h/.cpp
-//!  4) Printing functions should be added to ir_iostream.h/.cpp
-//!  5) Lower case convenience functions should be added to arith.h/.cpp (If
-//!     user facing)
-//!  6) An enum value must be added to ExprType in type.h
-//!  7) A string entry must be added in expr_type_string_map
-//!  8) Entry added to ir_graphviz .cpp/.h
-//!
+//  A Expr represents a "computation." These are functions that takes inputs
+//  and produce outputs, inputs and outputs all being Vals. There are
+//  specializations of BinaryOp which takes 2 inputs and produces 1 output, and
+//  UnaryOp which takes 1 input and produces 1 output. Exprs are unique and
+//  immutable. Conceptually, Exprs could always be manipulated using unique
+//  pointers, and we could add this later. However, for now Exprs can be
+//  replaced in a fusion, but they cannot be modified in place.
+
+//  The IR is static single assignment (SSA). Values can only be defined as an
+//  output of an Expr once. If they are re-defined the original definition is
+//  deleted from the program, as opposed to an ordered redefinition of the value
+//  in the program.
+
+//  Note: Registering an Expr with a Fusion is actually 2 parts, one part is
+//  done in the Expr constructor, so that should be called on anything that
+//  inherits Expr. The issue with having registration in Expr's constructor, is
+//  that the constructor of an Expr will set ouputs and inputs. This information
+//  is important for registration with Fuser, so it can track the dependency
+//  chain.
+
+//  Adding an Expr:
+//  Right now adding an Expr is quite involved. Expr's can be defined in ir.h or
+//  in their own header file. The following is what is currently needed for Expr
+//  definitions:
+//  1) Definition inheriting from Expr.
+//      - Members must be private or protected
+//      - Accessor functions for members
+//      - Constructors need to register with the Fusion after inputs/outputs are
+//         defined
+//      - Implementation of bool sameAs(...)
+//  2) dispatch.h/.cpp must be updated to include dispatch of the new Val
+//  3) Default mutator function should be added to mutator.h/.cpp
+//  4) Printing functions should be added to ir_iostream.h/.cpp
+//  5) Lower case convenience functions should be added to arith.h/.cpp (If user
+//   facing)
+//  6) An enum value must be added to ExprType in type.h
+//  7) A string entry must be added in expr_type_string_map
+//  8) Entry added to ir_graphviz .cpp/.h
+
 class TORCH_CUDA_CU_API Expr : public Statement {
  public:
-  explicit Expr(ExprType type);
+  Expr() = delete;
+  explicit Expr(ExprType _type);
   Expr(const Expr* src, IrCloner* ir_cloner);
+  ~Expr() override = default;
+
+  Expr(const Expr& other) = delete;
+  Expr& operator=(const Expr& other) = delete;
+
+  Expr(Expr&& other) = delete;
+  Expr& operator=(Expr&& other) = delete;
 
   c10::optional<ExprType> getExprType() const override {
     return type_;
@@ -340,7 +311,7 @@ class TORCH_CUDA_CU_API Expr : public Statement {
     return type_;
   }
 
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const Expr* const other) const;
 
   // Input/output accessors
   const auto& inputs() const {
index 0c9bbae..72ae3d5 100644 (file)
@@ -67,8 +67,12 @@ void IrCloner::handle(const Bool* b) {
   clone_ = new Bool(b, this);
 }
 
-void IrCloner::handle(const Double* d) {
-  clone_ = new Double(d, this);
+void IrCloner::handle(const Float* f) {
+  clone_ = new Float(f, this);
+}
+
+void IrCloner::handle(const Half* h) {
+  clone_ = new Half(h, this);
 }
 
 void IrCloner::handle(const Int* i) {
@@ -103,22 +107,6 @@ void IrCloner::handle(const ReductionOp* op) {
   clone_ = new ReductionOp(op, this);
 }
 
-void IrCloner::handle(const WelfordOp* op) {
-  clone_ = new WelfordOp(op, this);
-}
-
-void IrCloner::handle(const TransposeOp* op) {
-  clone_ = new TransposeOp(op, this);
-}
-
-void IrCloner::handle(const ShiftOp* op) {
-  clone_ = new ShiftOp(op, this);
-}
-
-void IrCloner::handle(const GatherOp* op) {
-  clone_ = new GatherOp(op, this);
-}
-
 void IrCloner::handle(const Split* split) {
   clone_ = new Split(split, this);
 }
index 4b9be75..badec64 100644 (file)
@@ -13,11 +13,7 @@ namespace cuda {
 
 class Fusion;
 
-//! Clones nodes from an exiting Fusion
-//!
-//! \warning IrCloner machinery is a specialized helper for implementing
-//!   Fusion copy operations and it's not intended for any other uses
-//!
+// Clones nodes from an exiting Fusion
 class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch {
   friend class Statement;
 
@@ -60,7 +56,8 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch {
   void handle(const IterDomain*) override;
 
   void handle(const Bool*) override;
-  void handle(const Double*) override;
+  void handle(const Float*) override;
+  void handle(const Half*) override;
   void handle(const Int*) override;
   void handle(const NamedScalar*) override;
 
@@ -69,10 +66,6 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch {
   void handle(const TernaryOp*) override;
   void handle(const BroadcastOp*) override;
   void handle(const ReductionOp*) override;
-  void handle(const WelfordOp*) override;
-  void handle(const TransposeOp*) override;
-  void handle(const ShiftOp*) override;
-  void handle(const GatherOp*) override;
 
   void handle(const Split*) override;
   void handle(const Merge*) override;
index 5ca8d54..ea3d31b 100644 (file)
@@ -43,14 +43,25 @@ class IrNodeLabel : private OptInConstDispatch {
     }
   }
 
-  void handle(const Double* d) override {
-    if (d->isSymbolic()) {
-      label_ << "d" << d->name();
+  void handle(const Float* f) override {
+    if (f->isSymbolic()) {
+      label_ << "f" << f->name();
     } else {
       if (detail_level_ >= DetailLevel::Explicit) {
-        label_ << "d" << d->name() << "=";
+        label_ << "f" << f->name() << "=";
       }
-      label_ << *d->value();
+      label_ << std::fixed << std::setprecision(2) << *f->value();
+    }
+  }
+
+  void handle(const Half* h) override {
+    if (h->isSymbolic()) {
+      label_ << "h" << h->name();
+    } else {
+      if (detail_level_ >= DetailLevel::Explicit) {
+        label_ << "h" << h->name() << "=";
+      }
+      label_ << *h->value();
     }
   }
 
@@ -78,12 +89,14 @@ class IrNodeLabel : private OptInConstDispatch {
       label_ << IrNodeLabel::gen(id->start()) << " : ";
     }
     label_ << IrNodeLabel::gen(id->extent());
+    if (id->rawExtent() != id->extent()) {
+      label_ << "\\<" << IrNodeLabel::gen(id->rawExtent()) << "\\>";
+    }
     label_ << ")";
   }
 
   void handle(const Split* split) override {
-    label_ << "Split(inner=" << (split->innerSplit() ? "true" : "false")
-           << ", factor=" << IrNodeLabel::gen(split->factor()) << ")";
+    label_ << "Split(factor=" << IrNodeLabel::gen(split->factor()) << ")";
   }
 
   void handle(const Merge* merge) override {
@@ -95,64 +108,28 @@ class IrNodeLabel : private OptInConstDispatch {
   const DetailLevel detail_level_;
 };
 
-// Small color palette from the X11 theme
-static const char* getColorFromIndex(size_t index) {
-  const size_t number_of_colors = 10;
-  index = index % number_of_colors;
-  switch (index) {
-    case 0: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-      return "azure";
-    case 1: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-      return "pink";
-    case 2: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-      return "green";
-    case 3: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-      return "grey";
-    case 4: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-      return "yellow";
-    case 5: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-      return "lavender";
-    case 6: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-      return "cyan";
-    case 7: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-      return "white";
-    case 8: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-      return "magenta";
-    case 9: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-      return "red";
-    default:
-      break;
-  }
-  return "";
-}
-
 } // anonymous namespace
 
 void IrGraphGenerator::print(
     const Fusion* fusion,
     const char* filename,
-    DetailLevel detail_level,
-    ExprColorMap* expr_color_map) {
+    DetailLevel detail_level) {
   std::ofstream dot_file(filename);
   TORCH_CHECK(dot_file.good(), "Failed to open the IR graph file");
-  dot_file << toGraphviz(fusion, detail_level, expr_color_map);
+  dot_file << toGraphviz(fusion, detail_level);
 }
 
 std::string IrGraphGenerator::toGraphviz(
     const Fusion* fusion,
-    DetailLevel detail_level,
-    ExprColorMap* expr_color_map) {
-  IrGraphGenerator ir_graph(fusion, detail_level, expr_color_map);
+    DetailLevel detail_level) {
+  IrGraphGenerator ir_graph(fusion, detail_level);
   return ir_graph.generate();
 }
 
 IrGraphGenerator::IrGraphGenerator(
     const Fusion* fusion,
-    DetailLevel detail_level,
-    ExprColorMap* expr_color_map)
-    : detail_level_(detail_level),
-      fusion_(fusion),
-      expr_color_map_(expr_color_map) {
+    DetailLevel detail_level)
+    : detail_level_(detail_level), fusion_(fusion) {
   // setup inputs & outputs
   // (indexes used to quickly check if a value is fusion input or output)
   for (const auto* input : fusion->inputs()) {
@@ -195,13 +172,7 @@ void IrGraphGenerator::addArc(
 void IrGraphGenerator::printExpr(const Expr* expr, const std::string& label) {
   graph_def_ << "    " << getid(expr) << " "
              << "[label=\"" << label << "\", shape=oval, color=blue, "
-             << "style=filled, fillcolor=";
-  if (expr_color_map_ != nullptr && expr_color_map_->count(expr)) {
-    graph_def_ << getColorFromIndex(expr_color_map_->at(expr));
-  } else {
-    graph_def_ << "azure";
-  }
-  graph_def_ << "];\n";
+             << "style=filled, fillcolor=azure];\n";
 }
 
 void IrGraphGenerator::printValue(const Val* val, const std::string& label) {
@@ -324,7 +295,7 @@ void IrGraphGenerator::handle(const Statement* s) {
 void IrGraphGenerator::handle(const Val* v) {
   if (!visited(v)) {
     visited_.insert(v);
-    if (const auto* def = v->definition()) {
+    if (const auto* def = fusion_->origin(v)) {
       handle(def);
     }
     OptInConstDispatch::handle(v);
@@ -355,15 +326,24 @@ void IrGraphGenerator::handle(const IterDomain* id) {
     addArc(id->start(), id, "[color=gray]");
   }
 
-  addArc(id->extent(), id, "[color=gray]");
+  addArc(id->rawExtent(), id, "[color=gray]");
+
+  if (detail_level_ >= DetailLevel::Explicit &&
+      id->rawExtent() != id->extent()) {
+    addArc(id->extent(), id, "[color=gray, style=dashed]");
+  }
 }
 
 void IrGraphGenerator::handle(const Bool* b) {
   printValue(b, IrNodeLabel::gen(b, detail_level_));
 }
 
-void IrGraphGenerator::handle(const Double* d) {
-  printValue(d, IrNodeLabel::gen(d, detail_level_));
+void IrGraphGenerator::handle(const Float* f) {
+  printValue(f, IrNodeLabel::gen(f, detail_level_));
+}
+
+void IrGraphGenerator::handle(const Half* h) {
+  printValue(h, IrNodeLabel::gen(h, detail_level_));
 }
 
 void IrGraphGenerator::handle(const Int* i) {
@@ -399,6 +379,13 @@ void IrGraphGenerator::handle(const TensorView* tv) {
   graph_def_ << "    " << getid(tv) << " [label=\"" << label.str()
              << "\", shape=Mrecord, color=brown, " << style << "];\n";
 
+  if (const auto* compute_at_view = tv->getComputeAtView()) {
+    std::stringstream arc_style;
+    arc_style << "[color=red, style=dashed, label=\""
+              << "ComputeAt(" << tv->getRelativeComputeAtAxis() << ")\"]";
+    addArc(tv, compute_at_view, arc_style.str());
+  }
+
   tensor_views_.push_back(tv);
 }
 
index 7bf7420..5632bc2 100644 (file)
@@ -42,25 +42,16 @@ class TORCH_CUDA_CU_API IrGraphGenerator : private OptInConstDispatch {
     Verbose, // Includes all values and dead definitions
   };
 
-  using ExprColorMap = std::unordered_map<const Expr*, size_t>;
-
  public:
   static void print(
       const Fusion* fusion,
       const char* filename,
-      DetailLevel detail_level = DetailLevel::Basic,
-      ExprColorMap* expr_color_map = nullptr);
+      DetailLevel detail_level = DetailLevel::Basic);
 
-  static std::string toGraphviz(
-      const Fusion* fusion,
-      DetailLevel detail_level,
-      ExprColorMap* expr_color_map = nullptr);
+  static std::string toGraphviz(const Fusion* fusion, DetailLevel detail_level);
 
  private:
-  IrGraphGenerator(
-      const Fusion* fusion,
-      DetailLevel detail_level,
-      ExprColorMap* expr_color_map = nullptr);
+  IrGraphGenerator(const Fusion* fusion, DetailLevel detail_level);
   ~IrGraphGenerator() override = default;
 
   std::string generate();
@@ -77,7 +68,8 @@ class TORCH_CUDA_CU_API IrGraphGenerator : private OptInConstDispatch {
   void handle(const IterDomain*) override;
 
   void handle(const Bool*) override;
-  void handle(const Double*) override;
+  void handle(const Float*) override;
+  void handle(const Half*) override;
   void handle(const Int*) override;
   void handle(const NamedScalar*) override;
 
@@ -116,7 +108,6 @@ class TORCH_CUDA_CU_API IrGraphGenerator : private OptInConstDispatch {
   std::vector<const TensorView*> tensor_views_;
   std::vector<std::string> arcs_;
   int next_id_ = 1;
-  ExprColorMap* expr_color_map_ = nullptr;
 };
 
 } // namespace cuda
index 9a9ca52..87d7bfc 100644 (file)
@@ -8,30 +8,38 @@
 
 #include <torch/csrc/jit/ir/ir.h>
 
-//! Nodes in here are intended to be "user facing" users in this sense being
-//! those that want to be able to generate CUDA code.
+/*
+ * Nodes in here are intended to be "user facing" users in this sense being
+ * those that want to be able to generate CUDA code.
+ */
 
 namespace torch {
 namespace jit {
 namespace fuser {
 namespace cuda {
 
-class WelfordResult;
-
-//! A Bool value
-//!
-//! This value can be a symbolic value (defined after the kernel
-//! is compiled) or a constant value (inlined into the kernel definition).
-//!
+/*
+ * A Bool value.
+ * This value can be a symbolic value (defined after the kernel
+ * is compiled) or a constant value (inlined into the kernel definition).
+ */
 class TORCH_CUDA_CU_API Bool : public Val {
  public:
+  ~Bool() override = default;
+
   Bool() : Val(ValType::Scalar, DataType::Bool), maybe_value_{c10::nullopt} {}
 
-  explicit Bool(bool value)
-      : Val(ValType::Scalar, DataType::Bool), maybe_value_{value} {}
+  explicit Bool(bool _value)
+      : Val(ValType::Scalar, DataType::Bool), maybe_value_{_value} {}
 
   Bool(const Bool* src, IrCloner* ir_cloner);
 
+  Bool(const Bool& other) = delete;
+  Bool& operator=(const Bool& other) = delete;
+
+  Bool(Bool&& other) = delete;
+  Bool& operator=(Bool&& other) = delete;
+
   bool isSymbolic() const {
     return !(maybe_value_.has_value());
   }
@@ -42,26 +50,35 @@ class TORCH_CUDA_CU_API Bool : public Val {
     return maybe_value_;
   }
 
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const Bool* const other) const;
 
  private:
   const c10::optional<bool> maybe_value_;
 };
 
-//! A Float64 value. For now we don't have any other type besides
-//! Float64. This value can be a symbolic value (defined after the kernel
-//! is compiled) or a constant value (inlined into the kernel definition).
-class TORCH_CUDA_CU_API Double : public Val {
+/*
+ * A Float32 value. For now we don't have any other type besides
+ * Float32. This value can be a symbolic value (defined after the kernel
+ * is compiled) or a constant value (inlined into the kernel definition).
+ */
+class TORCH_CUDA_CU_API Float : public Val {
  public:
   using ScalarType = double;
 
-  Double()
-      : Val(ValType::Scalar, DataType::Double), maybe_value_{c10::nullopt} {}
+  ~Float() override = default;
 
-  explicit Double(ScalarType value)
-      : Val(ValType::Scalar, DataType::Double), maybe_value_{value} {}
+  Float() : Val(ValType::Scalar, DataType::Float), maybe_value_{c10::nullopt} {}
 
-  Double(const Double* src, IrCloner* ir_cloner);
+  explicit Float(ScalarType _value)
+      : Val(ValType::Scalar, DataType::Float), maybe_value_{_value} {}
+
+  Float(const Float* src, IrCloner* ir_cloner);
+
+  Float(const Float& other) = delete;
+  Float& operator=(const Float& other) = delete;
+
+  Float(Float&& other) = delete;
+  Float& operator=(Float&& other) = delete;
 
   bool isSymbolic() const {
     return !(maybe_value_.has_value());
@@ -73,25 +90,71 @@ class TORCH_CUDA_CU_API Double : public Val {
     return maybe_value_;
   }
 
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const Float* const other) const;
 
  private:
   const c10::optional<ScalarType> maybe_value_;
 };
 
-//! An Int64 value. If used for indexing it's set as size_t. Otherwise it's an
-//! inlined literal in the kernel.
+/*
+ * An IEEE 754 Float16 value.
+ * This value can be a symbolic value (defined after the kernel
+ * is compiled) or a constant value (inlined into the kernel definition).
+ */
+class TORCH_CUDA_CU_API Half : public Val {
+ public:
+  ~Half() override = default;
+
+  Half() : Val(ValType::Scalar, DataType::Half), maybe_value_{c10::nullopt} {}
+
+  explicit Half(float _value)
+      : Val(ValType::Scalar, DataType::Half), maybe_value_{_value} {}
+
+  Half(const Half* src, IrCloner* ir_cloner);
+
+  Half(const Half& other) = delete;
+  Half& operator=(const Half& other) = delete;
+
+  Half(Half&& other) = delete;
+  Half& operator=(Half&& other) = delete;
+
+  bool isSymbolic() const {
+    return !(maybe_value_.has_value());
+  }
+  bool isConst() const {
+    return maybe_value_.has_value();
+  }
+  c10::optional<float> value() const {
+    return maybe_value_;
+  }
+
+  bool sameAs(const Half* const other) const;
+
+ private:
+  const c10::optional<float> maybe_value_;
+};
+
+// An Int64 value. If used for indexing it's set as size_t. Otherwise it's an
+// inlined literal in the kernel.
 class TORCH_CUDA_CU_API Int : public Val {
  public:
   using ScalarType = int64_t;
 
+  ~Int() override = default;
+
   Int() : Val(ValType::Scalar, DataType::Int), maybe_value_{c10::nullopt} {}
 
-  explicit Int(ScalarType value)
-      : Val(ValType::Scalar, DataType::Int), maybe_value_{value} {}
+  explicit Int(ScalarType _value)
+      : Val(ValType::Scalar, DataType::Int), maybe_value_{_value} {}
 
   Int(const Int* src, IrCloner* ir_cloner);
 
+  Int(const Int& other) = delete;
+  Int& operator=(const Int& other) = delete;
+
+  Int(Int&& other) = delete;
+  Int& operator=(Int&& other) = delete;
+
   bool isSymbolic() const {
     return !(maybe_value_.has_value());
   }
@@ -102,62 +165,59 @@ class TORCH_CUDA_CU_API Int : public Val {
     return maybe_value_;
   }
 
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const Int* const other) const;
 
  private:
   const c10::optional<ScalarType> maybe_value_;
 };
 
-//! Mode during propagation of computeAt, standard will throw an error if
-//! computeAt position provided can't be satisfied, best effort will lower the
-//! computeAt position as needed during traversal, most inlined will increase
-//! the compute at position to maximum possible through traversal.
-enum class ComputeAtMode { Standard, BestEffort, MostInlined };
-
 class ComputeAt;
-class TransformPropagator;
-class TransformIter;
 class TransformReplay;
+class TransformIter;
 class OptOutMutator;
+class LoopNestGenerator;
 
 namespace ir_utils {
 class TVDomainGuard;
 }
 
-//! TensorView is our primitive Tensor Type used in code generation. It can be
-//! thought of as representing physical memory, however, its dimensionality is
-//! modifed as split/merge/computeAt functions are called. The history of
-//! these transformations are kept and used for generating actual code
-//! referncing physical memory. Generally when users are thinking of code
-//! generation in reference to a Tensor, this is the class they should be
-//! interacting with.
-//!
-//! The reason we need both TensorView and TensorDomain is that we need to have
-//! a record of both what is being computed and how it is being computed. For
-//! example we may have the operation:
-//!
-//!   TV3[I, J, K] = TV2[I, J, K] + TV1[I, J, K]
-//!
-//! The mathematical operations here are on the tensor views TV1, TV2, and
-//! TV3. This operation is a pointwise operation. To compute this pointwise
-//! operation we iterate over the 3D TensorDomain [I, J, K], where K is the
-//! fastest changing dimension.
-//!
-//! \todo Need to work on the const model for TensorView, making all functions
-//! that should be const, const. Gave this a try but expanded really quickly.
-//! getComputeAtAxis not being const because it can return a TV that some expect
-//! to be non-const is the biggest headache.
-//!
+// TensorView is our primitive Tensor Type used in code generation. It can be
+// thought of as representing physical memory, however, its dimensionality is
+// modifed as split/merge/computeAt functions are called. The history of
+// these transformations are kept and used for generating actual code referncing
+// physical memory. Generally when users are thinking of code generation in
+// reference to a Tensor, this is the class they should be interacting with.
+//
+// The reason we need both TensorView and TensorDomain is that we need to have a
+// record of both what is being computed and how it is being computed. For
+// example we may have the operation: TV3[I, J, K] = TV2[I, J, K] + TV1[I, J, K]
+// The mathematical operations here are on the tensor views TV1, TV2, and TV3.
+// This operation is a pointwise operation. To compute this pointwise operation
+// we iterate over the 3D TensorDomain [I, J, K], where K is the fastest
+// changing dimension.
+//
+// TODO: Need to work on the const model for TensorView, making all functions
+// that should be const, const. Gave this a try but expanded really quickly.
+// getComputeAtAxis not being const because it can return a TV that some expect
+// to be non-const is the biggest headache.
 class TORCH_CUDA_CU_API TensorView : public Val {
  public:
+  ~TensorView() override = default;
+
+  TensorView(const TensorView& other) = delete;
+  TensorView& operator=(const TensorView& other) = delete;
+
+  TensorView(TensorView&& other) = delete;
+  TensorView& operator=(TensorView&& other) = delete;
+
   TensorView(
-      TensorDomain* domain,
+      TensorDomain* _domain,
       DataType dtype,
       MemoryType mtype = MemoryType::Local);
 
-  explicit TensorView(const std::shared_ptr<c10::TensorType>& tensor_type);
+  TensorView(const std::shared_ptr<c10::TensorType>& tensor_type);
 
-  explicit TensorView(const std::shared_ptr<Value>& jit_value)
+  TensorView(const std::shared_ptr<Value>& jit_value)
       : TensorView(jit_value->type()->cast<c10::TensorType>()) {}
 
   TensorView(const TensorView* src, IrCloner* ir_cloner);
@@ -169,15 +229,10 @@ class TORCH_CUDA_CU_API TensorView : public Val {
   bool hasReduction() const;
   bool hasBlockReduction() const;
   bool hasGridReduction() const;
+  bool hasBlockBroadcast() const;
   bool hasBroadcast() const;
   bool hasRFactor() const;
 
-  //! This is the previous hasReduction logic,
-  //! kept here exclusively for lower loop pass will
-  //! deprecate when Fusion IR pass can convert
-  //! trivial reductions
-  bool hasAnyReduction() const;
-
   c10::optional<unsigned int> getReductionAxis() const;
 
   const std::vector<IterDomain*>& getRootDomain() const;
@@ -190,71 +245,66 @@ class TORCH_CUDA_CU_API TensorView : public Val {
 
   IterDomain* axis(int pos) const;
 
-  // Does it share outer axes with other tensors?
+  // Is there an active computeAt TensorView/Axis
   bool hasComputeAt() const {
-    return compute_at_pos_ > 0;
+    return compute_at_view_ != nullptr;
   }
 
-  bool hasMaxProducerPosition() const {
-    return max_producer_pos_ > 0;
+  // Return the TensorView we're computing at
+  TensorView* getComputeAtView() const {
+    return compute_at_view_;
   }
 
   size_t nDims() const;
 
-  // Returns the position that this tensor is produced at relative to its axes.
-  unsigned int getComputeAtPosition() const {
-    return compute_at_pos_;
+  // Return compute at axis relative to this domain
+  unsigned int getThisComputeAtAxis() const {
+    return this_compute_at_axis_;
+  }
+
+  // Return compute at axis relative to compute at view
+  unsigned int getRelativeComputeAtAxis() const {
+    return relative_compute_at_axis_;
+  }
+
+  // Return position in compute_at_view that lines up with this->axis(pos)?
+  int getComputeAtRelPos(int pos);
+
+  // Will check if an axis is inside computeAtAxis and will fetch the reference
+  // to be used in code generation.
+  std::pair<int, TensorView*> getComputeAtPos(int pos) {
+    pos = normalizeAxisPos(pos);
+    TORCH_INTERNAL_ASSERT(
+        nDims() > 0, "Tried to access a computeAt axis in a 0-dim TensorView");
+    if (!hasComputeAt() || getThisComputeAtAxis() <= (unsigned int)pos)
+      return std::make_pair(pos, this);
+    return compute_at_view_->getComputeAtPos(getComputeAtRelPos(pos));
+  }
+
+  std::pair<IterDomain*, TensorView*> getComputeAtAxis(int pos) {
+    const auto computeAtPos = getComputeAtPos(pos);
+    return std::make_pair(
+        computeAtPos.second->axis(computeAtPos.first), computeAtPos.second);
   }
 
-  // Returns the maximum position of producers are being computed at relative to
-  // this tensor. This position dictates the clear expectations of producers.
-  unsigned int getMaxProducerPosition() const {
-    return max_producer_pos_;
+  // Compute this TensorView relative to another tensor at axis
+  TensorView* computeAt(TensorView* consumer, int axis);
+
+  void clearComputeAt() {
+    this_compute_at_axis_ = 0;
+    relative_compute_at_axis_ = 0;
+    compute_at_view_ = nullptr;
   }
 
-  //! This is used when we disconnect a tensorview from a reduction
-  //!  operation and connect it to a non-reduction operator. We need
-  //!  to remove the reduction ids on the tv in this case.
-  //! Currently only used in translate welford, and this function may
-  //!  be refactored or extended if any more use cases appear.
-  void clearReductionIterDomains();
-
-  //! Compute this TensorView relative to a consumer position, -1 will
-  //! compute tensors inline with each other, 0 doesn't share
-  //! any loop nests between the tensors. It's an error when the given
-  //! position is not legally viable. Alternatively, when the mode
-  //! parameter is ComputeAtMode::BestEffort, the position is lowered
-  //! one by one until a valid position is found. When
-  //! ComputeAtMode::MostInlined is given, the position parameter is
-  //! ignored, and the deepest possible position is searched.
-  TensorView* computeAt(
-      TensorView* consumer,
-      int position,
-      ComputeAtMode mode = ComputeAtMode::Standard);
-
-  //! Compute this tensor to consumer, at local position, -1 will compute
-  //! tensors inline with eachother, 0 doesn't share any loop nests between the
-  //! tensors. The mode parameter can be used in the same manner as computeAt.
-  TensorView* computeWith(
-      TensorView* consumer,
-      int position,
-      ComputeAtMode mode = ComputeAtMode::Standard);
-
-  // Split "axis" into 2 axes
-  //! inner_split dictates if the factor section of the split should be inside
-  //! the
-  //! remainer or outside.
-  //! e.g. split(0, 4, inner_split = true) will result in:
-  //! tv[id{extent}] -> tv[id{ceilDiv(extent, factor)}, id{factor}]
-  //! e.g. split(0, 4, inner_split = false) will result in:
-  //! tv[id{extent}] -> tv[id{factor}, id{ceilDiv(extent, factor)}]
-  TensorView* split(int axis, unsigned int factor, bool inner_split = true);
+  // Split "axis" into 2 axes where the inner axes is size of "factor"
+  // and outer axis is size axis.size() / factor
+  TensorView* split(int axis, unsigned int factor);
 
   // Split "axis" into 2 axes where the inner axes is size of "factor"
   // and outer axis is size axis.size() / factor. Factor can be a symbolic
   // value instead of constant. This requires setting the symbolic value as an
   // input, or using a parallel dim from NamedScalar::getParallelDim
-  TensorView* split(int axis, Val* factor, bool inner_split = true);
+  TensorView* split(int axis, Val* factor);
 
   // Merge axis_o and axis_i into 1 IterDomain
   TensorView* merge(int axis_o, int axis_i);
@@ -267,17 +317,6 @@ class TORCH_CUDA_CU_API TensorView : public Val {
   // Reorder axes according to old2new[old_pos] = new_pos
   TensorView* reorder(const std::unordered_map<int, int>& old2new);
 
-  //! Swizzle indices to improve memory access efficiency.
-  //!
-  //! Swizzle::Transpose is a pattern commonly used to avoid bank
-  //! conflicts in shared memory. It takes two axes and shifts the
-  //! second axis by the first axis as ((axis1 + axis2) % extent). The
-  //! memory type must be Shared.
-  //!
-  //! \input type Swizzle pattern such as transpose.
-  //! \input axes Axes to swizzle
-  TensorView* swizzle(SwizzleType type, const std::vector<int>& axes);
-
   // WARNING: rFactor does not return this TensorView, ir returns a new
   //  tensorview consumed by this!
   //
@@ -298,15 +337,6 @@ class TORCH_CUDA_CU_API TensorView : public Val {
   //
   TensorView* rFactor(const std::vector<int>& axes);
 
-  //! Welford Version of rFactor, semantically similar with
-  //!  the reduction version except that the rfactor is done
-  //!  in a multi-output scan pattern
-  WelfordResult rFactor(
-      const std::vector<int>& axes,
-      TensorView* avg,
-      TensorView* var,
-      TensorView* n);
-
   // Create a TensorView before the original tensor. A common use case is to
   // write results into shared memory or registers before moving to global
   // memory. Analogous to TVM Cache_Write
@@ -316,41 +346,39 @@ class TORCH_CUDA_CU_API TensorView : public Val {
   // read tensor into shared memory or registers. Analogous to TVM Cache_Read
   TensorView* cache_after();
 
-  // For a fusion output with other uses, we want to avoid writing to global
-  // memory and then reading the output again. We write to global memory
-  // separately after an operation. We replace this fusion output with the
-  // direct write TensorView.
-  TensorView* cache_fork();
-
   MemoryType getMemoryType() const {
     return memory_type_;
   }
 
   void setMemoryType(MemoryType mt);
 
-  SwizzleType swizzleType() const {
-    return swizzle_type_;
-  }
-
-  const std::vector<IterDomain*>& axesToSwizzle() const {
-    return axes_to_swizzle_;
-  }
-
-  friend TORCH_CUDA_CU_API TransformPropagator;
   friend TORCH_CUDA_CU_API TransformReplay;
   friend TORCH_CUDA_CU_API OptOutMutator;
+  friend TORCH_CUDA_CU_API LoopNestGenerator;
   friend ComputeAt;
+  friend void IrFixComputeAt(Fusion*);
   friend void adjustMemoryTypes(Fusion* fusion);
   friend class ir_utils::TVDomainGuard;
 
  protected:
+  // Make an exact copy of this tensor (similar to clone()), however, also grabs
+  // the same name. Current use of this is for initialization of reductions.
+  // This will break our dependency chain as it is a literal clone of a
+  // TensorView but it has a different dependency chain. We need to improve our
+  // dependency model to allow for initailziation of reduction buffers. The only
+  // reason we can get away with this for now is because we don't use dependency
+  // analysis for the IR after we call this.
+  TensorView* unsafeClone() const;
+
   void setDomain(TensorDomain* td) {
     domain_ = td;
   }
 
-  void setComputeAt(unsigned int this_pos, bool decrease = false);
+  void setComputeAt(TensorView* computeAtView, int axis);
 
-  void setMaxProducer(unsigned int this_pos, bool decrease = false);
+  // Set all computeAt members without checking any correctness. Useful for
+  // computeAt with outputs relative to eachother
+  void setComputeAt(TensorView* computeAtView, int thisPos, int relPos);
 
  private:
   int normalizeAxisPos(int pos) const {
@@ -360,53 +388,30 @@ class TORCH_CUDA_CU_API TensorView : public Val {
     return pos;
   }
 
-  //! A helper function to maintain the consistency of welford output
-  //! schedules when doing rfactor on welford ops.
-  TensorView* welfordRfactorHelper(
-      TensorView* tv,
-      const std::vector<int>& axes);
-
- private:
-  TensorDomain* domain_ = nullptr;
-  unsigned int compute_at_pos_ = 0;
-  unsigned int max_producer_pos_ = 0;
-  MemoryType memory_type_ = MemoryType::Local;
-  SwizzleType swizzle_type_ = SwizzleType::NoSwizzle;
-  std::vector<IterDomain*> axes_to_swizzle_;
-};
-
-//! A simple TensorView builder
-//!
-//! Example usage:
-//!
-//!   auto tv = TensorViewBuilder()
-//!       .ndims(ndims)
-//!       .dtype(dtype)
-//!       .contiguity(contiguity)
-//!       .build();
-//!
-class TORCH_CUDA_CU_API TensorViewBuilder {
- public:
-  //! Set the number of dimensions of the tensor (default 0, meaning scalar)
-  TensorViewBuilder& ndims(size_t ndims);
-
-  //! Set the data type of the tensor (default DataType::Float)
-  TensorViewBuilder& dtype(DataType dtype);
-
-  //! Set the contiguity information (default non-contiguous)
-  TensorViewBuilder& contiguity(std::vector<bool> contiguity);
+  // In Cache Before, for the origin expr of the original tensor,
+  // we create a new operation where the original tensor is replaced
+  // with the new cache tensor. This function creates a new expr
+  // given the consumer, the output of the expression.
+  void createExprConsumer(Expr* expr, TensorView* consumer);
 
-  //! Set the shape (default 0 dimensional, ie. scalar)
-  TensorViewBuilder& shape(std::vector<int64_t> shape);
+  // In Cache After, for all the uses of the original tensor, we create
+  // a new operation where the original tensor is replaced with the new
+  // cache tensor. This function creates a new expr given a producer,
+  // an input for the expression.
+  void createExprProducer(
+      Expr* expr,
+      TensorView* current,
+      TensorView* producer);
 
-  //! Creates a new TensorView with the specified options
-  TensorView* build() const;
+  void setThisComputeAtAxis();
 
  private:
-  size_t ndims_ = 0;
-  DataType dtype_ = DataType::Float;
-  std::vector<bool> contiguity_;
-  std::vector<int64_t> shape_;
+  TensorDomain* domain_ = nullptr;
+  TensorView* compute_at_view_ = nullptr;
+  // compute at axis in compute at view
+  unsigned int relative_compute_at_axis_ = 0;
+  unsigned int this_compute_at_axis_ = 0;
+  MemoryType memory_type_ = MemoryType::Local;
 };
 
 } // namespace cuda
index 90f59b2..87b5b8f 100644 (file)
@@ -6,36 +6,48 @@
 #include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
 
-//! Nodes in here should generally not be used by users. They should be behind
-//! the scenes and users shouldn't have to be aware of what they do to use the
-//! code generator
-//!
-//! \todo improve implementation bool IterDomain::sameAs(const IterDomain*)
-//! \todo Add testing of sameAs functions for these nodes
-//!
+/*
+ * Nodes in here should generally not be used by users. They should be behind
+ * the scenes and users shouldn't have to be aware of what they do to use the
+ * code generator.
+ */
 
 namespace torch {
 namespace jit {
 namespace fuser {
 namespace cuda {
 
-//! Returns true if both v1 and v2 are scalars, are the same type of scalars,
-//! and dispatches to the inherited Val type's `->sameAs` call. e.g. if both
-//! vals are `Int` will dispatch to v1->as<Int>()->sameAs(v2.as<Int>())
+// Returns true if both v1 and v2 are scalars, are the same type of scalars, and
+// dispatches to the inherited Val type's `->sameAs` call. e.g. if both vals are
+// `Int` will dispatch to v1->as<Int>()->sameAs(v2.as<Int>())
 bool areEqualScalars(Val* v1, Val* v2);
 
-//! A specialization for Unary operations. Unary operations take in a single
-//! input and produce a single output. Examples include:
-//!   1) Casting operation i.e. float(a_val)
-//!   2) Negation i.e. val * -1
-//!   3) Reduction across a dimension i.e. val.sum(axis=2)
-//!   4) split/merge
+/*
+ * TODO: improve implementation bool IterDomain::sameAs(const IterDomain*) const
+ * TODO: Add testing of sameAs functions for these nodes
+ */
+
+/*
+ * A specialization for Unary operations. Unary operations take in a single
+ * input and produce a single output. Examples include:
+ *   1) Casting operation i.e. float(a_val)
+ *   2) Negation i.e. val * -1
+ *   3) Reduction across a dimension i.e. val.sum(axis=2)
+ *   4) split/merge
+ */
 class TORCH_CUDA_CU_API UnaryOp : public Expr {
  public:
-  UnaryOp(UnaryOpType type, Val* out, Val* in);
+  ~UnaryOp() override = default;
+  UnaryOp(UnaryOpType _type, Val* _out, Val* _in);
 
   UnaryOp(const UnaryOp* src, IrCloner* ir_cloner);
 
+  UnaryOp(const UnaryOp& other) = delete;
+  UnaryOp& operator=(const UnaryOp& other) = delete;
+
+  UnaryOp(UnaryOp&& other) = delete;
+  UnaryOp& operator=(UnaryOp&& other) = delete;
+
   Val* out() const {
     return out_;
   }
@@ -47,7 +59,7 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr {
     return unary_op_type_;
   }
 
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const UnaryOp* const other) const;
 
  private:
   const UnaryOpType unary_op_type_;
@@ -55,16 +67,25 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr {
   Val* const in_ = nullptr;
 };
 
-//! A specialization for Binary operations. Binary operations take in two inputs
-//! and produce a single output. Examples include:
-//!  1) Add/mul/div/mod/sub (A * B)
-//!  2) LT (A < B)
+/*
+ * A specialization for Binary operations. Binary operations take in two inputs
+ * and produce a single output. Examples include:
+ *  1) Add/mul/div/mod/sub (A * B)
+ *  2) LT (A < B)
+ */
 class TORCH_CUDA_CU_API BinaryOp : public Expr {
  public:
-  BinaryOp(BinaryOpType type, Val* out, Val* lhs, Val* rhs);
+  ~BinaryOp() override = default;
+  BinaryOp(BinaryOpType _type, Val* _out, Val* _lhs, Val* _rhs);
 
   BinaryOp(const BinaryOp* src, IrCloner* ir_cloner);
 
+  BinaryOp(const BinaryOp& other) = delete;
+  BinaryOp& operator=(const BinaryOp& other) = delete;
+
+  BinaryOp(BinaryOp&& other) = delete;
+  BinaryOp& operator=(BinaryOp&& other) = delete;
+
   Val* out() const {
     return out_;
   }
@@ -79,7 +100,7 @@ class TORCH_CUDA_CU_API BinaryOp : public Expr {
     return binary_op_type_;
   }
 
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const BinaryOp* other) const;
 
  private:
   const BinaryOpType binary_op_type_;
@@ -88,17 +109,23 @@ class TORCH_CUDA_CU_API BinaryOp : public Expr {
   Val* const rhs_ = nullptr;
 };
 
-//! Broadcast in to match out. is_broadcast_dims are relative to out. Where
-//! is_broadcast_dims.size() == out->nDims().
+/*
+ * Broadcast _in to match _out. broadcast_dims are relative to out. Where
+ * broadcast_dims.size() + _in->nDims() == _out->nDims().
+ */
 class TORCH_CUDA_CU_API BroadcastOp : public Expr {
  public:
-  //! \param out The output tensor
-  //! \param in The input tensor
-  //! \param is_broadcast_dims True when output dim is a new broadcast domain
-  BroadcastOp(Val* out, Val* in, std::vector<bool> is_broadcast_dims);
+  ~BroadcastOp() override = default;
+  BroadcastOp(Val* _out, Val* _in);
 
   BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner);
 
+  BroadcastOp(const BroadcastOp& other) = delete;
+  BroadcastOp& operator=(const BroadcastOp& other) = delete;
+
+  BroadcastOp(BroadcastOp&& other) = delete;
+  BroadcastOp& operator=(BroadcastOp&& other) = delete;
+
   Val* out() const {
     return out_;
   }
@@ -106,40 +133,33 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr {
     return in_;
   }
 
-  bool isBroadcastDim(size_t dim) const {
-    return is_broadcast_dims_.at(dim);
-  }
-
-  const std::vector<bool>& getBroadcastDimFlags() const {
-    return is_broadcast_dims_;
-  }
-
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const BroadcastOp* const other) const;
 
  private:
   Val* const out_ = nullptr;
   Val* const in_ = nullptr;
-
-  //! The same list passed to the broadcast arithmetic op. Each
-  //! element corresponds to an IterDomain of the output tensor and is
-  //! true when the IterDomain is a new broadcast domain. Note
-  //! that the output tensor may have other broadcast domains whose
-  //! flags are false because the input tensor may already have
-  //! broadcast domains.
-  const std::vector<bool> is_broadcast_dims_;
 };
 
-//! Reduction operation. Out is first initialized to _init. Then
-//! reduction_op_type is used to update out as out = reductionOp(out, in).
-//! Output's axes marked as reduction will be reduced to produce an output
-//! tensor. The output tensors size will be the size of all
-//! non-reduction/non-broadcast dimensions.
+/*
+ * Reduction operation. Out is first initialized to _init. Then
+ * _reduction_op_type is used to update out as out = reductionOp(out, in).
+ * Output's axes marked as reduction will be reduced to produce an output
+ * tensor. The output tensors size will be the size of all
+ * non-reduction/non-broadcast dimensions.
+ */
 class TORCH_CUDA_CU_API ReductionOp : public Expr {
  public:
-  ReductionOp(BinaryOpType reduction_op_type, Val* init, Val* out, Val* in);
+  ~ReductionOp() override = default;
+  ReductionOp(BinaryOpType _reduction_op_type, Val* _init, Val* _out, Val* _in);
 
   ReductionOp(const ReductionOp* src, IrCloner* ir_cloner);
 
+  ReductionOp(const ReductionOp& other) = delete;
+  ReductionOp& operator=(const ReductionOp& other) = delete;
+
+  ReductionOp(ReductionOp&& other) = delete;
+  ReductionOp& operator=(ReductionOp&& other) = delete;
+
   Val* out() const {
     return out_;
   }
@@ -154,7 +174,7 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr {
     return reduction_op_type_;
   }
 
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const ReductionOp* const other) const;
 
  private:
   const BinaryOpType reduction_op_type_;
@@ -163,124 +183,19 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr {
   Val* const in_ = nullptr;
 };
 
-//! Welford Scan operation.
-class TORCH_CUDA_CU_API WelfordOp : public Expr {
- public:
-  WelfordOp(
-      Val* out_avg,
-      Val* out_var,
-      Val* out_N,
-      Val* init_avg,
-      Val* init_var,
-      Val* init_N,
-      Val* in_avg,
-      Val* in_var,
-      Val* in_N);
-
-  WelfordOp(const WelfordOp* src, IrCloner* ir_cloner);
-
-  Val* out() const {
-    return out_avg_;
-  }
-
-  Val* in() const {
-    return in_avg_;
-  }
-
-  Val* init() const {
-    return init_avg_;
-  }
-
-  bool sameAs(const Statement* const other) const override;
-
-  // Welford Accessors
-  // TODO clean up
-  Val* outAvg() const {
-    return out_avg_;
-  }
-
-  Val* outVar() const {
-    return out_var_;
-  }
-
-  Val* outN() const {
-    return out_N_;
-  }
-
-  Val* inAvg() const {
-    return in_avg_;
-  }
-
-  Val* inVar() const {
-    return in_var_;
-  }
-
-  Val* inN() const {
-    return in_N_;
-  }
-
-  Val* initAvg() const {
-    return init_avg_;
-  }
-
-  Val* initVar() const {
-    return init_var_;
-  }
-
-  Val* initN() const {
-    return init_N_;
-  }
-
-  bool singleValue() const {
-    return in_N_->isOneInt();
-  }
-
-  bool hasInit() const {
-    return !init_N_->isZeroInt();
-  }
-
- private:
-  Val* const out_avg_;
-  Val* const out_var_;
-  Val* const out_N_;
-  Val* const init_avg_;
-  Val* const init_var_;
-  Val* const init_N_;
-  Val* const in_avg_;
-  Val* const in_var_;
-  Val* const in_N_;
-};
-
-class TORCH_CUDA_CU_API TransposeOp : public Expr {
- public:
-  TransposeOp(TensorView* out, TensorView* in, std::vector<int> new2old);
-
-  TransposeOp(const TransposeOp* src, IrCloner* ir_cloner);
-
-  TensorView* out() const {
-    return out_;
-  }
-
-  TensorView* in() const {
-    return in_;
-  }
-
-  const std::vector<int>& new2old() const {
-    return new2old_;
-  }
-
- private:
-  TensorView* const out_ = nullptr;
-  TensorView* const in_ = nullptr;
-  const std::vector<int> new2old_;
-};
-
 class TORCH_CUDA_CU_API TernaryOp : public Expr {
  public:
-  TernaryOp(TernaryOpType type, Val* out, Val* in1, Val* in2, Val* in3);
+  ~TernaryOp() override = default;
+  TernaryOp(TernaryOpType _type, Val* _out, Val* _in1, Val* _in2, Val* _in3);
 
   TernaryOp(const TernaryOp* src, IrCloner* ir_cloner);
 
+  TernaryOp(const TernaryOp& other) = delete;
+  TernaryOp& operator=(const TernaryOp& other) = delete;
+
+  TernaryOp(TernaryOp&& other) = delete;
+  TernaryOp& operator=(TernaryOp&& other) = delete;
+
   Val* out() const {
     return out_;
   }
@@ -299,7 +214,7 @@ class TORCH_CUDA_CU_API TernaryOp : public Expr {
     return ternary_op_type_;
   }
 
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const TernaryOp* other) const;
 
  private:
   const TernaryOpType ternary_op_type_;
@@ -309,102 +224,22 @@ class TORCH_CUDA_CU_API TernaryOp : public Expr {
   Val* const in3_ = nullptr;
 };
 
-//! Shift
-class TORCH_CUDA_CU_API ShiftOp : public Expr {
- public:
-  //! \param out
-  //! \param in
-  //! \param offsets
-  ShiftOp(Val* out, Val* in, std::vector<int> offsets);
-
-  ShiftOp(const ShiftOp* src, IrCloner* ir_cloner);
-
-  Val* out() const {
-    return out_;
-  }
-  Val* in() const {
-    return in_;
-  }
-
-  int offset(size_t dim) const {
-    return offsets_.at(dim);
-  }
-
-  const std::vector<int>& offsets() const {
-    return offsets_;
-  }
-
-  bool sameAs(const Statement* other) const override;
-
- private:
-  Val* const out_ = nullptr;
-  Val* const in_ = nullptr;
-  //! Each of the root axes is shifted by the corresponding value of
-  //! offsets_. The sign of each value indicates the direction of
-  //! shifting.
-  const std::vector<int> offsets_;
-};
-
-//! Gather a window around each element.
-class TORCH_CUDA_CU_API GatherOp : public Expr {
- public:
-  GatherOp(
-      Val* out,
-      Val* in,
-      std::vector<Int*> window_shape,
-      std::vector<std::vector<Int*>> pad_width);
-
-  GatherOp(const GatherOp* src, IrCloner* ir_cloner);
-
-  Val* out() const {
-    return out_;
-  }
-  Val* in() const {
-    return in_;
-  }
-
-  const auto& windowShape() const {
-    return window_shape_;
-  }
-
-  //! Returns the gather axis that corresponds to an input axis
-  int gatherAxis(int axis) const;
-
-  const auto& padWidth() const {
-    return pad_width_;
-  }
-
-  bool sameAs(const Statement* other) const override;
-
- private:
-  Val* const out_ = nullptr;
-  Val* const in_ = nullptr;
-  //! Shape of a window gathered for each element.
-  std::vector<Int*> window_shape_;
-  //! The size of zero-padding of each axis.
-  std::vector<std::vector<Int*>> pad_width_;
-};
-
-// Friends for direct access to split
-class TensorDomain;
-class ReplayTransformations;
-class IndexReferenceReplay;
-//! Simply a representation of an annotated 1D iterable from start to extent.
-//! TensorDomains which represent how to iterate over a tensor is made up of
-//! IterDomains to form an ND iterable. We directly set parallization strategies
-//! on IterDomains.
+// Simply a representation of an annotated 1D iterable from start to extent.
+// TensorDomains which represent how to iterate over a tensor is made up of
+// IterDomains to form an ND iterable. We directly set parallization strategies
+// on IterDomains.
 class TORCH_CUDA_CU_API IterDomain : public Val {
  public:
   IterDomain(
-      Val* start,
-      Val* extent,
-      ParallelType parallel_type = ParallelType::Serial,
-      IterType iter_type = IterType::Iteration,
-      bool is_rfactor_domain = false);
+      Val* _start,
+      Val* _extent,
+      ParallelType _parallel_type = ParallelType::Serial,
+      IterType _iter_type = IterType::Iteration,
+      bool _is_rfactor_domain = false);
 
   IterDomain(const IterDomain* src, IrCloner* ir_cloner);
 
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const IterDomain* const other) const;
 
   // Returns a new IterDomain matching properties of this
   // TODO: parallel_method->getParallelType
@@ -417,15 +252,18 @@ class TORCH_CUDA_CU_API IterDomain : public Val {
         isRFactorProduct());
   }
 
-  //! Clone a vector domains
-  static std::vector<IterDomain*> clone(
-      const std::vector<IterDomain*>& domains);
-
   static IterDomain* merge(IterDomain* outer, IterDomain* inner);
 
-  //! Run concretization pass and return the concretized domain of broadcast id
+  // TODO: Make protected and friend TensorDomain so only it can call into this
+  // directly, users should not be able to use this call
+  static std::pair<IterDomain*, IterDomain*> split(IterDomain* in, Val* factor);
+
+  // Run concretization pass and return the concretized domain of broadcast id
   static const IterDomain* concretizeDomain(IterDomain* bcast_dom);
 
+  // Attempt to prove 2 IterDomains are equal in start and rawExtent
+  static bool proveEquivalent(IterDomain* a, IterDomain* b);
+
   bool isReduction() const {
     return getIterType() == IterType::Reduction;
   }
@@ -439,50 +277,48 @@ class TORCH_CUDA_CU_API IterDomain : public Val {
         getIterType() == IterType::BroadcastWithoutStride;
   }
 
-  bool isGather() const {
-    return getIterType() == IterType::Gather;
-  }
-
   bool isParallelized() const {
     return getParallelType() != ParallelType::Serial;
   }
 
-  //! Return if this iter domain is mapped to a grid dimension
+  // Return if this iter domain is mapped to a grid dimension
   bool isBlockDim() const {
-    return isParallelTypeBlockDim(getParallelType());
+    return (
+        getParallelType() == ParallelType::BIDz ||
+        getParallelType() == ParallelType::BIDy ||
+        getParallelType() == ParallelType::BIDx);
   }
 
-  //! Return if this iter domain is mapped to a block dimension
+  // Return if this iter domain is mapped to a block dimension
   bool isThreadDim() const {
-    return isParallelTypeThreadDim(getParallelType());
+    return (
+        getParallelType() == ParallelType::TIDz ||
+        getParallelType() == ParallelType::TIDy ||
+        getParallelType() == ParallelType::TIDx);
   }
 
-  //! Return if this iter domain is either mapped to a block or grid dimension
+  // Return if this iter domain is either mapped to a block or grid dimension
   bool isThread() const {
     return (isBlockDim() || isThreadDim());
   }
 
-  //! Convert to strided broadcast, used for supporting broadcast on output
-  void toStridedBroadcast() {
-    TORCH_INTERNAL_ASSERT(
-        isBroadcast(),
-        "toStridedBroadCast: converting an non-broadcast iterdomain",
-        this);
-    iter_type_ = IterType::BroadcastWithStride;
-  }
+  void parallelize(ParallelType t) {
+    parallel_type_ = t;
 
-  // Convert a serial iterdomain to broadcast, used for implicit broadcast
-  void convertToBroadcast() {
-    TORCH_INTERNAL_ASSERT(
-        !isBroadcast() && !isReduction(),
-        "convertToBroadcast: converting an non-serial iterdomain",
-        this);
+    TORCH_CHECK(
+        t != ParallelType::Vectorize, "Vectorization not yet supported.");
 
-    iter_type_ = IterType::BroadcastWithStride;
+    if (t == ParallelType::Unroll)
+      TORCH_CHECK(
+          start()->isZeroInt() && extent()->isConstScalar(),
+          "Unrolling only supported with start = 0 and extent as a const int, but got ",
+          "a start of ",
+          start(),
+          " and extent ",
+          extent(),
+          " .");
   }
 
-  void parallelize(ParallelType t);
-
   ParallelType getParallelType() const {
     return parallel_type_;
   }
@@ -494,40 +330,17 @@ class TORCH_CUDA_CU_API IterDomain : public Val {
   Val* start() const {
     return start_;
   }
+  Val* extent() const;
 
-  Val* extent() const {
-    TORCH_INTERNAL_ASSERT(extent_ != nullptr);
+  Val* rawExtent() const {
     return extent_;
   }
 
-  //! Check if IterDomain is a broadcast axis with compile-time
-  //! known extent. This is the case with all size-1 IterDomains on
-  //! a TensorView's root domain when the TensorView is created.
-  bool isImplicitBroadcast() const {
-    return isBroadcast() && extent()->isOneInt();
-  }
-
-  //! Check if IterDomain is a reduction axis with size of 1, i.e.
-  //! a "squeeze" operator.
-  //!
-  //! NOTE: Detection of trivial reduction here is not
-  //! comprehensive. See detectTrivialReductionDerivedDomains for more
-  //! comprehensive analysis. We typically use this for root domain trivial
-  //! reduction checks. So we ship to the correct scheduler. It may
-  //! not be incredibly robust, but it makes sense to keep it for now.
-  bool isTrivialReduction() const {
-    return isReduction() && extent()->isOneInt();
-  }
-
- protected:
-  friend TensorDomain;
-  friend ReplayTransformations;
-  friend IndexReferenceReplay;
+  IterDomain(const IterDomain& other) = delete;
+  IterDomain& operator=(const IterDomain& other) = delete;
 
-  static std::pair<IterDomain*, IterDomain*> split(
-      IterDomain* in,
-      Val* factor,
-      bool inner_split);
+  IterDomain(IterDomain&& other) = delete;
+  IterDomain& operator=(IterDomain&& other) = delete;
 
  private:
   Val* const start_ = nullptr;
@@ -537,36 +350,44 @@ class TORCH_CUDA_CU_API IterDomain : public Val {
   bool is_rfactor_domain_ = false;
 };
 
-//! TensorDomain holds a vector of IterDomains. It holds an IterDomain for every
-//! logical axis in its associated tensor. TensorDomain does not directly hold
-//! the Tensor it is associated with, and in theory could be associated with
-//! multiple tensors. TensorDomain's primary responsibility is to provide a
-//! mechanism to access history of transformations that were used to generate
-//! it. This is done through the normal interaction of Expr/Val in Fusion. i.e.
-//! if we want to know the previous operation generating a particular
-//! TensorDomain we can simply call:
-//!
-//!     FusionGuard::getCurFusion()->definition(a_tensor_domain)
-//!
-//! which should give us an operation in the list [split, merge] or similar
-//! operations that take in a TensorDomain, applies a transformation and outputs
-//! a tensor domain.
+/*
+ * TensorDomain holds a vector of IterDomains. It holds an IterDomain for every
+ * logical axis in its associated tensor. TensorDomain does not directly hold
+ * the Tensor it is associated with, and in theory could be associated with
+ * multiple tensors. TensorDomain's primary responsibility is to provide a
+ * mechanism to access history of transformations that were used to generate it.
+ * This is done through the normal interaction of Expr/Val in Fusion. i.e. if we
+ * want to know the previous operation generating a particular TensorDomain we
+ * can simply call FusionGuard::getCurFusion()->origin(a_tensor_domain) which
+ * should give us an operation in the list [split, merge] or similar
+ * operations that take in a TensorDomain, applies a transformation and outputs
+ * a tensor domain.
+ */
 class TORCH_CUDA_CU_API TensorDomain : public Val {
  public:
+  TensorDomain() = delete;
+  ~TensorDomain() override = default;
+
+  TensorDomain(const TensorDomain& other) = delete;
+  TensorDomain& operator=(const TensorDomain& other) = delete;
+
+  TensorDomain(TensorDomain&& other) = delete;
+  TensorDomain& operator=(TensorDomain&& other) = delete;
+
   explicit TensorDomain(
-      std::vector<IterDomain*> root_domain,
-      std::vector<bool> contiguity = std::vector<bool>());
+      std::vector<IterDomain*> _domain,
+      std::vector<bool> _contiguity = std::vector<bool>());
 
   TensorDomain(
-      std::vector<IterDomain*> root_domain,
-      std::vector<IterDomain*> domain,
-      std::vector<bool> contiguity = std::vector<bool>());
+      std::vector<IterDomain*> _root_domain,
+      std::vector<IterDomain*> _domain,
+      std::vector<bool> _contiguity = std::vector<bool>());
 
   TensorDomain(
-      std::vector<IterDomain*> root_domain,
-      std::vector<IterDomain*> rfactor_domain,
-      std::vector<IterDomain*> domain,
-      std::vector<bool> contiguity = std::vector<bool>());
+      std::vector<IterDomain*> _root_domain,
+      std::vector<IterDomain*> _rfactor_domain,
+      std::vector<IterDomain*> _domain,
+      std::vector<bool> _contiguity = std::vector<bool>());
 
   TensorDomain(const TensorDomain* src, IrCloner* ir_cloner);
 
@@ -579,7 +400,7 @@ class TORCH_CUDA_CU_API TensorDomain : public Val {
     return domain_.size();
   }
 
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const TensorDomain* const other) const;
 
   static bool sameAs(
       const std::vector<IterDomain*>& lhs,
@@ -604,9 +425,9 @@ class TORCH_CUDA_CU_API TensorDomain : public Val {
   bool hasReduction() const;
   bool hasBlockReduction() const;
   bool hasGridReduction() const;
+  bool hasBlockBroadcast() const;
   bool hasBroadcast() const;
   bool hasRFactor() const;
-  bool hasVectorize() const;
 
   c10::optional<unsigned int> getReductionAxis() const;
 
@@ -635,7 +456,6 @@ class TORCH_CUDA_CU_API TensorDomain : public Val {
   void resetDomains() {
     no_reduction_domain_ = noReductions(domain_);
     no_bcast_domain_ = noBroadcasts(domain_);
-    has_nontrivial_reduction_ = hasNontrivialReduction(domain_);
   }
 
   // i here is int, as we want to accept negative value and ::size_type can be a
@@ -644,15 +464,12 @@ class TORCH_CUDA_CU_API TensorDomain : public Val {
 
   size_t posOf(IterDomain* id) const;
 
-  // Split "axis" into 2 axes
-  //! inner_split dictates if the factor section of the split should be inside
-  //! the
-  //! remainer or outside.
-  //! e.g. split(0, 4, inner_split = true) will result in:
-  //! tv[id{extent}] -> tv[id{ceilDiv(extent, factor)}, id{factor}]
-  //! e.g. split(0, 4, inner_split = false) will result in:
-  //! tv[id{extent}] -> tv[id{factor}, id{ceilDiv(extent, factor)}]
-  void split(int axis_, Val* factor, bool inner_split);
+  // Split "axis" into 2 axes where the inner axes is size of "factor"
+  // and outer axis is size axis.size() / factor. Allow factor to be symbolic
+  // value instead of constant.
+  // TODO: Make protected and friend TensorDomain so only it can call into this
+  // directly, users should not be able to use this call
+  void split(int axis_, Val* factor);
 
   // Merge axis_o and axis_i. axis_i is the fast changing dimension. Resulting
   // axis is by default placed at original position axis_o
@@ -670,7 +487,58 @@ class TORCH_CUDA_CU_API TensorDomain : public Val {
 
   static bool hasBroadcast(const std::vector<IterDomain*>&);
   static bool hasReduction(const std::vector<IterDomain*>&);
-  static bool hasNontrivialReduction(const std::vector<IterDomain*>&);
+
+  // return std::pair<producer_id, consumer_id> representing
+  // the mapping between corresponding axes. Not all axes have
+  // corresponding mapping, e.g., broadcast axis in consumer
+  // does not have any corresponding axis in producer.
+  static std::vector<std::pair<int, int>> mapDomainPandC(
+      const std::vector<IterDomain*>& producer,
+      const std::vector<IterDomain*>& consumer);
+
+  // Create a map between producer root IterDomains and consumer root
+  // IterDomains.
+  static std::vector<std::pair<IterDomain*, IterDomain*>> mapRootPandC(
+      const TensorDomain* producer,
+      const TensorDomain* consumer);
+
+  // Create a map from consumer root IterDomains -> producer root IterDomains.
+  // Only those root consumer IDs present in consumer_root_dims_to_map
+  // will be attempted to map to their corresponding producer IDs.
+  static std::unordered_map<IterDomain*, IterDomain*> mapRootCtoP(
+      const TensorDomain* consumer,
+      const TensorDomain* producer,
+      const std::unordered_set<IterDomain*>& consumer_root_dims_to_map);
+
+  static std::unordered_map<IterDomain*, IterDomain*> mapRootCtoP(
+      const TensorDomain* consumer,
+      const TensorDomain* producer) {
+    return mapRootCtoP(
+        consumer,
+        producer,
+        std::unordered_set<IterDomain*>(
+            consumer->getRootDomain().begin(),
+            consumer->getRootDomain().end()));
+  }
+
+  // Create a map from producer root IterDomains -> consumer root IterDomains.
+  // Only those root producer IDs present in producer_maybe_rfactor_dims_to_map
+  // will be attempted to map to their corresponding consumer IDs.
+  static std::unordered_map<IterDomain*, IterDomain*> mapRootPtoC(
+      const TensorDomain* producer,
+      const TensorDomain* consumer,
+      const std::unordered_set<IterDomain*>&
+          producer_maybe_rfactor_dims_to_map);
+
+  static std::unordered_map<IterDomain*, IterDomain*> mapRootPtoC(
+      const TensorDomain* producer,
+      const TensorDomain* consumer) {
+    auto p_root = producer->getMaybeRFactorDomain();
+    return mapRootPtoC(
+        producer,
+        consumer,
+        std::unordered_set<IterDomain*>(p_root.begin(), p_root.end()));
+  }
 
   // pair is in order where second is the consumer of first
   std::pair<TensorDomain*, TensorDomain*> rFactor(const std::vector<int>& axes);
@@ -682,20 +550,23 @@ class TORCH_CUDA_CU_API TensorDomain : public Val {
   std::vector<IterDomain*> no_reduction_domain_;
   const std::vector<IterDomain*> rfactor_domain_;
   const std::vector<bool> contiguity_;
-  bool has_nontrivial_reduction_;
 };
 
-//! Representation a split on an IterDomain by "factor"
-//! inner_split dictates if the factor section of the split should be inside the
-//! remainer or outside.
+/*
+ * Representation a split on an IterDomain by "factor"
+ * TODO: Implement split by nparts
+ */
 class TORCH_CUDA_CU_API Split : public Expr {
  public:
-  Split(
-      IterDomain* outer,
-      IterDomain* inner,
-      IterDomain* in,
-      Val* factor,
-      bool inner_split = true);
+  ~Split() override = default;
+
+  Split(const Split& other) = delete;
+  Split& operator=(const Split& other) = delete;
+
+  Split(Split&& other) = delete;
+  Split& operator=(Split&& other) = delete;
+
+  Split(IterDomain* _outer, IterDomain* _inner, IterDomain* _in, Val* _factor);
 
   Split(const Split* src, IrCloner* ir_cloner);
 
@@ -711,34 +582,35 @@ class TORCH_CUDA_CU_API Split : public Expr {
   Val* factor() const {
     return factor_;
   }
-
-  bool innerSplit() const {
-    return inner_split_;
-  }
-
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const Split* const other) const;
 
  private:
   IterDomain* const outer_ = nullptr;
   IterDomain* const inner_ = nullptr;
   IterDomain* const in_ = nullptr;
   Val* const factor_ = nullptr;
-  bool inner_split_ = true;
 };
 
-//! Merge the IterDomains outer and inner into one domain, outer and inner
-//! dictate which will be traversed first (inner). Both IterDomains must be of
-//! the same iter or reduction type, as well as the same parallelization
-//! strategy if there is one
-//!
-//! \todo Should this be a unary op type?
-//!
+/*
+ * Merge the IterDomains outer and inner into one domain, outer and inner
+ * dictate which will be traversed first (inner). Both IterDomains must be of
+ * the same iter or reduction type, as well as the same parallelization strategy
+ * if there is one.
+ * TODO: Should this be a unary op type?
+ */
 class TORCH_CUDA_CU_API Merge : public Expr {
  public:
-  Merge(IterDomain* out, IterDomain* outer, IterDomain* inner);
+  ~Merge() override = default;
+  Merge(IterDomain* _out, IterDomain* _outer, IterDomain* _inner);
 
   Merge(const Merge* src, IrCloner* ir_cloner);
 
+  Merge(const Merge& other) = delete;
+  Merge& operator=(const Merge& other) = delete;
+
+  Merge(Merge&& other) = delete;
+  Merge& operator=(Merge&& other) = delete;
+
   IterDomain* out() const {
     return out_;
   }
@@ -749,7 +621,7 @@ class TORCH_CUDA_CU_API Merge : public Expr {
     return inner_;
   }
 
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const Merge* const other) const;
 
  private:
   IterDomain* const out_ = nullptr;
@@ -757,40 +629,50 @@ class TORCH_CUDA_CU_API Merge : public Expr {
   IterDomain* const inner_ = nullptr;
 };
 
-//! Integer value which has a special name
-//!
-//! These could be:
-//! - threadIdx.x
-//! - blockIdx.y
-//! - blockDim.z
-//! - T3.stride[2]
-//!
+/*
+ * Integer value which has a special name. These could be:
+ * - threadIdx.x
+ * - blockIdx.y
+ * - blockDim.z
+ * - T3.stride[2]
+ */
 class TORCH_CUDA_CU_API NamedScalar : public Val {
  public:
+  ~NamedScalar() override = default;
+  NamedScalar() = delete;
+
   // NOLINTNEXTLINE(modernize-pass-by-value)
-  NamedScalar(std::string name, DataType dtype)
-      : Val(ValType::NamedScalar, dtype), name_(name) {}
+  NamedScalar(std::string _name, DataType dtype)
+      : Val(ValType::NamedScalar, dtype), name_(_name) {}
 
   NamedScalar(const NamedScalar* src, IrCloner* ir_cloner);
 
+  NamedScalar(const NamedScalar& other) = delete;
+  NamedScalar& operator=(const NamedScalar& other) = delete;
+
+  NamedScalar(NamedScalar&& other) = delete;
+  NamedScalar& operator=(NamedScalar&& other) = delete;
+
   const std::string& name() const {
     return name_;
   }
 
-  bool sameAs(const Statement* other) const override;
+  bool sameAs(const NamedScalar* const other) const {
+    return other->name().compare(name()) == 0;
+  }
 
-  //! Return the named scalar extent of a parallel dimension (e.g. blockDim.x)
+  // Return the named scalar extent of a parallel dimension (e.g. blockDim.x)
   static NamedScalar* getParallelDim(ParallelType p_type);
 
-  //! Return the named scalar index of a parallel dimension (e.g. threadIdx.x)
+  // Return the named scalar index of a parallel dimension (e.g. threadIdx.x)
   static NamedScalar* getParallelIndex(ParallelType p_type);
 
-  //! Return the parallel type of this NamedScalar if it is an extent of a
-  //! parallel dimension
+  // Return the parallel type of this NamedScalar if it is an extent of a
+  // parallel dimension
   c10::optional<ParallelType> getParallelDim() const;
 
-  //! Return the parallel type of this NamedScalar if it is an index of a
-  //! parallel dimension
+  // Return the parallel type of this NamedScalar if it is an index of a
+  // parallel dimension
   c10::optional<ParallelType> getParallelIndex() const;
 
  private:
index d5ae0cb..0259383 100644 (file)
@@ -1,10 +1,8 @@
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
 
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
 #include <torch/csrc/jit/codegen/cuda/lower_utils.h>
 
 #include <c10/util/irange.h>
@@ -65,31 +63,33 @@ void IrPrinter::handle(const TensorDomain* td) {
 
 void IrPrinter::handle(const TensorView* tv) {
   if (tv->nDims() == 0) {
-    os_ << typePrefix(tv->getDataType().value()) << tv->name();
-  } else {
-    os_ << "T" << tv->name();
-    switch (tv->getMemoryType()) {
-      case MemoryType::Global:
-        os_ << "_g";
+    switch (tv->getDataType().value()) {
+      case DataType::Bool:
+        os_ << "b";
+        break;
+      case DataType::Float:
+        os_ << "f";
         break;
-      case MemoryType::Shared:
-        os_ << "_s";
+      case DataType::Half:
+        os_ << "h";
         break;
-      case MemoryType::Local:
-        os_ << "_l";
+      case DataType::Int:
+        os_ << "i";
         break;
+      default:
+        TORCH_INTERNAL_ASSERT(
+            false, "Did not recognize type ", tv->getDataType().value());
     }
+    os_ << tv->name();
+
+  } else {
+    os_ << "T" << tv->name();
     handle(tv->domain());
 
-    if (tv->getComputeAtPosition() > 0) {
-      os_ << " ca_pos( ";
-      os_ << tv->getComputeAtPosition();
-      os_ << " )";
-    }
-    if (tv->getMaxProducerPosition() > 0) {
-      os_ << " produce_pos( ";
-      os_ << tv->getMaxProducerPosition();
-      os_ << ")";
+    if (tv->getComputeAtView() != nullptr) {
+      os_ << " compute_at( ";
+      os_ << "T" << tv->getComputeAtView()->name();
+      os_ << ", " << tv->getRelativeComputeAtAxis() << " )";
     }
   }
 }
@@ -110,9 +110,9 @@ void IrPrinter::handle(const IterDomain* id) {
 }
 
 void IrPrinter::handle(const Bool* b) {
-  if (print_inline_ && b->definition() != nullptr) {
+  if (print_inline_ && FusionGuard::getCurFusion()->origin(b) != nullptr) {
     os_ << "( ";
-    handle(b->definition());
+    handle(FusionGuard::getCurFusion()->origin(b));
     os_ << " )";
     return;
   }
@@ -124,27 +124,42 @@ void IrPrinter::handle(const Bool* b) {
   }
 }
 
-void IrPrinter::handle(const Double* d) {
-  if (print_inline_ && d->definition() != nullptr) {
+void IrPrinter::handle(const Float* f) {
+  if (print_inline_ && FusionGuard::getCurFusion()->origin(f) != nullptr) {
     os_ << "( ";
-    handle(d->definition());
+    handle(FusionGuard::getCurFusion()->origin(f));
     os_ << " )";
     return;
   }
 
-  if (d->isSymbolic()) {
-    os_ << "d" << d->name();
+  if (f->isSymbolic()) {
+    os_ << "f" << f->name();
   } else {
-    os_ << "double("
+    os_ << "float("
         << std::setprecision(
-               std::numeric_limits<Double::ScalarType>::max_digits10)
-        << *(d->value()) << ")";
+               std::numeric_limits<Float::ScalarType>::max_digits10)
+        << *(f->value()) << ")";
+  }
+}
+
+void IrPrinter::handle(const Half* h) {
+  if (print_inline_ && FusionGuard::getCurFusion()->origin(h) != nullptr) {
+    os_ << "( ";
+    handle(FusionGuard::getCurFusion()->origin(h));
+    os_ << " )";
+    return;
+  }
+
+  if (h->isSymbolic()) {
+    os_ << "h" << h->name();
+  } else {
+    os_ << "__float2half(" << *(h->value()) << ")";
   }
 }
 
 void IrPrinter::handle(const Int* i) {
   if (print_inline_) {
-    if (auto def = i->definition()) {
+    if (auto def = FusionGuard::getCurFusion()->origin(i)) {
       os_ << "( ";
       handle(def);
       os_ << " )";
@@ -163,8 +178,45 @@ void IrPrinter::handle(const NamedScalar* i) {
   os_ << i->name();
 }
 
+void IrPrinter::handle(const kir::Bool* b) {
+  os_ << "kir::Bool (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::Float* f) {
+  os_ << "kir::Float (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::Half* h) {
+  os_ << "kir::Half (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::Int* i) {
+  os_ << "kir::Int (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::NamedScalar*) {
+  os_ << "kir::NamedScalar (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::TensorIndex*) {
+  os_ << "kir::TensorIndex (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::IterDomain*) {
+  os_ << "kir::IterDomain (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::TensorDomain*) {
+  os_ << "kir::TensorDomain (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::TensorView*) {
+  os_ << "kir::TensorView (use kir::toString() to print Kernel IR nodes)";
+}
+
 static bool isTV(const Val* val) {
-  return val->getValType().value() == ValType::TensorView;
+  return val->getValType().value() == ValType::TensorView ||
+      val->getValType().value() == ValType::TensorIndex;
 }
 
 // Check if we're a TensorView op that we can generate code for.
@@ -187,36 +239,23 @@ void IrPrinter::handle(const UnaryOp* uop) {
     checkInlineable(uop);
   }
 
-  auto op_type = uop->getUnaryOpType();
-
-  if (auto inline_uop = inline_op_str(op_type)) {
+  if (auto inline_uop = inline_op_str(uop->getUnaryOpType())) {
     os_ << inline_uop.value();
     handle(uop->in());
   } else {
-    if (op_type == UnaryOpType::Cast) {
+    if (uop->getUnaryOpType() == UnaryOpType::Cast) {
       c10::optional<std::string> cast_str = cast_func_str(std::make_pair(
           uop->in()->getDataType().value(), uop->out()->getDataType().value()));
       TORCH_INTERNAL_ASSERT(cast_str != c10::nullopt, "Unsupported Cast");
       os_ << cast_str.value();
     } else {
-      if (alsoBooleanOperator(op_type) &&
-          uop->out()->getDataType().value() == DataType::Bool) {
-        os_ << stringifyBooleanOp(op_type);
-      } else {
-        os_ << op_type;
-      }
-      if (uop->out()->getDataType().value() == DataType::Float &&
-          needFloatSuffix(op_type)) {
-        os_ << "f";
-      }
+      os_ << uop->getUnaryOpType();
     }
-    if (op_type == UnaryOpType::RandLike) {
-      os_ << "(";
-      handle(uop->in());
-    } else {
-      os_ << "(";
+    os_ << "(";
+    if (uop->getUnaryOpType() == UnaryOpType::RandLike)
+      os_ << "rnd";
+    else
       handle(uop->in());
-    }
     os_ << ")";
   }
 
@@ -245,8 +284,7 @@ void IrPrinter::handle(const BinaryOp* bop) {
     checkInlineable(bop);
   }
 
-  auto op_type = bop->getBinaryOpType();
-  if (auto inline_bop = inline_op_str(op_type)) {
+  if (auto inline_bop = inline_op_str(bop->getBinaryOpType())) {
     handle(bop->lhs());
     if (istvop) {
       os_ << "\n";
@@ -255,17 +293,7 @@ void IrPrinter::handle(const BinaryOp* bop) {
     os_ << " " << inline_bop.value() << " ";
     handle(bop->rhs());
   } else {
-    if (alsoBooleanOperator(op_type) &&
-        bop->out()->getDataType().value() == DataType::Bool) {
-      os_ << stringifyBooleanOp(op_type);
-    } else {
-      os_ << op_type;
-    }
-    if (bop->out()->getDataType().value() == DataType::Float &&
-        needFloatSuffix(op_type)) {
-      os_ << "f";
-    }
-    os_ << "(";
+    os_ << bop->getBinaryOpType() << "(";
     handle(bop->lhs());
     if (istvop) {
       os_ << "\n";
@@ -324,73 +352,62 @@ void IrPrinter::handle(const TernaryOp* top) {
     os_ << ";\n";
 }
 
+void IrPrinter::handle(const kir::UnaryOp* uop) {
+  os_ << "kir::UnaryOp (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::BinaryOp* bop) {
+  os_ << "kir::BinaryOp (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::TernaryOp* top) {
+  os_ << "kir::TernaryOp (use kir::toString() to print Kernel IR nodes)";
+}
+
 void IrPrinter::handle(const ReductionOp* rop) {
+  TORCH_CHECK(rop->out()->getValType() != ValType::TensorIndex);
   indent();
   os_ << rop->out() << " = reduction( " << rop->in()
       << ", op = " << rop->getReductionOpType()
       << ", initial value = " << rop->init() << " )\n";
 }
 
-void IrPrinter::handle(const WelfordOp* wop) {
-  indent();
-  os_ << wop->outAvg() << "(Avg),\n"
-      << wop->outVar() << "(Var),\n"
-      << wop->outN() << "(Count)"
-      << "\n = Welford ( ";
-  if (wop->singleValue()) {
-    os_ << wop->inAvg() << "(Avg), ";
-  } else {
-    os_ << wop->inAvg() << "(Avg)\n  " << wop->inVar() << "(Var)\n  "
-        << wop->inN() << "(Count)";
-  }
-  if (wop->hasInit()) {
-    os_ << "\n  initial value = " << wop->initAvg() << "(Avg)\n  "
-        << wop->initVar() << "(Var)\n  " << wop->initN() << "(N)";
-  }
-  os_ << " )\n";
+void IrPrinter::handle(const kir::ReductionOp* rop) {
+  os_ << "kir::ReductionOp (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::GridReduction* gr) {
+  os_ << "kir::GridReduction (use kir::toString() to print Kernel IR nodes)";
 }
 
 void IrPrinter::handle(const BroadcastOp* bop) {
+  TORCH_CHECK(bop->out()->getValType() != ValType::TensorIndex);
   indent();
   os_ << bop->out() << " = broadcast( " << bop->in() << " )\n";
 }
 
-void IrPrinter::handle(const TransposeOp* top) {
-  indent();
-  os_ << top->out() << " = transpose( " << top->in() << " )\n";
+void IrPrinter::handle(const kir::BroadcastOp*) {
+  os_ << "kir::BroadcastOp (use kir::toString() to print Kernel IR nodes)";
 }
 
-void IrPrinter::handle(const ShiftOp* sop) {
-  indent();
-  os_ << sop->out() << " = shift( " << sop->in() << ", {" << sop->offsets()
-      << "} )\n";
+void IrPrinter::handle(const kir::ForLoop* fl) {
+  os_ << "kir::ForLoop (use kir::toString() to print Kernel IR nodes)";
 }
 
-void IrPrinter::handle(const GatherOp* op) {
-  indent();
-  os_ << op->out() << " = gather( " << op->in() << ", {";
-  bool no_comma = true;
-  for (const auto& s : op->windowShape()) {
-    if (!no_comma) {
-      os_ << ", ";
-    }
-    os_ << s;
-    no_comma = false;
-  }
-  os_ << "}, {";
-  no_comma = true;
-  for (const auto& pad : op->padWidth()) {
-    if (!no_comma) {
-      os_ << ", ";
-    }
-    os_ << "{" << pad[0] << ", " << pad[1] << "}";
-    no_comma = false;
-  }
-  os_ << "} )\n";
+void IrPrinter::handle(const kir::IfThenElse* ite) {
+  os_ << "kir::IfThenElse (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::Allocate* a) {
+  os_ << "kir::Allocate (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::Sync* a) {
+  os_ << "kir::Sync (use kir::toString() to print Kernel IR nodes)";
 }
 
 void IrPrinter::handle(const Split* s) {
-  os_ << (s->innerSplit() ? "Split: " : "Outer split: ");
+  os_ << "Split: ";
   handle(s->in());
   os_ << " by factor " << s->factor() << " -> ";
   handle(s->outer());
@@ -409,37 +426,6 @@ void IrPrinter::handle(const Merge* m) {
   os_ << "\n";
 }
 
-void IrTransformPrinter::handle(Fusion* f) {
-  auto all_vals = f->usedMathVals();
-
-  for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
-    IrPrinter::handle(tv);
-    os() << "\n";
-    printTransforms(tv);
-  }
-}
-
-void IrTransformPrinter::printTransforms(TensorView* tv) {
-  auto root_domain = tv->getMaybeRFactorDomain();
-  auto all_exp = DependencyCheck::getAllExprsBetween(
-      {root_domain.begin(), root_domain.end()},
-      {tv->domain()->domain().begin(), tv->domain()->domain().end()});
-
-  os() << " root domain : (";
-  for (size_t root_idx = 0; root_idx < root_domain.size(); root_idx++) {
-    IrPrinter::handle(root_domain[root_idx]);
-    if (root_idx + 1 < root_domain.size()) {
-      os() << ",";
-    }
-  }
-  os() << ")\n";
-
-  for (auto exp : all_exp) {
-    os() << "    ";
-    IrPrinter::handle(exp);
-  }
-}
-
 std::ostream& operator<<(std::ostream& os, const Statement* stmt) {
   IrPrinter p(os);
   p.handle(stmt);
index fde0fd2..433779b 100644 (file)
@@ -61,7 +61,8 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch {
   void handle(const IterDomain*) override;
 
   void handle(const Bool*) override;
-  void handle(const Double*) override;
+  void handle(const Float*) override;
+  void handle(const Half*) override;
   void handle(const Int*) override;
   void handle(const NamedScalar*) override;
 
@@ -69,11 +70,30 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch {
   void handle(const BinaryOp*) override;
   void handle(const TernaryOp*) override;
   void handle(const ReductionOp*) override;
-  void handle(const WelfordOp*) override;
   void handle(const BroadcastOp*) override;
-  void handle(const TransposeOp*) override;
-  void handle(const ShiftOp*) override;
-  void handle(const GatherOp*) override;
+
+  void handle(const kir::Bool*) override;
+  void handle(const kir::Float*) override;
+  void handle(const kir::Half*) override;
+  void handle(const kir::Int*) override;
+  void handle(const kir::NamedScalar*) override;
+
+  void handle(const kir::TensorIndex*) override;
+  void handle(const kir::IterDomain*) override;
+  void handle(const kir::TensorDomain*) override;
+  void handle(const kir::TensorView*) override;
+
+  void handle(const kir::UnaryOp*) override;
+  void handle(const kir::BinaryOp*) override;
+  void handle(const kir::TernaryOp*) override;
+  void handle(const kir::ReductionOp*) override;
+  void handle(const kir::BroadcastOp*) override;
+
+  void handle(const kir::GridReduction*) override;
+  void handle(const kir::ForLoop*) override;
+  void handle(const kir::IfThenElse*) override;
+  void handle(const kir::Allocate*) override;
+  void handle(const kir::Sync*) override;
 
   void handle(const Split*) override;
   void handle(const Merge*) override;
@@ -85,11 +105,6 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch {
     print_inline_ = prev;
   }
 
- protected:
-  std::ostream& os() {
-    return os_;
-  }
-
  private:
   std::ostream& os_;
   bool print_inline_ = false;
@@ -103,6 +118,28 @@ TORCH_CUDA_CU_API std::ostream& operator<<(
 TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream& os, Fusion* f);
 TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream& os, Fusion& f);
 
+// TODO(kir): catch accidental << printing of Kernel IR nodes
+// (use kir::toString(node) instead)
+std::ostream& operator<<(std::ostream& os, const kir::Bool*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::Float*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::Half*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::Int*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::NamedScalar*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::TensorIndex*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::IterDomain*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::TensorDomain*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::TensorView*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::UnaryOp*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::BinaryOp*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::TernaryOp*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::ReductionOp*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::BroadcastOp*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::GridReduction*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::ForLoop*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::IfThenElse*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::Allocate*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::Sync*) = delete;
+
 } // namespace cuda
 } // namespace fuser
 } // namespace jit
index 30bbc7d..5dbd691 100644 (file)
@@ -1,11 +1,9 @@
 #include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
 #include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
 #include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
 #include <torch/csrc/jit/codegen/cuda/ir_utils.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
 #include <torch/csrc/jit/codegen/cuda/transform_iter.h>
 #include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
 
@@ -20,9 +18,9 @@ namespace cuda {
 
 namespace {
 
-class ScalarCheck : OptInConstDispatch {
+class ScalarCheck : OptInDispatch {
  public:
-  static bool sameAs(const Val* v1, const Val* v2) {
+  static bool sameAs(Val* v1, Val* v2) {
     if (v1 == v2)
       return true;
 
@@ -37,29 +35,33 @@ class ScalarCheck : OptInConstDispatch {
   }
 
  private:
-  void handle(const Bool* b) override {
+  void handle(Bool* b) override {
     same_ = v1_->as<Bool>()->sameAs(v2_->as<Bool>());
   }
 
-  void handle(const Double* d) override {
-    same_ = v1_->as<Double>()->sameAs(v2_->as<Double>());
+  void handle(Float* f) override {
+    same_ = v1_->as<Float>()->sameAs(v2_->as<Float>());
   }
 
-  void handle(const Int* i) override {
+  void handle(Half* h) override {
+    same_ = v1_->as<Half>()->sameAs(v2_->as<Half>());
+  }
+
+  void handle(Int* i) override {
     same_ = v1_->as<Int>()->sameAs(v2_->as<Int>());
   }
 
-  void handle(const NamedScalar* ns) override {
+  void handle(NamedScalar* ns) override {
     same_ = v1_->as<NamedScalar>()->sameAs(v2_->as<NamedScalar>());
   }
 
-  ScalarCheck(const Val* _v1, const Val* _v2) : v1_(_v1), v2_(_v2) {
-    OptInConstDispatch::handle(v1_);
+  ScalarCheck(Val* _v1, Val* _v2) : v1_(_v1), v2_(_v2) {
+    OptInDispatch::handle(v1_);
   }
 
  private:
-  const Val* v1_ = nullptr;
-  const Val* v2_ = nullptr;
+  Val* v1_ = nullptr;
+  Val* v2_ = nullptr;
   bool same_ = false;
 };
 
@@ -72,57 +74,43 @@ bool areEqualScalars(Val* v1, Val* v2) {
 Bool::Bool(const Bool* src, IrCloner* ir_cloner)
     : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
 
-bool Bool::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<Bool>()) {
-    return false;
-  }
-  const auto other_bool = other->as<Bool>();
-  if (isConst() && other_bool->isConst()) {
-    return *value() == *(other_bool->value());
-  }
-  return false;
+bool Bool::sameAs(const Bool* const other) const {
+  if (isConst() && other->isConst())
+    return *value() == *(other->value());
+  return this == other;
 }
 
-Double::Double(const Double* src, IrCloner* ir_cloner)
+Float::Float(const Float* src, IrCloner* ir_cloner)
     : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
 
-bool Double::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<Double>()) {
-    return false;
-  }
-  const auto other_double = other->as<Double>();
-  if (isConst() && other_double->isConst())
-    return *value() == *(other_double->value());
-  return false;
+bool Float::sameAs(const Float* const other) const {
+  if (isConst() && other->isConst())
+    return *value() == *(other->value());
+  return this == other;
+}
+
+Half::Half(const Half* src, IrCloner* ir_cloner)
+    : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
+
+bool Half::sameAs(const Half* const other) const {
+  if (isConst() && other->isConst())
+    return *value() == *(other->value());
+  return this == other;
 }
 
 Int::Int(const Int* src, IrCloner* ir_cloner)
     : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
 
-bool Int::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<Int>()) {
-    return false;
-  }
-  const auto other_int = other->as<Int>();
-  if (isConst() && other_int->isConst()) {
-    return *value() == *(other_int->value());
-  }
-  return false;
+bool Int::sameAs(const Int* const other) const {
+  if (isConst() && other->isConst())
+    return *value() == *(other->value());
+  return this == other;
 }
 
-UnaryOp::UnaryOp(UnaryOpType type, Val* out, Val* in)
-    : Expr(ExprType::UnaryOp), unary_op_type_{type}, out_{out}, in_{in} {
-  addOutput(out);
-  addInput(in);
+UnaryOp::UnaryOp(UnaryOpType _type, Val* _out, Val* _in)
+    : Expr(ExprType::UnaryOp), unary_op_type_{_type}, out_{_out}, in_{_in} {
+  addOutput(_out);
+  addInput(_in);
   name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
@@ -132,28 +120,21 @@ UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner)
       out_(ir_cloner->clone(src->out_)),
       in_(ir_cloner->clone(src->in_)) {}
 
-bool UnaryOp::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<UnaryOp>()) {
+bool UnaryOp::sameAs(const UnaryOp* const other) const {
+  if (type() != other->type())
     return false;
-  }
-  const auto other_op = other->as<UnaryOp>();
-  if (getUnaryOpType() != other_op->getUnaryOpType())
-    return false;
-  return Expr::sameAs(other);
+  return as<Expr>()->sameAs(other);
 }
 
-BinaryOp::BinaryOp(BinaryOpType type, Val* out, Val* lhs, Val* rhs)
+BinaryOp::BinaryOp(BinaryOpType _type, Val* _out, Val* _lhs, Val* _rhs)
     : Expr(ExprType::BinaryOp),
-      binary_op_type_{type},
-      out_{out},
-      lhs_{lhs},
-      rhs_{rhs} {
-  addOutput(out);
-  addInput(lhs);
-  addInput(rhs);
+      binary_op_type_{_type},
+      out_{_out},
+      lhs_{_lhs},
+      rhs_{_rhs} {
+  addOutput(_out);
+  addInput(_lhs);
+  addInput(_rhs);
   name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
@@ -164,30 +145,30 @@ BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner)
       lhs_(ir_cloner->clone(src->lhs_)),
       rhs_(ir_cloner->clone(src->rhs_)) {}
 
-bool BinaryOp::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<BinaryOp>()) {
+bool BinaryOp::sameAs(const BinaryOp* other) const {
+  if (getBinaryOpType() != other->getBinaryOpType())
     return false;
-  }
-  const auto other_op = other->as<BinaryOp>();
-  if (getBinaryOpType() != other_op->getBinaryOpType())
+  if (!(lhs()->sameAs(other->lhs()) && rhs()->sameAs(other->rhs())))
     return false;
-  return Expr::sameAs(other);
+  return true;
 }
 
-TernaryOp::TernaryOp(TernaryOpType type, Val* out, Val* in1, Val* in2, Val* in3)
+TernaryOp::TernaryOp(
+    TernaryOpType _type,
+    Val* _out,
+    Val* _in1,
+    Val* _in2,
+    Val* _in3)
     : Expr(ExprType::TernaryOp),
-      ternary_op_type_{type},
-      out_{out},
-      in1_{in1},
-      in2_{in2},
-      in3_{in3} {
-  addOutput(out);
-  addInput(in1);
-  addInput(in2);
-  addInput(in3);
+      ternary_op_type_{_type},
+      out_{_out},
+      in1_{_in1},
+      in2_{_in2},
+      in3_{_in3} {
+  addOutput(_out);
+  addInput(_in1);
+  addInput(_in2);
+  addInput(_in3);
   name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
@@ -199,227 +180,117 @@ TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner)
       in2_(ir_cloner->clone(src->in2_)),
       in3_(ir_cloner->clone(src->in3_)) {}
 
-bool TernaryOp::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<TernaryOp>()) {
+bool TernaryOp::sameAs(const TernaryOp* other) const {
+  if (getTernaryOpType() != other->getTernaryOpType())
     return false;
-  }
-  const auto other_op = other->as<TernaryOp>();
-  if (getTernaryOpType() != other_op->getTernaryOpType())
+  if (!(in1()->sameAs(other->in1()) && in2()->sameAs(other->in2()) &&
+        in3()->sameAs(other->in3())))
     return false;
-  return Expr::sameAs(other);
+  return true;
 }
 
-BroadcastOp::BroadcastOp(Val* out, Val* in, std::vector<bool> is_broadcast_dims)
-    : Expr(ExprType::BroadcastOp),
-      out_(out),
-      in_(in),
-      is_broadcast_dims_(std::move(is_broadcast_dims)) {
-  // clang-tidy complains about out_ that it may be null.
-  TORCH_INTERNAL_ASSERT(out_ != nullptr);
-  TORCH_INTERNAL_ASSERT(in_ != nullptr);
-
-  auto out_type = out->getValType().value();
-  auto in_type = in->getValType().value();
+BroadcastOp::BroadcastOp(Val* _out, Val* _in)
+    : Expr(ExprType::BroadcastOp), out_(_out), in_(_in) {
+  auto out_type = _out->getValType().value();
+  auto in_type = _in->getValType().value();
 
   TORCH_INTERNAL_ASSERT(
       out_type == ValType::TensorView && in_type == ValType::TensorView,
       "Cannot braodcast a non-tensor object.");
 
-  addOutput(out);
-  addInput(in);
-  name_ = FusionGuard::getCurFusion()->registerExpr(this);
-
   // This is a generic check that root dims of a consumer and producer match.
   // Maybe we shouldn't relegate it to this constructor.
-  const auto c_tv = out_->as<TensorView>();
-  const auto p_tv = in_->as<TensorView>();
+  const auto c_tv = out()->as<TensorView>();
+  const auto p_tv = in()->as<TensorView>();
 
   const auto& c_root = c_tv->getRootDomain();
   const auto& p_root = p_tv->getMaybeRFactorDomain();
 
-  const auto root_p2c =
-      PairwiseRootDomainMap(p_tv, c_tv)
-          .mapProducerToConsumer(p_tv->domain(), c_tv->domain());
-
-  for (auto id : p_root) {
-    if (root_p2c.find(id) == root_p2c.end()) {
-      TORCH_INTERNAL_ASSERT(
-          id->isReduction(),
-          "Invalid broadcast op: ",
-          id,
-          ". Non-reduction input dim does't match to output.");
-    }
-  }
+  const auto root_p2c = TensorDomain::mapDomainPandC(p_root, c_root);
+
+  std::vector<bool> c_mapped(c_root.size(), false);
+  std::vector<bool> p_mapped(p_root.size(), false);
 
-  std::unordered_set<IterDomain*> c_mapped;
   for (auto pair_entry : root_p2c) {
-    c_mapped.insert(pair_entry.second);
+    auto p_i = pair_entry.first;
+    p_mapped[p_i] = true;
+    auto c_i = pair_entry.second;
+    c_mapped[c_i] = true;
   }
 
-  for (size_t i = 0; i < c_root.size(); ++i) {
-    const auto c_id = c_root[i];
-    if (c_mapped.find(c_id) != c_mapped.end()) {
-      continue;
+  bool bad_mismatch = false;
+
+  for (const auto i : c10::irange(c_root.size())) {
+    if (!c_mapped[i]) {
+      if (!c_root[i]->isBroadcast()) {
+        bad_mismatch = true;
+      }
+    }
+  }
+
+  for (const auto i : c10::irange(p_root.size())) {
+    if (!p_mapped[i]) {
+      if (!p_root[i]->isReduction()) {
+        bad_mismatch = true;
+      }
     }
-    TORCH_INTERNAL_ASSERT(
-        c_id->isBroadcast() && is_broadcast_dims_[i],
-        "Invalid broadcast op: ",
-        c_id,
-        ". Non-broadcasted output dim isn't matched from input.");
   }
+
+  TORCH_INTERNAL_ASSERT(
+      !bad_mismatch,
+      "Invalid broadcast op. Non-broadcasted dims don't match from input to output.");
+
+  addOutput(_out);
+  addInput(_in);
+  name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
 BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner)
     : Expr(src, ir_cloner),
       out_(ir_cloner->clone(src->out_)),
-      in_(ir_cloner->clone(src->in_)),
-      is_broadcast_dims_(src->is_broadcast_dims_) {}
+      in_(ir_cloner->clone(src->in_)) {}
 
-bool BroadcastOp::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<BroadcastOp>()) {
-    return false;
-  }
-  const auto other_op = other->as<BroadcastOp>();
-  if (getBroadcastDimFlags() != other_op->getBroadcastDimFlags()) {
-    return false;
-  }
-  return Expr::sameAs(other);
+bool BroadcastOp::sameAs(const BroadcastOp* const other) const {
+  return other->in() == in() && other->out() == out();
 }
 
 ReductionOp::ReductionOp(
-    BinaryOpType reduction_op_type,
-    Val* init,
-    Val* out,
-    Val* in)
+    BinaryOpType _reduction_op_type,
+    Val* _init,
+    Val* _out,
+    Val* _in)
     : Expr(ExprType::ReductionOp),
-      reduction_op_type_(reduction_op_type),
-      init_(init),
-      out_(out),
-      in_(in) {
-  TORCH_CHECK(out->getValType().value() == ValType::TensorView);
-
-  TORCH_INTERNAL_ASSERT(
-      in->getValType() == ValType::TensorView &&
-          out->getValType() == ValType::TensorView,
-      "Reduction operation was created that does not have tensor inputs and outputs.");
-
-  TORCH_INTERNAL_ASSERT(
-      TensorDomain::noReductions(in->as<TensorView>()->getMaybeRFactorDomain())
-              .size() == out->as<TensorView>()->getRootDomain().size(),
-      "Reduction operation created with mismatched domains.");
-
-  TORCH_INTERNAL_ASSERT(
-      init->isConstScalar(),
-      "Tried to create a reduction operation whith an initial value that isn't a constant.");
-
-  addOutput(out);
-  addInput(in);
-  name_ = FusionGuard::getCurFusion()->registerExpr(this);
-}
-
-WelfordOp::WelfordOp(
-    Val* out_avg,
-    Val* out_var,
-    Val* out_N,
-    Val* init_avg,
-    Val* init_var,
-    Val* init_N,
-    Val* in_avg,
-    Val* in_var,
-    Val* in_N)
-    : Expr(ExprType::WelfordOp),
-      out_avg_(out_avg),
-      out_var_(out_var),
-      out_N_(out_N),
-      init_avg_(init_avg),
-      init_var_(init_var),
-      init_N_(init_N),
-      in_avg_(in_avg),
-      in_var_(in_var),
-      in_N_(in_N) {
-  // Check output type
-  TORCH_INTERNAL_ASSERT(out_avg->getValType().value() == ValType::TensorView);
-  TORCH_INTERNAL_ASSERT(out_var->getValType().value() == ValType::TensorView);
-  TORCH_INTERNAL_ASSERT(out_N->getValType().value() == ValType::TensorView);
-
-  // check initial value
-  TORCH_INTERNAL_ASSERT(init_N->getValType().value() == ValType::Scalar);
-  if (!init_N->isZeroInt()) {
-    // when initial count is zero, no initial variance or average is needed
-    // initial value with a count of 1 is un-common enough that I'll push
-    // the responsibility of creating all-zero var tensors to the user
-    TORCH_INTERNAL_ASSERT(
-        init_avg && init_avg->getValType().value() == ValType::TensorView);
+      reduction_op_type_(_reduction_op_type),
+      init_(_init),
+      out_(_out),
+      in_(_in) {
+  if (_out->getValType().value() == ValType::TensorView) {
     TORCH_INTERNAL_ASSERT(
-        init_var && init_var->getValType().value() == ValType::TensorView);
-  }
+        _in->getValType() == ValType::TensorView &&
+            _out->getValType() == ValType::TensorView,
+        "Reduction operation was created that does not have tensor inputs and outputs.");
 
-  TORCH_INTERNAL_ASSERT(
-      in_avg && in_avg->getValType().value() == ValType::TensorView);
-  // check input
-  TORCH_INTERNAL_ASSERT(
-      in_N->getValType().value() == ValType::Scalar ||
-      in_N->getValType().value() == ValType::TensorView);
-  if (!in_N->isOneInt()) {
-    // when input is only one value, only the value is required through avg
-    // input the var part is implicitly 0 and codegen will handle that.
     TORCH_INTERNAL_ASSERT(
-        in_var && in_var->getValType().value() == ValType::TensorView);
-  }
-
-  addOutput(out_avg);
-  addOutput(out_var);
-  addOutput(out_N);
+        TensorDomain::noReductions(
+            _in->as<TensorView>()->getMaybeRFactorDomain())
+                .size() == _out->as<TensorView>()->getRootDomain().size(),
+        "Reduction operation created with mismatched domains.");
 
-  addInput(in_avg);
-  // Conditionally adding this input?
-  if (!in_N->isOneInt()) {
-    addInput(in_var);
+  } else {
+    TORCH_INTERNAL_ASSERT(
+        _in->getValType() == ValType::TensorIndex &&
+            _out->getValType() == ValType::TensorIndex,
+        "Reduction operation was created that does not have tensor inputs and outputs.");
   }
-  addInput(in_N);
+  TORCH_INTERNAL_ASSERT(
+      _init->isConstScalar(),
+      "Tried to create a reduction operation whith an initial value that isn't a constant.");
 
+  addOutput(_out);
+  addInput(_in);
   name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
-WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner)
-    : Expr(src, ir_cloner),
-      out_avg_(ir_cloner->clone(src->out_avg_)),
-      out_var_(ir_cloner->clone(src->out_var_)),
-      out_N_(ir_cloner->clone(src->out_N_)),
-      init_avg_(src->init_avg_ ? ir_cloner->clone(src->init_avg_) : nullptr),
-      init_var_(src->init_var_ ? ir_cloner->clone(src->init_var_) : nullptr),
-      init_N_(ir_cloner->clone(src->init_N_)),
-      in_avg_(ir_cloner->clone(src->in_avg_)),
-      in_var_(src->in_var_ ? ir_cloner->clone(src->in_var_) : nullptr),
-      in_N_(ir_cloner->clone(src->in_N_)) {}
-
-namespace {
-inline bool sameOptionalVal(Val* a, Val* b) {
-  return ((a == nullptr && b == nullptr)) || ((a && b) && (a->sameAs(b)));
-}
-} // namespace
-
-bool WelfordOp::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (auto other_wop = dynamic_cast<const WelfordOp*>(other)) {
-    return in_avg_->sameAs(other_wop->in_avg_) &&
-        sameOptionalVal(in_var_, other_wop->in_var_) &&
-        in_N_->sameAs(other_wop->in_N_) &&
-        sameOptionalVal(init_avg_, other_wop->init_avg_) &&
-        sameOptionalVal(init_var_, other_wop->init_var_) &&
-        init_N_->sameAs(other_wop->init_N_);
-  }
-  return false;
-}
-
 ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner)
     : Expr(src, ir_cloner),
       reduction_op_type_(src->reduction_op_type_),
@@ -427,250 +298,57 @@ ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner)
       out_(ir_cloner->clone(src->out_)),
       in_(ir_cloner->clone(src->in_)) {}
 
-bool ReductionOp::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<ReductionOp>()) {
-    return false;
-  }
-  const auto other_op = other->as<ReductionOp>();
-  // Note that init is not part of input vals, so it must be checked separately.
+bool ReductionOp::sameAs(const ReductionOp* other) const {
   return (
-      Expr::sameAs(other) &&
-      getReductionOpType() == other_op->getReductionOpType() &&
-      init()->sameAs(other_op->init()));
-}
-
-TransposeOp::TransposeOp(
-    TensorView* out,
-    TensorView* in,
-    std::vector<int> new2old)
-    : Expr(ExprType::TransposeOp),
-      out_(out),
-      in_(in),
-      new2old_(std::move(new2old)) {
-  // Sanity check of the input parameters. Maybe not necessary as they
-  // should be checked at function transpose.
-
-  TORCH_INTERNAL_ASSERT(
-      !in->hasRFactor(), "Transposing rFactor tensors is not supported.");
-
-  TORCH_INTERNAL_ASSERT(
-      TensorDomain::noReductions(in->getRootDomain()).size() ==
-      out->getRootDomain().size());
-
-  TORCH_INTERNAL_ASSERT(new2old_.size() == out->getRootDomain().size());
-
-  // Make sure the entries of new2old are unique and range from 0 to
-  // N-1, where N == new2old.size().
-  std::set<int> old_positions(new2old_.begin(), new2old_.end());
-  TORCH_INTERNAL_ASSERT(old_positions.size() == new2old_.size());
-  // old_positions is sorted, so the first entry must be 0.
-  TORCH_INTERNAL_ASSERT(
-      *(old_positions.begin()) == 0,
-      "Invalid new2old vector detected: ",
-      new2old_);
-  // The last entry must be N-1, since old_positions is sorted, starts
-  // with 0, and its length is N.
-  TORCH_INTERNAL_ASSERT(
-      *(old_positions.rbegin()) == (int)(new2old_.size() - 1),
-      "Invalid new2old vector detected: ",
-      new2old_);
-
-  addOutput(out);
-  addInput(in);
-  name_ = FusionGuard::getCurFusion()->registerExpr(this);
-}
-
-TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner)
-    : Expr(src, ir_cloner),
-      out_(ir_cloner->clone(src->out_)),
-      in_(ir_cloner->clone(src->in_)),
-      new2old_(src->new2old_) {}
-
-ShiftOp::ShiftOp(Val* out, Val* in, std::vector<int> offsets)
-    : Expr(ExprType::ShiftOp),
-      out_(out),
-      in_(in),
-      offsets_(std::move(offsets)) {
-  // clang-tidy complains about out_ that it may be null.
-  TORCH_INTERNAL_ASSERT(out_ != nullptr);
-  TORCH_INTERNAL_ASSERT(in_ != nullptr);
-
-  auto out_type = out->getValType().value();
-  auto in_type = in->getValType().value();
-
-  TORCH_INTERNAL_ASSERT(
-      out_type == ValType::TensorView && in_type == ValType::TensorView,
-      "Cannot shift a non-tensor object.");
-
-  TORCH_INTERNAL_ASSERT(
-      offsets_.size() ==
-          TensorDomain::noReductions(in_->as<TensorView>()->getRootDomain())
-              .size(),
-      "Invalid offset vector: ",
-      offsets_);
-
-  addOutput(out);
-  addInput(in);
-  name_ = FusionGuard::getCurFusion()->registerExpr(this);
-}
-
-ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner)
-    : Expr(src, ir_cloner),
-      out_(ir_cloner->clone(src->out_)),
-      in_(ir_cloner->clone(src->in_)),
-      offsets_(src->offsets_) {}
-
-bool ShiftOp::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<ShiftOp>()) {
-    return false;
-  }
-  const auto other_op = other->as<ShiftOp>();
-  if (offsets() != other_op->offsets()) {
-    return false;
-  }
-  return Expr::sameAs(other);
-}
-
-GatherOp::GatherOp(
-    Val* out,
-    Val* in,
-    std::vector<Int*> window_shape,
-    std::vector<std::vector<Int*>> pad_width)
-    : Expr(ExprType::GatherOp),
-      out_(out),
-      in_(in),
-      window_shape_(std::move(window_shape)),
-      pad_width_(std::move(pad_width)) {
-  // clang-tidy complains about out_ that it may be null.
-  TORCH_INTERNAL_ASSERT(out_ != nullptr);
-  TORCH_INTERNAL_ASSERT(in_ != nullptr);
-
-  auto out_type = out->getValType().value();
-  auto in_type = in->getValType().value();
-
-  TORCH_INTERNAL_ASSERT(
-      out_type == ValType::TensorView && in_type == ValType::TensorView,
-      "Cannot shift a non-tensor object.");
-
-  const auto ndims =
-      TensorDomain::noReductions(in_->as<TensorView>()->getRootDomain()).size();
-
-  TORCH_INTERNAL_ASSERT(
-      window_shape_.size() == ndims,
-      "Invalid window_shape vector: ",
-      window_shape_);
-  TORCH_INTERNAL_ASSERT(
-      pad_width_.size() == ndims, "Invalid pad_width vector: ", pad_width_);
-
-  for (const auto& pad : pad_width_) {
-    TORCH_INTERNAL_ASSERT(
-        pad.size() == 2, "Padding size for each axis must have two Int vals.");
-  }
-
-  addOutput(out);
-  addInput(in);
-  name_ = FusionGuard::getCurFusion()->registerExpr(this);
-}
-
-GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner)
-    : Expr(src, ir_cloner),
-      out_(ir_cloner->clone(src->out_)),
-      in_(ir_cloner->clone(src->in_)) {
-  std::transform(
-      src->window_shape_.begin(),
-      src->window_shape_.end(),
-      std::back_inserter(window_shape_),
-      [&ir_cloner](const auto& x) { return ir_cloner->clone(x); });
-  for (const auto& pad : src->pad_width_) {
-    std::vector<Int*> pad_clone;
-    std::transform(
-        pad.begin(),
-        pad.end(),
-        std::back_inserter(pad_clone),
-        [&ir_cloner](const auto& x) { return ir_cloner->clone(x); });
-    pad_width_.push_back(pad_clone);
-  }
-}
-
-bool GatherOp::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<GatherOp>()) {
-    return false;
-  }
-  const auto other_op = other->as<GatherOp>();
-  if (windowShape().size() != other_op->windowShape().size()) {
-    return false;
-  }
-  for (size_t i = 0; i < windowShape().size(); ++i) {
-    if (!windowShape()[i]->sameAs(other_op->windowShape()[i])) {
-      return false;
-    }
-  }
-  if (padWidth().size() != other_op->padWidth().size()) {
-    return false;
-  }
-  for (size_t i = 0; padWidth().size(); ++i) {
-    if (!padWidth()[i][0]->sameAs(other_op->padWidth()[i][0]) ||
-        !padWidth()[i][1]->sameAs(other_op->padWidth()[i][1])) {
-      return false;
-    }
-  }
-  return Expr::sameAs(other);
-}
-
-int GatherOp::gatherAxis(int axis) const {
-  if (axis < 0) {
-    axis += out()->as<TensorView>()->nDims();
-  }
-  TORCH_INTERNAL_ASSERT(
-      axis >= 0 && axis < (int)windowShape().size(), "Invalid axis: ", axis);
-  return int(windowShape().size()) + axis;
+      in()->sameAs(other->in()) &&
+      getReductionOpType() == other->getReductionOpType() &&
+      init()->sameAs(other->init()));
 }
 
 IterDomain::IterDomain(
-    Val* start,
-    Val* extent,
-    ParallelType parallel_type,
-    IterType iter_type,
-    bool is_rfactor_domain)
+    Val* _start,
+    Val* _extent,
+    ParallelType _parallel_type,
+    IterType _iter_type,
+    bool _is_rfactor_domain)
     : Val(ValType::IterDomain, DataType::Int, false),
-      start_(start),
-      extent_(extent),
-      parallel_type_(parallel_type),
-      iter_type_(iter_type),
-      is_rfactor_domain_(is_rfactor_domain) {
+      start_(_start),
+      extent_(_extent),
+      parallel_type_(_parallel_type),
+      iter_type_(_iter_type),
+      is_rfactor_domain_(_is_rfactor_domain) {
   TORCH_CHECK(
       !(isRFactorProduct() && isBroadcast()),
       "IterDomain cannot be both a broadcast and rfactor domain.");
 
   TORCH_INTERNAL_ASSERT(
-      extent->isAnInt(),
+      _extent->isAnInt(),
       "Cannot create an iter domain over an extent that is not an int but received ",
-      extent,
+      _extent,
       " .");
 
   TORCH_INTERNAL_ASSERT(
-      start->isAnInt(),
+      _start->isAnInt(),
       "Cannot create an iter domain with a start that is not an int but received ",
-      start,
+      _extent,
       " .");
 
   // Check that all for-loops iterate from zero to some positive integer
   // lower_insert_syncs uses this assumption for correctness.
   TORCH_INTERNAL_ASSERT(
-      start->isZeroInt(),
+      _start->isZeroInt(),
       "Cannot create an iter domain with a start that is non-zero but received ",
-      start,
+      _extent,
       " .");
 
+  TORCH_INTERNAL_ASSERT(
+      !_extent->isZeroInt(),
+      "Cannot create an iter domain with a extent that is zero but received ",
+      _extent,
+      " .");
+
+  // TORCH_INTERNAL_ASSERT(!kir::isLoweredVal(_extent));
+
   name_ = fusion_->registerVal(this);
 }
 
@@ -682,52 +360,28 @@ IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner)
       iter_type_(src->iter_type_),
       is_rfactor_domain_(src->is_rfactor_domain_) {}
 
-bool IterDomain::sameAs(const Statement* other) const {
-  if (other == this) {
+bool IterDomain::sameAs(const IterDomain* const other) const {
+  if (other == this)
     return true;
-  }
-
-  if (!other->isA<IterDomain>()) {
-    return false;
-  }
-
-  const IterDomain* other_id = other->as<IterDomain>();
 
-  bool is_same = isReduction() == other_id->isReduction() &&
-      getParallelType() == other_id->getParallelType();
-  is_same = is_same && ScalarCheck::sameAs(extent(), other_id->extent());
-  is_same = is_same && ScalarCheck::sameAs(start(), other_id->start());
+  bool is_same = isReduction() == other->isReduction() &&
+      getParallelType() == other->getParallelType();
+  is_same = is_same && ScalarCheck::sameAs(extent(), other->extent());
+  is_same = is_same && ScalarCheck::sameAs(start(), other->start());
 
   return is_same;
 }
 
-std::vector<IterDomain*> IterDomain::clone(
-    const std::vector<IterDomain*>& domains) {
-  std::vector<IterDomain*> cloned_domains;
-  std::transform(
-      domains.begin(),
-      domains.end(),
-      std::back_inserter(cloned_domains),
-      [](auto id) { return id->clone(); });
-  return cloned_domains;
-}
-
 IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
   TORCH_CHECK(
       outer->start()->isZeroInt() && inner->start()->isZeroInt(),
       "Merging IterDomains with starting values that aren't 0 is not supported at this time.");
   TORCH_CHECK(
-      !outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(),
-      "Merging IterDomains with ending values that are 0 is not supported at this time.");
-  TORCH_CHECK(
-      outer->isReduction() == inner->isReduction() ||
-          (!outer->isReduction() && inner->extent()->isOneInt()) ||
-          (outer->extent()->isOneInt() && !inner->isReduction()),
+      outer->isReduction() == inner->isReduction(),
       "Merging IterDomains requires that their iteration types match.");
   TORCH_CHECK(
-      (outer->isGather() && inner->isGather()) ||
-          (!outer->isGather() && !inner->isGather()),
-      "Merging gather and non-gather domains is not supported.");
+      outer->getParallelType() == inner->getParallelType(),
+      "Merging IterDomains requires that their parallel types match.");
 
   Val* merged_id_size = mul(outer->extent(), inner->extent());
 
@@ -744,13 +398,6 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
     itype = IterType::Iteration;
   }
 
-  // Merging trivial reduction with iter domain, that's fine, just make it an
-  // iter domain.
-  if ((outer->isReduction() || inner->isReduction()) &&
-      (!outer->isReduction() || !inner->isReduction())) {
-    itype = IterType::Iteration;
-  }
-
   IterDomain* merged_id = new IterDomain(
       new Int(0),
       merged_id_size->as<Int>(),
@@ -765,15 +412,16 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
 
 std::pair<IterDomain*, IterDomain*> IterDomain::split(
     IterDomain* in,
-    Val* factor,
-    bool inner_split) {
+    Val* factor) {
   TORCH_CHECK(
       in->start()->isZeroInt(),
       "Splitting IterDomains with starting values that aren't 0 is not supported at this time.");
 
-  TORCH_CHECK(
-      !in->extent()->isZeroInt(),
-      "Splitting IterDomains with ending values that are 0 is not supported at this time.");
+  if (in->getParallelType() != ParallelType::Serial)
+    TORCH_CHECK(
+        false,
+        "Splitting an axis of non-Serial iteration is not supported at this time."
+        " Parallelization strategy must be set after calling split.");
 
   TORCH_CHECK(factor->isAnInt(), "Cannot split by non-integer value ", factor);
 
@@ -792,12 +440,12 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
   }
 
   // outer loop size
-  Val* remainder = ceilDiv(in->extent(), factor);
+  Val* vo = ceilDiv(in->extent(), factor);
 
   // outer loop IterDomain
   IterDomain* ido = new IterDomain(
       new Int(0),
-      inner_split ? remainder->as<Int>() : factor,
+      vo->as<Int>(),
       in->getParallelType(),
       in->getIterType(),
       in->isRFactorProduct());
@@ -805,41 +453,36 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
   // inner loop IterDomain
   IterDomain* idi = new IterDomain(
       new Int(0),
-      inner_split ? factor : remainder->as<Int>(),
+      factor,
       in->getParallelType(),
       in->getIterType(),
       in->isRFactorProduct());
 
-  new Split(ido, idi, in, factor, inner_split);
+  new Split(ido, idi, in, factor);
   return {ido, idi};
 }
 
-// TODO: We should change parallelize interface to be on tensorview or at least
-// vectorize should be done on tensorview. This would let us check that we don't
-// vectorize to the left of the computeAt domain, and could allow us to do some
-// simple validation of vectorize as it's inputs are right most and contiguous.
-void IterDomain::parallelize(ParallelType t) {
-  parallel_type_ = t;
-  if (t == ParallelType::Unroll || isParallelTypeVectorize(t)) {
-    TORCH_CHECK(
-        start()->isZeroInt() && extent()->isConstScalar(),
-        "Vectorization, unrolling, and unswitching are only supported with start = 0 and extent as a const int, but got ",
-        "a start of ",
-        start(),
-        " and extent ",
-        extent(),
-        " .");
+// TODO(kir): review if this is still needed in the Fusion IR
+Val* IterDomain::extent() const {
+  if (isThread()) {
+    // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
+    if (extent_->getValType() == ValType::Scalar)
+      if (extent_->as<Int>()->isConst())
+        return extent_;
+
+    return NamedScalar::getParallelDim(getParallelType());
   }
+  return extent_;
 }
 
 TensorDomain::TensorDomain(
-    std::vector<IterDomain*> root_domain,
-    std::vector<bool> contiguity)
-    : Val(ValType::TensorDomain, DataType::Null, false),
-      root_domain_(std::move(root_domain)),
+    std::vector<IterDomain*> _domain,
+    std::vector<bool> _contiguity)
+    : Val(ValType::TensorDomain),
+      root_domain_(std::move(_domain)),
       contiguity_(
-          contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
-                             : std::move(contiguity)) {
+          _contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
+                              : std::move(_contiguity)) {
   TORCH_CHECK(
       contiguity_.size() == root_domain_.size(),
       "Invalid contiguity information provided, incorrect size. Recieved vector of size ",
@@ -847,23 +490,20 @@ TensorDomain::TensorDomain(
       " but needed one of size ",
       root_domain_.size());
 
-  // Just due to clang-tidy, correct value set in resetDomains
-  has_nontrivial_reduction_ = false;
   domain_ = root_domain_;
   resetDomains();
-  name_ = fusion_->registerVal(this);
 }
 
 TensorDomain::TensorDomain(
-    std::vector<IterDomain*> root_domain,
-    std::vector<IterDomain*> domain,
-    std::vector<bool> contiguity)
+    std::vector<IterDomain*> _root_domain,
+    std::vector<IterDomain*> _domain,
+    std::vector<bool> _contiguity)
     : Val(ValType::TensorDomain, DataType::Null, false),
-      root_domain_(std::move(root_domain)),
-      domain_(std::move(domain)),
+      root_domain_(std::move(_root_domain)),
+      domain_(std::move(_domain)),
       contiguity_(
-          contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
-                             : std::move(contiguity)) {
+          _contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
+                              : std::move(_contiguity)) {
   TORCH_CHECK(
       contiguity_.size() == root_domain_.size(),
       "Invalid contiguity information provided, incorrect size. Recieved vector of size ",
@@ -874,7 +514,7 @@ TensorDomain::TensorDomain(
   std::vector<Val*> domain_vals(domain_.begin(), domain_.end());
   auto inps = IterVisitor::getInputsTo(domain_vals);
 
-  // Validate that the root domain consists of all inputs to domain
+  // Validate that the root domain consists of all inputs to _domain
   // Uncertain if this will hold for RFactor
 
   std::unordered_set<Val*> root_vals(root_domain_.begin(), root_domain_.end());
@@ -886,24 +526,23 @@ TensorDomain::TensorDomain(
         " is an input of domain, but it is not found in the root domain.");
   });
 
-  // Just due to clang-tidy, correct value set in resetDomains
-  has_nontrivial_reduction_ = false;
   resetDomains();
+
   name_ = fusion_->registerVal(this);
 }
 
 TensorDomain::TensorDomain(
-    std::vector<IterDomain*> root_domain,
-    std::vector<IterDomain*> rfactor_domain,
-    std::vector<IterDomain*> domain,
-    std::vector<bool> contiguity)
+    std::vector<IterDomain*> _root_domain,
+    std::vector<IterDomain*> _rfactor_domain,
+    std::vector<IterDomain*> _domain,
+    std::vector<bool> _contiguity)
     : Val(ValType::TensorDomain, DataType::Null, false),
-      root_domain_(std::move(root_domain)),
-      domain_(std::move(domain)),
-      rfactor_domain_(std::move(rfactor_domain)),
+      root_domain_(std::move(_root_domain)),
+      domain_(std::move(_domain)),
+      rfactor_domain_(std::move(_rfactor_domain)),
       contiguity_(
-          contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
-                             : std::move(contiguity)) {
+          _contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
+                              : std::move(_contiguity)) {
   TORCH_CHECK(
       contiguity_.size() == root_domain_.size(),
       "Invalid contiguity information provided, incorrect size. Recieved vector of size ",
@@ -914,7 +553,7 @@ TensorDomain::TensorDomain(
   auto inps = IterVisitor::getInputsTo(
       std::vector<Val*>(domain_.begin(), domain_.end()));
 
-  // Validate that the root domain consists of all inputs to domain
+  // Validate that the root domain consists of all inputs to _domain
   // Uncertain if this will hold for RFactor
 
   std::unordered_set<Val*> root_vals(root_domain_.begin(), root_domain_.end());
@@ -936,8 +575,6 @@ TensorDomain::TensorDomain(
         " is an input of the rfactor domain, but it is not found in the root domain.");
   });
 
-  // Just due to clang-tidy, correct value set in resetDomains
-  has_nontrivial_reduction_ = false;
   resetDomains();
   name_ = fusion_->registerVal(this);
 }
@@ -949,8 +586,7 @@ TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner)
       no_bcast_domain_(ir_cloner->clone(src->no_bcast_domain_)),
       no_reduction_domain_(ir_cloner->clone(src->no_reduction_domain_)),
       rfactor_domain_(ir_cloner->clone(src->rfactor_domain_)),
-      contiguity_(src->contiguity()),
-      has_nontrivial_reduction_(src->has_nontrivial_reduction_) {}
+      contiguity_(src->contiguity()) {}
 
 bool TensorDomain::operator==(const TensorDomain& other) const {
   // Checks equality of each class field. Should not be necessary to
@@ -961,44 +597,25 @@ bool TensorDomain::operator==(const TensorDomain& other) const {
       contiguity_ == other.contiguity_;
 }
 
-bool TensorDomain::sameAs(const Statement* const other) const {
-  if (this == other) {
-    return true;
-  }
-
-  if (!other->isA<TensorDomain>()) {
+bool TensorDomain::sameAs(const TensorDomain* const other) const {
+  if (nDims() != other->nDims())
     return false;
-  }
-
-  const TensorDomain* other_td = other->as<TensorDomain>();
-
-  if (nDims() != other_td->nDims()) {
+  if (getRootDomain().size() != other->getRootDomain().size())
     return false;
-  }
-  if (getRootDomain().size() != other_td->getRootDomain().size()) {
-    return false;
-  }
-  if (getRFactorDomain().size() != other_td->getRFactorDomain().size()) {
+  if (getRFactorDomain().size() != other->getRFactorDomain().size())
     return false;
-  }
 
-  for (size_t i = 0; i < nDims(); i++) {
-    if (!(axis(i)->sameAs(other_td->axis(i)))) {
+  for (const auto i : c10::irange(nDims()))
+    if (!(axis(i)->sameAs(other->axis(i))))
       return false;
-    }
-  }
 
-  for (size_t i = 0; i < getRootDomain().size(); i++) {
-    if (!(getRootDomain()[i]->sameAs(other_td->getRootDomain()[i]))) {
+  for (const auto i : c10::irange(getRootDomain().size()))
+    if (!(getRootDomain()[i]->sameAs(other->getRootDomain()[i])))
       return false;
-    }
-  }
 
-  for (size_t i = 0; i < getRFactorDomain().size(); i++) {
-    if (!(getRFactorDomain()[i]->sameAs(other_td->getRFactorDomain()[i]))) {
+  for (const auto i : c10::irange(getRFactorDomain().size()))
+    if (!(getRFactorDomain()[i]->sameAs(other->getRFactorDomain()[i])))
       return false;
-    }
-  }
 
   return true;
 }
@@ -1017,7 +634,7 @@ bool TensorDomain::sameAs(
 }
 
 bool TensorDomain::hasReduction() const {
-  return has_nontrivial_reduction_;
+  return no_reduction_domain_.size() != domain_.size();
 }
 
 bool TensorDomain::hasBlockReduction() const {
@@ -1032,6 +649,12 @@ bool TensorDomain::hasGridReduction() const {
   });
 }
 
+bool TensorDomain::hasBlockBroadcast() const {
+  return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
+    return id->isBroadcast() && id->isThreadDim();
+  });
+}
+
 bool TensorDomain::hasBroadcast() const {
   return no_bcast_domain_.size() != domain_.size();
 }
@@ -1040,13 +663,6 @@ bool TensorDomain::hasRFactor() const {
   return !rfactor_domain_.empty();
 }
 
-bool TensorDomain::hasVectorize() const {
-  return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
-    return id->getParallelType() == ParallelType::Vectorize ||
-        id->getParallelType() == ParallelType::MisalignedVectorize;
-  });
-}
-
 c10::optional<unsigned int> TensorDomain::getReductionAxis() const {
   auto it = std::find_if(domain_.begin(), domain_.end(), [](const auto& id) {
     return id->isReduction();
@@ -1085,7 +701,7 @@ size_t TensorDomain::posOf(IterDomain* id) const {
   TORCH_CHECK(false, "Provided id is not part of this domain.");
 }
 
-void TensorDomain::split(int axis_, Val* factor, bool inner_split) {
+void TensorDomain::split(int axis_, Val* factor) {
   TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim domain");
   if (axis_ < 0)
     axis_ += nDims();
@@ -1095,7 +711,7 @@ void TensorDomain::split(int axis_, Val* factor, bool inner_split) {
       "Tried to split on axis outside TensorDomain's range.");
 
   IterDomain* id = axis(axis_);
-  auto split_ids = IterDomain::split(id, factor, inner_split);
+  auto split_ids = IterDomain::split(id, factor);
   domain_.erase(domain_.begin() + axis_);
   domain_.insert(domain_.begin() + axis_, split_ids.second);
   domain_.insert(domain_.begin() + axis_, split_ids.first);
@@ -1156,7 +772,96 @@ std::vector<IterDomain*> TensorDomain::orderedAs(
   // Eventhough these checks are already in TensorView, we want to redo them as
   // we can enter this function from other places, not through TensorView
 
-  auto new2old = ir_utils::normalizeOld2New(old2new_, dom.size());
+  // adjust based on negative values (any negative values gets nDims added to
+  // it)
+  std::unordered_map<int, int> old2new;
+  auto ndims = dom.size();
+  std::transform(
+      old2new_.begin(),
+      old2new_.end(),
+      std::inserter(old2new, old2new.begin()),
+      [ndims](std::unordered_map<int, int>::value_type entry) {
+        return std::unordered_map<int, int>::value_type({
+            entry.first < 0 ? entry.first + ndims : entry.first,
+            entry.second < 0 ? entry.second + ndims : entry.second,
+        });
+      });
+
+  // Check if any adjusted values are < 0, or >= nDims, which are invalid
+
+  TORCH_CHECK(
+      std::none_of(
+          old2new.begin(),
+          old2new.end(),
+          [ndims](std::unordered_map<int, int>::value_type entry) {
+            return entry.first < 0 || (unsigned int)entry.first >= ndims ||
+                entry.second < 0 || (unsigned int)entry.second >= ndims;
+          }),
+      "Reorder axes are not within the number of dimensions of the provided domain.");
+
+  // Going to use sets, to see if any duplicate values are in the map.
+
+  std::set<int> old_pos_set;
+  std::transform(
+      old2new.begin(),
+      old2new.end(),
+      std::inserter(old_pos_set, old_pos_set.begin()),
+      [](std::unordered_map<int, int>::value_type entry) {
+        return entry.first;
+      });
+
+  std::set<int> new_pos_set;
+  std::transform(
+      old2new.begin(),
+      old2new.end(),
+      std::inserter(new_pos_set, new_pos_set.begin()),
+      [](std::unordered_map<int, int>::value_type entry) {
+        return entry.second;
+      });
+
+  // Error out if duplicate values are found.
+  TORCH_CHECK(
+      old_pos_set.size() == old2new.size() &&
+          new_pos_set.size() == old2new.size(),
+      "Duplicate entries in transformation map sent to TensorView reorder.");
+
+  // END VALIDATION CHECKS
+
+  std::vector<int> new2old(ndims, -1);
+
+  // Go through each old and new position, make sure they're within [0, ndims)
+  for (std::pair<int, int> elem : old2new) {
+    int old_pos = elem.first;
+    int new_pos = elem.second;
+    new2old[new_pos] = old_pos;
+  }
+
+  // old_positions that already have a new position
+  std::set<int> old_positions(new2old.begin(), new2old.end());
+  old_positions.erase(-1);
+
+  // All available new positions
+  std::set<int> all_positions;
+  for (const auto i : c10::irange(ndims))
+    all_positions.insert(i);
+
+  // Check what positions haven't been specified.
+  std::set<int> positions_left;
+  std::set_difference(
+      all_positions.begin(),
+      all_positions.end(),
+      old_positions.begin(),
+      old_positions.end(),
+      std::inserter(positions_left, positions_left.end()));
+
+  // Fill in positions that weren't specified, in relative order,
+  // in empty spots in the set of new positions.
+  // new2old[new_position] = old_position
+  auto it = positions_left.begin(); // old positions left
+  std::transform(
+      new2old.begin(), new2old.end(), new2old.begin(), [&it](int i) -> int {
+        return i == -1 ? *it++ : i;
+      });
 
   std::vector<IterDomain*> reordered_domain;
   std::transform(
@@ -1171,13 +876,13 @@ std::vector<IterDomain*> TensorDomain::orderedAs(
 std::vector<IterDomain*> TensorDomain::noReductions(
     const std::vector<IterDomain*>& td) {
   size_t size_out = 0;
-  for (auto id : td)
+  for (const auto& id : td)
     if (!id->isReduction())
       size_out++;
   std::vector<IterDomain*> noReductionDomain(size_out);
 
   int it = 0;
-  for (auto id : td)
+  for (const auto& id : td)
     if (!id->isReduction())
       noReductionDomain[it++] = id;
 
@@ -1187,13 +892,13 @@ std::vector<IterDomain*> TensorDomain::noReductions(
 std::vector<IterDomain*> TensorDomain::noBroadcasts(
     const std::vector<IterDomain*>& td) {
   size_t size_out = 0;
-  for (auto id : td)
+  for (const auto& id : td)
     if (!id->isBroadcast())
       size_out++;
   std::vector<IterDomain*> noBroadcastDomain(size_out);
 
   int it = 0;
-  for (auto id : td)
+  for (const auto& id : td)
     if (!id->isBroadcast())
       noBroadcastDomain[it++] = id;
 
@@ -1201,29 +906,86 @@ std::vector<IterDomain*> TensorDomain::noBroadcasts(
 }
 
 bool TensorDomain::hasBroadcast(const std::vector<IterDomain*>& td) {
-  for (auto id : td)
+  for (const auto& id : td)
     if (id->isBroadcast())
       return true;
   return false;
 }
-
 bool TensorDomain::hasReduction(const std::vector<IterDomain*>& td) {
-  for (auto id : td)
+  for (const auto& id : td)
     if (id->isReduction())
       return true;
   return false;
 }
 
-bool TensorDomain::hasNontrivialReduction(const std::vector<IterDomain*>& td) {
-  for (auto id : td) {
-    if (id->isReduction() && !id->isTrivialReduction()) {
-      return true;
+std::vector<std::pair<int, int>> TensorDomain::mapDomainPandC(
+    const std::vector<IterDomain*>& producer,
+    const std::vector<IterDomain*>& consumer) {
+  std::vector<std::pair<int, int>> dom_map;
+
+  size_t itc = 0, itp = 0;
+  while (itc < consumer.size() && itp < producer.size()) {
+    if (consumer[itc]->isBroadcast() && !producer[itp]->isBroadcast()) {
+      itc++;
+      continue;
+    }
+    if (producer[itp]->isReduction()) {
+      itp++;
+      continue;
+    }
+
+    dom_map.emplace_back(std::make_pair(itp, itc));
+    itc++;
+    itp++;
+  }
+  return dom_map;
+}
+
+std::vector<std::pair<IterDomain*, IterDomain*>> TensorDomain::mapRootPandC(
+    const TensorDomain* producer,
+    const TensorDomain* consumer) {
+  auto consumer_root = consumer->getRootDomain();
+  auto producer_root = producer->getMaybeRFactorDomain();
+  std::vector<std::pair<IterDomain*, IterDomain*>> root_id_map;
+  for (const auto& m : mapDomainPandC(producer_root, consumer_root)) {
+    auto producer_axis = producer_root[m.first];
+    auto consumer_axis = consumer_root[m.second];
+    root_id_map.emplace_back(std::make_pair(producer_axis, consumer_axis));
+  }
+  return root_id_map;
+}
+
+std::unordered_map<IterDomain*, IterDomain*> TensorDomain::mapRootCtoP(
+    const TensorDomain* consumer,
+    const TensorDomain* producer,
+    const std::unordered_set<IterDomain*>& consumer_root_dims_to_map) {
+  std::unordered_map<IterDomain*, IterDomain*> root_id_map;
+  for (const auto& kv : mapRootPandC(producer, consumer)) {
+    auto producer_axis = kv.first;
+    auto consumer_axis = kv.second;
+    if (consumer_root_dims_to_map.find(consumer_axis) !=
+        consumer_root_dims_to_map.end()) {
+      root_id_map[consumer_axis] = producer_axis;
     }
   }
-  return false;
+  return root_id_map;
 }
 
-// TODO: Rfactor a Welford
+std::unordered_map<IterDomain*, IterDomain*> TensorDomain::mapRootPtoC(
+    const TensorDomain* producer,
+    const TensorDomain* consumer,
+    const std::unordered_set<IterDomain*>& producer_maybe_rfactor_dims_to_map) {
+  std::unordered_map<IterDomain*, IterDomain*> root_id_map;
+  for (const auto& kv : mapRootPandC(producer, consumer)) {
+    auto producer_axis = kv.first;
+    auto consumer_axis = kv.second;
+    if (producer_maybe_rfactor_dims_to_map.find(producer_axis) !=
+        producer_maybe_rfactor_dims_to_map.end()) {
+      root_id_map[producer_axis] = consumer_axis;
+    }
+  }
+  return root_id_map;
+}
 
 // pair is in order where second is the consumer of first
 std::pair<TensorDomain*, TensorDomain*> TensorDomain::rFactor(
@@ -1232,7 +994,7 @@ std::pair<TensorDomain*, TensorDomain*> TensorDomain::rFactor(
 
   std::vector<int> axes(axes_.size());
 
-  auto ndims = nDims();
+  const auto ndims = nDims();
   std::transform(axes_.begin(), axes_.end(), axes.begin(), [ndims](int i) {
     return i < 0 ? i + ndims : i;
   });
@@ -1253,7 +1015,7 @@ std::pair<TensorDomain*, TensorDomain*> TensorDomain::rFactor(
 
   bool rfactor_found = false;
   bool reduction_found = false;
-  for (decltype(nDims()) i{0}; i < nDims(); i++) {
+  for (const auto i : c10::irange(nDims())) {
     if (axis(i)->isReduction()) {
       if (axes_set.find(i) != axes_set.end()) {
         rfactor_found = true;
@@ -1274,6 +1036,118 @@ std::pair<TensorDomain*, TensorDomain*> TensorDomain::rFactor(
 
 namespace {
 
+//! Container class DisjointSet models equivalence relationships
+//!
+//! Each instance of this class keeps a set of equivalent classes
+//! DisjointSet::join(a,b) makes the full class of a and b equivalent
+//! DisjointSet::areEqual(a,b) checks if a and b belong same class
+//!
+//! \note The template type T is assumed to be hashable
+template <typename T>
+class DisjointSet {
+ public:
+  DisjointSet() = default;
+
+  //! Joins the equivalent class that a and b belong to
+  //! areEqual(a',b') will be true for each a'=a and b'=b
+  //!
+  //! \param a An element from a equivalent class
+  //!          will create a new equivalent class if a does
+  //!          not belong to any
+  //! \param b An element from another equivalent class
+  //!          will create a new equivalent class if b does
+  //!          not belong to any
+  void join(T a, T b) {
+    // cases where either of the quiv class doesn't exist
+    if (!entry_map.count(a) && !entry_map.count(b)) {
+      createPoint(a);
+      entry_map[b] = fixedPoint(a);
+    } else if (!entry_map.count(a)) {
+      entry_map[a] = fixedPoint(b);
+    } else if (!entry_map.count(b)) {
+      entry_map[b] = fixedPoint(a);
+    } else {
+      // case where both equiv classes exist and need to join
+      const int i0 = fixedPoint(a);
+      const int i1 = fixedPoint(b);
+      int new_parent = 0;
+      int new_child = 0;
+
+      // Either order here is correct but joining larger class to smaller class
+      // tend to be faster
+      std::tie(new_parent, new_child) = (weights[i0] < weights[i1])
+          ? std::make_pair(i0, i1)
+          : std::make_pair(i1, i0);
+      weights[new_parent] += weights[new_child];
+      set_map[new_child] = new_parent;
+    }
+  }
+
+  //! Checks if a and b belong to the same equivalent class
+  //!
+  //! \param a An element from a equivalent class
+  //! \param b An element from another equivalent class
+  //! \returns Boolean value representing if a and b are
+  //!          recorded to be in the same equivalent class
+  //!          will return false if any of a or b doesn't
+  //!          have an equivalent class recorded
+  bool areEquivalent(T a, T b) const {
+    if (!entry_map.count(a) || !entry_map.count(b)) {
+      return false;
+    }
+    return fixedPoint(a) == fixedPoint(b);
+  }
+
+ private:
+  // Internal fixed point implementation:
+  //  Returns the equivalent class that e belongs to
+  int fixedPoint(int e) const {
+    TORCH_INTERNAL_ASSERT(static_cast<int>(set_map.size()) > e);
+    while (set_map[e] != e) {
+      // Chasing to fixed point
+      e = set_map[e];
+    }
+    return e;
+  }
+
+  //! Utility to check the class i belongs to:
+  //!
+  //! Will create a new class if no match seen
+  //! \param e element e to find the equiv class for
+  //! \returns the equivalent class that e belongs to
+  //!
+  int fixedPoint(T e) const {
+    // Handles case when i doesn't have an equivalence class
+    TORCH_INTERNAL_ASSERT(entry_map.count(e));
+
+    // Use fixed point as a representation for the equiv class
+    return fixedPoint(entry_map.at(e));
+  }
+
+  //! Utility to create a new equiv class for i
+  //
+  //! \param i Element i to create the equiv class for
+  void createPoint(T i) {
+    entry_map[i] = next_index_;
+    set_map.push_back(next_index_++);
+    weights.push_back(1);
+  }
+
+ private:
+  // Internal representation of the equivalence class as integers
+  // set_map implements the "parent" relationship
+  std::vector<int> set_map;
+  // Weights is used for preliminary perf optimization
+  std::vector<int> weights;
+
+  // Map the input of type T to its equivalence class
+  std::unordered_map<T, int> entry_map;
+
+  // Running counter for generating new index when
+  // Creating new equiv classes
+  int next_index_ = 0;
+};
+
 //! Concretize broadcast axes, i.e. identifying a non-broadcast
 //! IterDomain that the broadcast IterDomain can map to.
 //!
@@ -1375,6 +1249,124 @@ void ConcretizeDomain::concretizePwOp(Expr* e) {
   }
 }
 
+//! Models equivalence provable by the graph
+//!
+//! This traversal processes root domains only,
+//! equalities , e.g. :
+//!    T2 [i0,i1] = T1[i2,i3] + T0[i4,i5]
+//! will prove that i2 and i4 are equal in the sense that
+//!    i2.start = i4.start, i2.extent = i4.extent
+//! Depends on ConcretizeDomain, and equalities involving
+//! broadcast domains are defined based on the concretized version
+class ProveValEqual : private IterVisitor {
+ public:
+  explicit ProveValEqual(Fusion* fusion) : cd_(fusion) {
+    traverseFrom(fusion, fusion->outputs(), false);
+  }
+
+  //! Checks if two scalars are equal
+  //!
+  //! First checks if ScalarCheck has them equal,
+  //! next try to prove them equal from
+  //! the graph_traversal result
+  //!
+  //! \param a A symbolic value
+  //! \param b Another value from the same fusion
+  //! \returns Boolean representing if they are proven to be
+  //!          equal based on scalar check and graph traversal
+  bool areEqual(Val* a, Val* b) const {
+    if (ScalarCheck::sameAs(a, b)) {
+      return true;
+    }
+    if (eq_set_.areEquivalent(a, b)) {
+      return true;
+    }
+    return false;
+  }
+
+  //! Checks if two iterdomains are equal
+  //!
+  //! Equality defined as equal start and equal extent
+  //! true means a and b are equal
+  //! false only means that they cannot be proven equal based
+  //! on scalar check and graph traversal
+  //!
+  //! \param a An iterdomain
+  //! \param b Another iterdomain from the same fusion
+  //! \returns Boolean representing if they are proven to be
+  //!          equivalent in the sense that they have equal
+  //!          start and extent
+  bool areEquivalent(IterDomain* a, IterDomain* b) const {
+    if (a->sameAs(b)) {
+      return true;
+    }
+
+    // Abort on un-concretized domains, this can appear once we
+    // allow broadcast on fusion output
+    if (!cd_.canConcretize(a) || !cd_.canConcretize(b)) {
+      return false;
+    }
+
+    auto ac = cd_.concretized(a);
+    auto bc = cd_.concretized(b);
+    return areEqual(ac->start(), bc->start()) &&
+        areEqual(ac->rawExtent(), bc->rawExtent());
+  }
+
+ private:
+  // Utility class to record new equality found
+  void proveId(IterDomain* a, IterDomain* b) {
+    if (!a->sameAs(b)) {
+      eq_set_.join(a->start(), b->start());
+      eq_set_.join(a->rawExtent(), b->rawExtent());
+    }
+  }
+
+  // Inspect a pointwise op and record the identified equality
+  void provePwOp(Expr* e) {
+    if (e->output(0)->getValType() != ValType::TensorView) {
+      return;
+    }
+
+    TORCH_INTERNAL_ASSERT(e->outputs().size() == 1);
+    TensorView* tv = e->output(0)->as<TensorView>();
+    const std::vector<IterDomain*>& io = tv->getRootDomain();
+
+    // Record equalities from output to all the inputs
+    // ignores un-concretizable broadcasts
+    for (auto* i : ir_utils::filterByType<TensorView>(e->inputs())) {
+      std::vector<IterDomain*> ii =
+          TensorDomain::noReductions(i->getMaybeRFactorDomain());
+
+      for (const auto it : c10::irange(ii.size()))
+        if (cd_.canConcretize(ii[it]) && cd_.canConcretize(io[it]))
+          proveId(cd_.concretized(ii[it]), cd_.concretized(io[it]));
+    }
+  }
+
+  using IterVisitor::handle;
+
+  void handle(ReductionOp* rop) override {
+    provePwOp(rop);
+  }
+
+  void handle(UnaryOp* uop) override {
+    provePwOp(uop);
+  }
+
+  void handle(BinaryOp* bop) override {
+    provePwOp(bop);
+  }
+
+  void handle(TernaryOp* top) override {
+    provePwOp(top);
+  }
+
+ private:
+  ConcretizeDomain cd_;
+  DisjointSet<const Val*> eq_set_;
+};
+
 } // namespace
 
 // API call to return the concretized axis of a broadcast axis
@@ -1382,26 +1374,31 @@ const IterDomain* IterDomain::concretizeDomain(IterDomain* bcast_dom) {
   return ConcretizeDomain::getConcreteDomain(bcast_dom);
 }
 
+// API call to check if two IterDomains are equal
+// checks start and extent, contains both scalar check and graph traversal
+// broadcast domains are concretized before comparing
+bool IterDomain::proveEquivalent(IterDomain* a, IterDomain* b) {
+  TORCH_INTERNAL_ASSERT(a->fusion() == b->fusion());
+  ProveValEqual pve(a->fusion());
+  return pve.areEquivalent(a, b);
+}
+
 Split::Split(
-    IterDomain* outer,
-    IterDomain* inner,
-    IterDomain* in,
-    Val* factor,
-    bool inner_split)
+    IterDomain* _outer,
+    IterDomain* _inner,
+    IterDomain* _in,
+    Val* _factor)
     : Expr(ExprType::Split),
-      outer_{outer},
-      inner_{inner},
-      in_{in},
-      factor_{factor},
-      inner_split_{inner_split} {
+      outer_{_outer},
+      inner_{_inner},
+      in_{_in},
+      factor_{_factor} {
   TORCH_INTERNAL_ASSERT(
       factor_->isAnInt(),
       "Attempted to create a Split node with a non-integer factor.");
-  addOutput(outer);
-  addOutput(inner);
-  addInput(in);
-  // TODO add factor as an input, need to check Split::Split during validation
-  // and need to check BestEffortReplay::findFirstMismatchedID addInput(factor);
+  addOutput(_outer);
+  addOutput(_inner);
+  addInput(_in);
   name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
@@ -1410,26 +1407,19 @@ Split::Split(const Split* src, IrCloner* ir_cloner)
       outer_(ir_cloner->clone(src->outer_)),
       inner_(ir_cloner->clone(src->inner_)),
       in_(ir_cloner->clone(src->in_)),
-      factor_(ir_cloner->clone(src->factor_)),
-      inner_split_(src->inner_split_) {}
+      factor_(ir_cloner->clone(src->factor_)) {}
 
-bool Split::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<Split>()) {
-    return false;
-  }
-  return Expr::sameAs(other) &&
-      factor()->sameAs(other->as<Split>()->factor()) &&
-      innerSplit() == other->as<Split>()->innerSplit();
+bool Split::sameAs(const Split* const other) const {
+  return (
+      outer()->sameAs(other->outer()) && inner()->sameAs(other->inner()) &&
+      in()->sameAs(other->in()) && factor()->sameAs(other->factor()));
 }
 
-Merge::Merge(IterDomain* out, IterDomain* outer, IterDomain* inner)
-    : Expr(ExprType::Merge), out_{out}, outer_{outer}, inner_{inner} {
-  addOutput(out);
-  addInput(outer);
-  addInput(inner);
+Merge::Merge(IterDomain* _out, IterDomain* _outer, IterDomain* _inner)
+    : Expr(ExprType::Merge), out_{_out}, outer_{_outer}, inner_{_inner} {
+  addOutput(_out);
+  addInput(_outer);
+  addInput(_inner);
   name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
@@ -1439,29 +1429,15 @@ Merge::Merge(const Merge* src, IrCloner* ir_cloner)
       outer_(ir_cloner->clone(src->outer_)),
       inner_(ir_cloner->clone(src->inner_)) {}
 
-bool Merge::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<Merge>()) {
-    return false;
-  }
-  return Expr::sameAs(other);
+bool Merge::sameAs(const Merge* const other) const {
+  return (
+      out()->sameAs(other->out()) && outer()->sameAs(other->outer()) &&
+      inner()->sameAs(other->inner()));
 }
 
 NamedScalar::NamedScalar(const NamedScalar* src, IrCloner* ir_cloner)
     : Val(src, ir_cloner), name_(src->name_) {}
 
-bool NamedScalar::sameAs(const Statement* other) const {
-  if (this == other) {
-    return true;
-  }
-  if (!other->isA<NamedScalar>()) {
-    return false;
-  }
-  return other->as<NamedScalar>()->name().compare(name()) == 0;
-}
-
 NamedScalar* NamedScalar::getParallelDim(ParallelType p_type) {
   std::string parallel_dim = stringifyThreadSize(p_type);
   return new NamedScalar(parallel_dim, DataType::Int);
index 5c87cb1..0d421b8 100644 (file)
@@ -3,7 +3,6 @@
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
 
 #include <iostream>
 
@@ -46,10 +45,21 @@ class TORCH_CUDA_CU_API IrTransformPrinter : public IrPrinter {
  public:
   IrTransformPrinter(std::ostream& os) : IrPrinter(os) {}
 
-  void handle(Fusion* f) override;
+  void handle(const UnaryOp* const uop) override {
+    if (printInline()) {
+      IrPrinter::handle(uop);
+    }
+  }
+
+  void handle(const BinaryOp* const bop) override {
+    if (printInline()) {
+      IrPrinter::handle(bop);
+    }
+  }
 
- private:
-  void printTransforms(TensorView* tv);
+  void handle(Fusion* f) override {
+    IrPrinter::handle(f);
+  }
 };
 
 } // namespace cuda
diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp
deleted file mode 100644 (file)
index e48c2bd..0000000
+++ /dev/null
@@ -1,391 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-
-#include <set>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-namespace ir_utils {
-
-std::vector<int> normalizeOld2New(
-    const std::unordered_map<int, int>& old2new_in,
-    size_t ndims) {
-  // adjust based on negative values (any negative values gets nDims added to
-  // it)
-  std::unordered_map<int, int> old2new;
-  std::transform(
-      old2new_in.begin(),
-      old2new_in.end(),
-      std::inserter(old2new, old2new.begin()),
-      [ndims](std::unordered_map<int, int>::value_type entry) {
-        return std::unordered_map<int, int>::value_type({
-            entry.first < 0 ? entry.first + ndims : entry.first,
-            entry.second < 0 ? entry.second + ndims : entry.second,
-        });
-      });
-
-  // Check if any adjusted values are < 0, or >= nDims, which are invalid
-
-  TORCH_CHECK(
-      std::none_of(
-          old2new.begin(),
-          old2new.end(),
-          [ndims](std::unordered_map<int, int>::value_type entry) {
-            return entry.first < 0 || (unsigned int)entry.first >= ndims ||
-                entry.second < 0 || (unsigned int)entry.second >= ndims;
-          }),
-      "Reorder axes are not within the number of dimensions of the provided domain.");
-
-  // Going to use sets, to see if any duplicate values are in the map.
-
-  std::set<int> old_pos_set;
-  std::transform(
-      old2new.begin(),
-      old2new.end(),
-      std::inserter(old_pos_set, old_pos_set.begin()),
-      [](std::unordered_map<int, int>::value_type entry) {
-        return entry.first;
-      });
-
-  std::set<int> new_pos_set;
-  std::transform(
-      old2new.begin(),
-      old2new.end(),
-      std::inserter(new_pos_set, new_pos_set.begin()),
-      [](std::unordered_map<int, int>::value_type entry) {
-        return entry.second;
-      });
-
-  // Error out if duplicate values are found.
-  TORCH_CHECK(
-      old_pos_set.size() == old2new.size() &&
-          new_pos_set.size() == old2new.size(),
-      "Duplicate entries in transformation map sent to TensorView reorder.");
-
-  // END VALIDATION CHECKS
-
-  std::vector<int> new2old(ndims, -1);
-
-  // Go through each old and new position, make sure they're within [0, ndims)
-  for (std::pair<int, int> elem : old2new) {
-    int old_pos = elem.first;
-    int new_pos = elem.second;
-    new2old[new_pos] = old_pos;
-  }
-
-  // old_positions that already have a new position
-  std::set<int> old_positions(new2old.begin(), new2old.end());
-  old_positions.erase(-1);
-
-  // All available new positions
-  std::set<int> all_positions;
-  for (decltype(ndims) i{0}; i < ndims; i++)
-    all_positions.insert(i);
-
-  // Check what positions haven't been specified.
-  std::set<int> positions_left;
-  std::set_difference(
-      all_positions.begin(),
-      all_positions.end(),
-      old_positions.begin(),
-      old_positions.end(),
-      std::inserter(positions_left, positions_left.end()));
-
-  // Fill in positions that weren't specified, in relative order,
-  // in empty spots in the set of new positions.
-  // new2old[new_position] = old_position
-  auto it = positions_left.begin(); // old positions left
-  std::transform(
-      new2old.begin(), new2old.end(), new2old.begin(), [&it](int i) -> int {
-        return i == -1 ? *it++ : i;
-      });
-
-  return new2old;
-}
-
-namespace ValReplacement {
-// Create New Expr given producer - [an input for the expression]
-// Creates a new Expr substituting current with producer
-struct SubstituteInExpr : public OptInDispatch {
- public:
-  static Expr* subsitute(Expr* expr, Val* reference, Val* substitute) {
-    TORCH_INTERNAL_ASSERT(
-        expr != nullptr && reference != nullptr && substitute != nullptr,
-        "Nullptr arg found.");
-    SubstituteInExpr sie(reference, substitute);
-    sie.handle(expr);
-    TORCH_INTERNAL_ASSERT(
-        sie.expr_ != nullptr,
-        "Substitution failed of ",
-        reference,
-        " with ",
-        substitute);
-    return sie.expr_;
-  }
-
- private:
-  explicit SubstituteInExpr(Val* reference, Val* substitute)
-      : reference_(reference), substitute_(substitute) {}
-
-  void handle(Expr* expr) final {
-    OptInDispatch::handle(expr);
-  }
-
-  void handle(UnaryOp* unary_expr) final {
-    auto in =
-        reference_->sameAs(unary_expr->in()) ? substitute_ : unary_expr->in();
-    auto out =
-        reference_->sameAs(unary_expr->out()) ? substitute_ : unary_expr->out();
-    expr_ = new UnaryOp(unary_expr->getUnaryOpType(), out, in);
-  }
-
-  void handle(BinaryOp* binary_expr) final {
-    auto lhs = reference_->sameAs(binary_expr->lhs()) ? substitute_
-                                                      : binary_expr->lhs();
-    auto rhs = reference_->sameAs(binary_expr->rhs()) ? substitute_
-                                                      : binary_expr->rhs();
-    auto out = reference_->sameAs(binary_expr->out()) ? substitute_
-                                                      : binary_expr->out();
-
-    expr_ = new BinaryOp(binary_expr->getBinaryOpType(), out, lhs, rhs);
-  }
-
-  void handle(TernaryOp* ternary_expr) final {
-    auto in1 = reference_->sameAs(ternary_expr->in1()) ? substitute_
-                                                       : ternary_expr->in1();
-    auto in2 = reference_->sameAs(ternary_expr->in2()) ? substitute_
-                                                       : ternary_expr->in2();
-    auto in3 = reference_->sameAs(ternary_expr->in3()) ? substitute_
-                                                       : ternary_expr->in3();
-    auto out = reference_->sameAs(ternary_expr->out()) ? substitute_
-                                                       : ternary_expr->out();
-    expr_ = new TernaryOp(ternary_expr->getTernaryOpType(), out, in1, in2, in3);
-  }
-
-  void handle(ReductionOp* reduction_expr) final {
-    auto init = reference_->sameAs(reduction_expr->init())
-        ? substitute_
-        : reduction_expr->init();
-    auto out = reference_->sameAs(reduction_expr->out())
-        ? substitute_
-        : reduction_expr->out();
-    auto in = reference_->sameAs(reduction_expr->in()) ? substitute_
-                                                       : reduction_expr->in();
-
-    expr_ =
-        new ReductionOp(reduction_expr->getReductionOpType(), init, out, in);
-  }
-
-  void handle(BroadcastOp* broadcast_expr) final {
-    auto out = reference_->sameAs(broadcast_expr->out())
-        ? substitute_
-        : broadcast_expr->out();
-    auto in = reference_->sameAs(broadcast_expr->in()) ? substitute_
-                                                       : broadcast_expr->in();
-
-    expr_ = new BroadcastOp(out, in, broadcast_expr->getBroadcastDimFlags());
-  }
-
-  void handle(TransposeOp* transpose_expr) final {
-    TORCH_INTERNAL_ASSERT(
-        substitute_->isA<TensorView>(),
-        "All args to transpose must be tensor view, but received a non-TensorView for replacement: ",
-        substitute_);
-    auto out = reference_->sameAs(transpose_expr->out())
-        ? substitute_->as<TensorView>()
-        : transpose_expr->out();
-    auto in = reference_->sameAs(transpose_expr->in())
-        ? substitute_->as<TensorView>()
-        : transpose_expr->in();
-    expr_ = new TransposeOp(out, in, transpose_expr->new2old());
-  }
-
-  void handle(ShiftOp* shift_expr) final {
-    auto out =
-        reference_->sameAs(shift_expr->out()) ? substitute_ : shift_expr->out();
-    auto in =
-        reference_->sameAs(shift_expr->in()) ? substitute_ : shift_expr->in();
-
-    expr_ = new ShiftOp(out, in, shift_expr->offsets());
-  }
-
-  void handle(GatherOp* gather_expr) final {
-    auto out = reference_->sameAs(gather_expr->out()) ? substitute_
-                                                      : gather_expr->out();
-    auto in =
-        reference_->sameAs(gather_expr->in()) ? substitute_ : gather_expr->in();
-
-    expr_ = new GatherOp(
-        out, in, gather_expr->windowShape(), gather_expr->padWidth());
-  }
-
-  void handle(WelfordOp* welford_expr) final {
-    auto out_avg = reference_->sameAs(welford_expr->outAvg())
-        ? substitute_->as<TensorView>()
-        : welford_expr->outAvg();
-    auto out_var = reference_->sameAs(welford_expr->outVar())
-        ? substitute_->as<TensorView>()
-        : welford_expr->outVar();
-    auto out_N = reference_->sameAs(welford_expr->outN())
-        ? substitute_->as<TensorView>()
-        : welford_expr->outN();
-    auto in_avg = reference_->sameAs(welford_expr->inAvg())
-        ? substitute_->as<TensorView>()
-        : welford_expr->inAvg();
-    auto in_var =
-        welford_expr->inVar() && reference_->sameAs(welford_expr->inVar())
-        ? substitute_->as<TensorView>()
-        : welford_expr->inVar();
-    auto in_N = reference_->sameAs(welford_expr->inN()) ? substitute_
-                                                        : welford_expr->inN();
-    auto init_avg =
-        welford_expr->initAvg() && reference_->sameAs(welford_expr->initAvg())
-        ? substitute_->as<TensorView>()
-        : welford_expr->initAvg();
-    auto init_var =
-        welford_expr->initVar() && reference_->sameAs(welford_expr->initVar())
-        ? substitute_->as<TensorView>()
-        : welford_expr->initVar();
-    auto init_N =
-        welford_expr->initN() && reference_->sameAs(welford_expr->initN())
-        ? substitute_
-        : welford_expr->initN();
-    expr_ = new WelfordOp(
-        out_avg,
-        out_var,
-        out_N,
-        init_avg,
-        init_var,
-        init_N,
-        in_avg,
-        in_var,
-        in_N);
-  }
-
- private:
-  Val* reference_ = nullptr;
-  Val* substitute_ = nullptr;
-  Expr* expr_ = nullptr;
-};
-
-} // namespace ValReplacement
-
-Expr* replaceValInExpr(Expr* expr, Val* reference, Val* substitute) {
-  FusionGuard fg(expr->fusion());
-  return ValReplacement::SubstituteInExpr::subsitute(
-      expr, reference, substitute);
-}
-
-TensorView* rfactorHelper(
-    TensorView* reduction_tv,
-    const std::vector<int>& axes) {
-  TORCH_INTERNAL_ASSERT(reduction_tv->definition() != nullptr);
-  const bool is_welford = reduction_tv->definition()->isA<WelfordOp>();
-  if (!is_welford) {
-    return reduction_tv->rFactor(axes);
-  }
-  auto welford = reduction_tv->definition()->as<WelfordOp>();
-  auto w_avg = welford->outAvg()->as<TensorView>();
-  auto w_var = welford->outVar()->as<TensorView>();
-  auto w_n = welford->outN()->as<TensorView>();
-
-  WelfordResult rtvs = reduction_tv->rFactor(axes, w_avg, w_var, w_n);
-
-  // TODO: this can be more generic, using avg because
-  //      WelfordOp::out() returns the avg
-  return rtvs.avg;
-}
-
-namespace {
-
-std::vector<TensorView*> uniqueEntries(
-    const std::vector<TensorView*>& tv_deuqe) {
-  std::vector<TensorView*> unique_entries;
-  std::unordered_set<TensorView*> inserted;
-  for (auto tv_entry : tv_deuqe) {
-    if (inserted.emplace(tv_entry).second) {
-      unique_entries.emplace_back(tv_entry);
-    }
-  }
-  return unique_entries;
-}
-
-} // namespace
-
-std::vector<TensorView*> producerTvsOf(TensorView* tv) {
-  if (tv->definition() == nullptr) {
-    return {};
-  }
-  auto producer_vals =
-      ir_utils::filterByType<TensorView>(tv->definition()->inputs());
-  return uniqueEntries({producer_vals.begin(), producer_vals.end()});
-}
-
-std::vector<TensorView*> consumerTvsOf(TensorView* tv) {
-  std::vector<TensorView*> consumer_tvs;
-  for (auto use_expr : tv->uses()) {
-    auto outputs = ir_utils::filterByType<TensorView>(use_expr->outputs());
-    consumer_tvs.insert(consumer_tvs.end(), outputs.begin(), outputs.end());
-  }
-  return uniqueEntries(consumer_tvs);
-}
-
-std::vector<TensorView*> producerTvsOf(const std::vector<TensorView*>& tvs) {
-  std::vector<TensorView*> all_producer_tvs;
-  for (auto tv : tvs) {
-    auto producer_tvs = producerTvsOf(tv);
-    all_producer_tvs.insert(
-        all_producer_tvs.end(), producer_tvs.begin(), producer_tvs.end());
-  }
-
-  return uniqueEntries(all_producer_tvs);
-}
-
-std::vector<TensorView*> consumerTvsOf(const std::vector<TensorView*>& tvs) {
-  std::vector<TensorView*> all_consumer_tvs;
-  for (auto tv : tvs) {
-    auto consumer_tvs = consumerTvsOf(tv);
-    all_consumer_tvs.insert(
-        all_consumer_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
-  }
-
-  return uniqueEntries(all_consumer_tvs);
-}
-
-std::vector<TensorView*> inputTvsOf(TensorView* tv) {
-  return inputTvsOf(std::vector<TensorView*>{tv});
-}
-
-std::vector<TensorView*> outputTvsOf(TensorView* tv) {
-  return outputTvsOf(std::vector<TensorView*>{tv});
-}
-
-std::vector<TensorView*> inputTvsOf(std::vector<TensorView*> tvs) {
-  auto inp_vals = IterVisitor::getInputsTo({tvs.begin(), tvs.end()});
-  auto filtered = ir_utils::filterByType<TensorView>(inp_vals);
-  std::vector<TensorView*> inp_tvs(filtered.begin(), filtered.end());
-  return uniqueEntries(inp_tvs);
-}
-
-std::vector<TensorView*> outputTvsOf(std::vector<TensorView*> tvs) {
-  auto out_vals = DependencyCheck::getAllOutputsOf({tvs.begin(), tvs.end()});
-  auto filtered = ir_utils::filterByType<TensorView>(out_vals);
-  std::vector<TensorView*> out_tvs(filtered.begin(), filtered.end());
-  return uniqueEntries(out_tvs);
-}
-
-std::vector<TensorView*> allTvs(Fusion* fusion) {
-  auto used_vals = fusion->usedMathVals();
-  auto used_tvs = ir_utils::filterByType<TensorView>(used_vals);
-  return uniqueEntries({used_tvs.begin(), used_tvs.end()});
-}
-
-} // namespace ir_utils
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
index 144b640..e5402da 100644 (file)
@@ -1,10 +1,8 @@
 #pragma once
 
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/type.h>
 
 #include <iterator>
-#include <unordered_map>
 
 namespace torch {
 namespace jit {
@@ -66,7 +64,7 @@ class FilterIterator {
 
  private:
   Iterator current_;
-  Iterator end_;
+  const Iterator end_;
 };
 
 // An iterable view to a given container of Val pointers. Only returns
@@ -111,60 +109,6 @@ auto filterByType(const ContainerType& inputs) {
   return filterByType<FilterType>(inputs.cbegin(), inputs.cend());
 }
 
-//! Returns a list of new-to-old mappings.
-//!
-//! The input map does not need to be complete. Missing axes are
-//! assumed not to be affected.
-//!
-//! This is used to preprocess broadcast and transpose arguments.
-//!
-//! Example: (N := ndims)
-//!   {{0, 1}} -> [1, 0, ...., N-1]
-//!   Transposes the first two axes with no other change.
-//!
-//!   {{0, -1}} -> [N-1, ...., 0]
-//!   Swaps the first and last axes.
-std::vector<int> normalizeOld2New(
-    const std::unordered_map<int, int>& old2new_in,
-    size_t ndims);
-
-// Replace all uses of reference with substitute in expr. Return the Expr.
-// Warning: Invalidates provided Expr.
-// Warning: Removes connection of reference through provided Expr.
-// Warning: Creates new Expr connecting substitue.
-// Reference is found through direct pointer comparison.
-Expr* replaceValInExpr(Expr* expr, Val* reference, Val* substitute);
-
-// Makes rfactor generic with reduction ops and Welford
-TensorView* rfactorHelper(TensorView* red_tv, const std::vector<int>& axes);
-
-// Return immediate producers of tv
-std::vector<TensorView*> producerTvsOf(TensorView* tv);
-
-// Return immediate consumers of tv
-std::vector<TensorView*> consumerTvsOf(TensorView* tv);
-
-// Return immediate producers of tvs (can return tvs input)
-std::vector<TensorView*> producerTvsOf(const std::vector<TensorView*>& tvs);
-
-// Return immediate consumers of tvs (can return tvs input)
-std::vector<TensorView*> consumerTvsOf(const std::vector<TensorView*>& tvs);
-
-// Returns producers of tv that are inputs of fusion
-std::vector<TensorView*> inputTvsOf(TensorView* tv);
-
-// Returns consumers of tv that are outputs of fusion
-std::vector<TensorView*> outputTvsOf(TensorView* tv);
-
-// Returns producers of tvs that are inputs of fusion
-std::vector<TensorView*> inputTvsOf(std::vector<TensorView*> tvs);
-
-// Returns consumers of tvs that are outputs of fusion
-std::vector<TensorView*> outputTvsOf(std::vector<TensorView*> tvs);
-
-// returns all tensor views in fusion that are used between outputs and inputs.
-TORCH_CUDA_CU_API std::vector<TensorView*> allTvs(Fusion* fusion);
-
 } // namespace ir_utils
 } // namespace cuda
 } // namespace fuser
index 8b96196..d1efc6d 100644 (file)
@@ -33,50 +33,6 @@ void remove_visited(
 
 } // namespace
 
-std::vector<Statement*> IterVisitor::next(Statement* stmt) {
-  if (stmt->isVal()) {
-    return next(stmt->as<Val>());
-  } else if (stmt->isExpr()) {
-    return next(stmt->as<Expr>());
-  } else {
-    TORCH_INTERNAL_ASSERT(
-        false, "IterVisitor could not detect type in next_dispatch.");
-  }
-}
-
-std::vector<Statement*> IterVisitor::next(Val* v) {
-  FusionGuard::getCurFusion()->assertInFusion(v, "Cannot traverse val, ");
-  if (v->definition() != nullptr) {
-    return {v->definition()};
-  }
-  return {};
-}
-
-std::vector<Statement*> IterVisitor::next(Expr* expr) {
-  FusionGuard::getCurFusion()->assertInFusion(expr, "Cannot traverse expr, ");
-  std::vector<Statement*> next_stmts{
-      expr->inputs().begin(), expr->inputs().end()};
-  return next_stmts;
-}
-
-// This handle functions is called on every Statement* in topological order,
-// starting from outputs to inputs.
-void IterVisitor::handle(Statement* s) {
-  OptOutDispatch::handle(s);
-}
-
-// This handle functions is called on every Expr* in topological order,
-// starting from outputs to inputs.
-void IterVisitor::handle(Expr* e) {
-  OptOutDispatch::handle(e);
-}
-
-// This handle functions is called on every Val* in topological order,
-// starting from outputs to inputs.
-void IterVisitor::handle(Val* v) {
-  OptOutDispatch::handle(v);
-}
-
 // Implementation details:
 // We start with an entry in stmt_stack that is the outputs we want to
 // process. We cannot process these outputs untill all Stmts in their history
@@ -120,14 +76,11 @@ void IterVisitor::traverseFrom(
 
     // If we just poped a stmt_stack level, we can finally visit it!
     if (all_inputs_visited) {
-      // stmt may have be already visited.
-      if (traverseAllPaths || visited.find(stmt) == visited.end()) {
-        // Mark visited
-        visited.insert(stmt);
+      // Mark visited
+      visited.insert(stmt);
 
-        // Actually visit stmt
-        handle(stmt);
-      }
+      // Actually visit stmt
+      handle(stmt);
 
       // Remove last value just visited
       current_inputs.pop_back();
@@ -157,21 +110,38 @@ void IterVisitor::traverseFrom(
   }
 }
 
-void IterVisitor::traverseHelper(Fusion* fusion, bool traverse_all_paths) {
+void IterVisitor::traverse_(
+    Fusion* fusion,
+    bool from_outputs_only,
+    bool traverse_all_paths) {
   FusionGuard fg(fusion);
 
-  auto term_val_outs = fusion->getTerminatingOutputs();
-  if (!term_val_outs.empty()) {
-    traverseFrom(fusion, term_val_outs, traverse_all_paths);
+  if (from_outputs_only) {
+    auto term_val_outs = fusion->getTerminatingOutputs();
+    if (!term_val_outs.empty()) {
+      traverseFrom(fusion, term_val_outs, traverse_all_paths);
+    }
+    return;
+  }
+
+  std::vector<Val*> leaves;
+  // Search for Vals with no uses (output edges)
+  for (Val* val : fusion->deterministic_vals())
+    if (!fusion->used(val)) {
+      leaves.push_back(val);
+    }
+
+  if (!leaves.empty()) {
+    traverseFrom(fusion, leaves, traverse_all_paths);
   }
 }
 
-void IterVisitor::traverse(Fusion* fusion) {
-  traverseHelper(fusion, false);
+void IterVisitor::traverse(Fusion* fusion, bool from_outputs_only) {
+  traverse_(fusion, from_outputs_only, false);
 }
 
-void IterVisitor::traverseAllPaths(Fusion* fusion) {
-  traverseHelper(fusion, true);
+void IterVisitor::traverseAllPaths(Fusion* fusion, bool from_outputs_only) {
+  traverse_(fusion, from_outputs_only, true);
 }
 
 namespace {
@@ -180,30 +150,29 @@ namespace {
 // expressions.
 class Inputs : public IterVisitor {
  private:
-  std::vector<Val*> inputs_;
+  std::unordered_set<Val*> inputs;
 
   void handle(Val* val) override {
-    if (val->definition() == nullptr) {
-      if (std::find(inputs_.begin(), inputs_.end(), val) == inputs_.end()) {
-        inputs_.push_back(val);
-      }
+    if (val->getOrigin() == nullptr) {
+      inputs.emplace(val);
     }
   }
 
  public:
-  static std::vector<Val*> getInputs(const std::vector<Val*>& of) {
+  static std::unordered_set<Val*> getInputs(const std::vector<Val*>& of) {
     if (of.empty()) {
-      return {};
+      return std::unordered_set<Val*>();
     }
     Inputs inps;
     inps.traverseFrom(of[0]->fusion(), of);
-    return inps.inputs_;
+    return inps.inputs;
   }
 };
 
 } // namespace
 
-std::vector<Val*> IterVisitor::getInputsTo(const std::vector<Val*>& vals) {
+std::unordered_set<Val*> IterVisitor::getInputsTo(
+    const std::vector<Val*>& vals) {
   return Inputs::getInputs(vals);
 }
 
@@ -269,18 +238,6 @@ std::vector<Statement*> BackwardVisitor::next(Val* val) {
   return next_stmts;
 }
 
-void BackwardVisitor::handle(Statement* stmt) {
-  OptOutDispatch::handle(stmt);
-}
-
-void BackwardVisitor::handle(Expr* expr) {
-  OptOutDispatch::handle(expr);
-}
-
-void BackwardVisitor::handle(Val* val) {
-  OptOutDispatch::handle(val);
-}
-
 void BackwardVisitor::traverseFrom(
     Fusion* fusion,
     const std::vector<Val*>& from,
@@ -308,14 +265,11 @@ void BackwardVisitor::traverseFrom(
   // All stmts we've called handle on
   std::unordered_set<Statement*> visited_stmts_;
 
-  if (must_cover_all_expr_outputs_) {
-    for (auto traversal_pair : traversal_exprs_) {
-      for (auto out : traversal_pair.first->outputs()) {
-        TORCH_INTERNAL_ASSERT(
-            vals.find(out) != vals.end(),
-            "Invalid backward traversal found. Some output paths were not provided:",
-            out);
-      }
+  for (auto traversal_pair : traversal_exprs_) {
+    for (auto out : traversal_pair.first->outputs()) {
+      TORCH_INTERNAL_ASSERT(
+          vals.find(out) != vals.end(),
+          "Invalid backward traversal found. Some output paths were not provided.");
     }
   }
 
@@ -370,64 +324,17 @@ namespace {
 // Looks for and returns all values in between dependencies and vals, including
 // them.
 struct Dependencies : public IterVisitor {
- private:
-  //! A given set of dependency Vals
-  const std::unordered_set<Val*> dependencies_;
-  //! Vals that are found between dependencies_ and of. Topologically
-  //! ordered.
-  std::vector<Val*> vals_;
-  //! Exprs that are found between dependencies_ and of. Topologically
-  //! ordered.
-  std::vector<Expr*> exprs_;
-  //! A set version of vals_
-  std::unordered_set<Val*> dependent_vals_;
-  //! A set version of exprs_
-  std::unordered_set<Expr*> dependent_exprs_;
+  std::unordered_set<Val*> dependencies_;
+  std::unordered_set<Val*> vals_;
 
- private:
   std::vector<Statement*> next(Val* v) override {
-    if (dependencies_.find(v) != dependencies_.end()) {
+    if (dependencies_.find(v) != dependencies_.end())
       return std::vector<Statement*>();
-    }
     return IterVisitor::next(v);
   }
 
   void handle(Val* val) override {
-    // val is included if:
-    // 1. it is one of the dependencies, or
-    // 2. its defining expression is included in the dependent expr set
-    if (dependencies_.find(val) != dependencies_.end()) {
-      TORCH_INTERNAL_ASSERT(
-          dependent_vals_.find(val) == dependent_vals_.end(),
-          "Trying to add already added val: ",
-          val);
-      vals_.push_back(val);
-      dependent_vals_.insert(val);
-    } else {
-      auto def = val->definition();
-      if (def != nullptr &&
-          dependent_exprs_.find(def) != dependent_exprs_.end()) {
-        TORCH_INTERNAL_ASSERT(
-            dependent_vals_.find(val) == dependent_vals_.end(),
-            "Trying to add already added val: ",
-            val);
-        vals_.push_back(val);
-        dependent_vals_.insert(val);
-      }
-    }
-  }
-
-  void handle(Expr* expr) override {
-    // Track which expr is depedent on the dependencies_ exprs.
-    if (std::any_of(
-            expr->inputs().begin(), expr->inputs().end(), [&](Val* input_val) {
-              return dependent_vals_.find(input_val) != dependent_vals_.end();
-            })) {
-      if (!dependent_exprs_.count(expr)) {
-        exprs_.push_back(expr);
-        dependent_exprs_.insert(expr);
-      }
-    }
+    vals_.emplace(val);
   }
 
   Dependencies(
@@ -438,27 +345,16 @@ struct Dependencies : public IterVisitor {
   };
 
  public:
-  static std::vector<Val*> getAllVals(
+  static std::unordered_set<Val*> getAllVals(
       const std::unordered_set<Val*>& dependencies,
       const std::vector<Val*>& of) {
     if (of.empty()) {
-      return {};
+      return std::unordered_set<Val*>();
     }
 
     Dependencies deps(dependencies, of);
     return deps.vals_;
   }
-
-  static std::vector<Expr*> getAllExprs(
-      const std::unordered_set<Val*>& dependencies,
-      const std::vector<Val*>& of) {
-    if (of.empty()) {
-      return {};
-    }
-
-    Dependencies deps(dependencies, of);
-    return deps.exprs_;
-  }
 };
 
 // Looks for and returns all output values with dependencies on `of`.
@@ -469,19 +365,18 @@ struct FindOutputs : public IterVisitor {
   void handle(Val* val) override {
     if (of_.find(val) != of_.end()) {
       Statement* out_stmt = stmt_stack.front().back();
-      TORCH_INTERNAL_ASSERT(out_stmt->isVal());
-      auto out_val = out_stmt->as<Val>();
-      if (of_.find(out_val) == of_.end()) {
-        outs_.emplace(out_val);
+      if (out_stmt->isVal()) {
+        auto out_val = out_stmt->as<Val>();
+        if (of_.find(out_val) == of_.end()) {
+          outs_.emplace(out_val);
+        }
       }
     }
   }
 
-  // TODO: Simply traverse through uses from of. Would be a lot faster than
-  // tracing all paths like this.
   FindOutputs(const std::unordered_set<Val*>& _of) : of_(_of) {
     auto fusion = (*of_.begin())->fusion();
-    traverseFrom(fusion, fusion->outputs(), true);
+    traverseFrom(fusion, fusion->outputs(), false);
   };
 
   static std::unordered_set<Val*> getAllOutputsOf(
@@ -495,66 +390,6 @@ struct FindOutputs : public IterVisitor {
   }
 };
 
-// Looks for and returns all values that depends on `of`.
-class DependentVals : public IterVisitor {
- private:
-  // Which nodes to find dependencies of
-  const std::unordered_set<Val*>& of_;
-
-  // Dependencies we have so far
-  std::unordered_set<Val*> outs_;
-
-  // Boundary where we want to stop searching beyond
-  std::unordered_set<Val*> boundary_;
-
-  std::vector<Statement*> next(Val* v) override {
-    if (boundary_.find(v) != boundary_.end())
-      return std::vector<Statement*>();
-    return IterVisitor::next(v);
-  }
-
-  void handle(Val* val) override {
-    if (val->isFusionInput() || val->definition() == nullptr ||
-        of_.count(val) || outs_.count(val)) {
-      return;
-    }
-
-    for (auto v : val->definition()->inputs()) {
-      if (of_.count(v) || outs_.count(v)) {
-        outs_.emplace(val);
-        return;
-      }
-    }
-  }
-
-  // optimization to limit search path
-  void createBoundary() {
-    for (auto v_of : of_) {
-      for (auto v_expr : v_of->uses()) {
-        for (auto v_in : v_expr->inputs()) {
-          boundary_.emplace(v_in);
-        }
-      }
-    }
-  }
-
-  DependentVals(const std::unordered_set<Val*>& _of) : of_(_of) {
-    createBoundary();
-    auto fusion = (*of_.begin())->fusion();
-    traverseFrom(fusion, fusion->outputs(), false);
-  };
-
- public:
-  static std::unordered_set<Val*> getAllDependentVals(
-      const std::unordered_set<Val*>& of) {
-    if (of.empty()) {
-      return std::unordered_set<Val*>();
-    }
-    DependentVals dependencies(of);
-    return dependencies.outs_;
-  }
-};
-
 class DependencyChains : public IterVisitor {
  public:
   std::deque<std::deque<Val*>> dep_chains;
@@ -583,9 +418,9 @@ class DependencyChains : public IterVisitor {
   DependencyChains(Val* _dependency, bool all_chains_ = false)
       : dependencies_({_dependency}) {
     if (all_chains_) {
-      traverseAllPaths(_dependency->fusion());
+      traverseAllPaths(_dependency->fusion(), false);
     } else {
-      traverse(_dependency->fusion());
+      traverse(_dependency->fusion(), false);
     }
   }
 
@@ -598,9 +433,9 @@ class DependencyChains : public IterVisitor {
     }
 
     if (all_chains_) {
-      traverseAllPaths((*dependencies_.begin())->fusion());
+      traverseAllPaths((*dependencies_.begin())->fusion(), false);
     } else {
-      traverse((*dependencies_.begin())->fusion());
+      traverse((*dependencies_.begin())->fusion(), false);
     }
   }
 
@@ -663,18 +498,12 @@ std::deque<std::deque<Val*>> DependencyCheck::getAllUseChains(Val* producer) {
   return DependencyChains::getAllUseChains(producer);
 }
 
-std::vector<Val*> DependencyCheck::getAllValsBetween(
+std::unordered_set<Val*> DependencyCheck::getAllValsBetween(
     const std::unordered_set<Val*>& dependencies,
     const std::vector<Val*>& of) {
   return Dependencies::getAllVals(dependencies, of);
 }
 
-std::vector<Expr*> DependencyCheck::getAllExprsBetween(
-    const std::unordered_set<Val*>& dependencies,
-    const std::vector<Val*>& of) {
-  return Dependencies::getAllExprs(dependencies, of);
-}
-
 std::unordered_set<Val*> DependencyCheck::getAllOutputsOf(
     const std::unordered_set<Val*>& of) {
   if (of.empty()) {
@@ -684,22 +513,13 @@ std::unordered_set<Val*> DependencyCheck::getAllOutputsOf(
   return FindOutputs::getAllOutputsOf(of);
 }
 
-std::unordered_set<Val*> DependencyCheck::getAllDependentVals(
-    const std::unordered_set<Val*>& of) {
-  if (of.empty()) {
-    return std::unordered_set<Val*>();
-  }
-  FusionGuard fg((*of.begin())->fusion());
-  return DependentVals::getAllDependentVals(of);
-}
-
 void ExprSort::handle(Expr* expr) {
   exprs.push_back(expr);
 }
 
-std::vector<Expr*> ExprSort::getExprs(Fusion* fusion) {
+std::vector<Expr*> ExprSort::getExprs(Fusion* fusion, bool from_outputs_only) {
   ExprSort es;
-  es.traverse(fusion);
+  es.traverse(fusion, from_outputs_only);
   return es.exprs;
 }
 
@@ -712,23 +532,14 @@ std::vector<Expr*> ExprSort::getExprs(
 }
 
 void InputsOf::handle(Val* v) {
-  if (v->definition() == nullptr) {
-    if (grabbed_inputs.emplace(v).second) {
-      ordered_inputs.push_back(v);
-    }
-  }
+  if (FusionGuard::getCurFusion()->origin(v) == nullptr)
+    inputs.emplace(v);
 }
 
-std::vector<Val*> InputsOf::output(Fusion* fusion, Val* output_) {
-  return outputs(fusion, {output_});
-}
-
-std::vector<Val*> InputsOf::outputs(
-    Fusion* fusion,
-    const std::vector<Val*>& outputs_) {
+std::unordered_set<Val*> InputsOf::output(Fusion* fusion, Val* output_) {
   InputsOf io;
-  io.traverseFrom(fusion, outputs_, false);
-  return io.ordered_inputs;
+  io.traverseFrom(FusionGuard::getCurFusion(), {output_}, false);
+  return io.inputs;
 }
 
 } // namespace cuda
index 31e5ee1..43590b1 100644 (file)
@@ -3,6 +3,10 @@
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <torch/csrc/jit/codegen/cuda/dispatch.h>
+
+#include <torch/csrc/jit/codegen/cuda/fusion.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
 #include <torch/csrc/jit/codegen/cuda/type.h>
 
 #include <deque>
@@ -14,11 +18,6 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-class Fusion;
-class Statement;
-class Expr;
-class Val;
-
 /*
  * IterVisitor starts from leaf nodes, fusion outputs, or the provided values.
  * It walks the DAG bacwkards from the starting nodes, to roots. Each node in
@@ -49,23 +48,49 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch {
   // These functions will start at outputs and propagate up through the DAG
   // to inputs based on depth first traversal. Next could be called on a node
   // multiple times.
-  virtual std::vector<Statement*> next(Statement* stmt);
-
-  virtual std::vector<Statement*> next(Val* v);
-
-  virtual std::vector<Statement*> next(Expr* expr);
+  virtual std::vector<Statement*> next(Statement* stmt) {
+    // NOLINTNEXTLINE(bugprone-branch-clone)
+    if (stmt->isVal()) {
+      return next(stmt->as<Val>());
+    } else if (stmt->isExpr()) {
+      return next(stmt->as<Expr>());
+    } else {
+      TORCH_INTERNAL_ASSERT(
+          false, "IterVisitor could not detect type in next_dispatch.");
+    }
+  }
+
+  virtual std::vector<Statement*> next(Val* v) {
+    FusionGuard::getCurFusion()->assertInFusion(v, "Cannot traverse val, ");
+    if (FusionGuard::getCurFusion()->origin(v) != nullptr) {
+      return {FusionGuard::getCurFusion()->origin(v)};
+    }
+    return {};
+  }
+
+  virtual std::vector<Statement*> next(Expr* expr) {
+    FusionGuard::getCurFusion()->assertInFusion(expr, "Cannot traverse expr, ");
+    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+    std::vector<Statement*> next_stmts{
+        expr->inputs().begin(), expr->inputs().end()};
+    return next_stmts;
+  }
 
   // This handle functions is called on every Statement* in topological order,
   // starting from outputs to inputs.
-  void handle(Statement* s) override;
-
+  void handle(Statement* s) override {
+    OptOutDispatch::handle(s);
+  }
   // This handle functions is called on every Expr* in topological order,
   // starting from outputs to inputs.
-  void handle(Expr* e) override;
-
+  void handle(Expr* e) override {
+    OptOutDispatch::handle(e);
+  }
   // This handle functions is called on every Val* in topological order,
   // starting from outputs to inputs.
-  void handle(Val* v) override;
+  void handle(Val* v) override {
+    OptOutDispatch::handle(v);
+  }
 
   // The entire stack during traversal. stmt_stack.back().back() is the node
   // that is being called in handle(). stmt_stack.back() contains siblings (not
@@ -80,7 +105,10 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch {
   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
   std::unordered_set<Statement*> termination_stmts;
 
-  void traverseHelper(Fusion* fusion, bool traverse_all_paths = false);
+  void traverse_(
+      Fusion* fusion,
+      bool from_outputs_only = false,
+      bool traverse_all_paths = false);
 
  public:
   // Starts at nodes provided in from, traverses from these nodes to inputs.
@@ -96,16 +124,17 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch {
       const std::vector<Val*>& from,
       bool traverseAllPaths = false);
 
-  // Iterates from terminating outputs registered with the fusion. Terminating
-  // means value is not used to generate any other value used in producing
-  // registered outputs.
-  void traverse(Fusion* fusion);
+  // from_outputs_only = true start from outputs registered with fusion,
+  // from_outputs_only = false start from all leaf nodes. Calls into
+  // traverseFrom.
+  void traverse(Fusion* fusion, bool from_outputs_only = false);
 
-  // Same as traverse put it traverses every edge, meaning it will traverse
-  // values more than once.
-  void traverseAllPaths(Fusion* fusion);
+  // from_outputs_only = true start from outputs registered with fusion,
+  // from_outputs_only = false start from all leaf nodes. Calls into
+  // traverseFrom.
+  void traverseAllPaths(Fusion* fusion, bool from_outputs_only = false);
 
-  static std::vector<Val*> getInputsTo(const std::vector<Val*>& vals);
+  static std::unordered_set<Val*> getInputsTo(const std::vector<Val*>& vals);
 };
 
 /*
@@ -121,27 +150,14 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch {
  *
  * The first step of BackwardVisitor is to make sure we've specified enough
  * outputs to guarentee that we will traverse all outputs of all exprs during
- * the backward traversal. In the case where we don't require visiting all
- * outputs of some exprs, example being the `N` output of welford ops.
- * `must_cover_all_expr_outputs` is added to disable the check, and in
- * this case the visitor pass need be aware
- *  1. Exprs with any output that has a use chain that ends with a final
- * consumer in the `from` list `will be` visited.
- *  2. Vals that doesn't have a use chain that ends with a final
- * consumer in the `from` list `will not be` visited, even though its
- * definition expr might be visited. An example is if the `N` output
- * of an welford op is unused, but other outputs are, the welford op
- * will be visited but the `N` output will not.
- *
+ * the backward traversal.
  */
 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
 class TORCH_CUDA_CU_API BackwardVisitor : public OptOutDispatch {
- protected:
-  // NOLINTNEXTLINE(modernize-use-override)
-  virtual ~BackwardVisitor() = default;
+ public:
+  ~BackwardVisitor() override = default;
 
-  BackwardVisitor(bool must_cover_all_expr_outputs = true)
-      : must_cover_all_expr_outputs_(must_cover_all_expr_outputs) {}
+  BackwardVisitor() = default;
 
   BackwardVisitor(const BackwardVisitor& other) = default;
   BackwardVisitor& operator=(const BackwardVisitor& other) = default;
@@ -161,18 +177,19 @@ class TORCH_CUDA_CU_API BackwardVisitor : public OptOutDispatch {
 
   // This handle functions is called on every Statement* in topological order,
   // starting from outputs to inputs.
-  // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
-  virtual void handle(Statement* stmt) override;
-
+  void handle(Statement* stmt) override {
+    OptOutDispatch::handle(stmt);
+  }
   // This handle functions is called on every Expr* in topological order,
   // starting from outputs to inputs.
-  // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
-  virtual void handle(Expr* expr) override;
-
+  void handle(Expr* expr) override {
+    OptOutDispatch::handle(expr);
+  }
   // This handle functions is called on every Val* in topological order,
   // starting from outputs to inputs.
-  // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
-  virtual void handle(Val* val) override;
+  void handle(Val* val) override {
+    OptOutDispatch::handle(val);
+  }
 
   // All exprs that need to be visited in this traversal. Labeled in topological
   // order (size_t).
@@ -194,8 +211,6 @@ class TORCH_CUDA_CU_API BackwardVisitor : public OptOutDispatch {
       Fusion* fusion,
       const std::vector<Val*>& from,
       bool traverseAllPaths = false);
-
-  bool must_cover_all_expr_outputs_ = true;
 };
 
 class TORCH_CUDA_CU_API DependencyCheck {
@@ -219,37 +234,26 @@ class TORCH_CUDA_CU_API DependencyCheck {
   // Returns an empty deque if there are no uses of dependency found.
   static std::deque<std::deque<Val*>> getAllUseChains(Val* dependency);
 
-  // Grab all values that exist between and including provided
-  // vals. Returned values are topologicaly ordered.
-  static std::vector<Val*> getAllValsBetween(
-      const std::unordered_set<Val*>& dependencies,
-      const std::vector<Val*>& of);
-
-  // Returns all dependent exprs that exist between
-  //  the provided vals
-  static std::vector<Expr*> getAllExprsBetween(
+  // Grab all values that exist between and including provided vals
+  static std::unordered_set<Val*> getAllValsBetween(
       const std::unordered_set<Val*>& dependencies,
       const std::vector<Val*>& of);
 
   // Return registered outputs of the fusion that are a dependency of any val of
   static std::unordered_set<Val*> getAllOutputsOf(
       const std::unordered_set<Val*>& of);
-
-  // Return all Vals that depend on the given Vals
-  static std::unordered_set<Val*> getAllDependentVals(
-      const std::unordered_set<Val*>& of);
 };
 
 // Expr sort will take a fusion and return a topologically sorted list of
 // expressions.
 class ExprSort : public IterVisitor {
- protected:
+ private:
   std::vector<Expr*> exprs;
 
   void handle(Expr* expr) override;
 
  public:
-  static std::vector<Expr*> getExprs(Fusion* fusion);
+  static std::vector<Expr*> getExprs(Fusion* fusion, bool from_outputs_only);
 
   static std::vector<Expr*> getExprs(
       Fusion* fusion,
@@ -258,16 +262,12 @@ class ExprSort : public IterVisitor {
 
 class InputsOf : public IterVisitor {
  private:
-  std::unordered_set<Val*> grabbed_inputs;
-  std::vector<Val*> ordered_inputs;
+  std::unordered_set<Val*> inputs;
 
   void handle(Val* v) final;
 
  public:
-  static std::vector<Val*> output(Fusion* fusion, Val* output_);
-  static std::vector<Val*> outputs(
-      Fusion* fusion,
-      const std::vector<Val*>& outputs_);
+  static std::unordered_set<Val*> output(Fusion* fusion, Val* output_);
 };
 
 } // namespace cuda
index 56be39e..1ad8d06 100644 (file)
@@ -1,8 +1,8 @@
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/kernel.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
+
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
 
 #include <iostream>
 #include <unordered_set>
@@ -11,244 +11,148 @@ namespace torch {
 namespace jit {
 namespace fuser {
 namespace cuda {
-namespace kir {
 
 namespace {
 
 //! Scan all primary expressions in the Kernel IR and build
-//! lists of specialized nodes and other interesting information
-class KernelIrScanner : private kir::IrVisitor {
+//! list of specialized nodes
+//!
+//! \note primary expressions are expressions which are not subexpressions
+//!   in a larger expression (things like ForLoop or IfThenElse are not
+//!   real expressions)
+//!
+class KernelIrScanner : private OptOutDispatch {
  public:
-  explicit KernelIrScanner(const Kernel* kernel) {
-    for (const auto& ir_node : kernel->irNodes()) {
-      ir_node->accept(this);
+  // Use expression count to uniquely identify each expression
+  size_t all_expression_count = 0;
+
+  // Map expression id to war hazard sync
+  std::unordered_map<size_t, kir::Sync*> war_hazard_syncs;
+
+  std::vector<kir::Allocate*> global_allocations;
+  std::vector<kir::Allocate*> dynamic_allocations;
+  std::vector<kir::Allocate*> static_allocations;
+  std::unordered_set<Expr*> primary_expressions;
+
+ public:
+  explicit KernelIrScanner(const std::vector<Expr*>& exprs) {
+    TORCH_INTERNAL_ASSERT(!exprs.empty());
+    for (auto expr : exprs) {
+      handle(expr);
     }
   }
 
-  const auto& summary() const {
-    return summary_;
+ private:
+  void handle(Expr* expr) final {
+    TORCH_CHECK(primary_expressions.insert(expr).second);
+    ++all_expression_count;
+    OptOutDispatch::handle(expr);
   }
 
- private:
-  void visit(const kir::Sync* sync) final {
+  void handle(kir::Sync* sync) final {
     // TODO: Move to a dedicated validation pass
     // which is not on the common execution/compilation path
     if (sync->isWarHazardSync()) {
-      ++summary_.war_hazard_syncs_count;
+      war_hazard_syncs[all_expression_count] = sync;
     }
   }
 
-  void visit(const kir::Allocate* allocate) final {
-    switch (allocate->memoryType()) {
+  void handle(kir::ForLoop* fl) final {
+    for (auto expr : fl->body().exprs()) {
+      handle(expr);
+    }
+  }
+
+  void handle(kir::IfThenElse* ite) final {
+    for (auto expr : ite->thenBody().exprs()) {
+      handle(expr);
+    }
+    for (auto expr : ite->elseBody().exprs()) {
+      handle(expr);
+    }
+  }
+
+  void handle(kir::Allocate* a) final {
+    switch (a->getMemoryType()) {
       case MemoryType::Global:
-        summary_.global_allocations.push_back(allocate);
+        global_allocations.push_back(a);
         break;
       case MemoryType::Shared:
-        if (ExpressionEvaluator::isConst(allocate->size())) {
-          summary_.static_smem_allocations.push_back(allocate);
+        if (a->size()->isConstScalar()) {
+          static_allocations.push_back(a);
         } else {
-          summary_.dynamic_smem_allocations.push_back(allocate);
+          dynamic_allocations.push_back(a);
         }
         break;
       case MemoryType::Local:
-        if (!ExpressionEvaluator::isConst(allocate->size())) {
-          summary_.has_dynamic_local_memory_allocations = true;
-          summary_.dynamic_lmem_allocations.emplace_back(allocate);
-        }
         break;
     }
   }
+};
 
-  void visit(const kir::UnaryOp* unary_op) final {
-    if (unary_op->operation() == UnaryOpType::RandLike) {
-      // This kernel is using random numbers
-      summary_.is_stochastic = true;
-    }
-  }
-
-  void visit(const kir::TensorIndex* tensor_index) final {
-    const auto tv = tensor_index->view();
-    const auto domain = tv->domain();
-
-    // Do we have any reductions?
-    summary_.has_block_reductions =
-        summary_.has_block_reductions || domain->hasBlockReduction();
-
-    // Do we have block broadcasts?
-    summary_.has_block_broadcasts =
-        summary_.has_block_broadcasts || domain->hasBlockBroadcast();
-
-    // Update the largest smem data type
-    if (domain->hasBlockReduction() || domain->hasGridReduction() ||
-        tv->memoryType() == MemoryType::Shared) {
-      const auto data_type = tv->dtype();
-      const size_t type_size = dataTypeSize(data_type);
-      if (type_size > max_smem_type_size_) {
-        max_smem_type_size_ = type_size;
-        summary_.largest_smem_data_type = data_type;
-      }
-    }
-
-    // Update Welford
-    if (tensor_index->definition() != nullptr &&
-        tensor_index->definition()->isA<kir::WelfordOp>()) {
-      summary_.has_welford = true;
-      summary_.has_block_welford =
-          summary_.has_block_welford || domain->hasBlockReduction();
-      summary_.has_grid_welford =
-          summary_.has_grid_welford || domain->hasGridReduction();
-    }
-  }
-
-  void visit(const kir::GridWelford* grid_welford) final {
-    const auto dom = grid_welford->welford_op()
-                         ->out()
-                         ->as<kir::TensorIndex>()
-                         ->view()
-                         ->domain();
-    updateGridReductionInLoop(dom);
-  }
-
-  void visit(const kir::GridReduction* grid_reduction) final {
-    const auto dom = grid_reduction->reduction_op()
-                         ->out()
-                         ->as<kir::TensorIndex>()
-                         ->view()
-                         ->domain();
-    updateGridReductionInLoop(dom);
-  }
+} // namespace
 
- private:
-  size_t max_smem_type_size_ = 0;
-  KernelSummary summary_;
+// TODO(kir): Kernel IR validation
+void Kernel::finalize(
+    std::vector<Expr*> top_level_exprs,
+    ThreadPredicateMap predicate_map) {
+  TORCH_CHECK(top_level_exprs_.empty());
+  TORCH_CHECK(!predicate_map_);
+  top_level_exprs_ = std::move(top_level_exprs);
+  predicate_map_ =
+      std::make_unique<ThreadPredicateMap>(std::move(predicate_map));
+  analyze();
+}
 
- private:
-  void updateGridReductionInLoop(TensorDomain* dom) {
-    ++summary_.number_of_grid_reductions;
+void Kernel::analyze() {
+  FUSER_PERF_SCOPE("Kernel::analyze");
 
-    const auto gpu_lower = GpuLower::current();
-    for (size_t i = 0; i < dom->nDims(); ++i) {
-      const auto id =
-          gpu_lower->caParallelMap().getConcreteMappedID(dom->domain()[i]);
-      summary_.has_grid_reduction_in_loop =
-          summary_.has_grid_reduction_in_loop ||
-          !(id->isThread() || id->extent()->isOneInt());
-    }
-  }
-};
+  const KernelIrScanner ir_scanner(top_level_exprs_);
 
-//! Make sure tensors have valid allocations even when parallelized
-//! loops potentially have larger iteration counts than the number of
-//! threads.
-//!
-//! When an IterDomain of a tensor is parallelized, the IterDomain
-//! may not contribute to the allocation of the tensor. For example,
-//! it is assumed that an allocation of a local-memory tensor does not
-//! need to be accounted for an parallelied IterDomain. This is true
-//! when it is guaranteed that each thread only needs to execute the
-//! loop body once. However, if not, the allocation is invalid as it
-//! only has a space for one value per thread.
-//!
-//! ValidateAllocation checks all tensor allocations and sees if any
-//! tensor may have a parallelized loop whose iteration count may
-//! be larger than the number of threads. If so, an error is thrown if
-//! the tensor is not allocated on thread-shared memories. Note that
-//! when allocated on a shared memory (i.e., MemoryType::Shared or
-//! MemoryType::Global for tensors parallelized with threadIdx, or
-//! MemoryType::Global for tensors parallelized with blockIdx), it is
-//! assumed that allocation is properly extended for the iteration
-//! count.
-class ValidateAllocation : private kir::IrVisitor {
- public:
-  static void validate(const Kernel* kernel) {
-    ValidateAllocation validate_allocation(kernel);
-  }
+  // Cache the list of buffers used within the kernel
+  summary_.war_hazard_syncs = ir_scanner.war_hazard_syncs;
+  summary_.global_allocations = ir_scanner.global_allocations;
+  summary_.dynamic_smem_allocations = ir_scanner.dynamic_allocations;
+  summary_.static_smem_allocations = ir_scanner.static_allocations;
 
- private:
-  explicit ValidateAllocation(const Kernel* kernel) {
-    live_allocations_.emplace_back(std::vector<const Allocate*>());
-    for (const auto& ir_node : kernel->topLevelExprs()) {
-      ir_node->accept(this);
+  // Figure out if the kernel uses random numbers
+  for (auto expr : ir_scanner.primary_expressions) {
+    if (expr->getExprType() == ExprType::KirUnaryOp) {
+      if (expr->as<kir::UnaryOp>()->getUnaryOpType() == UnaryOpType::RandLike) {
+        summary_.is_stochastic = true;
+        break;
+      }
     }
-    live_allocations_.pop_back();
-    TORCH_INTERNAL_ASSERT(live_allocations_.empty());
-  }
-
-  void visit(const kir::Allocate* allocate) final {
-    TORCH_INTERNAL_ASSERT(!live_allocations_.empty());
-    live_allocations_.back().push_back(allocate);
   }
 
-  // for_loop is parallelized and its stop value is not guaranteed to
-  // be <= the number of threads, which breaks an assumption made
-  // during in the allocation lowering if it's thread-parallel and not
-  // allocated on shared or global memories, or if it's block-parallel
-  // ando not allocated on global memory.
-  void validate(const kir::ForLoop* for_loop) {
-    const auto loop_id = for_loop->iter_domain();
-    const auto gpu_lower = GpuLower::current();
-    for (const auto& allocations : live_allocations_) {
-      for (const auto& allocate : allocations) {
-        const auto tv = allocate->buffer()->as<kir::TensorView>();
-        for (const auto& axis : tv->domain()->domain()) {
-          if (!gpu_lower->caParallelMap().areMapped(loop_id, axis)) {
-            continue;
-          }
-          if (isParallelTypeThreadDim(loop_id->parallelType())) {
-            TORCH_INTERNAL_ASSERT(
-                tv->memoryType() == MemoryType::Shared ||
-                tv->memoryType() == MemoryType::Global);
-          } else if (isParallelTypeBlockDim(loop_id->parallelType())) {
-            TORCH_INTERNAL_ASSERT(tv->memoryType() == MemoryType::Global);
+  // Look for reductions and shared memory buffers
+  size_t max_smem_type_size = 0;
+  for (auto expr : ir_scanner.primary_expressions) {
+    for (auto out : expr->outputs()) {
+      if (out->getValType() == ValType::TensorIndex) {
+        const auto tv = out->as<kir::TensorIndex>()->view();
+        const auto domain = tv->domain();
+
+        // Do we have any reductions?
+        summary_.has_block_reductions |= domain->hasBlockReduction();
+        summary_.has_grid_reductions |= domain->hasGridReduction();
+
+        // Do we have block broadcasts?
+        summary_.has_block_broadcasts |= domain->hasBlockBroadcast();
+
+        // Update the largest smem data type
+        if (domain->hasBlockReduction() || domain->hasGridReduction() ||
+            tv->memoryType() == MemoryType::Shared) {
+          const auto data_type = tv->getDataType().value();
+          const size_t type_size = dataTypeSize(data_type);
+          if (type_size > max_smem_type_size) {
+            max_smem_type_size = type_size;
+            summary_.largest_smem_data_type = data_type;
           }
         }
       }
     }
   }
-
-  void visit(const kir::ForLoop* for_loop) final {
-    if (for_loop->stop() != for_loop->iter_domain()->extent() &&
-        isParallelTypeThread(for_loop->iter_domain()->parallelType())) {
-      validate(for_loop);
-    }
-
-    live_allocations_.emplace_back(std::vector<const Allocate*>());
-    for (const auto& expr : for_loop->body().exprs()) {
-      expr->accept(this);
-    }
-    live_allocations_.pop_back();
-  }
-
-  void visit(const kir::IfThenElse* ite) final {
-    for (const auto& expr : ite->thenBody().exprs()) {
-      expr->accept(this);
-    }
-    for (const auto& expr : ite->elseBody().exprs()) {
-      expr->accept(this);
-    }
-  }
-
- private:
-  std::vector<std::vector<const Allocate*>> live_allocations_;
-};
-
-} // namespace
-
-// TODO(kir): Kernel IR validation
-void Kernel::finalize(std::vector<kir::Expr*> top_level_exprs) {
-  TORCH_CHECK(top_level_exprs_.empty());
-  top_level_exprs_ = std::move(top_level_exprs);
-  predicate_map_ = std::make_unique<ThreadPredicateMap>(
-      GpuLower::current()->threadPredMap());
-  ValidateAllocation::validate(this);
-  analyze();
-}
-
-void Kernel::analyze() {
-  FUSER_PERF_SCOPE("Kernel::analyze");
-
-  const KernelIrScanner ir_scanner(this);
-  summary_ = ir_scanner.summary();
 }
 
 void Kernel::print() const {
@@ -256,7 +160,6 @@ void Kernel::print() const {
   ir_printer.printKernel(this);
 }
 
-} // namespace kir
 } // namespace cuda
 } // namespace fuser
 } // namespace jit
index a6522ed..d4493a6 100644 (file)
@@ -13,22 +13,24 @@ namespace torch {
 namespace jit {
 namespace fuser {
 namespace cuda {
-namespace kir {
 
 //! Summary of interesting facts about the kernel
+//!
+//! TODO(kir): const node ptrs
+//!
 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
 struct KernelSummary {
-  //! Count of WAR (write-after-read) hazard barriers
-  int war_hazard_syncs_count = 0;
+  //! List of Write-After-Read (WAR) synchronization barriers
+  std::unordered_map<size_t, kir::Sync*> war_hazard_syncs;
 
   //! List of global buffers
-  std::vector<const kir::Allocate*> global_allocations;
+  std::vector<kir::Allocate*> global_allocations;
 
   //! List of dynamic shared memory buffers
-  std::vector<const kir::Allocate*> dynamic_smem_allocations;
+  std::vector<kir::Allocate*> dynamic_smem_allocations;
 
   //! List of static shared memory buffers
-  std::vector<const kir::Allocate*> static_smem_allocations;
+  std::vector<kir::Allocate*> static_smem_allocations;
 
   //! Indicate the need to generate random numbers
   bool is_stochastic = false;
@@ -36,33 +38,14 @@ struct KernelSummary {
   //! Do we have any block reductions?
   bool has_block_reductions = false;
 
-  //! Number of static grid reductions
-  int number_of_grid_reductions = 0;
-
-  //! Do we have any grid reduction in a loop?
-  bool has_grid_reduction_in_loop = false;
+  //! Do we have any grid reductions?
+  bool has_grid_reductions = false;
 
   //! Do we have any block broadcasts?
   bool has_block_broadcasts = false;
 
-  //! Do we have any welford op?
-  bool has_welford = false;
-
-  //! Do we have any welford op?
-  bool has_block_welford = false;
-
-  //! Do we have any welford op?
-  bool has_grid_welford = false;
-
   //! Largest shared memory buffer base type
   DataType largest_smem_data_type = DataType::Null;
-
-  //! Do we have allocations of dynamic local memory?
-  bool has_dynamic_local_memory_allocations = false;
-
-  //! List of dynamic local memory buffers.
-  //! Only used for debugging.
-  std::vector<const kir::Allocate*> dynamic_lmem_allocations;
 };
 
 //! Container for a lowered Kernel IR
@@ -81,18 +64,18 @@ class TORCH_CUDA_CU_API Kernel final : public NonCopyable {
   //! At this point we have a complete kernel definition and we can
   //! run analysis passes to build a KernelSummary
   //!
-  void finalize(std::vector<kir::Expr*> top_level_exprs);
+  void finalize(
+      std::vector<Expr*> top_level_exprs,
+      ThreadPredicateMap predicate_map);
 
   //! Register input as an input of the kernel
   void addInput(Val* input) {
     inputs_.push_back(input);
-    input_set_.insert(input);
   }
 
   //! Register output as an output of the kernel
   void addOutput(Val* output) {
     outputs_.push_back(output);
-    output_set_.insert(output);
   }
 
   const auto& inputs() const {
@@ -103,22 +86,10 @@ class TORCH_CUDA_CU_API Kernel final : public NonCopyable {
     return outputs_;
   }
 
-  bool isInput(Val* val) const {
-    return input_set_.find(val) != input_set_.end();
-  }
-
-  bool isOutput(Val* val) const {
-    return output_set_.find(val) != output_set_.end();
-  }
-
   const auto& topLevelExprs() const {
     return top_level_exprs_;
   }
 
-  const auto& irNodes() const {
-    return ir_nodes_;
-  }
-
   const KernelSummary& summary() const {
     return summary_;
   }
@@ -132,17 +103,10 @@ class TORCH_CUDA_CU_API Kernel final : public NonCopyable {
   //! \note This is a specialized helper for kir::IrBuilder, not
   //!   intendted for general use
   //!
-  void registerIrNode(kir::Passkey passkey, std::unique_ptr<kir::Node> node) {
-    TORCH_CHECK(passkey.kernel == this);
+  void registerIrNode(std::unique_ptr<Statement> node) {
     ir_nodes_.push_back(std::move(node));
   }
 
-  //! Allocates a new value identifier
-  kir::ValueId newValueId(kir::Passkey passkey) {
-    TORCH_CHECK(passkey.kernel == this);
-    return next_value_id_++;
-  }
-
   //! Debug dump of the Kernel IR
   void print() const;
 
@@ -152,19 +116,17 @@ class TORCH_CUDA_CU_API Kernel final : public NonCopyable {
 
  private:
   // Kernel IR nodes
-  std::vector<std::unique_ptr<kir::Node>> ir_nodes_;
+  std::vector<std::unique_ptr<Statement>> ir_nodes_;
 
-  // Top level statements
-  std::vector<kir::Expr*> top_level_exprs_;
+  // Map from value to its definition expression
+  std::unordered_map<const Val*, Expr*> definitions_;
+
+  // Top level expressions
+  std::vector<Expr*> top_level_exprs_;
 
   // Kernel inputs and outputs
   std::vector<Val*> inputs_;
   std::vector<Val*> outputs_;
-  std::unordered_set<Val*> input_set_;
-  std::unordered_set<Val*> output_set_;
-
-  // Used to allocate unique value IDs
-  kir::ValueId next_value_id_ = 1;
 
   // Summary of interesting kernel data
   KernelSummary summary_;
@@ -174,7 +136,6 @@ class TORCH_CUDA_CU_API Kernel final : public NonCopyable {
   std::unique_ptr<ThreadPredicateMap> predicate_map_;
 };
 
-} // namespace kir
 } // namespace cuda
 } // namespace fuser
 } // namespace jit
index 67938c9..9bfbdf9 100644 (file)
@@ -3,7 +3,7 @@
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_utils.h>
 #include <torch/csrc/jit/codegen/cuda/parser.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
+#include <torch/csrc/jit/codegen/cuda/scheduler.h>
 #include <torch/csrc/jit/runtime/graph_executor.h>
 
 #include <c10/util/irange.h>
@@ -30,7 +30,8 @@ int getCommonDeviceCUDA(const at::ArrayRef<IValue>& inputs) {
     if (index != -1 && index != cur_index) {
       return -1;
     }
-    index = (int)cur_index; // NOLINT
+    // NOLINTNEXTLINE(bugprone-signed-char-misuse)
+    index = cur_index;
   }
   return index;
 }
@@ -90,22 +91,13 @@ void debugPrint(const TensorTypePtr& type) {
 }
 #pragma clang diagnostic pop
 
-at::DimVector graphReductionAxes(
-    const std::shared_ptr<Graph>& graph,
-    bool& simple_reduction) {
+at::DimVector graphReductionAxes(const std::shared_ptr<Graph>& graph) {
   FUSER_PERF_SCOPE("graphReductionAxes");
-  simple_reduction = true;
 
   at::DimVector reduction_axes;
   // TODO: let check that we have only single reduction node in the graph.
-  int reduction_count = 0;
   for (const auto& n : graph->nodes()) {
-    if (isReductionToSizeNode(n)) {
-      // TODO: we don't support permutation with ReductionToSize;
-      simple_reduction = false;
-      reduction_axes.clear();
-      return reduction_axes;
-    } else if (isReductionNode(n)) {
+    if (isReductionNode(n)) {
       // TODO: we should return empty when `keepdim` is True?
       auto dims_list = constant_as<c10::List<int64_t>>(n->input(1));
       TORCH_INTERNAL_ASSERT(
@@ -113,17 +105,12 @@ at::DimVector graphReductionAxes(
       for (const auto dim : dims_list->vec()) {
         reduction_axes.emplace_back(static_cast<int>(dim));
       }
-      ++reduction_count;
       // we should return here, but we don't!
       // We continue the traversal and check for other reduction node. Because
-      // our permutation doesn't really support intermediate reduction, hence we
-      // mark simple_reduction as false;
-      if (reduction_count != 1) {
-        simple_reduction = false;
-        return reduction_axes;
-      }
+      // our permutation doesn't really support intermediate reduction; Continue
+      // traversal would trigger the `TORCH_INTERNAL_ASSERT`, it's not ideal but
+      // at least it's not silent error.
     }
-    // TODO: this doesn't apply any more, clean it up
   }
   return reduction_axes;
 }
@@ -220,54 +207,45 @@ at::DimVector inversePermutation(
   }
 }
 
-void encodeBuffer(size_t value, std::string& buffer) {
-  const char* v = reinterpret_cast<char*>(&value);
-  for (size_t i = 0; i < sizeof(size_t); i++) {
-    buffer.push_back(*(v++));
-  }
-}
-
 } // namespace
 
 InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId(
-    const at::ArrayRef<IValue>& inputs,
-    const SchedulerRuntimeInfo* additional_info) {
+    const at::ArrayRef<IValue>& inputs) {
   IdLookupReturn ret;
-
-  // lock mutex_ because we are touching encoding_
-  std::lock_guard<std::mutex> guard(mutex_);
-  encoding_.clear();
+  std::stringstream encoded_inputs;
   for (const auto& input : inputs) {
     if (input.isTensor()) {
       auto& input_tensor = input.toTensor();
 
+      encoded_inputs << ";";
+      auto sep = "";
       for (auto size : input_tensor.sizes()) {
-        encodeBuffer(size, encoding_);
-        encoding_.push_back(' ');
+        encoded_inputs << sep << size;
+        sep = ",";
       }
-      encoding_.push_back('X');
-      encoding_.push_back(' ');
+      encoded_inputs << "@";
+      sep = "";
       for (auto stride : input_tensor.strides()) {
-        encodeBuffer(stride, encoding_);
-        encoding_.push_back(' ');
+        encoded_inputs << sep << stride;
+        sep = ",";
       }
-      encoding_.push_back('d');
-      encodeBuffer(input_tensor.device().index(), encoding_);
+      encoded_inputs << "@" << input_tensor.device().str();
     } else {
       // encode s for scalar;
-      encoding_.push_back('s');
+      encoded_inputs << ";s";
     }
-    encoding_.push_back(';');
-  }
-  if (additional_info) {
-    encodeBuffer(additional_info->getCommonAlignmentSize(), encoding_);
   }
+  auto& id_iter_pair = encoding_lookup_[encoded_inputs.str()];
 
-  auto& entry = encoding_lookup_[encoding_];
+  // short-cut to leave LRU entry as is;
+  if (id_iter_pair.lru_iter == used_entry_.begin()) {
+    ret.id = id_iter_pair.id;
+    return ret;
+  }
 
-  if (entry.id == 0) {
+  if (id_iter_pair.id == 0) {
     // no entry existed for given input set, set id for given entry
-    entry.id = current_id_++;
+    id_iter_pair.id = current_id_++;
     if (used_entry_.size() == max_cache_size_) {
       // pop least recently used cache;
       const auto& remove_iter = encoding_lookup_.find(used_entry_.back());
@@ -277,434 +255,165 @@ InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId(
       encoding_lookup_.erase(remove_iter);
     }
   } else {
-    // short-cut to leave LRU entry as is
-    if (entry.lru_iter == used_entry_.begin()) {
-      ret.id = entry.id;
-      return ret;
-    }
-
-    used_entry_.erase(entry.lru_iter);
+    used_entry_.erase(id_iter_pair.lru_iter);
   }
 
-  ret.id = entry.id;
-  entry.lru_iter = used_entry_.insert(used_entry_.begin(), encoding_);
+  ret.id = id_iter_pair.id;
+  id_iter_pair.lru_iter =
+      used_entry_.insert(used_entry_.begin(), encoded_inputs.str());
   return ret;
 }
 
-FusionExecutorCache::FusionExecutorCache(std::unique_ptr<Fusion> fusion)
-    : fusion_(std::move(fusion)) {}
+// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+FusionExecutorCache::FusionExecutorCache(std::unique_ptr<Fusion>&& fusion)
+    : fusion_(std::move(fusion)) {
+  FUSER_PERF_SCOPE("FusionExecutorCache::FusionExecutorCache");
+  // avoid putting `has_reduction_` in the initializer list
+  has_reduction_ = fusion_->hasReduction();
+}
 
 std::vector<at::Tensor> FusionExecutorCache::runFusionWithInputs(
     const at::ArrayRef<IValue>& inputs) {
-  FUSER_PERF_SCOPE("FusionExecutorCache::runFusionWithInputs");
+  FUSER_PERF_SCOPE("runFusionWithInputs");
 
-  SchedulerRuntimeInfo runtime_info(fusion(), inputs);
-
-  auto id_lookup_ret = inputs_id_lookup_.lookupId(inputs, &runtime_info);
+  // get unique id `unique_id` for given input set `inputs`;
+  auto id_lookup_ret = inputs_id_lookup_.lookupId(inputs);
   if (id_lookup_ret.eviction) {
     evictCache(id_lookup_ret.evict_id);
   }
 
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
   const size_t unique_id = id_lookup_ret.id;
-  auto kernel_runtime = getKernelRuntimeFor(inputs, unique_id);
-  most_recent_runtime_ = kernel_runtime;
-  return kernel_runtime->runWithInput(inputs, unique_id);
-}
-
-void FusionExecutorCache::evictCache(size_t cache_id) {
-  auto it = id_to_kernel_runtime_.find(cache_id);
-  TORCH_INTERNAL_ASSERT(it != id_to_kernel_runtime_.end());
-  it->second->evictCache(cache_id);
-  id_to_kernel_runtime_.erase(it);
-}
-
-FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor(
-    const at::ArrayRef<IValue>& inputs,
-    size_t unique_id) {
-  // Check for id hit case
-  auto id_it = id_to_kernel_runtime_.find(unique_id);
-  if (id_it != id_to_kernel_runtime_.end()) {
-    return id_it->second;
-  }
-
-  // Access kernels associated with the common device id
-  auto dev_id = getCommonDeviceCUDA(inputs);
-  TORCH_INTERNAL_ASSERT(dev_id >= 0);
-  auto& kernel_runtimes = kernel_runtimes_[dev_id];
-
-  // Check for re-use hit case
-  //  a kernel runtime is re-usable if all the compiled
-  //  kernels have the same heuristic parameters
-  std::unique_ptr<FusionHeuristics> new_heuristics;
-
-  auto reuse_it = std::find_if(
-      kernel_runtimes.begin(),
-      kernel_runtimes.end(),
-      [&inputs, &new_heuristics](auto& kernel_runtime) {
-        auto maybe_heuristics = kernel_runtime->getMaybeHeuristicsFor(inputs);
-        if (!maybe_heuristics.has_value()) {
-          return false;
-        }
-        new_heuristics = std::move(maybe_heuristics.value());
-        return true;
-      });
-
-  FusionKernelRuntime* kernel_runtime;
-  if (reuse_it != kernel_runtimes.end()) {
-    kernel_runtime = reuse_it->get();
-    kernel_runtime->updateHeuristicsLaunchParams(new_heuristics.get());
-  } else {
-    // graph miss, need to re-build an optimized graph for this case
-    kernel_runtimes.emplace_back(
-        std::make_unique<FusionKernelRuntime>(fusion_.get(), inputs));
-    kernel_runtime = kernel_runtimes.back().get();
-    if (profiling_) {
-      kernel_runtime->profile(true);
-    }
-  }
-
-  id_to_kernel_runtime_[unique_id] = kernel_runtime;
-  return kernel_runtime;
-}
-
-FusionKernelRuntime::FusionKernelRuntime(
-    Fusion* fusion,
-    const at::ArrayRef<IValue>& inputs) {
-  FUSER_PERF_SCOPE("FusionKernelRuntime::FusionKernelRuntime");
-
-  // Make a copy of fusion and do segmentation and translation
-  //  on this copy
-  auto fusion_copy = std::make_unique<Fusion>(*fusion);
-
-  // Run segmentation on the copied fusion
-  SchedulerRuntimeInfo runtime_info(fusion_copy.get(), inputs, true);
-
-  //! Try to schedule the complete fusion
-  const auto maybe_complete_fusion_heuristic =
-      SchedulerEntry::proposeHeuristics(fusion_copy.get(), runtime_info);
-
-  //! Decide if this fusion is segmented or not
-  const bool segmented = !maybe_complete_fusion_heuristic.has_value();
-
-  if (segmented) {
-    // Take ownership and segment transformed fusion
-    segmented_fusion_ =
-        SegmentCandidateFinder::segment(std::move(fusion_copy), inputs);
-    heuristics_ = segmented_fusion_->makeInitialHeuristics(inputs);
-    executors_ =
-        std::vector<FusionExecutor>(segmented_fusion_->groups().size());
-    if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) {
-      segmented_fusion_->print();
-    }
-  } else {
-    auto complete_fusion_heuristic = maybe_complete_fusion_heuristic.value();
-
-    // Translate welfords if apply
-    if (fusion_copy->hasWelford()) {
-      bool translated = SegmentCandidateFinder::TranslateWelfordInFusion(
-          fusion_copy.get(), inputs);
-      if (translated) {
-        complete_fusion_heuristic = ScheduleHeuristic::Normalization;
-      }
-    }
-    // Take ownership of the transformed fusion
-    single_kernel_fusion_ = std::move(fusion_copy);
-
-    single_kernel_fusion_data_cache_ = std::make_unique<HeuristicSummary>(
-        single_kernel_fusion_.get(), complete_fusion_heuristic, runtime_info);
-
-    heuristics_ = std::make_unique<FusionHeuristics>(
-        complete_fusion_heuristic,
-        runtime_info,
-        single_kernel_fusion_data_cache_.get());
-
-    executors_ = std::vector<FusionExecutor>(1);
-    // In the case that the fusion isn't segmented but user
-    //  wants segmented fusion in the debug print. Will
-    //  print math of the composite fusion as placeholder
-    if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) {
-      single_kernel_fusion_->printMath();
-    }
-  }
-
-  is_segmented_ = segmented;
-}
-
-std::vector<at::Tensor> FusionKernelRuntime::runKernelWithInput(
-    const at::ArrayRef<IValue>& inputs,
-    size_t input_id,
-    SegmentedGroup* sg) {
-  FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput");
-  // This function will be called once on un-segmented fusion,
-  //  for segmented fusion, this function will be called on each segment
-  //  In the case of segmented fusion, segmented group needs to be given so
-  //   a kernel is compiled and run for a segmented group
-  //  In the case of complete fusion, sg = nullptr, and the original fusion
-  //   is complied and run
-  auto group_id = sg ? sg->groupId() : 0;
   const int device_index = getCommonDeviceCUDA(inputs);
   TORCH_CHECK(device_index >= 0, "device is not coherent for fusion inputs");
 
   LaunchParams launch_params;
+  if (code_to_fe_lookup_.count(unique_id) == 0) {
+    // enter when we get a new input set. We need to search for compatible
+    // entries in cached `FusionExecutor` or compile new one as needed.
+
+    // caching strategy is different for pw-fusion and reduction-fusion.
+    if (has_reduction_) {
+      // Grab the fusion to analyze for heuristics
+      FusionGuard fg(fusion_.get());
+
+      TensorView* reduction_tv = nullptr;
+      // Use dependency check to find the reduction tv as it returns used values
+      // instead of exprs.
+
+      // The call is relatively heavy weight, consider caching
+      auto used_vals = DependencyCheck::getAllValsBetween(
+          {fusion_->inputs().begin(), fusion_->inputs().end()},
+          fusion_->outputs());
+
+      // Find the reduction tensor view, make sure there's only one
+      for (auto val : used_vals) {
+        if (val->getValType().value() == ValType::TensorView) {
+          auto tv = val->as<TensorView>();
+          if (tv->hasReduction()) {
+            TORCH_INTERNAL_ASSERT(
+                reduction_tv == nullptr,
+                "Already found a reduction tensorview, cannot handle fusion of multiple reductions.");
+            reduction_tv = tv;
+          }
+        }
+      }
 
-  auto scheduler_entry = schedulers()[group_id].get();
-
-  // Check that the heuristics are matched, in the case of segmented fusion
-  TORCH_INTERNAL_ASSERT(!sg || scheduler_entry->heuristc() == sg->heuristic());
-
-  if (!executors_[group_id].compiled()) {
-    FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput::Compile");
-    std::unique_ptr<Fusion> fusion_to_run;
-    if (sg) {
-      // Running a segment group as a single kernel,
-      //  make a fusion to run from segmented fusion
-      fusion_to_run = segmented_fusion_->makeFusion(sg);
-    } else {
-      // Without a segmented group defaults to compiling the
-      //  complete fusion
-      fusion_to_run = std::make_unique<Fusion>(*single_kernel_fusion_);
-    }
-    CompileOptions options;
-    options.device = c10::Device(DeviceType::CUDA, device_index);
-    options.index_mode = scheduler_entry->indexMode();
-    FusionGuard fg(fusion_to_run.get());
-    scheduler_entry->schedule(fusion_to_run.get());
-    // Load launch params for reduction and normalization kernels
-    if (scheduler_entry->hasReductionParam()) {
-      launch_params = scheduler_entry->reductionParams().lparams;
-    } else {
-      launch_params = scheduler_entry->pointwiseParams().lparams;
-    }
-    executors_[group_id].compileFusion(
-        fusion_to_run.get(), options, inputs, launch_params);
-  } else {
-    FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput::FetchFromCache");
-    // Load launch params for reduction and normalization kernels
-    if (scheduler_entry->hasReductionParam()) {
-      launch_params = scheduler_entry->reductionParams().lparams;
-    } else {
-      launch_params = scheduler_entry->pointwiseParams().lparams;
-    }
-  }
-
-  if (profiling_) {
-    FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput::profiling_");
-    most_recent_executor_log_.fusion_executor = &executors_[group_id];
-    most_recent_executor_log_.launch_constraints = launch_params;
-    if (scheduler_entry->hasReductionParam()) {
-      most_recent_executor_log_.reduction_params =
-          scheduler_entry->reductionParams();
-    } else {
-      most_recent_executor_log_.pointwise_params =
-          scheduler_entry->pointwiseParams();
-    }
-  }
-
-  return executors_[group_id].runFusion(inputs, launch_params, input_id);
-}
+      TORCH_INTERNAL_ASSERT(
+          reduction_tv != nullptr,
+          "Could not find the reduction tensor view in the fusion.");
 
-std::vector<at::Tensor> FusionKernelRuntime::runMultiKernelWithInput(
-    const at::ArrayRef<IValue>& inputs,
-    size_t input_id) {
-  FUSER_PERF_SCOPE("FusionKernelRuntime::runMultiKernelWithInput");
+      // Generate the reduction parameters
+      auto reduction_params =
+          getReductionHeuristics(fusion_.get(), inputs, reduction_tv);
 
-  TORCH_INTERNAL_ASSERT(
-      inputs.size() == segmented_fusion_->inputs().size(),
-      "Inputs were not set up correctly, recieved ",
-      inputs.size(),
-      " inputs but expecting ",
-      segmented_fusion_->inputs().size());
-
-  // Map to keep track of currently available tensors
-  std::unordered_map<Val*, IValue> tensor_map;
-
-  // Bind input in the tensor_map
-  for (size_t i = 0; i < inputs.size(); i++) {
-    tensor_map.emplace(segmented_fusion_->inputs()[i], inputs[i]);
-
-    // Bind tensorview inputs values in case some segmented group
-    //  needs it down the road.
-    // TODO: we probably have done this already up to this point
-    //      should consider caching the expression evaluators, both
-    //      more convenient and safer than replication
-    if (inputs[i].isTensor()) {
-      auto aten_tensor = inputs[i].toTensor();
       TORCH_INTERNAL_ASSERT(
-          segmented_fusion_->inputs()[i]->getValType() == ValType::TensorView);
-      auto input_tv = segmented_fusion_->inputs()[i]->as<TensorView>();
-      auto root_dom = TensorDomain::noReductions(input_tv->getRootDomain());
-      for (size_t dim = 0; dim < root_dom.size(); dim++) {
-        const auto extent = root_dom[dim]->extent();
-        const auto value = aten_tensor.sizes()[dim];
-        tensor_map.emplace(extent, value);
-      }
-    }
-  }
+          reduction_params.has_value(),
+          "Error getting reduction heuristics for scheduling.");
 
-  // Keep track of groups that has run
-  std::vector<bool> group_ran(segmented_fusion_->groups().size(), false);
+      launch_params = reduction_params.value().lparams;
 
-  while (!std::all_of(
-      group_ran.begin(), group_ran.end(), [](bool b) { return b; })) {
-    bool one_ran = false;
+      auto fusion_executor =
+          &red_fusion_executor_cache_[device_index][reduction_params.value()];
 
-    // Find the first segment with all inputs available to run
-    for (size_t group_i = 0; group_i < segmented_fusion_->groups().size();
-         group_i++) {
-      auto& group = segmented_fusion_->groups()[group_i];
-      if (group_ran[group_i]) {
-        continue;
-      }
-      const auto& group_inputs = group->inputs();
-      bool ready_to_run = std::all_of(
-          group_inputs.begin(), group_inputs.end(), [&tensor_map](Val* val) {
-            return tensor_map.find(val) != tensor_map.end();
-          });
-
-      if (ready_to_run) {
-        std::vector<IValue> group_runtime_inputs;
-        group_runtime_inputs.reserve(group_inputs.size());
-
-        // Prepare input vector
-        for (auto input : group_inputs) {
-          group_runtime_inputs.push_back(tensor_map.at(input));
-        }
+      if (!fusion_executor->compiled()) {
+        // HEURISTIC NOT COMPILED, COMPILE A KERNEL
+        Fusion fusion = *fusion_;
 
-        // Run graph segment
-        auto group_runtime_outputs =
-            runKernelWithInput(group_runtime_inputs, input_id, group);
+        FusionGuard fg(&fusion);
 
-        const auto& group_outputs = group->outputs();
+        // Heavy weight call
+        auto used_vals = DependencyCheck::getAllValsBetween(
+            {fusion.inputs().begin(), fusion.inputs().end()}, fusion.outputs());
 
-        // Insert graph segment output to tensor map
-        for (size_t group_out_i = 0; group_out_i < group_outputs.size();
-             group_out_i++) {
-          tensor_map.emplace(
-              group_outputs[group_out_i], group_runtime_outputs[group_out_i]);
-        }
-        group_ran[group_i] = true;
-        one_ran = true;
-      }
-    }
-    TORCH_INTERNAL_ASSERT(
-        one_ran,
-        "Couldn't run all groups, something must have gone wrong in segmentation.");
-  }
+        TensorView* reduction_tv = nullptr;
 
-  // Produce final global output
-  std::vector<IValue> fusion_outputs;
-  for (auto output : segmented_fusion_->outputs()) {
-    const auto iter = tensor_map.find(output);
-    if (iter != tensor_map.end()) {
-      fusion_outputs.push_back(iter->second);
-    } else {
-      // This is the check for an empty tensor;
-      TORCH_INTERNAL_ASSERT(
-          output->as<TensorView>()->nDims() == 0 &&
-              output->getDataType().has_value() &&
-              output->getDataType().value() == DataType::Float,
-          "Non empty tensor cannot be found at tensor_map in ",
-          __FUNCTION__);
-      fusion_outputs.emplace_back(at::Tensor());
-    }
-  }
+        for (auto val : used_vals) {
+          if (val->getValType().value() == ValType::TensorView) {
+            auto tv = val->as<TensorView>();
+            if (tv->hasReduction()) {
+              TORCH_INTERNAL_ASSERT(
+                  reduction_tv == nullptr,
+                  "Already found a reduction tensorview, cannot handle fusion of multiple reductions.");
+              reduction_tv = tv;
+            }
+          }
+        }
 
-  std::vector<at::Tensor> fusion_output_tensors;
-  std::transform(
-      fusion_outputs.begin(),
-      fusion_outputs.end(),
-      std::back_inserter(fusion_output_tensors),
-      [](IValue ival) {
         TORCH_INTERNAL_ASSERT(
-            ival.isTensor(), "Cannot output non-tensor objects from a fusion.");
-        return ival.toTensor();
-      });
-
-  return fusion_output_tensors;
-}
-
-const std::vector<FusionKernelRuntime::SchedulerEntryPtr>& FusionKernelRuntime::
-    schedulers() {
-  return heuristics_->heuristicsList();
-}
+            reduction_tv != nullptr,
+            "Could not find the reduction tensor view in the fusion.");
+
+        // Heavy weight call
+        auto outputsOfReduction =
+            DependencyCheck::getAllOutputsOf({reduction_tv});
+
+        auto tv_entries =
+            ir_utils::filterByType<TensorView>(outputsOfReduction);
+
+        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+        std::vector<TensorView*> tvOutputsOfReduction(
+            tv_entries.begin(), tv_entries.end());
+
+        scheduleReduction(
+            &fusion,
+            reduction_params.value(),
+            reduction_tv,
+            tvOutputsOfReduction);
+
+        // This means we have not found a previously generated kernel that's
+        // compatible with the new reduction params. We need to finish codegen.
+        CompileOptions options;
+        options.device = c10::Device(DeviceType::CUDA, device_index);
+        fusion_executor->compileFusion(&fusion, options);
+      }
+      // record new short cut to `FusionExecutor`
+      code_to_fe_lookup_[unique_id] = fusion_executor;
 
-void FusionKernelRuntime::updateHeuristicsLaunchParams(
-    FusionHeuristics* update_heuristics) {
-  FUSER_PERF_SCOPE("FusionKernelRuntime::updateHeuristicsLaunchParams");
-  auto scheduler_list_length = heuristics_->heuristicsList().size();
-  TORCH_INTERNAL_ASSERT(
-      update_heuristics->heuristicsList().size() == scheduler_list_length);
-  for (size_t i = 0; i < scheduler_list_length; i++) {
-    auto& schedulerPtr = heuristics_->heuristicsList()[i];
-    if (schedulerPtr->hasReductionParam()) {
-      schedulerPtr->updateLaunchConstraint(
-          update_heuristics->heuristicsList()[i]->reductionParams().lparams);
     } else {
-      schedulerPtr->updateLaunchConstraint(
-          update_heuristics->heuristicsList()[i]->pointwiseParams().lparams);
-    }
-  }
-}
-
-c10::optional<FusionKernelRuntime::HeuristicsPtr> FusionKernelRuntime::
-    getMaybeHeuristicsFor(const at::ArrayRef<IValue>& inputs) {
-  FUSER_PERF_SCOPE("FusionKernelRuntime::getMaybeHeuristicsFor");
-  auto complete_fusion = is_segmented_ ? segmented_fusion_->completeFusion()
-                                       : single_kernel_fusion_.get();
-  SchedulerRuntimeInfo runtime_info(complete_fusion, inputs, true);
-
-  c10::optional<FusionKernelRuntime::HeuristicsPtr> ret;
-  // Segmented case, need to iterate over all segmented groups
-  if (is_segmented_) {
-    ret = std::make_unique<FusionHeuristics>();
-    size_t total_groups = segmented_fusion_->groups().size();
-    for (size_t group_index = 0; group_index < total_groups; group_index++) {
-      auto group = segmented_fusion_->groups()[group_index];
-
-      auto maybe_scheduler_entry = group->getMaybeSchedulerEntry(runtime_info);
-      if (!maybe_scheduler_entry.has_value()) {
-        return c10::nullopt;
+      // Handle pointwise operations
+      if (pw_fusion_executor_cache_.count(device_index) == 0) {
+        pw_fusion_executor_cache_[device_index] =
+            std::make_unique<FusionExecutor>();
+        CompileOptions options;
+        options.device = c10::Device(DeviceType::CUDA, device_index);
+        // no need to copy fusion_, as we are not generating more than 1 kernel
+        // for PW.
+        scheduleFusion(fusion_.get(), inputs);
+        pw_fusion_executor_cache_[device_index]->compileFusion(
+            fusion_.get(), options);
       }
-      auto scheduler_entry = std::move(maybe_scheduler_entry.value());
-      if (!scheduler_entry->sameAs(
-              heuristics_->heuristicsList()[group_index].get())) {
-        return c10::nullopt;
-      }
-      ret.value()->emplaceBack(std::move(scheduler_entry));
+      // record new short cut to `FusionExecutor`
+      code_to_fe_lookup_[unique_id] =
+          pw_fusion_executor_cache_[device_index].get();
     }
-
-    return ret;
   }
 
-  // Un-segmented case, just check the complete fusion
-  auto& complete_fusion_scheduler = schedulers()[0];
-  auto complete_fusion_heuristic = complete_fusion_scheduler->heuristc();
-  if (!SchedulerEntry::canSchedule(
-          complete_fusion_heuristic,
-          complete_fusion,
-          runtime_info,
-          single_kernel_fusion_data_cache_.get())) {
-    return c10::nullopt;
-  }
-
-  ret = std::make_unique<FusionHeuristics>(
-      complete_fusion_heuristic,
-      runtime_info,
-      single_kernel_fusion_data_cache_.get());
-  if (!complete_fusion_scheduler->sameAs(
-          ret.value()->heuristicsList()[0].get())) {
-    return c10::nullopt;
-  }
-
-  return ret;
+  return code_to_fe_lookup_[unique_id]->runFusion(
+      inputs, launch_params, unique_id);
 }
 
 bool GraphCache::requiresPermutation() {
-  if (!support_permutation_) {
-    return false;
-  }
-
   const size_t input_rank = input_permutation_.size();
   for (const auto i : c10::irange(input_rank)) {
     if (input_permutation_[i] != (long)i) {
@@ -788,6 +497,7 @@ void GraphCache::createFusion(const std::shared_ptr<Graph>& graph) {
           permuted_vec_optional_stride,
           type->requires_grad());
     }; // closing lambda
+
     for (auto input : graph->inputs()) {
       if (auto input_type = input->type()->cast<TensorType>()) {
         input->setType(type_permute_fn(input_type));
@@ -841,34 +551,30 @@ GraphCache::GraphCache(const std::shared_ptr<Graph>& graph) {
   // 2. adjust reduction axes for the permutation;
   //    permute changes the semantics of axes, we need to update the reduction
   //    axes in the graph in order to match the behavior;
-  reduction_axes_ = graphReductionAxes(graph, support_permutation_);
-
-  // TODO: reduction with permutation is tricky now as we might support complex
-  // topology in graph with segmented fusion.
-  if (support_permutation_) {
-    // run over inputs to extract common types;
-    TensorTypePtr acc_type = TensorType::get();
-    for (const auto& input : graph->inputs()) {
-      // only check tensor types;
-      if (auto input_type = input->type()->cast<TensorType>()) {
-        if (acc_type->dim().has_value()) {
-          // TODO: I think merge cannot handle broadcast - Go verify it later;
-          // TODO: Since we are only handling permutation here, we should just
-          //       merge the stride_index_;
-          acc_type = acc_type->merge(*input_type);
-        } else {
-          acc_type = input_type;
-        }
+  reduction_axes_ = graphReductionAxes(graph);
+
+  // run over inputs to extract common types;
+  TensorTypePtr acc_type = TensorType::get();
+  for (const auto& input : graph->inputs()) {
+    // only check tensor types;
+    if (auto input_type = input->type()->cast<TensorType>()) {
+      if (acc_type->dim().has_value()) {
+        // TODO: I think merge cannot handle broadcast - Go verify it later;
+        // TODO: Since we are only handling permutation here, we should just
+        //       merge the stride_index_;
+        acc_type = acc_type->merge(*input_type);
+      } else {
+        acc_type = input_type;
       }
     }
-    extractPermutation(acc_type);
   }
+  extractPermutation(acc_type);
   createFusion(graph);
 }
 
 std::vector<at::Tensor> GraphCache::runGraphWithInputs(
     const at::ArrayRef<IValue>& inputs) {
-  FUSER_PERF_SCOPE("GraphCache::runGraphWithInputs");
+  FUSER_PERF_SCOPE("runGraphWithInputs");
 
   // GraphCache need to permute inputs/outputs to accommodate dimension
   // coalescing
index a535098..1a24efe 100644 (file)
@@ -2,14 +2,11 @@
 
 #include <torch/csrc/jit/codegen/cuda/executor.h>
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
+#include <torch/csrc/jit/codegen/cuda/scheduler.h>
 
 #include <c10/util/ArrayRef.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
-#include <mutex>
 #include <type_traits>
 #include <unordered_map>
 
@@ -18,154 +15,6 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-class SegmentedGroup;
-class FusionHeuristics;
-class SchedulerRuntimeInfo;
-
-// Utilities for benchmarking and profiling
-struct ExecutorLog {
-  c10::optional<ReductionParams> reduction_params = c10::nullopt;
-  c10::optional<PointwiseParams> pointwise_params = c10::nullopt;
-  c10::optional<LaunchParams> launch_constraints = c10::nullopt;
-  FusionExecutor* fusion_executor = nullptr;
-};
-
-//! FusionKernelRuntime is the unified interface from fusion graphs into
-//!  caching, compilation into kernels, and kernel launches.
-//!
-//! Each instance is also a cache entry tracked by FusionKernelRuntimeCache.
-//!
-//! Two types of instance can be created, one for complete/single-kernel fusion
-//!  and one for segmented/multi-kernel fusion.
-//! Conceptually this is a generalization of FusionExecutor that supports both
-//!  single-kernel and multi-kernel caching/compiling/launching
-class TORCH_CUDA_CU_API FusionKernelRuntime {
- public:
-  explicit FusionKernelRuntime(
-      Fusion* fusion,
-      const at::ArrayRef<IValue>& inputs);
-
-  //! Type notations within FusionKernelRuntime Context
-  using HashType = size_t;
-  using SchedulerEntryPtr = std::unique_ptr<SchedulerEntry>;
-
-  //! Evicts internally cached parameters based on input sizes.
-  //!  An interface used by runtime caches.
-  void evictCache(size_t input_id) {
-    for (auto& fe : executors_) {
-      fe.evictCache(input_id);
-    }
-  }
-
-  //! Unified interface to run the managed kernels with given input
-  std::vector<at::Tensor> runWithInput(
-      const at::ArrayRef<IValue>& inputs,
-      size_t input_id) {
-    if (is_segmented_) {
-      return runMultiKernelWithInput(inputs, input_id);
-    } else {
-      return runKernelWithInput(inputs, input_id);
-    }
-  }
-
-  //! Turn On/Off profiling
-  void profile(bool to_profile = true) {
-    profiling_ = to_profile;
-  }
-
-  //! Returns if this runtime is segmented
-  bool isSegmented() {
-    return is_segmented_;
-  }
-
-  //! Returns the fusion segments if applicable
-  SegmentedFusion* fusionSegments() {
-    TORCH_INTERNAL_ASSERT(is_segmented_);
-    return segmented_fusion_.get();
-  }
-
-  //! Returns the single kernel fusion if applicable
-  Fusion* singleKernelFusion() {
-    TORCH_INTERNAL_ASSERT(!is_segmented_);
-    return single_kernel_fusion_.get();
-  }
-
-  //! Returns the list of heuristics in this runtime
-  FusionHeuristics* schedulerHeuristics() {
-    return heuristics_.get();
-  }
-
-  //! Return the most recently used executor, corresponding to the
-  //!  most recent kernel launch.
-  //! TODO: have a interface for grabbing all recent logs. Need to put a buffer
-  //! space for recent logs
-  ExecutorLog getMostRecentExecutorLog() {
-    TORCH_INTERNAL_ASSERT(
-        profiling_, "Executor log is only produced in profiling mode");
-    return most_recent_executor_log_;
-  }
-
-  // Try to compute heuristics based on the SegmentedFusion managed
-  //  in this kernel runtime, and will return a nullopt if either
-  //  any segment cannot be scheduled or the parameters don't match
-  using HeuristicsPtr = std::unique_ptr<FusionHeuristics>;
-  c10::optional<HeuristicsPtr> getMaybeHeuristicsFor(
-      const at::ArrayRef<IValue>& inputs);
-
-  //! Copy the launch params given in the parameter heuristics to prepare
-  //!  for kernel launch for a new input dimension but same heuristics
-  void updateHeuristicsLaunchParams(FusionHeuristics* update_heuristics);
-
- private:
-  //! Interface to run a single kernel, either one kernel for single-kernel
-  //! fusions,
-  //!  or a kernel for a segmentedGrouup in a segmented fusion. Returns the
-  //!  kernel outputs.
-  std::vector<at::Tensor> runKernelWithInput(
-      const at::ArrayRef<IValue>& inputs,
-      size_t input_id,
-      SegmentedGroup* sg = nullptr);
-
-  //! Interface to run a the whole graph in a segmented fusion and return the
-  //! complete
-  //!  fusion outputs.
-  std::vector<at::Tensor> runMultiKernelWithInput(
-      const at::ArrayRef<IValue>& inputs,
-      size_t input_id);
-
-  //! Access the list of schedulers maintained in this runtime instance
-  const std::vector<SchedulerEntryPtr>& schedulers();
-
- private:
-  //! Entries indexed by groupID:
-  //! Executors holding compiled kernels
-  std::vector<FusionExecutor> executors_;
-
-  //! Heuristics object holding scheduler entries for all segments
-  std::unique_ptr<FusionHeuristics> heuristics_;
-
-  // Checks if this runtime instance is for a single-kernel fusion (false) or a
-  //  segmented fusion (true).
-  bool is_segmented_ = true;
-
-  //! Multi-Kernel fusion segment when applies
-  std::unique_ptr<SegmentedFusion> segmented_fusion_ = nullptr;
-
-  //! Single-Kernel fusion when applies
-  //!  TODO: unify the segmented and un-segmented code-path
-  std::unique_ptr<Fusion> single_kernel_fusion_ = nullptr;
-
-  //! Graph traversal datacache for the single kernel fusion
-  //!  TODO: unify the segmented and un-segmented code-path
-  std::unique_ptr<HeuristicSummary> single_kernel_fusion_data_cache_ = nullptr;
-
-  // States for profiling support
-  bool profiling_ = false;
-
-  // The heuristics and executor for most recent kernel launch
-  ExecutorLog most_recent_executor_log_;
-};
-
 //! Encoding an input set to unique id, which is used to short-cut cache entry
 //! selection in our nested cache implementation to cut off overhead.
 //!
@@ -177,11 +26,11 @@ class TORCH_CUDA_CU_API FusionKernelRuntime {
 //! \note the uniqueness of the ide generated for a given input set is only
 //!   local to the instance of `InputsIdLookup`.
 //!
-class TORCH_CUDA_CU_API InputsIdLookup : public NonCopyable {
+class TORCH_CUDA_CU_API InputsIdLookup {
  public:
   //! constructor where maximum cache size is fixed during init
-  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
-  explicit InputsIdLookup(size_t max_cache_size = 100)
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+  explicit InputsIdLookup(size_t max_cache_size = 10)
       : max_cache_size_(max_cache_size){};
 
   //! struct to hold return value for lookupId.
@@ -196,9 +45,7 @@ class TORCH_CUDA_CU_API InputsIdLookup : public NonCopyable {
   //! within the lookup cache. This is needed because lookup shortcut is also
   //! cached in nested `GraphCache`, `FusionExecutorCache` and `FusionExecutor`.
   //! see [ Note -- 2 level cache implementation ]
-  IdLookupReturn lookupId(
-      const at::ArrayRef<IValue>& inputs,
-      const SchedulerRuntimeInfo* additional_info = nullptr);
+  IdLookupReturn lookupId(const at::ArrayRef<IValue>& inputs);
 
   //! debugging API that returns the size of lookup table
   size_t size() const {
@@ -206,17 +53,10 @@ class TORCH_CUDA_CU_API InputsIdLookup : public NonCopyable {
   }
 
  private:
-  // string to store encoded input meta information. Reuse the buffer instead of
-  // stringtream gives few us perf gain.
-  std::string encoding_; // Note: shared state, guarded by mutex_
-
-  // mutex_ used to guard reused encoding_
-  std::mutex mutex_;
-
   //! entry stored in `encoding_lookup_` to implement LRU
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
   struct EncodingEntry {
-    size_t id = 0;
+    size_t id;
     std::list<std::string>::iterator lru_iter;
   };
 
@@ -283,88 +123,69 @@ class TORCH_CUDA_CU_API InputsIdLookup : public NonCopyable {
 //!     c) broadcasting semantics (size-1 or not);
 //!     d) rank;
 //!     e) scalar type;
-//!
-//!
-//! [ Note -- Segmented Fusion Tentative Design ]
-//! Segmentation adds an extra dimension in caching. Initial implementation,
-//! assumed graph partition strategy is independent of input pattern, which we
-//! can revisit once we have more advanced graph segmentation logic Each
-//! FusionExecutorCache corresponds to one graph and one graph segmentation.
-//!
-//!
-class TORCH_CUDA_CU_API FusionExecutorCache {
+
+class FusionExecutorCache {
  public:
   //! create new fusion executor cache at a given device to handle kernel
   //! generation of dynamic sizes;
   //! fusion executor is taking the ownership of `fusion`;
-  explicit FusionExecutorCache(std::unique_ptr<Fusion> fusion);
+  explicit FusionExecutorCache(std::unique_ptr<Fusion>&& fusion);
 
   //! Execute fusion graph with given inputs, create `FusionExecutor` as needed;
   std::vector<at::Tensor> runFusionWithInputs(
       const at::ArrayRef<IValue>& inputs);
 
-  Fusion* fusion() {
-    return fusion_.get();
-  }
-
-  void printFusion() {
-    fusion_->printMath();
-  }
-
-  FusionKernelRuntime* getMostRecentKernelRuntime() {
-    return most_recent_runtime_;
-  }
-
-  // TODO: in a follow up we need a global logging structure
-  //  to capture runtime profiling info. We also need to define
-  //  a suitable profiling window / buffer size.
-  ExecutorLog getMostRecentExecutorInfo() {
-    TORCH_INTERNAL_ASSERT(most_recent_runtime_ != nullptr);
-    return most_recent_runtime_->getMostRecentExecutorLog();
-  }
-
-  void profile(bool to_profile) {
-    profiling_ = to_profile;
-    for (auto& it : kernel_runtimes_) {
-      for (auto& kernel_runtime : it.second) {
-        kernel_runtime->profile(to_profile);
-      }
-    }
-  }
-
  private:
   //! evict cached short cut entry in `code_to_fe_lookup_` as well as cached
   //! entry in `FusionExecutor`
-  void evictCache(size_t cache_id);
-
-  FusionKernelRuntime* getKernelRuntimeFor(
-      const at::ArrayRef<IValue>& inputs,
-      size_t id);
+  void evictCache(size_t cache_id) {
+    auto iter = code_to_fe_lookup_.find(cache_id);
+    TORCH_INTERNAL_ASSERT(
+        iter != code_to_fe_lookup_.end(),
+        "evict cache failed to find an entry");
+    // evict nested lookup entry in nested `FusionExecutor`
+    (iter->second)->evictCache(cache_id);
+    code_to_fe_lookup_.erase(iter);
+  };
 
  private:
   //! original un-scheduled `Fusion`;
   std::unique_ptr<Fusion> fusion_;
 
+  // I'm trading the const model in favor of assigning `has_reduction_` in the
+  // body of constructor, instead of the initializer list;
+  // Because of the move statement used in the constructor, it's tricky to
+  // maintain the code if we have `has_reduction_` as a const member and
+  // initizlize it in the initializer list, where the order of initialization
+  // is controled by the order of declaration instead of their order in the list
+  //
+  //! cache fusion->hasReduction() because it's expensive;
+  bool has_reduction_;
+
+  //! TODO: ugly logic for now. We should integrate the hashing of cache for
+  //!       different kernels. (alternatively we could do so in scheduler).
+  //! ugly bits now:
+  //! The fact that we have heuristics only for reduction, but use a general
+  //! kernel for all point-wise fusion ended up with this:
+  //! 1. For point-wise fusion, we have a single `FusionExecutor` in
+  //!    `pw_fusion_executor_cache_`
+  //! 2. For reduction fusion we have a hash table with ReductionParams as entry
+  //!    pointing to the actual `FusionExecutor` in `red_fusion_executor_cache_`
+  //!
+  //! Both cache_ key on device_index, because `FusionExecutor` is designated to
+  //! a single device
+  std::unordered_map<int, std::unique_ptr<FusionExecutor>>
+      pw_fusion_executor_cache_;
+  std::unordered_map<
+      int,
+      std::unordered_map<ReductionParams, FusionExecutor, ReductionParamsHash>>
+      red_fusion_executor_cache_;
+
+  //! short cut to FusionExecutor for input set encoded with id;
+  std::unordered_map<size_t, FusionExecutor*> code_to_fe_lookup_;
+
   //! inputs to unique_id lookup table;
   InputsIdLookup inputs_id_lookup_;
-
-  //! Graphs after input dependent transfoms
-  std::unordered_map<size_t, std::vector<std::unique_ptr<FusionKernelRuntime>>>
-      kernel_runtimes_;
-
-  //! Logging state for most recent compilation
-  bool profiling_ = false;
-
-  //! Logging state for most recent compilation
-  ExecutorLog most_recent_executor_log_;
-
-  //! short-cut for cache hit
-  std::unordered_map<size_t, FusionKernelRuntime*> id_to_kernel_runtime_;
-
-  //! Profiling info:
-  //! TODO: this can be largely expanded to look at complete
-  //!   caching profiles. Currently it just makes it easier to test
-  FusionKernelRuntime* most_recent_runtime_ = nullptr;
 };
 
 class GraphCache {
@@ -385,7 +206,6 @@ class GraphCache {
   std::shared_ptr<Graph> graph_;
   //! TODO: poor name, we should use `eliminated_axes_` instead;
   at::DimVector reduction_axes_;
-  bool support_permutation_;
 
   //! helper function used at run-time to check whether a common permutation is
   //! present, this is used to take the short-cut to skip permutation logic.
diff --git a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp
deleted file mode 100644 (file)
index 47ea142..0000000
+++ /dev/null
@@ -1,135 +0,0 @@
-
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-
-#include <iostream>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-namespace kir {
-
-void ExpressionEvaluator::bind(
-    const Val* value,
-    Int::ScalarType concrete_value) {
-  TORCH_CHECK(value->isScalar());
-  TORCH_CHECK(value->dtype() == DataType::Int);
-  TORCH_CHECK(!value->isConst(), "Tried to bind to a constant value");
-  TORCH_CHECK(
-      value->definition() == nullptr,
-      "Tried to bind to a value that is computed in the kernel IR");
-  known_values_[value] = concrete_value;
-}
-
-c10::optional<Int::ScalarType> ExpressionEvaluator::evaluate(const Val* value) {
-  FUSER_PERF_SCOPE("kir::ExpressionEvaluator::evaluate");
-
-  TORCH_CHECK(value->isScalar());
-  TORCH_CHECK(value->dtype() == DataType::Int);
-
-  // Const scalar?
-  if (value->isScalar() && value->isConst()) {
-    return value->as<Int>()->value();
-  }
-
-  // Is the value known (either explicit binding or memoized)?
-  const auto pre_eval_it = known_values_.find(value);
-  if (pre_eval_it != known_values_.end()) {
-    return pre_eval_it->second;
-  }
-
-  value->accept(this);
-
-  const auto post_eval_it = known_values_.find(value);
-  return post_eval_it != known_values_.end()
-      ? c10::optional<Int::ScalarType>(post_eval_it->second)
-      : c10::nullopt;
-}
-
-bool ExpressionEvaluator::isConst(const Val* value) {
-  return ExpressionEvaluator().evaluate(value).has_value();
-}
-
-void ExpressionEvaluator::print() const {
-  std::cout << "\nEvaluation context\n";
-  std::cout << "--------------------\n";
-  for (const auto& kv : known_values_) {
-    std::cout << toString(kv.first) << " = " << kv.second << "\n";
-  }
-  std::cout << "--------------------\n\n";
-}
-
-void ExpressionEvaluator::unhandled(const void*) {
-  TORCH_INTERNAL_ASSERT(
-      false, "Kernel IR expression evaluation reached an unsupported node");
-}
-
-void ExpressionEvaluator::visit(const Int* value) {
-  TORCH_INTERNAL_ASSERT(!value->isConst());
-  if (auto def = value->definition()) {
-    def->accept(this);
-  }
-}
-
-void ExpressionEvaluator::visit(const NamedScalar* named_scalar) {
-  // It's a legal expresison node so we must handle it
-}
-
-void ExpressionEvaluator::visit(const UnaryOp* unary_op) {
-  const auto in = evaluate(unary_op->in());
-  if (in.has_value()) {
-    switch (unary_op->operation()) {
-      case UnaryOpType::Neg:
-        known_values_[unary_op->out()] = -*in;
-        break;
-      case UnaryOpType::Cast:
-        known_values_[unary_op->out()] = *in;
-        break;
-      default:
-        TORCH_CHECK(!"Unexpected operator type");
-    }
-  }
-}
-
-void ExpressionEvaluator::visit(const BinaryOp* binary_op) {
-  const auto lhs = evaluate(binary_op->lhs());
-  const auto rhs = evaluate(binary_op->rhs());
-  if (lhs.has_value() && rhs.has_value()) {
-    switch (binary_op->operation()) {
-      case BinaryOpType::Add:
-        known_values_[binary_op->out()] = *lhs + *rhs;
-        break;
-      case BinaryOpType::Sub:
-        known_values_[binary_op->out()] = *lhs - *rhs;
-        break;
-      case BinaryOpType::Mul:
-        known_values_[binary_op->out()] = *lhs * *rhs;
-        break;
-      case BinaryOpType::Div:
-        TORCH_CHECK(*rhs != 0);
-        known_values_[binary_op->out()] = *lhs / *rhs;
-        break;
-      case BinaryOpType::Mod:
-        TORCH_CHECK(*rhs != 0);
-        known_values_[binary_op->out()] = *lhs % *rhs;
-        break;
-      case BinaryOpType::CeilDiv:
-        TORCH_CHECK(*rhs != 0);
-        known_values_[binary_op->out()] = (*lhs + *rhs - 1) / *rhs;
-        break;
-      case BinaryOpType::And:
-        known_values_[binary_op->out()] = Int::ScalarType(*lhs && *rhs);
-        break;
-      default:
-        TORCH_CHECK(!"Unexpected operator type");
-    }
-  }
-}
-
-} // namespace kir
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h
deleted file mode 100644 (file)
index 3064c3e..0000000
+++ /dev/null
@@ -1,62 +0,0 @@
-
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <c10/util/Optional.h>
-
-#include <unordered_map>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-namespace kir {
-
-//! Calculate Kernel IR expressions
-//!
-//! How to evaluate Kernel IR expressions:
-//!
-//! ```cpp
-//!   kir::ExpressionEvaluator eval;
-//!   eval.bind(symbolic_value, concrete_value);
-//!   ... bind more values ...
-//!   const auto result = eval.evaluate(interesting_value);
-//!   if (result.has_value()) {
-//!     ... we have successfully calculated the result ...
-//!   } else {
-//!     ... expression can't be evaluated ...
-//!   }
-//! ```
-//!
-class TORCH_CUDA_CU_API ExpressionEvaluator : private IrVisitor {
- public:
-  //! Set a concrete value for a symbolic value
-  void bind(const Val* value, Int::ScalarType concrete_value);
-
-  //! Try to evaluate a Kernel IR value
-  c10::optional<Int::ScalarType> evaluate(const Val* value);
-
-  //! Returns true if `value` is known before binding kernel inputs
-  static bool isConst(const Val* value);
-
-  //! Debugging helper, prints all the currently known values
-  void print() const;
-
- private:
-  void unhandled(const void*) final;
-  void visit(const Int* value) final;
-  void visit(const NamedScalar* named_scalar) final;
-  void visit(const UnaryOp* unary_op) final;
-  void visit(const BinaryOp* binary_op) final;
-
- private:
-  std::unordered_map<const Val*, Int::ScalarType> known_values_;
-};
-
-} // namespace kir
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
index 43d98bf..aef65a8 100644 (file)
@@ -1,96 +1,15 @@
-#include <torch/csrc/jit/codegen/cuda/kernel.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
 #include <torch/csrc/jit/codegen/cuda/lower_utils.h>
 #include <torch/csrc/jit/codegen/cuda/type.h>
 
-#include <iostream>
-
 namespace torch {
 namespace jit {
 namespace fuser {
 namespace cuda {
 namespace kir {
 
-void Node::print() const {
-  std::cout << "\n";
-  IrPrinter(std::cout).printNode(this);
-  std::cout << "\n";
-}
-
-Val::Val(Passkey passkey, DataType dtype) : Node(passkey), dtype_(dtype) {
-  // NOLINTNEXTLINE: https://bugs.llvm.org/show_bug.cgi?id=48534
-  id_ = passkey.kernel->newValueId(passkey);
-}
-
-namespace {
-
-// Traverse definition of all values involved in constructing the provided val.
-// Check if all values involved are constant values, meaning the provided
-// val is also a constant value.
-class ConstCheck : IrVisitor {
- private:
-  bool is_const_ = true;
-
-  using IrVisitor::visit;
-
-  void visit(const Bool* b) {
-    is_const_ = is_const_ && b->isConst();
-  }
-
-  void visit(const Double* d) {
-    is_const_ = is_const_ && d->isConst();
-  }
-
-  void visit(const Int* i) {
-    is_const_ = is_const_ && i->isConst();
-  }
-
-  void visit(const NamedScalar* ns) {
-    is_const_ = is_const_ && false;
-  }
-
-  void visit(const Expr* expr) {
-    for (auto inp : expr->inputs()) {
-      visit(inp);
-    }
-  }
-
-  void visit(const Val* val) {
-    if (val->definition() != nullptr) {
-      visit(val->definition());
-    } else {
-      val->accept(this);
-    }
-  }
-
- public:
-  static bool isConst(const Val* val) {
-    ConstCheck cc;
-    cc.visit(val);
-    return cc.is_const_;
-  }
-};
-
-} // namespace
-
-bool Val::isConstScalar() const {
-  if (!isScalar())
-    return false;
-  return ConstCheck::isConst(this);
-}
-
-Expr* Expr::parentScope() const {
-  if (scope()) {
-    return scope()->owner();
-  } else {
-    return nullptr;
-  }
-}
-
 NamedScalar* NamedScalar::getParallelDim(ParallelType p_type) {
   std::string parallel_dim = stringifyThreadSize(p_type);
   kir::IrBuilder ir_builder(GpuLower::current()->kernel());
@@ -137,52 +56,50 @@ c10::optional<ParallelType> NamedScalar::getParallelIndex() const {
   return c10::nullopt;
 }
 
-IterDomain::IterDomain(Passkey passkey, Val* start, Val* extent)
-    : Val(passkey, DataType::Int), start_(start), extent_(extent) {}
+IterDomain::IterDomain(Passkey, Val* start, Val* extent)
+    : Val(ValType::KirIterDomain, DataType::Int, true, true),
+      start_(start),
+      extent_(extent) {}
 
-IterDomain::IterDomain(
-    Passkey passkey,
-    const fuser::cuda::IterDomain* iter_domain)
-    : Val(passkey, iter_domain->getDataType().value()),
-      start_(GpuLower::current()->lowerValue(iter_domain->start())),
-      extent_(GpuLower::current()->lowerValue(iter_domain->extent())),
+IterDomain::IterDomain(Passkey, const fuser::cuda::IterDomain* iter_domain)
+    : Val(iter_domain),
+      start_(GpuLower::lowerValue(iter_domain->start())),
+      extent_(GpuLower::lowerValue(iter_domain->rawExtent())),
       parallel_type_(iter_domain->getParallelType()),
       iter_type_(iter_domain->getIterType()),
-      is_rfactor_domain_(iter_domain->isRFactorProduct()),
-      is_simple_(iter_domain->definition() == nullptr) {
-  // preserve the fusion node's name
-  setName(iter_domain->name());
-}
+      is_rfactor_domain_(iter_domain->isRFactorProduct()) {}
 
-//! Note that the parallel dimension, if available, may be different
-//! from the actual extent of this IterDomain as the parallel
-//! dimension is determined by the largest extent of IterDomains
-//! sharing the same loop.
 Val* IterDomain::extent() const {
-  TORCH_INTERNAL_ASSERT(extent_ != nullptr);
+  TORCH_CHECK(isLoweredVal(extent_));
+  if (isThread()) {
+    // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
+    if (extent_->getValType() == ValType::KirScalar) {
+      if (extent_->as<kir::Int>()->isConst()) {
+        return extent_;
+      }
+    }
+    return NamedScalar::getParallelDim(getParallelType());
+  }
   return extent_;
 }
 
-TensorDomain::TensorDomain(Passkey passkey, std::vector<IterDomain*> domain)
-    : Val(passkey, DataType::Null), root_domain_(std::move(domain)) {
+TensorDomain::TensorDomain(Passkey, std::vector<IterDomain*> domain)
+    : Val(ValType::KirTensorDomain), root_domain_(std::move(domain)) {
   domain_ = root_domain_;
   resetDomains();
 }
 
 TensorDomain::TensorDomain(
-    Passkey passkey,
+    Passkey,
     const fuser::cuda::TensorDomain* tensor_domain)
-    : Val(passkey, DataType::Null), contiguity_(tensor_domain->contiguity()) {
-  // preserve the fusion node's name
-  setName(tensor_domain->name());
-
+    : Val(tensor_domain), contiguity_(tensor_domain->contiguity()) {
   const auto lowerIterDomains =
       [](const std::vector<fuser::cuda::IterDomain*>& domains) {
         std::vector<IterDomain*> lowered_domains;
         lowered_domains.reserve(domains.size());
         for (const auto iter_domain : domains) {
           lowered_domains.push_back(
-              GpuLower::current()->lowerValue(iter_domain)->as<IterDomain>());
+              GpuLower::lowerValue(iter_domain)->as<IterDomain>());
         }
         return lowered_domains;
       };
@@ -224,13 +141,6 @@ bool TensorDomain::hasRFactor() const {
   return !rfactor_domain_.empty();
 }
 
-bool TensorDomain::hasVectorize() const {
-  return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
-    return id->parallelType() == ParallelType::Vectorize ||
-        id->parallelType() == ParallelType::MisalignedVectorize;
-  });
-}
-
 IterDomain* TensorDomain::axis(int i) const {
   TORCH_INTERNAL_ASSERT(i >= 0 && i < int(domain_.size()));
   return domain_[i];
@@ -258,124 +168,67 @@ std::vector<IterDomain*> TensorDomain::noBroadcasts(
   return no_broadcast_domains;
 }
 
-TensorView::TensorView(Passkey passkey, const fuser::cuda::TensorView* tv)
-    : Val(passkey, tv->getDataType().value()), fuser_tv_(tv) {
-  setName(tv->name());
-  domain_ = GpuLower::current()->lowerValue(tv->domain())->as<TensorDomain>();
+TensorView::TensorView(Passkey, const fuser::cuda::TensorView* tv)
+    : Val(tv), fuser_tv_(tv) {
+  domain_ = GpuLower::lowerValue(tv->domain())->as<TensorDomain>();
   memory_type_ = tv->getMemoryType();
 }
 
-TensorView::TensorView(
-    Passkey passkey,
-    DataType dtype,
-    TensorDomain* domain,
-    MemoryType memory_type)
-    : Val(passkey, dtype), domain_(domain), memory_type_(memory_type) {}
-
-UnaryOp::UnaryOp(Passkey passkey, UnaryOpType operation, Val* out, Val* in)
-    : Expr(passkey), operation_(operation), out_(out), in_(in) {
+UnaryOp::UnaryOp(Passkey, UnaryOpType type, Val* out, Val* in)
+    : Expr(ExprType::KirUnaryOp), unary_op_type_{type}, out_{out}, in_{in} {
   addOutput(out);
   addInput(in);
+  name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
 }
 
-BinaryOp::BinaryOp(
-    Passkey passkey,
-    BinaryOpType operation,
-    Val* out,
-    Val* lhs,
-    Val* rhs)
-    : Expr(passkey), operation_(operation), out_(out), lhs_(lhs), rhs_(rhs) {
+BinaryOp::BinaryOp(Passkey, BinaryOpType type, Val* out, Val* lhs, Val* rhs)
+    : Expr(ExprType::KirBinaryOp),
+      binary_op_type_{type},
+      out_{out},
+      lhs_{lhs},
+      rhs_{rhs} {
   addOutput(out);
   addInput(lhs);
   addInput(rhs);
+  name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
 }
 
 TernaryOp::TernaryOp(
-    Passkey passkey,
-    TernaryOpType operation,
+    Passkey,
+    TernaryOpType type,
     Val* out,
     Val* in1,
     Val* in2,
     Val* in3)
-    : Expr(passkey),
-      operation_(operation),
-      out_(out),
-      in1_(in1),
-      in2_(in2),
-      in3_(in3) {
+    : Expr(ExprType::KirTernaryOp),
+      ternary_op_type_{type},
+      out_{out},
+      in1_{in1},
+      in2_{in2},
+      in3_{in3} {
   addOutput(out);
   addInput(in1);
   addInput(in2);
   addInput(in3);
+  name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
 }
 
 ReductionOp::ReductionOp(
-    Passkey passkey,
-    BinaryOpType operation,
+    Passkey,
+    BinaryOpType reduction_op_type,
     Val* init,
     Val* out,
-    Val* in)
-    : Expr(passkey), operation_(operation), init_(init), out_(out), in_(in) {
+    Val* in,
+    Bool* pred)
+    : Expr(ExprType::KirReductionOp),
+      reduction_op_type_(reduction_op_type),
+      init_(init),
+      out_(out),
+      in_(in),
+      pred_(pred) {
   addOutput(out);
   addInput(in);
-}
-
-WelfordOp::WelfordOp(
-    Passkey passkey,
-    Val* out_var,
-    Val* out_avg,
-    Val* out_N,
-    Val* init_var,
-    Val* init_avg,
-    Val* init_N,
-    Val* in_var,
-    Val* in_avg,
-    Val* in_N)
-    : Expr(passkey),
-      out_var_(out_var),
-      out_avg_(out_avg),
-      out_N_(out_N),
-      init_var_(init_var),
-      init_avg_(init_avg),
-      init_N_(init_N),
-      in_var_(in_var),
-      in_avg_(in_avg),
-      in_N_(in_N) {
-  addOutput(out_avg);
-  addOutput(out_var);
-  addOutput(out_N);
-
-  if (!in_N->isOneInt()) {
-    addInput(in_var);
-  }
-  addInput(in_avg);
-  addInput(in_N);
-}
-
-std::vector<IterDomain*> WelfordOp::getReductionDomains() const {
-  // out is a TensorIndex after lowering
-  const auto out_val = out()->as<kir::TensorIndex>()->view();
-
-  auto vec_domain = out_val->as<TensorView>()->domain()->domain();
-
-  vec_domain.erase(
-      std::remove_if(
-          vec_domain.begin(),
-          vec_domain.end(),
-          [](IterDomain* id) { return !id->isReduction(); }),
-      vec_domain.end());
-  return vec_domain;
-}
-
-std::unordered_map<ParallelType, IterDomain*, TypeHash> WelfordOp::
-    getParallelReductionDomains() const {
-  std::unordered_map<ParallelType, IterDomain*, TypeHash> parallel_domains;
-  for (auto d : getReductionDomains()) {
-    if (d->isThread()) {
-      parallel_domains.insert(std::make_pair(d->parallelType(), d));
-    }
-  }
-  return parallel_domains;
+  name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
 }
 
 std::vector<IterDomain*> ReductionOp::getReductionDomains() const {
@@ -398,211 +251,122 @@ std::unordered_map<ParallelType, IterDomain*, TypeHash> ReductionOp::
   std::unordered_map<ParallelType, IterDomain*, TypeHash> parallel_domains;
   for (auto d : getReductionDomains()) {
     if (d->isThread()) {
-      parallel_domains.insert(std::make_pair(d->parallelType(), d));
+      parallel_domains.insert(std::make_pair(d->getParallelType(), d));
     }
   }
   return parallel_domains;
 }
 
-BroadcastOp::BroadcastOp(Passkey passkey, Val* out, Val* in)
-    : Expr(passkey), out_(out), in_(in) {
-  TORCH_CHECK(in->isA<TensorIndex>() || in->isA<TensorView>());
-  TORCH_CHECK(out->isA<TensorIndex>() || out->isA<TensorView>());
+BroadcastOp::BroadcastOp(Passkey, Val* out, Val* in)
+    : Expr(ExprType::KirBroadcastOp), out_(out), in_(in) {
+  TORCH_CHECK(in->getValType().value() == ValType::TensorIndex);
+  TORCH_CHECK(out->getValType().value() == ValType::TensorIndex);
   addOutput(out);
   addInput(in);
+  name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
 }
 
 TensorIndex::TensorIndex(
-    Passkey passkey,
+    Passkey,
     const fuser::cuda::TensorView* view,
     std::vector<Val*> indices)
-    : Val(passkey, view->getDataType().value()),
-      view_(GpuLower::current()->lowerValue(view)->as<TensorView>()),
+    : Val(ValType::TensorIndex, view->getDataType().value(), true, true),
+      view_(GpuLower::lowerValue(view)->as<TensorView>()),
       indices_(indices) {
   TORCH_INTERNAL_ASSERT(
       std::all_of(
           indices.begin(),
           indices.end(),
-          [](Val* v) { return v->dtype() == DataType::Int; }),
+          [](Val* v) {
+            return (v->getValType() == ValType::KirScalar ||
+                    v->getValType() == ValType::KirNamedScalar) &&
+                v->getDataType() == DataType::Int;
+          }),
       "Cannot index with a value other than an int.");
-  indices_.erase(
-      std::remove_if(
-          indices_.begin(),
-          indices_.end(),
-          [](Val* index) { return index->isZeroInt(); }),
-      indices_.end());
-  // If indices becomes empty, just put one ZeroInt
-  if (indices_.empty()) {
-    indices_.push_back(kir::IrBuilder(GpuLower::current()->kernel()).zeroVal());
-  }
 }
 
-Sync::Sync(Passkey passkey, bool war_sync)
-    : Expr(passkey), war_sync_(war_sync) {}
-
-InitMagicZero::InitMagicZero(Passkey passkey) : Expr(passkey) {}
-
-UpdateMagicZero::UpdateMagicZero(Passkey passkey) : Expr(passkey) {}
-
-void Scope::insert(std::vector<Expr*>::const_iterator pos, Expr* expr) {
-  exprs_.insert(pos, expr);
-  expr->setScope(this);
+Sync::Sync(Passkey, bool war_sync) : Expr(ExprType::Sync), war_sync_(war_sync) {
+  name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
 }
 
 void Scope::insert_before(Expr* ref, Expr* expr) {
-  const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
-  TORCH_INTERNAL_ASSERT(
-      it != exprs_.end(),
-      "Tried to insert ",
-      expr,
-      " before the reference: ",
-      ref,
-      " however the reference was not found in this scope.");
-  insert(it, expr);
+  auto it = exprs_.begin();
+  while (it != exprs_.end()) {
+    if ((*it)->sameAs(ref))
+      break;
+    it++;
+  }
+  if (it != exprs_.end())
+    exprs_.insert(it, expr);
 }
 
 void Scope::insert_after(Expr* ref, Expr* expr) {
-  const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
-  TORCH_INTERNAL_ASSERT(
-      it != exprs_.end(),
-      "Tried to insert ",
-      expr,
-      " after the reference: ",
-      ref,
-      " however the reference was not found in this scope.");
-  insert(it + 1, expr);
-}
-
-void Scope::insert(size_t pos, Expr* expr) {
-  const auto it = exprs_.begin() + pos;
-  insert(it, expr);
-}
-
-void Scope::erase(std::vector<Expr*>::const_iterator pos) {
-  // Remove the scope of the expr if this is the scope
-  auto expr = *pos;
-  TORCH_INTERNAL_ASSERT(
-      expr->scope() == this,
-      "Inconsistent scoping of expression detected: ",
-      kir::toString(expr));
-  expr->setScope(nullptr);
-  exprs_.erase(pos);
+  auto it = exprs_.begin();
+  while (it != exprs_.end()) {
+    if (*it == ref)
+      break;
+    it++;
+  }
+  if (it != exprs_.end())
+    exprs_.insert(++it, expr);
 }
 
 void Scope::erase(Expr* ref) {
-  const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
-  if (it != exprs_.end()) {
-    erase(it);
+  auto it = exprs_.begin();
+  while (it != exprs_.end()) {
+    if (*it == ref)
+      break;
+    it++;
   }
-}
-
-void Scope::erase(size_t pos) {
-  TORCH_INTERNAL_ASSERT(pos < size());
-  erase(exprs_.begin() + pos);
+  if (it != exprs_.end())
+    exprs_.erase(it);
 }
 
 bool Scope::contains(Expr* expr) const {
-  const auto it = std::find(exprs_.begin(), exprs_.end(), expr);
-  return it != exprs_.end();
+  for (auto e : exprs_)
+    if (e == expr)
+      return true;
+  return false;
 }
 
 void Scope::clear() {
-  exprs_.clear();
+  exprs_ = std::vector<Expr*>();
 }
 
 ForLoop::ForLoop(
-    Passkey passkey,
-    IterDomain* iter_domain,
+    Passkey,
     Val* index,
-    Val* start,
-    Val* stop,
-    Val* step,
-    bool vectorize,
-    Val* vectorize_shift)
-    : Expr(passkey),
+    IterDomain* iter_domain,
+    Expr* parent_scope)
+    : Expr(ExprType::ForLoop),
+      index_{index},
       iter_domain_{iter_domain},
-      index_(index),
-      start_(start),
-      stop_(stop),
-      step_(step),
-      vectorize_(vectorize),
-      vectorize_shift_(vectorize_shift),
-      body_(this) {
-  TORCH_INTERNAL_ASSERT(index->dtype() == DataType::Int);
+      parent_scope_{parent_scope} {
+  TORCH_INTERNAL_ASSERT(index->isAnInt());
+  TORCH_INTERNAL_ASSERT(isLoweredScalar(index));
   addInput(index);
   addInput(iter_domain);
-  if (start_ == nullptr && iter_domain->isThread()) {
-    start_ =
-        IrBuilder(GpuLower::current()->kernel())
-            .create<kir::NamedScalar>(
-                stringifyThread(iter_domain->parallelType()), DataType::Int);
-  }
-  if (step_ == nullptr) {
-    if (iter_domain->isThread()) {
-      step_ = IrBuilder(GpuLower::current()->kernel())
-                  .create<kir::NamedScalar>(
-                      stringifyThreadSize(iter_domain->parallelType()),
-                      DataType::Int);
-    } else {
-      step_ = IrBuilder(GpuLower::current()->kernel()).oneVal();
-    }
-  }
-}
-
-ForLoop::ForLoop(Passkey passkey, IterDomain* iter_domain)
-    : ForLoop(
-          passkey,
-          iter_domain,
-          iter_domain->isBroadcast()
-              ? IrBuilder(GpuLower::current()->kernel()).zeroVal()
-              : IrBuilder(GpuLower::current()->kernel())
-                    .create<kir::Int>(c10::nullopt),
-          nullptr,
-          nullptr,
-          nullptr,
-          isParallelTypeVectorize(iter_domain->parallelType()),
-          nullptr) {}
-
-ForLoop::ForLoop(Passkey passkey, const ForLoop* other)
-    : ForLoop(
-          passkey,
-          other->iter_domain(),
-          other->index(),
-          other->start(),
-          other->stop(),
-          other->step(),
-          other->vectorize(),
-          other->vectorize_shift()) {}
-
-Val* ForLoop::start() const {
-  if (start_ != nullptr) {
-    return start_;
-  } else {
-    // clang-tidy complains without this
-    TORCH_INTERNAL_ASSERT(iter_domain_ != nullptr);
-    return iter_domain_->start();
-  }
+  name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
 }
 
-Val* ForLoop::stop() const {
-  if (stop_ != nullptr) {
-    return stop_;
-  } else {
-    // clang-tidy complains without this
-    TORCH_INTERNAL_ASSERT(iter_domain_ != nullptr);
-    return iter_domain_->extent();
-  }
+void ForLoop::setParentScope(Expr* scope) {
+  TORCH_INTERNAL_ASSERT(
+      !scope_utils::exprInScope(parentScope(), this),
+      "Cannot change parent scope if not already removed from previous parent.");
+  parent_scope_ = scope;
 }
 
-Val* ForLoop::step() const {
-  TORCH_INTERNAL_ASSERT(step_ != nullptr);
-  return step_;
+IfThenElse::IfThenElse(Passkey, Bool* cond, Expr* parent_scope)
+    : Expr(ExprType::IfThenElse), cond_{cond}, parent_scope_(parent_scope) {
+  addInput(cond);
+  name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
 }
 
-IfThenElse::IfThenElse(Passkey passkey, Predicate* cond)
-    : Expr(passkey), then_body_(this), else_body_(this) {
-  setPredicate(cond);
-  addInput(cond);
+void IfThenElse::setParentScope(Expr* scope) {
+  TORCH_INTERNAL_ASSERT(
+      !scope_utils::exprInScope(parentScope(), this),
+      "Cannot change parent scope if not already removed from previous parent.");
+  parent_scope_ = scope;
 }
 
 Val* TensorIndex::index(int i) const {
@@ -610,78 +374,73 @@ Val* TensorIndex::index(int i) const {
       nDims() > 0, "Tried to get an index of a 0-dim TensorIndex");
   if (i < 0)
     i += nDims();
-  TORCH_INTERNAL_ASSERT(i >= 0 && i < int(nDims()));
+  assert(i >= 0 && i < nDims());
   return indices_[i];
 }
 
 Allocate::Allocate(
-    Passkey passkey,
+    Passkey,
     Val* buffer,
     MemoryType memory_type,
-    std::vector<Val*> shape,
+    Val* size,
     bool zero_init)
-    : Expr(passkey),
+    : Expr(ExprType::Allocate),
       buffer_(buffer),
       memory_type_(memory_type),
-      shape_(std::move(shape)),
+      size_(size),
       zero_init_(zero_init) {
-  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-  if (!shape_.empty()) {
+  if (size_ != nullptr) {
     TORCH_INTERNAL_ASSERT(
-        (shape_.size() == 1 && shape_[0]->isOneInt()) ||
-        buffer_->isA<TensorView>());
+        size_->isOneInt() ||
+            buffer_->getValType().value() == ValType::KirTensorView,
+        "Cannot allocate a non-TensorView buffer with a size != 1, received buffer: ",
+        buffer_);
   } else {
-    TORCH_INTERNAL_ASSERT(buffer_->isA<TensorView>());
+    TORCH_INTERNAL_ASSERT(
+        buffer_->getValType().value() == ValType::KirTensorView);
     TORCH_INTERNAL_ASSERT(
         buffer_->as<TensorView>()->memoryType() == memory_type_);
+    kir::IrBuilder ir_builder(GpuLower::current()->kernel());
     const auto domain = buffer_->as<TensorView>()->domain();
-    for (auto axis : domain->noReductions()) {
-      shape_.push_back(axis->extent());
+    size_ = domain->nDims() == 0 ? ir_builder.create<Int>(1)
+                                 : domain->axis(0)->extent();
+    for (size_t i = 1; i < domain->nDims(); i++) {
+      size_ = ir_builder.mulExpr(size_, domain->axis(i)->extent());
     }
   }
 
-  for (auto s : shape_) {
-    if (size_ == nullptr) {
-      size_ = s;
-    } else {
-      size_ = ir_builder.mulExpr(size_, s);
+  if (memory_type_ == MemoryType::Local) {
+    if (!size_->isConstScalar()) {
+      TORCH_INTERNAL_ASSERT(
+          false,
+          "Allocations must be based on constant integers for the memory type ",
+          memory_type_,
+          " but tried to alloc ",
+          buffer_,
+          " with symbolic size.");
     }
   }
 
-  if (size_ == nullptr) {
-    size_ = ir_builder.oneVal();
-  }
-
   addInput(size_);
+  name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
 }
 
-Allocate::Allocate(
-    Passkey passkey,
-    Val* buffer,
-    MemoryType memory_type,
-    Val* size,
-    bool zero_init)
-    : Allocate(
-          passkey,
-          buffer,
-          memory_type,
-          size == nullptr ? std::vector<Val*>{} : std::vector<Val*>{size},
-          zero_init) {}
-
-GridReduction::GridReduction(Passkey passkey, ReductionOp* reduction_op)
-    : Expr(passkey), reduction_op_(reduction_op) {
+GridReduction::GridReduction(Passkey, ReductionOp* reduction_op)
+    : Expr(ExprType::GridReduction), reduction_op_(reduction_op) {
   TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
 }
 
 GridReduction::GridReduction(
-    Passkey passkey,
+    Passkey,
     ReductionOp* reduction_op,
     Allocate* reduction_buffer,
-    Allocate* sync_buffer)
-    : Expr(passkey),
+    Allocate* sync_buffer,
+    Bool* pred)
+    : Expr(ExprType::GridReduction),
       reduction_op_(reduction_op),
       reduction_buffer_(reduction_buffer),
-      sync_buffer_(sync_buffer) {}
+      sync_buffer_(sync_buffer),
+      pred_(pred) {}
 
 std::string GridReduction::getPredicateFlagName(const TensorView* val) {
   std::stringstream ss;
@@ -697,34 +456,6 @@ std::string GridReduction::getPredicateFlagName(
   return ss.str();
 }
 
-GridWelford::GridWelford(
-    Passkey passkey,
-    WelfordOp* welford_op,
-    Allocate* var_buffer,
-    Allocate* avg_buffer,
-    Allocate* n_buffer,
-    Allocate* sync_buffer)
-    : Expr(passkey),
-      welford_op_(welford_op),
-      var_buffer_(var_buffer),
-      avg_buffer_(avg_buffer),
-      n_buffer_(n_buffer),
-      sync_buffer_(sync_buffer) {}
-
-std::string GridWelford::getPredicateFlagName(const TensorView* val) {
-  std::stringstream ss;
-  ss << "T" << val->name() << "_pred";
-  return ss.str();
-}
-
-// TODO(kir): remove this
-std::string GridWelford::getPredicateFlagName(
-    const fuser::cuda::TensorView* val) {
-  std::stringstream ss;
-  ss << "T" << val->name() << "_pred";
-  return ss.str();
-}
-
 } // namespace kir
 } // namespace cuda
 } // namespace fuser
index b00d817..9f24b8d 100644 (file)
@@ -1,18 +1,16 @@
 #pragma once
 
 #include <torch/csrc/jit/codegen/cuda/type.h>
-#include <torch/csrc/jit/codegen/cuda/utils.h>
 
 // TODO(kir): remove these once the Kernel IR is separated from Fusion IR
+#include <torch/csrc/jit/codegen/cuda/fusion.h>
 #include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_internal_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h>
 
 #include <c10/util/Optional.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
-#include <cstdint>
 #include <string>
 #include <unordered_map>
 #include <vector>
@@ -24,407 +22,27 @@ namespace cuda {
 namespace kir {
 
 class IrBuilder;
-class Kernel;
-
-// Abstract nodes
-class Node;
-class Val;
-class Expr;
-
-// Values
-class NamedScalar;
-class Predicate;
-class Bool;
-class Double;
-class Int;
-class IterDomain;
-class TensorDomain;
-class TensorView;
-class TensorIndex;
-
-// Expressions
-class UnaryOp;
-class BinaryOp;
-class TernaryOp;
-class ReductionOp;
-class WelfordOp;
-class BroadcastOp;
-
-// Statements
-class Allocate;
-class Sync;
-class InitMagicZero;
-class UpdateMagicZero;
-class ForLoop;
-class IfThenElse;
-class GridReduction;
-class GridWelford;
-
-// Expr container
-class Scope;
-
-using ValueId = int32_t;
-
-//! Token used to restrict the access to Kernel IR creation
-//!
-//! A token is associated with a kernel, which is passed with the key
-//! (Passkey::kernel)
+
+//! Token used to restrict the access to Kernel IR constructors
 //!
-//! It is a "granular friendship" token, used to implement the "passkey" idiom:
+//! Granular "friendship" token, used to implement the "passkey" idiom:
 //! https://www.spiria.com/en/blog/desktop-software/passkey-idiom-and-better-friendship-c
 //! https://arne-mertz.de/2016/10/passkey-idiom
 //!
 class Passkey {
   friend class IrBuilder;
-
- public:
-  Kernel* const kernel = nullptr;
-
- private:
-  explicit Passkey(Kernel* kernel) : kernel(kernel) {}
-};
-
-//! Kernel IR visitor interface
-class TORCH_CUDA_CU_API IrVisitor : public PolymorphicBase {
- public:
-  // TODO(kir): use Node* instead of void*
-  virtual void unhandled(const void* node) {}
-
-  // Values
-  virtual void visit(const NamedScalar* named_scalar) {
-    unhandled(named_scalar);
-  }
-  virtual void visit(const Predicate* value) {
-    unhandled(value);
-  }
-  virtual void visit(const Bool* value) {
-    unhandled(value);
-  }
-  virtual void visit(const Double* value) {
-    unhandled(value);
-  }
-  virtual void visit(const Int* value) {
-    unhandled(value);
-  }
-  virtual void visit(const IterDomain* iter_domain) {
-    unhandled(iter_domain);
-  }
-  virtual void visit(const TensorDomain* tensor_domain) {
-    unhandled(tensor_domain);
-  }
-  virtual void visit(const TensorView* tensor_view) {
-    unhandled(tensor_view);
-  }
-  virtual void visit(const TensorIndex* tensor_index) {
-    unhandled(tensor_index);
-  }
-
-  // Expressions
-  virtual void visit(const UnaryOp* node) {
-    unhandled(node);
-  }
-  virtual void visit(const BinaryOp* node) {
-    unhandled(node);
-  }
-  virtual void visit(const TernaryOp* node) {
-    unhandled(node);
-  }
-  virtual void visit(const ReductionOp* node) {
-    unhandled(node);
-  }
-  virtual void visit(const WelfordOp* node) {
-    unhandled(node);
-  }
-  virtual void visit(const BroadcastOp* node) {
-    unhandled(node);
-  }
-
-  // Statements
-  virtual void visit(const Allocate* node) {
-    unhandled(node);
-  }
-  virtual void visit(const Sync* node) {
-    unhandled(node);
-  }
-  virtual void visit(const InitMagicZero* node) {
-    unhandled(node);
-  }
-  virtual void visit(const UpdateMagicZero* node) {
-    unhandled(node);
-  }
-  virtual void visit(const ForLoop* node) {
-    unhandled(node);
-  }
-  virtual void visit(const IfThenElse* node) {
-    unhandled(node);
-  }
-  virtual void visit(const GridReduction* node) {
-    unhandled(node);
-  }
-  virtual void visit(const GridWelford* node) {
-    unhandled(node);
-  }
-};
-
-//! Kernel IR visitor interface
-class TORCH_CUDA_CU_API MutableIrVisitor : public PolymorphicBase {
- public:
-  // TODO(kir): use Node* instead of void*
-  virtual void unhandled(const void*) {}
-
-  // Values
-  virtual void visit(NamedScalar* named_scalar) {
-    unhandled(named_scalar);
-  }
-  virtual void visit(Predicate* value) {
-    unhandled(value);
-  }
-  virtual void visit(Bool* value) {
-    unhandled(value);
-  }
-  virtual void visit(Double* value) {
-    unhandled(value);
-  }
-  virtual void visit(Int* value) {
-    unhandled(value);
-  }
-  virtual void visit(IterDomain* iter_domain) {
-    unhandled(iter_domain);
-  }
-  virtual void visit(TensorDomain* tensor_domain) {
-    unhandled(tensor_domain);
-  }
-  virtual void visit(TensorView* tensor_view) {
-    unhandled(tensor_view);
-  }
-  virtual void visit(TensorIndex* tensor_index) {
-    unhandled(tensor_index);
-  }
-
-  // Expressions
-  virtual void visit(UnaryOp* node) {
-    unhandled(node);
-  }
-  virtual void visit(BinaryOp* node) {
-    unhandled(node);
-  }
-  virtual void visit(TernaryOp* node) {
-    unhandled(node);
-  }
-  virtual void visit(ReductionOp* node) {
-    unhandled(node);
-  }
-  virtual void visit(BroadcastOp* node) {
-    unhandled(node);
-  }
-
-  virtual void visit(WelfordOp* node) {
-    unhandled(node);
-  }
-
-  // Statements
-  virtual void visit(Allocate* node) {
-    unhandled(node);
-  }
-  virtual void visit(Sync* node) {
-    unhandled(node);
-  }
-  virtual void visit(InitMagicZero* node) {
-    unhandled(node);
-  }
-  virtual void visit(UpdateMagicZero* node) {
-    unhandled(node);
-  }
-  virtual void visit(ForLoop* node) {
-    unhandled(node);
-  }
-  virtual void visit(IfThenElse* node) {
-    unhandled(node);
-  }
-  virtual void visit(GridReduction* node) {
-    unhandled(node);
-  }
-
-  virtual void visit(GridWelford* node) {
-    unhandled(node);
-  }
-};
-
-//! Base class for Kernel IR nodes
-class TORCH_CUDA_CU_API Node : public NonCopyable, public PolymorphicBase {
- public:
-  explicit Node(Passkey) {}
-
-  //! IR Visitor double-dispatch interface
-  //! (https://en.wikipedia.org/wiki/Visitor_pattern)
-  virtual void accept(IrVisitor* visitor) const = 0;
-
-  //! Non constant IR Visitor
-  virtual void accept(MutableIrVisitor* visitor) = 0;
-
-  //! Debug helper, prints the textual representation of an IR node
-  void print() const;
-};
-
-//! Generic value (scalar or tensor)
-class TORCH_CUDA_CU_API Val : public Node {
- public:
-  Val(Passkey passkey, DataType dtype);
-
-  // TODO(kir): consider renaming
-  StmtNameType name() const {
-    return name_;
-  }
-
-  void setName(StmtNameType name) {
-    name_ = name;
-  }
-
-  ValueId id() const {
-    return id_;
-  }
-
-  DataType dtype() const {
-    return dtype_;
-  }
-
-  Expr* definition() const {
-    return definition_;
-  }
-
-  void setDefinition(Expr* expr) {
-    // TODO(kir): extra checks on changing existing definitions?
-    definition_ = expr;
-  }
-
-  virtual bool isScalar() const {
-    return false;
-  }
-
-  bool isConstScalar() const;
-
-  virtual bool isConst() const {
-    return false;
-  }
-
-  // TODO(kir): revisit and find a better interface
-  virtual bool isZeroInt() const {
-    return false;
-  }
-
-  virtual bool isOneInt() const {
-    return false;
-  }
-
- private:
-  const DataType dtype_;
-
-  // The expression which defines this value, or nullptr
-  Expr* definition_ = nullptr;
-
-  // This is a value name preserved from the Fusion IR (optional)
-  StmtNameType name_ = kInvalidStmName;
-
-  // All Kernel IR values have IDs (unique within the same Kernel)
-  ValueId id_ = -1;
-};
-
-//! Base class for expressions and statements
-//!
-//! Expressions consume inputs and produce outputs (depending on the context
-//! this may imply assignments). Currently some of the expressions
-//! don't actually produce any outputs (ForLoop, IfThenElse) and they
-//! model statements to be executed.
-//!
-//! TODO(kir): split the expressions, assignments and statements?
-//!
-class TORCH_CUDA_CU_API Expr : public Node {
- public:
-  explicit Expr(Passkey passkey) : Node(passkey) {}
-
-  const auto& inputs() const {
-    return inputs_;
-  }
-
-  const auto& outputs() const {
-    return outputs_;
-  }
-
-  Scope* scope() const {
-    return scope_;
-  }
-
-  //! Set the current scope
-  void setScope(Scope* scope) {
-    scope_ = scope;
-  }
-
-  Expr* parentScope() const;
-
-  Predicate* predicate() const {
-    return predicate_;
-  }
-
-  void setPredicate(Predicate* predicate) {
-    predicate_ = predicate;
-  }
-
-  Predicate* writePredicate() const {
-    return write_predicate_;
-  }
-
-  void setWritePredicate(Predicate* write_predicate) {
-    write_predicate_ = write_predicate;
-  }
-
- protected:
-  // TODO(kir): try to avoid this protected interface
-  void addInput(Val* input) {
-    inputs_.push_back(input);
-  }
-
-  void addOutput(Val* output) {
-    output->setDefinition(this);
-    outputs_.push_back(output);
-  }
-
- private:
-  // TODO(kir): can we avoid this?
-  std::vector<Val*> inputs_;
-  std::vector<Val*> outputs_;
-
-  // TODO(kir): revisit scope/nesting data structures
-  Scope* scope_ = nullptr;
-
-  Predicate* predicate_ = nullptr;
-  // Only used for reduction-related expressions
-  Predicate* write_predicate_ = nullptr;
+  Passkey() = default;
 };
 
-class TORCH_CUDA_CU_API NamedScalar final : public Val {
+class TORCH_CUDA_CU_API NamedScalar : public Val {
  public:
-  // NOLINTNEXTLINE(modernize-pass-by-value)
-  NamedScalar(Passkey passkey, std::string name, DataType dtype)
-      : Val(passkey, dtype), name_(name) {}
+  NamedScalar(Passkey, std::string name, DataType dtype)
+      : Val(ValType::KirNamedScalar, dtype, true, true),
+        name_(std::move(name)) {}
 
-  explicit NamedScalar(Passkey passkey, const fuser::cuda::NamedScalar* node)
-      : Val(passkey, node->getDataType().value()) {
-    name_ = node->name();
-  }
-
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
+  explicit NamedScalar(Passkey, const fuser::cuda::NamedScalar* node)
+      : Val(node), name_(node->name()) {}
 
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
-
-  bool isScalar() const override {
-    return true;
-  }
-
-  // TODO(kir): this is hiding and redefining Val::name()
   const std::string& name() const {
     return name_;
   }
@@ -447,214 +65,97 @@ class TORCH_CUDA_CU_API NamedScalar final : public Val {
   std::string name_;
 };
 
-class TORCH_CUDA_CU_API Predicate final : public Val {
+class TORCH_CUDA_CU_API Bool : public Val {
  public:
-  explicit Predicate(
-      Passkey passkey,
-      PredicateType ptype,
-      const Expr* expr = nullptr,
-      Bool* thread_pred = nullptr)
-      : Val(passkey, DataType::Bool),
-        ptype_(ptype),
-        expr_(expr),
-        thread_pred_(thread_pred) {
-    TORCH_INTERNAL_ASSERT(
-        ptype != PredicateType::Unswitch && ptype != PredicateType::Manual);
-  }
-
-  explicit Predicate(Passkey passkey, ForLoop* unrolled_loop)
-      : Val(passkey, DataType::Bool),
-        ptype_(PredicateType::Unswitch),
-        unrolled_loop_(unrolled_loop) {
-    TORCH_INTERNAL_ASSERT(unrolled_loop != nullptr);
-  }
-
-  explicit Predicate(Passkey passkey, Bool* value)
-      : Val(passkey, DataType::Bool),
-        ptype_(PredicateType::Manual),
-        value_(value) {
-    TORCH_INTERNAL_ASSERT(value != nullptr);
-  }
+  explicit Bool(Passkey, const c10::optional<bool>& value)
+      : Val(ValType::KirScalar, DataType::Bool, true, true),
+        maybe_value_(value) {}
 
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
+  explicit Bool(Passkey, const fuser::cuda::Bool* node)
+      : Val(node), maybe_value_(node->value()) {}
 
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
+  bool isSymbolic() const {
+    return !(maybe_value_.has_value());
   }
-
-  PredicateType predicate_type() const {
-    return ptype_;
-  }
-
-  const Expr* expr() const {
-    TORCH_INTERNAL_ASSERT(
-        ptype_ != PredicateType::Unswitch &&
-        ptype_ != PredicateType::Vectorize && ptype_ != PredicateType::Manual);
-    return expr_;
-  }
-
-  Bool* thread_pred() {
-    TORCH_INTERNAL_ASSERT(
-        ptype_ == PredicateType::Inline ||
-        ptype_ == PredicateType::Misaligned || ptype_ == PredicateType::Shift ||
-        ptype_ == PredicateType::Padding ||
-        ptype_ == PredicateType::ReductionWrite);
-    return thread_pred_;
-  }
-
-  ForLoop* unrolled_loop() const {
-    TORCH_INTERNAL_ASSERT(ptype_ == PredicateType::Unswitch);
-    return unrolled_loop_;
-  }
-
-  bool hasValue() const {
-    return value_ != nullptr;
-  }
-
-  Bool* value() const {
-    TORCH_INTERNAL_ASSERT(
-        value_ != nullptr,
-        "The conditional expression for this Predicate is invalid.");
-    return value_;
+  bool isConst() const {
+    return maybe_value_.has_value();
   }
-
-  void setValue(Bool* value) {
-    TORCH_INTERNAL_ASSERT(value != nullptr, "The Bool expression is invalid.");
-    value_ = value;
+  c10::optional<bool> value() const {
+    return maybe_value_;
   }
 
  private:
-  PredicateType ptype_ = PredicateType::Manual;
-
-  // For PredicateCompute::getInlinePredicate,
-  // ShiftPredicateInserter::getShiftPredicate and getPaddingPredicate
-  const Expr* expr_ = nullptr;
-
-  // For PredicateCompute::getInlinePredicate
-  Bool* thread_pred_ = nullptr;
-
-  // For ParallelType::Unswitch - UnswitchPredicate::get
-  ForLoop* unrolled_loop_ = nullptr;
-
-  // The Bool conditional value
-  // The value is nullptr until lower_predicate pass
-  Bool* value_ = nullptr;
+  const c10::optional<bool> maybe_value_;
 };
 
-class TORCH_CUDA_CU_API Bool final : public Val {
+class TORCH_CUDA_CU_API Float : public Val {
  public:
-  explicit Bool(Passkey passkey, const c10::optional<bool>& value)
-      : Val(passkey, DataType::Bool), maybe_value_(value) {}
-
-  explicit Bool(Passkey passkey, const fuser::cuda::Bool* node)
-      : Val(passkey, DataType::Bool), maybe_value_(node->value()) {
-    setName(node->name());
-  }
+  using ScalarType = double;
 
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
+  explicit Float(Passkey, const c10::optional<ScalarType>& value)
+      : Val(ValType::KirScalar, DataType::Float, true, true),
+        maybe_value_(value) {}
 
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
+  explicit Float(Passkey, const fuser::cuda::Float* node)
+      : Val(node), maybe_value_(node->value()) {}
 
-  bool isScalar() const override {
-    return true;
+  bool isSymbolic() const {
+    return !(maybe_value_.has_value());
   }
-
-  bool isConst() const override {
+  bool isConst() const {
     return maybe_value_.has_value();
   }
-
-  c10::optional<bool> value() const {
+  c10::optional<ScalarType> value() const {
     return maybe_value_;
   }
 
  private:
-  const c10::optional<bool> maybe_value_;
+  const c10::optional<ScalarType> maybe_value_;
 };
 
-class TORCH_CUDA_CU_API Double final : public Val {
+class TORCH_CUDA_CU_API Half : public Val {
  public:
-  using ScalarType = double;
+  explicit Half(Passkey, const c10::optional<float>& value)
+      : Val(ValType::KirScalar, DataType::Half, true, true),
+        maybe_value_(value) {}
 
-  explicit Double(Passkey passkey, const c10::optional<ScalarType>& value)
-      : Val(passkey, DataType::Double), maybe_value_(value) {}
+  explicit Half(Passkey, const fuser::cuda::Half* node)
+      : Val(node), maybe_value_(node->value()) {}
 
-  explicit Double(Passkey passkey, const fuser::cuda::Double* node)
-      : Val(passkey, DataType::Double), maybe_value_(node->value()) {
-    setName(node->name());
+  bool isSymbolic() const {
+    return !(maybe_value_.has_value());
   }
-
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
-
-  bool isScalar() const override {
-    return true;
-  }
-
-  bool isConst() const override {
+  bool isConst() const {
     return maybe_value_.has_value();
   }
-
-  c10::optional<ScalarType> value() const {
+  c10::optional<float> value() const {
     return maybe_value_;
   }
 
  private:
-  const c10::optional<ScalarType> maybe_value_;
+  const c10::optional<float> maybe_value_;
 };
 
-class TORCH_CUDA_CU_API Int final : public Val {
+class TORCH_CUDA_CU_API Int : public Val {
  public:
   using ScalarType = int64_t;
 
-  explicit Int(Passkey passkey, const c10::optional<ScalarType>& value)
-      : Val(passkey, DataType::Int), maybe_value_(value) {}
+  explicit Int(Passkey, const c10::optional<ScalarType>& value)
+      : Val(ValType::KirScalar, DataType::Int, true, true),
+        maybe_value_(value) {}
 
-  // SFINAE constructor to avoid 0 constant pointer ambiguity
-  template <
-      typename T,
-      typename = typename std::enable_if<
-          std::is_pointer<T>::value &&
-          std::is_convertible<T, const fuser::cuda::Int*>::value>::type>
-  explicit Int(Passkey passkey, T node)
-      : Val(passkey, DataType::Int), maybe_value_(node->value()) {
-    setName(node->name());
-  }
-
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
+  explicit Int(
+      Passkey,
+      const fuser::cuda::Int* node,
+      bool /*avoid_zero_ambiguity*/)
+      : Val(node), maybe_value_(node->value()) {}
 
-  bool isScalar() const override {
-    return true;
+  bool isSymbolic() const {
+    return !(maybe_value_.has_value());
   }
-
-  bool isConst() const override {
+  bool isConst() const {
     return maybe_value_.has_value();
   }
-
-  bool isZeroInt() const override {
-    return maybe_value_.has_value() && *maybe_value_ == 0;
-  }
-
-  bool isOneInt() const override {
-    return maybe_value_.has_value() && *maybe_value_ == 1;
-  }
-
   c10::optional<ScalarType> value() const {
     return maybe_value_;
   }
@@ -663,22 +164,14 @@ class TORCH_CUDA_CU_API Int final : public Val {
   const c10::optional<ScalarType> maybe_value_;
 };
 
-class TORCH_CUDA_CU_API IterDomain final : public Val {
+class TORCH_CUDA_CU_API IterDomain : public Val {
  public:
-  IterDomain(Passkey passkey, Val* start, Val* extent);
+  IterDomain(Passkey, Val* start, Val* extent);
 
   explicit IterDomain(Passkey, const fuser::cuda::IterDomain* iter_domain);
 
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
-
   bool isReduction() const {
-    return iterType() == IterType::Reduction;
+    return getIterType() == IterType::Reduction;
   }
 
   bool isRFactorProduct() const {
@@ -686,42 +179,40 @@ class TORCH_CUDA_CU_API IterDomain final : public Val {
   }
 
   bool isBroadcast() const {
-    return iterType() == IterType::BroadcastWithStride ||
-        iterType() == IterType::BroadcastWithoutStride;
-  }
-
-  bool isGather() const {
-    return iterType() == IterType::Gather;
+    return getIterType() == IterType::BroadcastWithStride ||
+        getIterType() == IterType::BroadcastWithoutStride;
   }
 
   bool isParallelized() const {
-    return parallelType() != ParallelType::Serial;
+    return getParallelType() != ParallelType::Serial;
   }
 
   // Return if this iter domain is mapped to a grid dimension
   bool isBlockDim() const {
-    return parallelType() == ParallelType::BIDz ||
-        parallelType() == ParallelType::BIDy ||
-        parallelType() == ParallelType::BIDx;
+    return (
+        getParallelType() == ParallelType::BIDz ||
+        getParallelType() == ParallelType::BIDy ||
+        getParallelType() == ParallelType::BIDx);
   }
 
   // Return if this iter domain is mapped to a block dimension
   bool isThreadDim() const {
-    return parallelType() == ParallelType::TIDz ||
-        parallelType() == ParallelType::TIDy ||
-        parallelType() == ParallelType::TIDx;
+    return (
+        getParallelType() == ParallelType::TIDz ||
+        getParallelType() == ParallelType::TIDy ||
+        getParallelType() == ParallelType::TIDx);
   }
 
   // Return if this iter domain is either mapped to a block or grid dimension
   bool isThread() const {
-    return isBlockDim() || isThreadDim();
+    return (isBlockDim() || isThreadDim());
   }
 
-  ParallelType parallelType() const {
+  ParallelType getParallelType() const {
     return parallel_type_;
   }
 
-  IterType iterType() const {
+  IterType getIterType() const {
     return iter_type_;
   }
 
@@ -731,8 +222,8 @@ class TORCH_CUDA_CU_API IterDomain final : public Val {
 
   Val* extent() const;
 
-  bool isSimple() const {
-    return is_simple_;
+  Val* rawExtent() const {
+    return extent_;
   }
 
  private:
@@ -741,37 +232,20 @@ class TORCH_CUDA_CU_API IterDomain final : public Val {
   ParallelType parallel_type_ = ParallelType::Serial;
   IterType iter_type_ = IterType::Iteration;
   bool is_rfactor_domain_ = false;
-
-  // An IterDomain is "simple" if the original Fusion IterDomain
-  // doesn't have a definition ("definition" expression)
-  //
-  // TODO(kir): this feels like a hack, revisit
-  //
-  bool is_simple_ = true;
 };
 
-// TODO(kir): is this really a value?
-class TORCH_CUDA_CU_API TensorDomain final : public Val {
+class TORCH_CUDA_CU_API TensorDomain : public Val {
  public:
   explicit TensorDomain(Passkey, std::vector<IterDomain*> domain);
 
   explicit TensorDomain(
-      Passkey passkey,
+      Passkey,
       const fuser::cuda::TensorDomain* tensor_domain);
 
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
-
   std::vector<IterDomain*>::size_type nDims() const {
     return domain_.size();
   }
 
-  // TODO(kir): rename this
   const std::vector<IterDomain*>& domain() const {
     return domain_;
   }
@@ -794,7 +268,6 @@ class TORCH_CUDA_CU_API TensorDomain final : public Val {
   bool hasBlockBroadcast() const;
   bool hasBroadcast() const;
   bool hasRFactor() const;
-  bool hasVectorize() const;
 
   const std::vector<IterDomain*>& noReductions() const {
     return no_reduction_domain_;
@@ -832,36 +305,21 @@ class TORCH_CUDA_CU_API TensorDomain final : public Val {
   const std::vector<bool> contiguity_;
 };
 
-class TORCH_CUDA_CU_API TensorView final : public Val {
+class TORCH_CUDA_CU_API TensorView : public Val {
  public:
   explicit TensorView(Passkey, const fuser::cuda::TensorView* tv);
 
-  TensorView(
-      Passkey,
-      DataType dtype,
-      TensorDomain* domain,
-      MemoryType memory_type);
-
   TensorDomain* domain() const {
     return domain_;
   }
 
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
-
   MemoryType memoryType() const {
     return memory_type_;
   }
 
-  fuser::cuda::TensorView* fuserTv() const {
+  const fuser::cuda::TensorView* fuserTv() const {
     TORCH_INTERNAL_ASSERT(fuser_tv_ != nullptr);
-    // TODO(kir): remove the need for const_cast
-    return const_cast<fuser::cuda::TensorView*>(fuser_tv_); // NOLINT
+    return fuser_tv_;
   }
 
  private:
@@ -872,17 +330,9 @@ class TORCH_CUDA_CU_API TensorView final : public Val {
   const fuser::cuda::TensorView* fuser_tv_ = nullptr;
 };
 
-class TORCH_CUDA_CU_API UnaryOp final : public Expr {
+class TORCH_CUDA_CU_API UnaryOp : public Expr {
  public:
-  UnaryOp(Passkey passkey, UnaryOpType operation, Val* out, Val* in);
-
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
+  UnaryOp(Passkey, UnaryOpType type, Val* out, Val* in);
 
   Val* out() const {
     return out_;
@@ -892,32 +342,19 @@ class TORCH_CUDA_CU_API UnaryOp final : public Expr {
     return in_;
   }
 
-  UnaryOpType operation() const {
-    return operation_;
+  UnaryOpType getUnaryOpType() const {
+    return unary_op_type_;
   }
 
  private:
-  const UnaryOpType operation_;
+  const UnaryOpType unary_op_type_;
   Val* const out_ = nullptr;
   Val* const in_ = nullptr;
 };
 
-class TORCH_CUDA_CU_API BinaryOp final : public Expr {
+class TORCH_CUDA_CU_API BinaryOp : public Expr {
  public:
-  BinaryOp(
-      Passkey passkey,
-      BinaryOpType operation,
-      Val* out,
-      Val* lhs,
-      Val* rhs);
-
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
+  BinaryOp(Passkey, BinaryOpType type, Val* out, Val* lhs, Val* rhs);
 
   Val* out() const {
     return out_;
@@ -931,35 +368,27 @@ class TORCH_CUDA_CU_API BinaryOp final : public Expr {
     return rhs_;
   }
 
-  BinaryOpType operation() const {
-    return operation_;
+  BinaryOpType getBinaryOpType() const {
+    return binary_op_type_;
   }
 
  private:
-  const BinaryOpType operation_;
+  const BinaryOpType binary_op_type_;
   Val* const out_ = nullptr;
   Val* const lhs_ = nullptr;
   Val* const rhs_ = nullptr;
 };
 
-class TORCH_CUDA_CU_API TernaryOp final : public Expr {
+class TORCH_CUDA_CU_API TernaryOp : public Expr {
  public:
   TernaryOp(
-      Passkey passkey,
-      TernaryOpType operation,
+      Passkey,
+      TernaryOpType type,
       Val* out,
       Val* in1,
       Val* in2,
       Val* in3);
 
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
-
   Val* out() const {
     return out_;
   }
@@ -976,34 +405,27 @@ class TORCH_CUDA_CU_API TernaryOp final : public Expr {
     return in3_;
   }
 
-  TernaryOpType operation() const {
-    return operation_;
+  TernaryOpType getTernaryOpType() const {
+    return ternary_op_type_;
   }
 
  private:
-  const TernaryOpType operation_;
+  const TernaryOpType ternary_op_type_;
   Val* const out_ = nullptr;
   Val* const in1_ = nullptr;
   Val* const in2_ = nullptr;
   Val* const in3_ = nullptr;
 };
 
-class TORCH_CUDA_CU_API ReductionOp final : public Expr {
+class TORCH_CUDA_CU_API ReductionOp : public Expr {
  public:
   ReductionOp(
-      Passkey passkey,
-      BinaryOpType operation,
+      Passkey,
+      BinaryOpType reduction_op_type,
       Val* init,
       Val* out,
-      Val* in);
-
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
+      Val* in,
+      Bool* pred = nullptr);
 
   Val* out() const {
     return out_;
@@ -1017,8 +439,12 @@ class TORCH_CUDA_CU_API ReductionOp final : public Expr {
     return init_;
   }
 
-  BinaryOpType operation() const {
-    return operation_;
+  Bool* pred() const {
+    return pred_;
+  }
+
+  BinaryOpType getReductionOpType() const {
+    return reduction_op_type_;
   }
 
   std::unordered_map<ParallelType, IterDomain*, TypeHash>
@@ -1028,113 +454,20 @@ class TORCH_CUDA_CU_API ReductionOp final : public Expr {
   std::vector<IterDomain*> getReductionDomains() const;
 
  private:
-  const BinaryOpType operation_;
+  const BinaryOpType reduction_op_type_;
   Val* const init_ = nullptr;
   Val* const out_ = nullptr;
   Val* const in_ = nullptr;
+  Bool* const pred_ = nullptr;
 };
 
-class TORCH_CUDA_CU_API WelfordOp final : public Expr {
- public:
-  WelfordOp(
-      Passkey passkey,
-      Val* out_var,
-      Val* out_avg,
-      Val* out_N,
-      Val* init_var,
-      Val* init_avg,
-      Val* init_N,
-      Val* in_var,
-      Val* in_avg,
-      Val* in_N);
-
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
-
-  Val* out() const {
-    return out_avg_;
-  }
-
-  Val* in() const {
-    return in_avg_;
-  }
-
-  // Welford Specific accessors
-  // Almost wanted to add a new struct for {var, avg, N}
-  Val* outVar() const {
-    return out_var_;
-  }
-
-  Val* outAvg() const {
-    return out_avg_;
-  }
-
-  Val* outN() const {
-    return out_N_;
-  }
-
-  Val* initVar() const {
-    return init_var_;
-  }
-
-  Val* initAvg() const {
-    return init_avg_;
-  }
-
-  Val* initN() const {
-    return init_N_;
-  }
-
-  Val* inVar() const {
-    return in_var_;
-  }
-
-  Val* inAvg() const {
-    return in_avg_;
-  }
-
-  Val* inN() const {
-    return in_N_;
-  }
-
-  std::unordered_map<ParallelType, IterDomain*, TypeHash>
-  getParallelReductionDomains() const;
-
- private:
-  std::vector<IterDomain*> getReductionDomains() const;
-
- private:
-  Val* const out_var_;
-  Val* const out_avg_;
-  Val* const out_N_;
-  Val* const init_var_;
-  Val* const init_avg_;
-  Val* const init_N_;
-  Val* const in_var_;
-  Val* const in_avg_;
-  Val* const in_N_;
-};
-
-class TORCH_CUDA_CU_API TensorIndex final : public Val {
+class TORCH_CUDA_CU_API TensorIndex : public Val {
  public:
   TensorIndex(
       Passkey,
       const fuser::cuda::TensorView* view,
       std::vector<Val*> indices);
 
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
-
   std::vector<Val*>::size_type nDims() const {
     return indices_.size();
   }
@@ -1145,10 +478,8 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val {
     return indices_;
   }
 
-  TensorView* view() const {
-    TORCH_INTERNAL_ASSERT(view_ != nullptr);
-    // TODO(kir): remove the need for const_cast
-    return const_cast<fuser::cuda::kir::TensorView*>(view_); // NOLINT
+  const TensorView* view() const {
+    return view_;
   }
 
  private:
@@ -1156,17 +487,9 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val {
   std::vector<Val*> indices_;
 };
 
-class TORCH_CUDA_CU_API BroadcastOp final : public Expr {
+class TORCH_CUDA_CU_API BroadcastOp : public Expr {
  public:
-  BroadcastOp(Passkey passkey, Val* out, Val* in);
-
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
+  BroadcastOp(Passkey, Val* out, Val* in);
 
   Val* out() const {
     return out_;
@@ -1181,49 +504,27 @@ class TORCH_CUDA_CU_API BroadcastOp final : public Expr {
   Val* const in_ = nullptr;
 };
 
-//! Allocate is a lower level Node that describes a buffer of memory that
-//! is required as an intermediate within a kernel. The extent is the expression
-//! of the size of the buffer that is generated from the TensorView that
-//! describes the output of an operation.
-//!
-//! TODO(kir): The components of Allocate like Type and Name could be separated
-//!   from the the assocated TensorView.  Perhaps that is more appropriate?
-//!
-class TORCH_CUDA_CU_API Allocate final : public Expr {
+// Allocate is a lower level Node that describes a buffer of memory that
+// is required as an intermediate within a kernel.  The extent is the expression
+// of the size of the buffer that is generated from the TensorView that
+// describes the output of an operation.
+//
+// TODO: The components of Allocate like Type and Name could be separated from
+// the the assocated TensorView.  Perhaps that is more appropriate?
+class TORCH_CUDA_CU_API Allocate : public Expr {
  public:
-  //! Allocation of a multi-dimensional buffer
-  //!
-  //! param shape Size of each dimension
   explicit Allocate(
-      Passkey passkey,
-      Val* buffer,
-      MemoryType memory_type,
-      std::vector<Val*> shape = {},
-      bool zero_init = false);
-
-  //! Allocation of a non-dimensional buffer
-  //!
-  //! param size Size of allocation
-  explicit Allocate(
-      Passkey passkey,
+      Passkey,
       Val* buffer,
-      MemoryType memory_type,
-      Val* size,
+      MemoryType memory_type = MemoryType::Local,
+      Val* size = nullptr,
       bool zero_init = false);
 
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
-
   Val* buffer() const {
     return buffer_;
   }
 
-  MemoryType memoryType() const {
+  MemoryType getMemoryType() const {
     return memory_type_;
   }
 
@@ -1231,53 +532,38 @@ class TORCH_CUDA_CU_API Allocate final : public Expr {
     return size_;
   }
 
-  const std::vector<Val*>& shape() const {
-    return shape_;
-  }
-
   bool zeroInit() const {
     return zero_init_;
   }
 
-  const Allocate* alias() const {
+  DataType buffer_type() const {
+    return buffer_->getDataType().value();
+  }
+
+  Allocate* alias() const {
     return alias_;
   }
 
-  void setAlias(const Allocate* alias) {
-    TORCH_INTERNAL_ASSERT(alias != this);
-    TORCH_INTERNAL_ASSERT(alias->memoryType() == memory_type_);
+  void setAlias(Allocate* alias) {
+    TORCH_INTERNAL_ASSERT(alias->getMemoryType() == memory_type_);
     alias_ = alias;
   }
 
  private:
   Val* buffer_ = nullptr;
   MemoryType memory_type_ = MemoryType::Local;
-  //! Size of each dimension
-  std::vector<Val*> shape_;
-  bool zero_init_ = false;
-  //! Total size
   Val* size_ = nullptr;
+  bool zero_init_ = false;
 
   // This alias tracks the next Allocate node in a linked chain of aliases
   // If the alias is nullptr, then the Allocate node uses memory in the kernel
-  const Allocate* alias_ = nullptr;
+  Allocate* alias_ = nullptr;
 };
 
 // Sync represents __syncthreads barrier for block level coordination.
-//
-// TODO(kir): change name to SyncThreads as we could have other barriers.
-//
-class TORCH_CUDA_CU_API Sync final : public Expr {
+class TORCH_CUDA_CU_API Sync : public Expr {
  public:
-  explicit Sync(Passkey passkey, bool war_sync = false);
-
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
+  explicit Sync(Passkey, bool war_sync = false);
 
   bool isWarHazardSync() const {
     return war_sync_;
@@ -1288,43 +574,26 @@ class TORCH_CUDA_CU_API Sync final : public Expr {
   bool war_sync_ = false;
 };
 
-// Simply prints "DEFINE_MAGIC_ZERO" in the code in accordance with magic_zero
-// in helpers.cu
-class TORCH_CUDA_CU_API InitMagicZero final : public Expr {
+// TODO(kir): promote to IR node
+// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+class TORCH_CUDA_CU_API Scope {
  public:
-  explicit InitMagicZero(Passkey passkey);
+  Scope() = default;
 
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
+  const std::vector<Expr*>& exprs() const {
+    return exprs_;
   }
-};
-
-// Simply prints "UPDATE_MAGIC_ZERO" in the code in accordance with magic_zero
-// in helpers.cu
-class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr {
- public:
-  explicit UpdateMagicZero(Passkey passkey);
 
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
+  void push_back(Expr* e) {
+    exprs_.push_back(e);
   }
 
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
+  void insert(size_t pos, Expr* expr) {
+    exprs_.insert(exprs_.begin() + pos, expr);
   }
-};
 
-// TODO(kir): promote to IR node
-class TORCH_CUDA_CU_API Scope {
- public:
-  explicit Scope(Expr* owner) : owner_(owner) {}
-
-  const std::vector<Expr*>& exprs() const {
-    return exprs_;
+  void erase(size_t pos) {
+    exprs_.erase(exprs_.begin() + pos);
   }
 
   bool empty() const {
@@ -1343,100 +612,37 @@ class TORCH_CUDA_CU_API Scope {
     return exprs_[i];
   }
 
-  // Insert expr before expression at pos
-  void insert(size_t pos, Expr* expr);
-
   // Insert expr before ref
   void insert_before(Expr* ref, Expr* expr);
 
   // Insert expr after ref
   void insert_after(Expr* ref, Expr* expr);
 
-  void push_back(Expr* e) {
-    exprs_.push_back(e);
-    e->setScope(this);
-  }
-
-  // Erase expr at pos
-  void erase(size_t pos);
+  bool contains(Expr* expr) const;
 
-  // Erase expr ref
   void erase(Expr* ref);
 
-  bool contains(Expr* expr) const;
-
   void clear();
 
-  Expr* owner() const {
-    return owner_;
-  }
-
- private:
-  // Insert expr before pos
-  void insert(std::vector<Expr*>::const_iterator pos, Expr* expr);
-
-  // Erase expr at pos
-  void erase(std::vector<Expr*>::const_iterator pos);
-
  private:
   std::vector<Expr*> exprs_;
-
-  //! Owner exprssion of this scope, e.g., IfThenElse
-  Expr* owner_ = nullptr;
 };
 
-//! ForLoop provides scoping around an int iterator from 0 to range. Exprs
-//! placed in its body are considered inside the scope of the for loop. In the
-//! future the implementation should look quite different so that we can do
-//! proper dependency annalysis like in Fusion.
-//!
-//! TODO(kir): this is not a real expression
-//!
-//! ForLoop may represent a part of an iteration domain representend
-//! by iter_domain_. In that case, the loop extent field, extent_, may
-//! be smaller than the extent of iter_domain_.
-class TORCH_CUDA_CU_API ForLoop final : public Expr {
+// ForLoop provides scoping around an int iterator from 0 to range. Exprs placed
+// in its body are considered inside the scope of the for loop. In the future
+// the implementation should look quite different so that we can do proper
+// dependency annalysis like in Fusion.
+//
+// TODO(kir): this is not a real expression
+//
+class TORCH_CUDA_CU_API ForLoop : public Expr {
  public:
-  //! By default, start and stop are the same as those of iter_domain.
-  //! Step is one by default.
-  //!
-  //! TODO: cleaner way to set options?
-  ForLoop(
-      Passkey passkey,
-      IterDomain* iter_domain,
-      Val* index,
-      Val* start,
-      Val* stop,
-      Val* step,
-      bool vectorize,
-      Val* vectorize_shift);
-
-  ForLoop(Passkey passkey, IterDomain* iter_domain);
-
-  ForLoop(Passkey passkey, const ForLoop* other);
-
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
+  ForLoop(Passkey, Val* index, IterDomain* iter_domain, Expr* parent_scope);
 
   Val* index() const {
     return index_;
   }
 
-  Val* start() const;
-
-  Val* stop() const;
-
-  Val* step() const;
-
-  Val* vectorize_shift() const {
-    return vectorize_shift_;
-  }
-
   IterDomain* iter_domain() const {
     return iter_domain_;
   }
@@ -1449,55 +655,32 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr {
     return body_;
   }
 
-  bool vectorize() const {
-    return vectorize_;
+  Expr* parentScope() const {
+    return parent_scope_;
   }
 
-  // Returns if a loop could be unrolled. Start and stop must be constant, it
-  // must not be a broadcast dimension, cannot be bound to a parallel dimension,
-  // and returns false if start is 0 and stop is 1.
-  bool isUnrollable() const {
-    return start()->isConstScalar() && stop()->isConstScalar() &&
-        !iter_domain()->isThread() && !iter_domain()->isBroadcast() &&
-        !(start()->isZeroInt() && stop()->isOneInt()) &&
-        iter_domain()->parallelType() != ParallelType::Vectorize;
-  }
+  void setParentScope(Expr* scope);
 
  private:
-  IterDomain* const iter_domain_ = nullptr;
-
-  Val* index_ = nullptr;
-  Val* start_ = nullptr;
-  Val* stop_ = nullptr;
-  Val* step_ = nullptr;
-
-  // vectorize is true when the for-loop contains a vectorize set
-  // the flag is used to omit the for-loop from the kernel
-  bool vectorize_ = false;
-  // [pre | vectorize | post] <= inner-most, merged root domain
-  // shift_ is applied to vectorize and post sections.
-  Val* vectorize_shift_ = nullptr;
-
+  Val* const index_ = nullptr;
+  IterDomain* const iter_domain_;
   Scope body_;
+  Expr* parent_scope_ = nullptr;
 };
 
-//! IfThenElse provides scoping for an boolean operator. Exprs placed in its
-//! body are considered inside the scope of the if statement. In the future the
-//! implementation should look quite different so that we can do proper
-//! dependency annalysis like in Fusion.
-//!
-//! TODO(kir): this is not a real expression
-//!
-class TORCH_CUDA_CU_API IfThenElse final : public Expr {
+// IfThenElse provides scoping for an boolean operator. Exprs placed in its body
+// are considered inside the scope of the if statement. In the future the
+// implementation should look quite different so that we can do proper
+// dependency annalysis like in Fusion.
+//
+// TODO(kir): this is not a real expression
+//
+class TORCH_CUDA_CU_API IfThenElse : public Expr {
  public:
-  explicit IfThenElse(Passkey passkey, Predicate* cond);
+  explicit IfThenElse(Passkey, Bool* cond, Expr* parent_scope);
 
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
+  Bool* cond() const {
+    return cond_;
   }
 
   Scope& thenBody() {
@@ -1519,35 +702,33 @@ class TORCH_CUDA_CU_API IfThenElse final : public Expr {
     return !else_body_.empty();
   }
 
+  Expr* parentScope() const {
+    return parent_scope_;
+  }
+
+  void setParentScope(Expr* scope);
+
  private:
+  Bool* const cond_ = nullptr;
   Scope then_body_;
   Scope else_body_;
+  Expr* parent_scope_ = nullptr;
 };
 
-//! Grid reduction operation
-//!
-//! This node is used only after lowering a fusion to explicitly mark a grid
-//! reduction and the buffer allocation needed to do it.
-//!
-//! This node provides FusionExecutor the information it needs to allocate the
-//! reduction and sync buffers.
-class TORCH_CUDA_CU_API GridReduction final : public Expr {
+// Grid reduction operation, this node is used only after lowering a fusion to
+// explicitly mark a grid reduction and the buffer allocation needed to do it.
+// This node provides FusionExecutor the information it needs to allocate the
+// reduction and sync buffers.
+class TORCH_CUDA_CU_API GridReduction : public Expr {
  public:
-  explicit GridReduction(Passkey passkey, ReductionOp* reduction_op);
-
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
+  explicit GridReduction(Passkey, ReductionOp* reduction_op);
 
   GridReduction(
-      Passkey passkey,
+      Passkey,
       ReductionOp* reduction_op,
       Allocate* reduction_buffer,
-      Allocate* sync_buffer);
+      Allocate* sync_buffer,
+      Bool* pred = nullptr);
 
   ReductionOp* reduction_op() const {
     return reduction_op_;
@@ -1561,12 +742,8 @@ class TORCH_CUDA_CU_API GridReduction final : public Expr {
     return sync_buffer_;
   }
 
-  const ParallelTypeBitmap& threadPredicate() const {
-    return thread_predicate_;
-  }
-
-  void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) {
-    thread_predicate_ = thread_predicate;
+  Bool* pred() const {
+    return pred_;
   }
 
   static std::string getPredicateFlagName(const TensorView* val);
@@ -1576,78 +753,7 @@ class TORCH_CUDA_CU_API GridReduction final : public Expr {
   ReductionOp* reduction_op_ = nullptr;
   Allocate* reduction_buffer_ = nullptr;
   Allocate* sync_buffer_ = nullptr;
-  // gridReduce has template flags for thread predicates. In order to
-  // use them, the thread predicate is held here separately from
-  // Expr::predicate_.
-  ParallelTypeBitmap thread_predicate_;
-};
-
-//! Grid welford operation
-//!
-//! This node is used only after lowering a fusion to explicitly mark a grid
-//! reduction and the buffer allocation needed to do it.
-//!
-//! This node provides FusionExecutor the information it needs to allocate the
-//! reduction and sync buffers.
-class TORCH_CUDA_CU_API GridWelford final : public Expr {
- public:
-  void accept(IrVisitor* visitor) const override {
-    visitor->visit(this);
-  }
-
-  void accept(MutableIrVisitor* visitor) override {
-    visitor->visit(this);
-  }
-
-  GridWelford(
-      Passkey passkey,
-      WelfordOp* welford_op,
-      Allocate* var_buffer,
-      Allocate* avg_buffer,
-      Allocate* n_buffer,
-      Allocate* sync_buffer);
-
-  WelfordOp* welford_op() const {
-    return welford_op_;
-  }
-
-  Allocate* var_buffer() const {
-    return var_buffer_;
-  }
-
-  Allocate* avg_buffer() const {
-    return avg_buffer_;
-  }
-
-  Allocate* N_buffer() const {
-    return n_buffer_;
-  }
-
-  Allocate* sync_buffer() const {
-    return sync_buffer_;
-  }
-
-  const ParallelTypeBitmap& threadPredicate() const {
-    return thread_predicate_;
-  }
-
-  void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) {
-    thread_predicate_ = thread_predicate;
-  }
-
-  static std::string getPredicateFlagName(const TensorView* val);
-  static std::string getPredicateFlagName(const fuser::cuda::TensorView* val);
-
- private:
-  WelfordOp* welford_op_ = nullptr;
-  Allocate* var_buffer_ = nullptr;
-  Allocate* avg_buffer_ = nullptr;
-  Allocate* n_buffer_ = nullptr;
-  Allocate* sync_buffer_ = nullptr;
-  // gridReduce has template flags for thread predicates. In order to
-  // use them, the thread predicate is held here separately from
-  // Expr::predicate_.
-  ParallelTypeBitmap thread_predicate_;
+  Bool* pred_ = nullptr;
 };
 
 } // namespace kir
index 7914fa7..86bc00c 100644 (file)
@@ -6,12 +6,43 @@ namespace fuser {
 namespace cuda {
 namespace kir {
 
-Val* IrBuilder::newResult(DataType dtype) {
-  switch (dtype) {
+bool isLoweredScalar(const Val* val) {
+  switch (val->getValType().value()) {
+    case ValType::KirNamedScalar:
+    case ValType::KirScalar:
+      return true;
+    default:
+      return false;
+  }
+}
+
+bool isLoweredVal(const Val* val) {
+  switch (val->getValType().value()) {
+    case ValType::TensorIndex:
+    case ValType::KirNamedScalar:
+    case ValType::KirScalar:
+    case ValType::KirTensorDomain:
+    case ValType::KirIterDomain:
+    case ValType::KirTensorView:
+      return true;
+    default:
+      return false;
+  }
+}
+
+Val* IrBuilder::newResult(const Val* lhs, const Val* rhs) {
+  TORCH_CHECK(isLoweredScalar(lhs));
+  TORCH_CHECK(isLoweredScalar(rhs));
+  TORCH_CHECK(lhs->getDataType() == rhs->getDataType());
+
+  // Allocate a compatible result value
+  switch (lhs->getDataType().value()) {
     case DataType::Bool:
       return create<Bool>(c10::nullopt);
-    case DataType::Double:
-      return create<Double>(c10::nullopt);
+    case DataType::Float:
+      return create<Float>(c10::nullopt);
+    case DataType::Half:
+      return create<Half>(c10::nullopt);
     case DataType::Int:
       return create<Int>(c10::nullopt);
     default:
@@ -20,8 +51,7 @@ Val* IrBuilder::newResult(DataType dtype) {
 }
 
 Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
-  TORCH_CHECK(lhs->dtype() == rhs->dtype(), "Incompatible operand types");
-  auto result = newResult(lhs->dtype());
+  auto result = newResult(lhs, rhs);
   create<BinaryOp>(op_type, result, lhs, rhs);
   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
   return result;
@@ -34,31 +64,6 @@ Val* IrBuilder::newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
   return result;
 }
 
-Val* IrBuilder::whereExpr(Val* pred, Val* lhs, Val* rhs) {
-  TORCH_CHECK(lhs->dtype() == rhs->dtype(), "Incompatible operand types");
-  auto result = newResult(lhs->dtype());
-  create<TernaryOp>(TernaryOpType::Where, result, pred, lhs, rhs);
-  return result;
-}
-
-Val* IrBuilder::negExpr(Val* val) {
-  auto result = newResult(val->dtype());
-  create<UnaryOp>(UnaryOpType::Neg, result, val);
-  return result;
-}
-
-Val* IrBuilder::setExprNamedScalar(const std::string& name, Val* val) {
-  auto result = create<NamedScalar>(name, val->dtype());
-  create<UnaryOp>(UnaryOpType::Set, result, val);
-  return result;
-}
-
-Val* IrBuilder::addressExprNamedScalar(const std::string& name, Val* val) {
-  auto result = create<NamedScalar>(name, DataType::Int);
-  create<UnaryOp>(UnaryOpType::Address, result, val);
-  return result;
-}
-
 Val* IrBuilder::andExpr(Val* lhs, Val* rhs) {
   return newLogicExpr(BinaryOpType::And, lhs, rhs);
 }
@@ -67,22 +72,10 @@ Val* IrBuilder::eqExpr(Val* lhs, Val* rhs) {
   return newLogicExpr(BinaryOpType::Eq, lhs, rhs);
 }
 
-Val* IrBuilder::gtExpr(Val* lhs, Val* rhs) {
-  return newLogicExpr(BinaryOpType::GT, lhs, rhs);
-}
-
 Val* IrBuilder::ltExpr(Val* lhs, Val* rhs) {
   return newLogicExpr(BinaryOpType::LT, lhs, rhs);
 }
 
-Val* IrBuilder::leExpr(Val* lhs, Val* rhs) {
-  return newLogicExpr(BinaryOpType::LE, lhs, rhs);
-}
-
-Val* IrBuilder::geExpr(Val* lhs, Val* rhs) {
-  return newLogicExpr(BinaryOpType::GE, lhs, rhs);
-}
-
 Val* IrBuilder::addExpr(Val* lhs, Val* rhs) {
   return newArithmeticExpr(BinaryOpType::Add, lhs, rhs);
 }
@@ -107,49 +100,6 @@ Val* IrBuilder::modExpr(Val* lhs, Val* rhs) {
   return newArithmeticExpr(BinaryOpType::Mod, lhs, rhs);
 }
 
-Val* IrBuilder::maxExpr(Val* lhs, Val* rhs) {
-  return newArithmeticExpr(BinaryOpType::Max, lhs, rhs);
-}
-
-Val* IrBuilder::minExpr(Val* lhs, Val* rhs) {
-  return newArithmeticExpr(BinaryOpType::Min, lhs, rhs);
-}
-
-Int* IrBuilder::zeroVal() {
-  if (zero_ == nullptr) {
-    zero_ = create<kir::Int>(0);
-  }
-  return zero_;
-}
-
-Int* IrBuilder::oneVal() {
-  if (one_ == nullptr) {
-    one_ = create<kir::Int>(1);
-  }
-  return one_;
-}
-
-Bool* IrBuilder::falseVal() {
-  if (false_ == nullptr) {
-    false_ = create<kir::Bool>(false);
-  }
-  return false_;
-}
-
-Bool* IrBuilder::trueVal() {
-  if (true_ == nullptr) {
-    true_ = create<kir::Bool>(true);
-  }
-  return true_;
-}
-
-NamedScalar* IrBuilder::magicZeroVal() {
-  if (magic_zero_ == nullptr) {
-    magic_zero_ = create<kir::NamedScalar>("nvfuser_zero", DataType::Int);
-  }
-  return magic_zero_;
-}
-
 } // namespace kir
 } // namespace cuda
 } // namespace fuser
index 70925f1..0af37c8 100644 (file)
@@ -12,6 +12,10 @@ namespace fuser {
 namespace cuda {
 namespace kir {
 
+// Simple classification helpers
+bool isLoweredScalar(const Val* val);
+bool isLoweredVal(const Val* val);
+
 //! Kernel IR builder interface
 //!
 //! The only way to create new Kernel IR nodes is through the
@@ -35,7 +39,7 @@ namespace kir {
 //!   auto new_node = ir_builder.create<kir::Int>(1));
 //!   auto result = ir_builder.mulExpr(lhs, rhs);
 //!
-class TORCH_CUDA_CU_API IrBuilder {
+class IrBuilder {
  public:
   explicit IrBuilder(Kernel* kernel) : kernel_(kernel) {}
 
@@ -43,60 +47,32 @@ class TORCH_CUDA_CU_API IrBuilder {
   //! to the appropriate constructor
   template <class T, class... Args>
   T* create(Args&&... args) {
-    const kir::Passkey passkey(kernel_);
-    const auto node = new T(passkey, std::forward<Args>(args)...);
-    kernel_->registerIrNode(passkey, std::unique_ptr<T>(node));
-    return node;
+    // TODO(kir): switch this to Kernel registration
+    return new T(kir::Passkey(), std::forward<Args>(args)...);
   }
 
-  // Unary operations
-  Val* negExpr(Val* val);
-  Val* setExprNamedScalar(const std::string& name, Val* val);
-  Val* addressExprNamedScalar(const std::string& name, Val* val);
-
-  // Binary operations
+  // Binary expressions
   Val* andExpr(Val* lhs, Val* rhs);
   Val* eqExpr(Val* lhs, Val* rhs);
-  Val* gtExpr(Val* lhs, Val* rhs);
   Val* ltExpr(Val* lhs, Val* rhs);
-  Val* leExpr(Val* lhs, Val* rhs);
-  Val* geExpr(Val* lhs, Val* rhs);
   Val* addExpr(Val* lhs, Val* rhs);
   Val* subExpr(Val* lhs, Val* rhs);
   Val* mulExpr(Val* lhs, Val* rhs);
   Val* divExpr(Val* lhs, Val* rhs);
   Val* ceilDivExpr(Val* lhs, Val* rhs);
   Val* modExpr(Val* lhs, Val* rhs);
-  Val* maxExpr(Val* lhs, Val* rhs);
-  Val* minExpr(Val* lhs, Val* rhs);
-
-  // Ternary operations
-  Val* whereExpr(Val* pred, Val* lhs, Val* rhs);
-
-  // Shortcuts for frequently used vals
-  Int* zeroVal();
-  Int* oneVal();
-  Bool* falseVal();
-  Bool* trueVal();
-
-  NamedScalar* magicZeroVal();
 
  private:
-  Val* newResult(DataType dtype);
+  Val* newResult(const Val* lhs, const Val* rhs);
   Val* newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs);
   Val* newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs);
 
  private:
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wunused-private-field"
   // Non-owning pointer to the kernel to be modified
   Kernel* kernel_ = nullptr;
-  // Frequently used constant vals
-  Int* zero_ = nullptr;
-  Int* one_ = nullptr;
-  Bool* false_ = nullptr;
-  Bool* true_ = nullptr;
-
-  // Magic zero corresponds to runtime/helpers.cu magic_zero
-  NamedScalar* magic_zero_ = nullptr;
+#pragma clang diagnostic pop
 };
 
 } // namespace kir
index e00da31..1b474e5 100644 (file)
@@ -4,7 +4,7 @@
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/type.h>
 
-#include <utility>
+#include <sstream>
 
 namespace torch {
 namespace jit {
@@ -12,28 +12,12 @@ namespace fuser {
 namespace cuda {
 namespace kir {
 
-namespace {
-
-const char* boolLiteral(bool value) {
+static std::string boolLiteral(bool value) {
   return value ? "true" : "false";
 }
 
-std::string varName(const kir::Val* val, const char* prefix) {
-  std::stringstream value_name;
-  if (val == nullptr) {
-    value_name << "$nullptr";
-  } else if (val->name() != kInvalidStmName) {
-    value_name << prefix << val->name();
-  } else {
-    value_name << "k" << prefix << val->id();
-  }
-  return value_name.str();
-}
-
-} // namespace
-
-void IrPrinter::printNode(const kir::Node* node) {
-  os_ << gen(node, true);
+void IrPrinter::printNode(const Statement* stmt) {
+  handle(stmt);
 }
 
 void IrPrinter::printKernel(const Kernel* kernel) {
@@ -59,7 +43,7 @@ void IrPrinter::printKernel(const Kernel* kernel) {
   // kernel body
   startBlock();
   for (auto expr : kernel->topLevelExprs()) {
-    os_ << gen(expr, true);
+    handle(expr);
   }
   endBlock();
   os_ << "END.\n\n";
@@ -68,69 +52,16 @@ void IrPrinter::printKernel(const Kernel* kernel) {
 std::ostream& IrPrinter::indent() {
   for (const auto i : c10::irange(indent_level_)) {
     (void)i; // Suppress unused variable warning
-    ir_str_ << kTab;
-  }
-  ir_str_ << margin_;
-  return ir_str_;
-}
-
-std::string IrPrinter::gen(const kir::Node* node, bool top_level) {
-  if (node == nullptr) {
-    return "$nullptr";
-  }
-
-  // If we're generatign a top level statement we expect to start
-  // with an empty set of uses
-  TORCH_INTERNAL_ASSERT(!implicit_definition_ || uses_.empty() || !top_level);
-
-  // Mark the node as generated
-  visited_.insert(node);
-
-  // Generate the node itself
-  std::stringstream node_str;
-  std::swap(node_str, ir_str_);
-  node->accept(this);
-  std::swap(node_str, ir_str_);
-
-  if (!implicit_definition_) {
-    return node_str.str();
-  }
-
-  if (top_level) {
-    // Implicitly mark top level nodes as used, so we
-    // get their definitions printed (useful for debugging)
-    if (auto val = dynamic_cast<const kir::Val*>(node)) {
-      uses_.insert(val);
-    }
-
-    // Make a copy of the node uses (and reset global state)
-    const auto node_uses = uses_;
-    uses_.clear();
-
-    std::stringstream top_level_str;
-
-    // Hoist implicit definitions
-    for (auto use : node_uses) {
-      const auto def = use->definition();
-      if (def && visited_.find(def) == visited_.end()) {
-        margin_ = "~ ";
-        top_level_str << gen(def, true);
-        margin_ = "";
-      }
-    }
-
-    top_level_str << node_str.str();
-    return top_level_str.str();
-  } else {
-    return node_str.str();
+    os_ << kTab;
   }
+  return os_;
 }
 
-std::string IrPrinter::use(const kir::Val* val) {
-  if (val != nullptr) {
-    uses_.insert(val);
-  }
-  return gen(val);
+std::string IrPrinter::gen(const Statement* stmt) {
+  std::stringstream ss;
+  IrPrinter ir_printer(ss);
+  ir_printer.handle(stmt);
+  return ss.str();
 }
 
 void IrPrinter::startBlock() {
@@ -143,258 +74,172 @@ void IrPrinter::endBlock() {
 }
 
 void IrPrinter::handleBlock(const kir::Scope& scope) {
-  // Save the uses of the parent scope
-  decltype(uses_) outer_uses;
-  std::swap(uses_, outer_uses);
-
   startBlock();
   for (auto expr : scope.exprs()) {
-    ir_str_ << gen(expr, true);
+    handle(expr);
   }
   endBlock();
+}
+
+void IrPrinter::handle(const Statement* s) {
+  OptInConstDispatch::handle(s);
+}
 
-  // Restore parent's uses
-  std::swap(uses_, outer_uses);
+void IrPrinter::handle(const Val* v) {
+  OptInConstDispatch::handle(v);
 }
 
-void IrPrinter::visit(const kir::Bool* node) {
-  if (node->isConst()) {
-    ir_str_ << boolLiteral(*node->value());
+void IrPrinter::handle(const Expr* e) {
+  OptInConstDispatch::handle(e);
+}
+
+void IrPrinter::handle(const kir::Bool* node) {
+  if (node->isSymbolic()) {
+    os_ << "b" << node->name();
   } else {
-    ir_str_ << varName(node, "b");
+    os_ << boolLiteral(*node->value());
   }
 }
 
-void IrPrinter::visit(const kir::Double* node) {
-  if (node->isConst()) {
-    const int digits = std::numeric_limits<Double::ScalarType>::max_digits10;
-    ir_str_ << "double(" << std::setprecision(digits) << *node->value() << ")";
+void IrPrinter::handle(const kir::Float* node) {
+  if (node->isSymbolic()) {
+    os_ << "f" << node->name();
   } else {
-    ir_str_ << varName(node, "d");
+    const int digits = std::numeric_limits<Float::ScalarType>::max_digits10;
+    os_ << "float(" << std::setprecision(digits) << *node->value() << ")";
   }
 }
 
-void IrPrinter::visit(const kir::Int* node) {
-  if (node->isConst()) {
-    ir_str_ << *node->value();
+void IrPrinter::handle(const kir::Half* node) {
+  if (node->isSymbolic()) {
+    os_ << "h" << node->name();
   } else {
-    ir_str_ << varName(node, "i");
+    os_ << "half(" << *node->value() << ")";
   }
 }
 
-void IrPrinter::visit(const kir::NamedScalar* node) {
-  ir_str_ << node->name();
+void IrPrinter::handle(const kir::Int* node) {
+  if (node->isSymbolic()) {
+    os_ << "i" << node->name();
+  } else {
+    os_ << *node->value();
+  }
 }
 
-void IrPrinter::visit(const kir::Predicate* node) {
-  switch (node->predicate_type()) {
-    case PredicateType::Inline: {
-      ir_str_ << "Inline";
-      break;
-    }
-    case PredicateType::Manual: {
-      ir_str_ << node->value();
-      break;
-    }
-    case PredicateType::Misaligned: {
-      ir_str_ << "Misaligned";
-      break;
-    }
-    case PredicateType::Padding: {
-      ir_str_ << "Padding";
-      break;
-    }
-    case PredicateType::Shift: {
-      ir_str_ << "Shift";
-      break;
-    }
-    case PredicateType::Unswitch: {
-      ir_str_ << "Unswitch";
-      break;
-    }
-    case PredicateType::Vectorize: {
-      ir_str_ << "Vectorize";
-      break;
-    }
-    default:
-      break;
-  }
+void IrPrinter::handle(const kir::NamedScalar* node) {
+  os_ << node->name();
 }
 
-void IrPrinter::visit(const kir::TensorIndex* node) {
-  ir_str_ << gen(node->view()) << "[";
+void IrPrinter::handle(const kir::TensorIndex* node) {
+  os_ << gen(node->view()) << "[";
   for (auto index : node->indices()) {
-    ir_str_ << use(index);
+    os_ << gen(index);
     if (index != node->indices().back()) {
-      ir_str_ << ", ";
+      os_ << ", ";
     }
   }
-  ir_str_ << "]";
+  os_ << "]";
 }
 
-void IrPrinter::visit(const kir::IterDomain* node) {
-  ir_str_ << varName(node, "id") << "[";
+void IrPrinter::handle(const kir::IterDomain* node) {
   if (node->isRFactorProduct()) {
-    ir_str_ << "rfactor.";
+    os_ << "rfactor.";
   }
-  ir_str_ << node->parallelType() << "." << node->iterType() << "("
-          << use(node->start()) << " .. " << use(node->extent()) << ")]";
+  os_ << node->getParallelType() << "." << node->getIterType() << "("
+      << gen(node->start()) << " .. " << gen(node->rawExtent()) << ")";
 }
 
-void IrPrinter::visit(const kir::TensorDomain*) {
+void IrPrinter::handle(const kir::TensorDomain*) {
   // TODO(kir): print Tensor shapes?
-  ir_str_ << "kir::TensorDomain";
+  os_ << "kir::TensorDomain";
 }
 
-void IrPrinter::visit(const kir::TensorView* node) {
-  // TODO(kir): print memory type too?
-  ir_str_ << varName(node, "T");
+void IrPrinter::handle(const kir::TensorView* node) {
+  // TODO(KIR): print memory type too?
+  os_ << "T" << node->name();
 }
 
-void IrPrinter::visit(const kir::UnaryOp* node) {
+void IrPrinter::handle(const kir::UnaryOp* node) {
   indent() << gen(node->out()) << " = ";
 
-  auto op_type = node->operation();
-
-  if (auto op = inline_op_str(op_type)) {
-    if (alsoBooleanOperator(op_type) &&
-        node->out()->dtype() == DataType::Bool) {
-      ir_str_ << stringifyBooleanOp(op_type) << gen(node->in());
-    } else {
-      ir_str_ << *op << gen(node->in());
-    }
+  if (auto op = inline_op_str(node->getUnaryOpType())) {
+    os_ << *op << gen(node->in());
   } else {
-    if (op_type == UnaryOpType::Cast) {
-      const auto cast_str =
-          cast_func_str({node->in()->dtype(), node->out()->dtype()});
-      ir_str_ << cast_str.value();
+    if (node->getUnaryOpType() == UnaryOpType::Cast) {
+      const auto cast_str = cast_func_str(
+          {node->in()->getDataType().value(),
+           node->out()->getDataType().value()});
+      os_ << cast_str.value();
     } else {
-      ir_str_ << op_type;
-      if (needFloatSuffix(op_type) && node->out()->dtype() == DataType::Float) {
-        ir_str_ << "f";
-      }
+      os_ << node->getUnaryOpType();
     }
 
-    if (op_type == UnaryOpType::RandLike) {
-      ir_str_ << "(RND";
+    os_ << "(";
+    if (node->getUnaryOpType() == UnaryOpType::RandLike) {
+      os_ << "RND";
     } else {
-      ir_str_ << "(";
-      ir_str_ << use(node->in());
+      os_ << gen(node->in());
     }
-    ir_str_ << ")";
+    os_ << ")";
   }
 
-  ir_str_ << "\n";
+  os_ << "\n";
 }
 
-void IrPrinter::visit(const kir::BinaryOp* node) {
+void IrPrinter::handle(const kir::BinaryOp* node) {
   indent() << gen(node->out()) << " = ";
 
-  const auto op_type = node->operation();
-  const auto lhs = use(node->lhs());
-  const auto rhs = use(node->rhs());
+  const auto op_type = node->getBinaryOpType();
+  const auto lhs = gen(node->lhs());
+  const auto rhs = gen(node->rhs());
 
   if (auto op = inline_op_str(op_type)) {
-    ir_str_ << lhs << " ";
-    if (alsoBooleanOperator(op_type) &&
-        node->out()->dtype() == DataType::Bool) {
-      ir_str_ << stringifyBooleanOp(op_type);
-    } else {
-      ir_str_ << *op;
-    }
-    ir_str_ << " " << rhs;
+    os_ << lhs << " " << *op << " " << rhs;
   } else {
-    ir_str_ << op_type;
-    if (needFloatSuffix(op_type) && node->out()->dtype() == DataType::Float) {
-      ir_str_ << "f";
-    }
-    ir_str_ << "(" << lhs << ", " << rhs << ")";
+    os_ << op_type << "(" << lhs << ", " << rhs << ")";
   }
 
-  ir_str_ << "\n";
+  os_ << "\n";
 }
 
-void IrPrinter::visit(const kir::TernaryOp* node) {
-  indent() << gen(node->out()) << " = " << node->operation() << "("
-           << use(node->in1()) << ", " << use(node->in2()) << ", "
-           << use(node->in3()) << ")\n";
+void IrPrinter::handle(const kir::TernaryOp* node) {
+  indent() << gen(node->out()) << " = " << node->getTernaryOpType() << "("
+           << gen(node->in1()) << ", " << gen(node->in2()) << ", "
+           << gen(node->in3()) << ")\n";
 }
 
-void IrPrinter::visit(const kir::ReductionOp* node) {
+void IrPrinter::handle(const kir::ReductionOp* node) {
   indent() << gen(node->out()) << " = "
-           << "REDUCTION(op='" << node->operation() << "'"
-           << ", in=" << use(node->in()) << ", init=" << use(node->init())
-           << ", pred=" << use(node->predicate()) << ")\n";
-}
-
-void IrPrinter::visit(const kir::WelfordOp* node) {
-  indent() << gen(node->outVar()) << "," << gen(node->outAvg()) << ","
-           << gen(node->outN()) << " = "
-           << "Welford( inAvg=" << use(node->inAvg());
-  if (!node->inN()->isOneInt()) {
-    indent() << " inVar=" << use(node->inVar());
-  }
-  indent() << " inN=" << use(node->inN());
-  if (!node->initN()->isZeroInt()) {
-    indent() << ", initVar=" << use(node->initVar())
-             << " initAvg=" << use(node->initAvg())
-             << " initN=" << use(node->initN());
-  }
-  indent() << ", pred=" << use(node->predicate()) << ")\n";
+           << "REDUCTION(op='" << node->getReductionOpType() << "'"
+           << ", in=" << gen(node->in()) << ", init=" << gen(node->init())
+           << ", pred=" << gen(node->pred()) << ")\n";
 }
 
-void IrPrinter::visit(const kir::GridReduction* node) {
+void IrPrinter::handle(const kir::GridReduction* node) {
   const auto* reduction_op = node->reduction_op();
   indent() << gen(reduction_op->out()) << " = "
-           << "GRID_REDUCTION(op='" << reduction_op->operation() << "'"
-           << ", in=" << use(reduction_op->in())
-           << ", init=" << use(reduction_op->init())
-           << ", pred=" << use(reduction_op->predicate()) << ")\n";
-  indent() << kTab << kTab
-           << ".reduction_buffer=" << use(node->reduction_buffer()->buffer())
+           << "GRID_REDUCTION(op='" << reduction_op->getReductionOpType() << "'"
+           << ", in=" << gen(reduction_op->in())
+           << ", init=" << gen(reduction_op->init())
+           << ", pred=" << gen(reduction_op->pred()) << ")\n";
+  indent() << kTab << ".reduction_buffer=" << gen(node->reduction_buffer())
            << "\n";
-  indent() << kTab << kTab
-           << ".sync_buffer=" << use(node->sync_buffer()->buffer()) << "\n";
-  indent() << kTab << kTab << ".grid_pred=" << use(node->predicate()) << "\n";
+  indent() << kTab << ".sync_buffer=" << gen(node->sync_buffer()) << "\n";
+  indent() << kTab << ".grid_pred=" << gen(node->pred()) << "\n";
 }
 
-void IrPrinter::visit(const kir::GridWelford* node) {
-  const auto* welford_op = node->welford_op();
-  indent() << gen(welford_op->outVar()) << "," << gen(welford_op->outAvg())
-           << "," << gen(welford_op->outN()) << " = "
-           << "GRID_WELFORD("
-           << "inAvg=" << use(welford_op->inAvg());
-  if (!welford_op->inN()->isOneInt()) {
-    indent() << ", inVar=" << use(welford_op->inVar());
-  }
-  indent() << ", inN=" << use(welford_op->inN());
-  if (!welford_op->initN()->isZeroInt()) {
-    indent() << ", initVar=" << use(welford_op->initVar())
-             << " initAvg=" << use(welford_op->initAvg())
-             << " initN=" << use(welford_op->initN());
-  }
-  indent() << ", pred=" << use(welford_op->predicate()) << ")\n";
-  indent() << kTab << kTab
-           << ".var_buffer=" << use(node->var_buffer()->buffer())
-           << ".avg_buffer=" << use(node->avg_buffer()->buffer())
-           << ".n_buffer=" << use(node->N_buffer()->buffer()) << "\n";
-  indent() << kTab << kTab
-           << ".sync_buffer=" << use(node->sync_buffer()->buffer()) << "\n";
-  indent() << kTab << kTab << ".grid_pred=" << use(node->predicate()) << "\n";
+void IrPrinter::handle(const kir::BroadcastOp* node) {
+  indent() << gen(node->out()) << " = BROADCAST(" << gen(node->in()) << ")\n";
 }
 
-void IrPrinter::visit(const kir::BroadcastOp* node) {
-  indent() << gen(node->out()) << " = BROADCAST(" << use(node->in()) << ")\n";
-}
-
-void IrPrinter::visit(const kir::ForLoop* node) {
+void IrPrinter::handle(const kir::ForLoop* node) {
   indent() << "FOR " << gen(node->index()) << " in " << gen(node->iter_domain())
            << ":\n";
   handleBlock(node->body());
 }
 
-void IrPrinter::visit(const kir::IfThenElse* node) {
-  indent() << "IF " << use(node->predicate()) << ":\n";
+void IrPrinter::handle(const kir::IfThenElse* node) {
+  indent() << "IF " << gen(node->cond()) << ":\n";
   handleBlock(node->thenBody());
   if (node->hasElse()) {
     indent() << "ELSE:\n";
@@ -402,48 +247,25 @@ void IrPrinter::visit(const kir::IfThenElse* node) {
   }
 }
 
-void IrPrinter::visit(const kir::Allocate* node) {
+void IrPrinter::handle(const kir::Allocate* node) {
   indent() << gen(node->buffer()) << " = ALLOCATE("
-           << "mem_type=" << node->memoryType() << ", "
-           << "size=" << use(node->size()) << ", "
+           << "mem_type=" << node->getMemoryType() << ", "
+           << "size=" << gen(node->size()) << ", "
            << "zero_init=" << boolLiteral(node->zeroInit()) << ")\n";
-  if (node->alias() != nullptr) {
-    indent() << kTab << kTab << ".alias=" << gen(node->alias()->buffer())
-             << "\n";
-  }
 }
 
-void IrPrinter::visit(const kir::Sync* node) {
+void IrPrinter::handle(const kir::Sync* node) {
   indent() << "SYNC(war_hazard=" << boolLiteral(node->isWarHazardSync())
            << ")\n";
 }
 
-void IrPrinter::visit(const kir::InitMagicZero* node) {
-  indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n";
-}
-
-void IrPrinter::visit(const kir::UpdateMagicZero* node) {
-  indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
-}
-
-std::string toString(const kir::Node* stmt, bool implicit_definitions) {
+std::string toString(const Statement* stmt) {
   std::stringstream ss;
-  IrPrinter ir_printer(ss, implicit_definitions);
+  IrPrinter ir_printer(ss);
   ir_printer.printNode(stmt);
   return ss.str();
 }
 
-std::string toString(
-    const std::vector<kir::Expr*>& exprs,
-    bool implicit_definitions) {
-  std::stringstream ss;
-  IrPrinter ir_printer(ss, implicit_definitions);
-  for (auto expr : exprs) {
-    ir_printer.printNode(expr);
-  }
-  return ss.str();
-}
-
 } // namespace kir
 } // namespace cuda
 } // namespace fuser
index c286a4b..f516638 100644 (file)
@@ -2,13 +2,12 @@
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
 #include <torch/csrc/jit/codegen/cuda/kernel.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
 
 #include <iostream>
-#include <sstream>
 #include <string>
-#include <unordered_set>
 
 namespace torch {
 namespace jit {
@@ -21,33 +20,21 @@ namespace kir {
 //! This class is intended for debug printing, so it attempts
 //! to handle invalid IR states as much as possible.
 //!
-//! implicit_definition_ = true will recurisvely print the definition of all
-//! inputs to an expression if they haven't been printed.
-class TORCH_CUDA_CU_API IrPrinter : private kir::IrVisitor {
-  static constexpr char const* kTab = "  ";
+class TORCH_CUDA_CU_API IrPrinter : private OptInConstDispatch {
+  static constexpr char* kTab = "  ";
 
  public:
   //! Constructs a new IrPrinter which outputs to the specified stream
-  explicit IrPrinter(std::ostream& os, bool implicit_definition = true)
-      : os_(os), implicit_definition_(implicit_definition) {}
+  explicit IrPrinter(std::ostream& os) : os_(os) {}
 
   //! Print a single Kernel IR node
-  void printNode(const kir::Node* node);
+  void printNode(const Statement* stmt);
 
   //! Print a complete Kernel definition
   void printKernel(const Kernel* kernel);
 
  private:
-  // Generates a string representation of an IR node
-  //
-  // If `top_level` is true, all the value uses are tracked and
-  // their definitions are implicitly printed before the node itself
-  //
-  std::string gen(const kir::Node* node, bool top_level = false);
-
-  // Generate a string representation of an used value
-  // (this helps automatically tracking the value uses)
-  std::string use(const kir::Val* val);
+  static std::string gen(const Statement* stmt);
 
   std::ostream& indent();
 
@@ -55,72 +42,40 @@ class TORCH_CUDA_CU_API IrPrinter : private kir::IrVisitor {
   void endBlock();
   void handleBlock(const kir::Scope& scope);
 
-  void visit(const kir::Bool*) final;
-  void visit(const kir::Double*) final;
-  void visit(const kir::Int*) final;
-  void visit(const kir::NamedScalar*) final;
-  void visit(const kir::Predicate*) final;
-
-  void visit(const kir::TensorIndex*) final;
-  void visit(const kir::IterDomain*) final;
-  void visit(const kir::TensorDomain*) final;
-  void visit(const kir::TensorView*) final;
-
-  void visit(const kir::UnaryOp*) final;
-  void visit(const kir::BinaryOp*) final;
-  void visit(const kir::TernaryOp*) final;
-  void visit(const kir::ReductionOp*) final;
-  void visit(const kir::WelfordOp*) final;
-  void visit(const kir::BroadcastOp*) final;
-
-  void visit(const kir::GridReduction*) final;
-  void visit(const kir::GridWelford*) final;
-  void visit(const kir::ForLoop*) final;
-  void visit(const kir::IfThenElse*) final;
-  void visit(const kir::Allocate*) final;
-  void visit(const kir::Sync*) final;
-  void visit(const kir::InitMagicZero*) final;
-  void visit(const kir::UpdateMagicZero*) final;
+  void handle(const Statement*) final;
+  void handle(const Val*) final;
+  void handle(const Expr*) final;
+
+  void handle(const kir::Bool*) final;
+  void handle(const kir::Float*) final;
+  void handle(const kir::Half*) final;
+  void handle(const kir::Int*) final;
+  void handle(const kir::NamedScalar*) final;
+
+  void handle(const kir::TensorIndex*) final;
+  void handle(const kir::IterDomain*) final;
+  void handle(const kir::TensorDomain*) final;
+  void handle(const kir::TensorView*) final;
+
+  void handle(const kir::UnaryOp*) final;
+  void handle(const kir::BinaryOp*) final;
+  void handle(const kir::TernaryOp*) final;
+  void handle(const kir::ReductionOp*) final;
+  void handle(const kir::BroadcastOp*) final;
+
+  void handle(const kir::GridReduction*) final;
+  void handle(const kir::ForLoop*) final;
+  void handle(const kir::IfThenElse*) final;
+  void handle(const kir::Allocate*) final;
+  void handle(const kir::Sync*) final;
 
  private:
   std::ostream& os_;
-
-  // Current indentation level
   int indent_level_ = 0;
-
-  // Internal IR generation stream
-  std::stringstream ir_str_;
-
-  // Tracks the set of nodes which have been printed
-  std::unordered_set<const kir::Node*> visited_;
-
-  // Optional left margin printed after the indentation
-  const char* margin_ = "";
-
-  // The set of values used by the current top-level IR node
-  std::unordered_set<const kir::Val*> uses_;
-
-  // If the definition of all inputs to an expression haven't been printed
-  // already implicit_definition_ = true will print them before printing the
-  // requested node.
-  bool implicit_definition_ = true;
 };
 
-//! Returns the string representation of a Kernel IR node. If the definition of
-//! all inputs to an expression haven't been printed already
-//! implicit_definition_ = true will print them before printing the requested
-//! node.
-TORCH_CUDA_CU_API std::string toString(
-    const kir::Node* stmt,
-    bool implicit_definitions = true);
-
-//! Returns the string representation of a vector of kir::Expr, convenient
-//! debugm echanism during lowering. If the definition of all inputs to an
-//! expression haven't been printed already implicit_definition_ = true will
-//! print them before printing the requested node.
-TORCH_CUDA_CU_API std::string toString(
-    const std::vector<kir::Expr*>& exprs,
-    bool implicit_definitions = true);
+//! Returns the string representation of a Kernel IR node
+std::string toString(const Statement* stmt);
 
 } // namespace kir
 } // namespace cuda
index d994012..73edb11 100644 (file)
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
 #include <torch/csrc/jit/codegen/cuda/lower_alias_memory.h>
-#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
-#include <torch/csrc/jit/codegen/cuda/lower_expr_sort.h>
 #include <torch/csrc/jit/codegen/cuda/lower_index.h>
 #include <torch/csrc/jit/codegen/cuda/lower_insert_syncs.h>
 #include <torch/csrc/jit/codegen/cuda/lower_loops.h>
-#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
-#include <torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h>
-#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
-#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
 #include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
-#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
 #include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
 #include <torch/csrc/jit/codegen/cuda/lower_utils.h>
 #include <torch/csrc/jit/codegen/cuda/lower_validation.h>
 
-#include <list>
-#include <unordered_map>
-#include <unordered_set>
-
 namespace torch {
 namespace jit {
 namespace fuser {
 namespace cuda {
 
 // TODO(kir): revisit this
-thread_local GpuLower* active_gpu_lower = nullptr; // NOLINT
-namespace {
-
-// Going to generate a map of tensor view root domain extents to reduce the
-// number used during lowering. For example if we have:
-//
-// T2[i0, i1] = T1[i0, i1] + T2[i2, i3]
-//
-// We know it would be safe to use:
-//
-// T2[i0, i1] = T1[i0, i1] + T2[i0, i1]
-//
-// And that way we don't generate T2.size[0] and T2.size[1], instead we will
-// reuse T1.size[0] and T1.size[1]
-// This is important when doing CSE as T2 and T1 would otherwise look like
-// they're using different values, even though we know they're the same
-//
-// There's some duplicate logic here that's in computeAt map, but it's not so
-// concice there to pull out. May want to consider making this mapping its own
-// class especially as it may be useful during scheduling.
-std::unordered_map<Val*, Val*> getSimplificationMap(Fusion* fusion) {
-  std::list<std::unordered_set<IterDomain*>> disjoint_root_sets;
-  std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>*>
-      id_to_disjoint_root_set;
-
-  auto map_root_ids = [&disjoint_root_sets, &id_to_disjoint_root_set](
-                          IterDomain* id0, IterDomain* id1) {
-    if (id0->isBroadcast() || id1->isBroadcast()) {
-      return;
-    }
-
-    auto disjoint_set_0_it = id_to_disjoint_root_set.find(id0);
-    auto disjoint_set_1_it = id_to_disjoint_root_set.find(id1);
-    bool set_0_found = disjoint_set_0_it != id_to_disjoint_root_set.end();
-    bool set_1_found = disjoint_set_1_it != id_to_disjoint_root_set.end();
-
-    if (set_0_found && set_1_found) {
-      if (disjoint_set_0_it->second == disjoint_set_1_it->second) {
-        return;
-      }
-      // merge second disjoint set into first
-      auto* set_0 = disjoint_set_0_it->second;
-      auto* set_1 = disjoint_set_1_it->second;
-      for (auto id : *set_1) {
-        set_0->emplace(id);
-        id_to_disjoint_root_set[id] = set_0;
-      }
-      // remove second set from disjoint_root_sets
-      disjoint_root_sets.erase(std::find(
-          disjoint_root_sets.begin(), disjoint_root_sets.end(), *set_1));
-    } else if (set_0_found || set_1_found) {
-      auto existing_set =
-          set_0_found ? disjoint_set_0_it->second : disjoint_set_1_it->second;
-      auto to_add_id = set_0_found ? id1 : id0;
-      existing_set->emplace(to_add_id);
-      id_to_disjoint_root_set[to_add_id] = existing_set;
-      // add entry into existing set
-    } else {
-      // create new set entry
-      disjoint_root_sets.push_back(std::unordered_set<IterDomain*>());
-      auto* new_set = &disjoint_root_sets.back();
-      new_set->emplace(id0);
-      new_set->emplace(id1);
-      id_to_disjoint_root_set[id0] = new_set;
-      id_to_disjoint_root_set[id1] = new_set;
-    }
-  };
-
-  auto fusion_vals = fusion->usedMathVals();
-  for (auto producer_tv : ir_utils::filterByType<TensorView>(fusion_vals)) {
-    auto consumer_tvs = ir_utils::consumerTvsOf(producer_tv);
-    for (auto consumer_tv : consumer_tvs) {
-      auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv);
-      auto c2p_root_map = pairwise_map.mapConsumerToProducer(
-          consumer_tv->domain(), producer_tv->domain());
-      for (auto entry : c2p_root_map) {
-        auto c_id = entry.first;
-        auto p_id = entry.second;
-        map_root_ids(p_id, c_id);
-      }
-    }
-  }
-
-  // Map each set to an input ID (if it exists) that has the smallest ->name()
-  // entry value
-  std::unordered_map<std::unordered_set<IterDomain*>*, IterDomain*>
-      set_to_input_id;
-
-  // Loop over the root domains, of the inputs to the fusion. Pick an input ID
-  // to use as the representative ID of the collected sets. Only consider inputs
-  // as those are the ones that map to values like "T0.size[1]". They are he
-  // ID's that propagated their extents into the problem. We could also check
-  // the outputs as we do have C++ examples of using output dimensions for the
-  // problem size instead of inputs. However, we don't do anything where we can
-  // translate to those kinds of kernels integrated into PyTorch.
-  for (auto input_tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
-    for (auto id :
-         TensorDomain::noReductions(input_tv->getMaybeRFactorDomain())) {
-      auto id_set_it = id_to_disjoint_root_set.find(id);
-      if (id_set_it == id_to_disjoint_root_set.end()) {
-        continue;
-      }
-      auto* id_set = id_set_it->second;
-      if (set_to_input_id.find(id_set) == set_to_input_id.end()) {
-        set_to_input_id[id_set] = id;
-      } else {
-        auto input_id_of_set = set_to_input_id.at(id_set);
-        // Swap id's if new name is less than previously set
-        bool swap_ids = id->name() < input_id_of_set->name();
-        // If new id is a const scalar but previously was'nt use the const
-        // scalar
-        swap_ids = swap_ids ||
-            (id->extent()->isConstScalar() &&
-             !input_id_of_set->extent()->isConstScalar());
-        // If previous scalar was const and new isn't, don't swap
-        swap_ids = swap_ids &&
-            !(input_id_of_set->extent()->isConstScalar() &&
-              !id->extent()->isConstScalar());
-
-        if (swap_ids) {
-          set_to_input_id[id_set] = id;
-        }
-      }
-    }
-  }
+thread_local GpuLower* active_gpu_lower = nullptr;
 
-  // Finally make map from ID extents to the representitive ID extent.
-  std::unordered_map<Val*, Val*> extent_to_min_input_id_extent;
-  for (auto entry : set_to_input_id) {
-    auto* set = entry.first;
-    auto input_id = entry.second;
-    for (auto id : *set) {
-      extent_to_min_input_id_extent[id->extent()] = input_id->extent();
-    }
-  }
-  return extent_to_min_input_id_extent;
-}
-
-} // namespace
 void GpuLower::replaceSymbolicSizes() {
-  FUSER_PERF_SCOPE("GpuLower::Lower::replaceSymbolicSizes");
+  FUSER_PERF_SCOPE("replaceSymbolicSizes");
 
   kir::IrBuilder ir_builder(kernel());
 
   // Grab inputs and outputs
+  // TODO: Only run through inputs for the size map, outputs don't actually set
+  // any sizes of the problem.
   std::vector<TensorView*> inputs_and_outputs;
   for (auto val : fusion_->inputs()) {
     if (ir_utils::isTV(val)) {
@@ -188,11 +40,14 @@ void GpuLower::replaceSymbolicSizes() {
     }
   }
 
-  // Generate map for all tensorview root domain values to map them to symbolic
-  // values. i.e. T0->getRootDomain()[0] would map to a named scalar
-  // "T0.size[0]". This map will be used when lowering fusion ir to kernel ir.
+  // Run through inputs and outputs first. Since we're replacing full
+  // tensorviews their names are going to change. We need  the new referenc
+  // name for the inputs/outputs. This way we won't reference the wrong tensor
+  // view. For example T0 may be translated to T9. We don't want our new
+  // variable to be T0->size[...] we need it to be T9->size[...]
   for (TensorView* tv : inputs_and_outputs) {
     // Replace the domain with one based on Ti.size[j]
+    std::vector<IterDomain*> new_domain_iters;
     const std::vector<IterDomain*>& root_td = tv->getRootDomain();
 
     size_t dim = 0;
@@ -201,50 +56,33 @@ void GpuLower::replaceSymbolicSizes() {
 
       // Output sizes could have reduction axes, which isn't what gets output.
       // NOLINTNEXTLINE(bugprone-branch-clone)
-      if (id->isReduction() ||
-          (id->getIterType() == IterType::BroadcastWithoutStride)) {
+      if (id->isReduction()) {
+        continue;
+      } else if (id->getIterType() == IterType::BroadcastWithoutStride) {
+        continue;
+        // NOLINTNEXTLINE(bugprone-branch-clone)
+      } else if (id->getIterType() == IterType::BroadcastWithStride) {
+        dim++;
         continue;
-      } else if (
-          // NOLINTNEXTLINE(bugprone-branch-clone)
-          (id->getIterType() == IterType::BroadcastWithStride) ||
-          orig_size->isConstScalar()) {
+      } else if (orig_size->isConstScalar()) {
         dim++;
         continue;
       }
 
       // TODO(kir): consider a different implementation which doesn't
-      //  hijack the kir_val_map_
-      // Currently turn off this part for inputs of segmented fusion,
-      //  since FusionKernelRuntime will provide these as integer inputs
-      if (kir_val_map_.find(orig_size) == kir_val_map_.end() &&
-          !orig_size->isFusionInput() && !orig_size->isConstScalar()) {
+      //  hijack the kir_map_
+      if (kir_map_.find(orig_size) == kir_map_.end()) {
         std::stringstream ss;
         ss << "T" << tv->name() << ".size[" << dim++ << "]";
-        kir_val_map_[orig_size] = ir_builder.create<kir::NamedScalar>(
+        kir_map_[orig_size] = ir_builder.create<kir::NamedScalar>(
             ss.str(), orig_size->getDataType().value());
-      } else {
-        dim++;
-      }
-    }
-  }
-
-  // Use a minimal number of sizes from provided tensors.
-  auto extent_simplification_map = getSimplificationMap(fusion_);
-  for (auto extent_entry : extent_simplification_map) {
-    auto orig_extent = extent_entry.first;
-    auto simplified_extent = extent_entry.second;
-    if (kir_val_map_.count(orig_extent)) {
-      if (kir_val_map_.count(simplified_extent)) {
-        kir_val_map_[orig_extent] = kir_val_map_[simplified_extent];
-      } else {
-        kir_val_map_[orig_extent] = lowerValue(simplified_extent);
       }
     }
   }
 }
 
 void GpuLower::lower() {
-  FUSER_PERF_SCOPE("GpuLower::lower");
+  FUSER_PERF_SCOPE("lower");
 
   TORCH_INTERNAL_ASSERT(fusion_ != nullptr);
   TORCH_INTERNAL_ASSERT(
@@ -263,293 +101,167 @@ void GpuLower::lower() {
   FusionGuard fg(fusion_);
 
   // Start with a fresh kernel
-  kernel_ = std::make_unique<kir::Kernel>();
+  kernel_ = std::make_unique<Kernel>();
 
   // prepare for lowering
   validateIr(fusion_);
   replaceSymbolicSizes();
-  trivial_reduction_info_.build(fusion_, this);
 
-  // In the future we may directly use this map, but for now it will propagate
-  // and validate (to some extent) the parallelization strategy.
-  // This is the first time nodes will be lowered to kir nodes. Since for now we
-  // propagate the parallel strategy in some instances, we need to do it before
-  // lowering.
-  ca_parallel_map_ = ComputeAtMap(ComputeAtMap::MappingMode::PARALLEL);
-  ca_parallel_map_.build(fusion_, current());
-
-  // Want to run this after parallel map is created
-  validateVectorize(fusion_);
-
-  // Generate mappings to generate indices
-  ca_index_map_ = ComputeAtMap(ComputeAtMap::MappingMode::INDEX);
-  ca_index_map_.build(fusion_, current());
+  // Compute thread predicates
+  ThreadPredicateMap preds(fusion_);
 
-  // Generate mappings to generate and map to loop nests
-  ca_loop_map_ = ComputeAtMap(ComputeAtMap::MappingMode::LOOP);
-  ca_loop_map_.build(fusion_, current());
+  // Run our passes keeping the lowered expressions and forwarding them
+  const auto lowered_exprs =
+      LoopNestGenerator::loweredExprs(fusion_, preds, fusion_->exprs(true));
 
-  validateParallelize(fusion_);
+  const auto unrolled_loops =
+      UnrollPass::runPass(fusion_, lowered_exprs, preds);
 
-  parallelDimensionMap().build(fusion_);
-  if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) {
-    std::cout << parallelDimensionMap().toString();
-  }
+  // Reuse memory locations if:
+  // TensorView is dynamic shared memory
+  // TensorViews have the same size
+  // Output TensorView is modified using Input TensorView
+  const auto reuse_mem_exprs = reuseMemoryAllocations(fusion_, unrolled_loops);
 
-  // Scan the whole fusion and build mappings about halo extensions of
-  // all IterDomains
-  haloInfo().build(fusion_);
+  // Insert SyncThreads at end of for-loop to avoid WAR race condition
+  const auto sync_exprs = insertThreadSynchronization(fusion_, reuse_mem_exprs);
 
-  // Compute thread predicates
-  thread_pred_map_.build(fusion_);
+  const auto indexed_loops =
+      IndexLowering::getIndexedExprs(fusion_, sync_exprs);
 
-  // Detects all exprssions that don't need predicates
-  predicateElimination().build(fusion_);
+  // We now have the lowered expressions, finalize the kernel IR
+  kernel_->finalize(indexed_loops, preds);
 
   // Set the kernel inputs & outputs
   for (auto input : fusion_->inputs()) {
     kernel_->addInput(GpuLower::lowerValue(input));
   }
-
   for (auto output : fusion_->outputs()) {
     kernel_->addOutput(GpuLower::lowerValue(output));
   }
-
-  // Run our passes keeping the lowered expressions and forwarding
-  // them
-
-  // Reorder expressions for loop-nest generation respecting computeAt
-  // relationships
-  auto sorted_exprs = reorderExprsForComputeAt();
-
-  // Generate loop-nests and place each expression at its
-  // corresponding loop
-  const auto lowered_exprs = LoopNestGenerator::loweredExprs(sorted_exprs);
-
-  // Insert allocations
-  const auto alloced_exprs = insertAllocations(lowered_exprs);
-
-  // Insert read after write smem syncs
-  const auto raw_sync_exprs = insertRawThreadSynchronization(alloced_exprs);
-
-  // Inserts predicates after this, need to be careful in later passes when
-  // inserting in loop nest structure as insertions could be on if then else
-  // instead of directly on a for loop
-  const auto unrolled_loops = UnrollPass::runPass(fusion_, raw_sync_exprs);
-
-  const auto unrolled_mv_loops =
-      processMisalignedVectorization(fusion_, unrolled_loops);
-
-  // Reuse memory locations
-  // TODO: Reenable once fixed.
-  // const auto reuse_mem_exprs = reuseMemoryAllocations(unrolled_mv_loops);
-
-  // Insert SyncThreads at end of for-loop to avoid WAR race condition
-  // const auto war_sync_exprs =
-  // insertWarThreadSynchronization(reuse_mem_exprs);
-  const auto war_sync_exprs = insertWarThreadSynchronization(unrolled_mv_loops);
-
-  const auto indexed_loops = IndexLowering::getIndexedExprs(war_sync_exprs);
-
-  const auto conditional_loops =
-      generateConditionalFromPredicate(fusion_, indexed_loops);
-
-  // Insert fake zero updates to make sure nvrtc doesn't blow out register use
-  // on index and predicate reuse
-  const auto register_adjusted = insertMagicZero(conditional_loops);
-
-  // We now have the lowered expressions, finalize the kernel IR
-  kernel_->finalize(register_adjusted);
 }
 
-kir::Kernel* GpuLower::kernel() const {
+Kernel* GpuLower::kernel() const {
   TORCH_CHECK(kernel_);
   return kernel_.get();
 }
 
 // Maps Fusion IR nodes to the Kernel IR counterparts
-class GpuLower::KernelIrMapper : private OptInConstDispatch {
+//
+// TODO(kir): this is a interim solution for easing the Kernel IR splitting
+//
+class TORCH_CUDA_CU_API GpuLower::KernelIrMapper : private OptInConstDispatch {
  public:
   explicit KernelIrMapper(GpuLower* gpu_lower)
       : gpu_lower_(gpu_lower), ir_builder_(gpu_lower->kernel()) {}
 
-  kir::Val* lowerValue(const Val* value) {
-    const auto it = gpu_lower_->kir_val_map_.find(value);
-    if (it != gpu_lower_->kir_val_map_.end()) {
+  Val* lower(const Val* value) {
+    const auto it = gpu_lower_->kir_map_.find(value);
+    if (it != gpu_lower_->kir_map_.end()) {
       return it->second;
     } else {
       handle(value);
-      const auto kir_value = gpu_lower_->kir_val_map_[value];
-      TORCH_CHECK(kir_value != nullptr);
+      const auto lowered_node = gpu_lower_->kir_map_[value];
+      TORCH_CHECK(lowered_node != nullptr);
+      TORCH_CHECK(kir::isLoweredVal(lowered_node));
 
-      // Lower the value definition, if any
+      // Lower the arithmetic expression defining the value, if any
       if (value->isScalar()) {
-        if (auto def = value->definition()) {
-          const auto kir_def = lowerExpr(def);
-          TORCH_INTERNAL_ASSERT(kir_value->definition() == kir_def);
+        if (auto def = value->getOrigin()) {
+          lowerDefinition(lowered_node, def);
         }
       }
 
-      return kir_value;
+      return lowered_node;
     }
   }
 
-  kir::Expr* lowerExpr(const Expr* expr) {
-    const auto it = gpu_lower_->kir_expr_map_.find(expr);
-    if (it != gpu_lower_->kir_expr_map_.end()) {
-      return it->second;
-    } else {
-      handle(expr);
-      const auto lowered_node = gpu_lower_->kir_expr_map_[expr];
-      TORCH_CHECK(lowered_node != nullptr);
-      return lowered_node;
+ private:
+  // TODO(kir): rewrite this
+  void lowerDefinition(Val* lowered_value, const Expr* def) {
+    switch (def->type()) {
+      case ExprType::UnaryOp: {
+        const auto op = def->as<UnaryOp>();
+        ir_builder_.create<kir::UnaryOp>(
+            op->getUnaryOpType(), lowered_value, lower(op->in()));
+        break;
+      }
+      case ExprType::BinaryOp: {
+        const auto op = def->as<BinaryOp>();
+        ir_builder_.create<kir::BinaryOp>(
+            op->getBinaryOpType(),
+            lowered_value,
+            lower(op->lhs()),
+            lower(op->rhs()));
+        break;
+      }
+      case ExprType::TernaryOp: {
+        const auto op = def->as<TernaryOp>();
+        ir_builder_.create<kir::TernaryOp>(
+            op->getTernaryOpType(),
+            lowered_value,
+            lower(op->in1()),
+            lower(op->in2()),
+            lower(op->in3()));
+        break;
+      }
+      default:
+        TORCH_CHECK(false, "Unexpected expression type");
     }
     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
   }
 
- private:
-  void handle(const Statement* node) final {
+  void handle(const Statement* node) override {
     OptInConstDispatch::handle(node);
   }
 
-  void handle(const Val* node) final {
+  void handle(const Val* node) override {
     OptInConstDispatch::handle(node);
   }
 
-  void handle(const Expr* node) final {
+  void handle(const Expr* node) override {
     OptInConstDispatch::handle(node);
   }
 
-  void handle(const TensorDomain* node) final {
+  void handle(const TensorDomain* node) override {
     const auto lowered_node = ir_builder_.create<kir::TensorDomain>(node);
-    TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
+    TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
   }
 
-  void handle(const IterDomain* node) final {
+  void handle(const IterDomain* node) override {
     const auto lowered_node = ir_builder_.create<kir::IterDomain>(node);
-    TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
+    TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
   }
 
-  void handle(const TensorView* node) final {
+  void handle(const TensorView* node) override {
     const auto lowered_node = ir_builder_.create<kir::TensorView>(node);
-    TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
+    TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
   }
 
-  void handle(const Bool* node) final {
+  void handle(const Bool* node) override {
     const auto lowered_node = ir_builder_.create<kir::Bool>(node);
-    TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
+    TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
   }
 
-  void handle(const Double* node) final {
-    const auto lowered_node = ir_builder_.create<kir::Double>(node);
-    TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
+  void handle(const Float* node) override {
+    const auto lowered_node = ir_builder_.create<kir::Float>(node);
+    TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
   }
 
-  void handle(const Int* node) final {
-    const auto lowered_node = ir_builder_.create<kir::Int>(node);
-    TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
+  void handle(const Half* node) override {
+    const auto lowered_node = ir_builder_.create<kir::Half>(node);
+    TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
   }
 
-  void handle(const NamedScalar* node) final {
-    const auto lowered_node = ir_builder_.create<kir::NamedScalar>(
-        node->name(), node->getDataType().value());
-    TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
-  }
-
-  void handle(const UnaryOp* node) final {
-    const auto lowered_node = ir_builder_.create<kir::UnaryOp>(
-        node->getUnaryOpType(),
-        lowerValue(node->out()),
-        lowerValue(node->in()));
-    TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
-  }
-
-  void handle(const BinaryOp* node) final {
-    const auto lowered_node = ir_builder_.create<kir::BinaryOp>(
-        node->getBinaryOpType(),
-        lowerValue(node->out()),
-        lowerValue(node->lhs()),
-        lowerValue(node->rhs()));
-    TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
-  }
-
-  void handle(const TernaryOp* node) final {
-    const auto lowered_node = ir_builder_.create<kir::TernaryOp>(
-        node->getTernaryOpType(),
-        lowerValue(node->out()),
-        lowerValue(node->in1()),
-        lowerValue(node->in2()),
-        lowerValue(node->in3()));
-    TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
-  }
-
-  void handle(const ReductionOp* node) final {
-    auto out_tv = node->out()->as<TensorView>();
-    // If trivial reduction operation lower to set operation.
-    if (std::all_of(
-            out_tv->domain()->domain().begin(),
-            out_tv->domain()->domain().end(),
-            [&](IterDomain* id) {
-              // If id is a reduction axis, is it a trivial reduction?
-              if (id->isReduction()) {
-                return gpu_lower_->trivialReductionInfo().isDerived(id);
-              } else {
-                return true;
-              }
-            })) {
-      const auto lowered_node = ir_builder_.create<kir::UnaryOp>(
-          UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in()));
-      TORCH_CHECK(
-          gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
-      return;
-    }
-
-    const auto lowered_node = ir_builder_.create<kir::ReductionOp>(
-        node->getReductionOpType(),
-        lowerValue(node->init()),
-        lowerValue(node->out()),
-        lowerValue(node->in()));
-    TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
+  void handle(const Int* node) override {
+    const auto lowered_node = ir_builder_.create<kir::Int>(node, false);
+    TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
   }
 
-  void handle(const WelfordOp* node) final {
-    auto lowerOptional = [&](Val* v) { return v ? lowerValue(v) : nullptr; };
-    const auto lowered_node = ir_builder_.create<kir::WelfordOp>(
-        lowerValue(node->outVar()),
-        lowerValue(node->outAvg()),
-        lowerValue(node->outN()),
-        lowerValue(node->initVar()),
-        lowerValue(node->initAvg()),
-        lowerValue(node->initN()),
-        lowerOptional(node->inVar()),
-        lowerValue(node->inAvg()),
-        lowerValue(node->inN()));
-
-    TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
-  }
-
-  void handle(const BroadcastOp* node) final {
-    const auto lowered_node = ir_builder_.create<kir::BroadcastOp>(
-        lowerValue(node->out()), lowerValue(node->in()));
-    TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
-  }
-
-  void handle(const TransposeOp* node) final {
-    const auto lowered_node = ir_builder_.create<kir::UnaryOp>(
-        UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in()));
-    TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
-  }
-
-  void handle(const ShiftOp* node) final {
-    const auto lowered_node = ir_builder_.create<kir::UnaryOp>(
-        UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in()));
-    TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
-  }
-
-  void handle(const GatherOp* node) final {
-    const auto lowered_node = ir_builder_.create<kir::UnaryOp>(
-        UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in()));
-    TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
+  void handle(const NamedScalar* node) override {
+    const auto lowered_node = ir_builder_.create<kir::NamedScalar>(
+        node->name(), node->getDataType().value());
+    TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
   }
 
  private:
@@ -557,17 +269,20 @@ class GpuLower::KernelIrMapper : private OptInConstDispatch {
   kir::IrBuilder ir_builder_;
 };
 
-kir::Val* GpuLower::lowerValue(const Val* val) {
-  KernelIrMapper kir_mapper(this);
-  return kir_mapper.lowerValue(val);
+Val* GpuLower::lowerValue(const Val* val) {
+  TORCH_INTERNAL_ASSERT(!kir::isLoweredVal(val));
+  TORCH_INTERNAL_ASSERT(active_gpu_lower != nullptr);
+  KernelIrMapper kir_mapper(active_gpu_lower);
+  return kir_mapper.lower(val);
 }
 
-kir::Expr* GpuLower::lowerExpr(const Expr* expr) {
+Val* GpuLower::getLowerValue(const Val* val) {
   KernelIrMapper kir_mapper(this);
-  return kir_mapper.lowerExpr(expr);
+  return kir_mapper.lower(val);
 }
 
 GpuLower* GpuLower::current() {
+  TORCH_INTERNAL_ASSERT(active_gpu_lower != nullptr);
   return active_gpu_lower;
 }
 
index 871a09c..afd3481 100644 (file)
@@ -2,15 +2,9 @@
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
-#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/kernel.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
-#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
-#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
-#include <torch/csrc/jit/codegen/cuda/parallel_dimension_map.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
 
 #include <memory>
 #include <ostream>
@@ -20,10 +14,6 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-// TODO: we frequently use pairwise root mapping from consumers to producers.
-// This information is implicitly in the computeAtMaps, but there's no isolated
-// container for this information that we can reuse. Would be nice to generate
-// such a structure and propagate it through lowering.
 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
 class TORCH_CUDA_CU_API GpuLower {
   class KernelIrMapper;
@@ -36,62 +26,22 @@ class TORCH_CUDA_CU_API GpuLower {
     lower();
   }
 
-  kir::Kernel* kernel() const;
+  Kernel* kernel() const;
 
-  //! Converts a Fusion IR value into the Kernel IR equivalent
-  kir::Val* lowerValue(const Val* val);
+  // Converts a Fusion IR value into the Kernel IR equivalent
+  //
+  // TODO(kir): revisit this interface
+  //
+  static Val* lowerValue(const Val* val);
 
-  //! Converts a Fusion IR expression into the Kernel IR equivalent
-  kir::Expr* lowerExpr(const Expr* expr);
+  // TODO(kir): we have two methods which do almost the same thing
+  //
+  Val* getLowerValue(const Val* val);
 
   //! Returns the currently active lowering object
   //! (or nullptr if no lowering is in progress)
   static GpuLower* current();
 
-  const ThreadPredicateMap& threadPredMap() const {
-    return thread_pred_map_;
-  }
-
-  const ComputeAtMap& caLoopMap() const {
-    return ca_loop_map_;
-  }
-
-  const ComputeAtMap& caIndexMap() const {
-    return ca_index_map_;
-  }
-
-  const ComputeAtMap& caParallelMap() const {
-    return ca_parallel_map_;
-  }
-
-  const auto& trivialReductionInfo() const {
-    return trivial_reduction_info_;
-  }
-
-  const HaloInfo& haloInfo() const {
-    return halo_info_;
-  }
-
-  HaloInfo& haloInfo() {
-    return halo_info_;
-  }
-
-  const ParallelDimensionMap& parallelDimensionMap() const {
-    return parallel_dimension_map_;
-  }
-
-  ParallelDimensionMap& parallelDimensionMap() {
-    return parallel_dimension_map_;
-  }
-
-  PredicateElimination& predicateElimination() {
-    return pred_elimination_;
-  }
-
-  const PredicateElimination& predicateElimination() const {
-    return pred_elimination_;
-  }
-
  private:
   void lower();
 
@@ -105,21 +55,10 @@ class TORCH_CUDA_CU_API GpuLower {
 
  private:
   // Lowered Kernel IR
-  std::unique_ptr<kir::Kernel> kernel_;
+  std::unique_ptr<Kernel> kernel_;
 
   // Fusion IR node to Kernel IR node mapping
-  std::unordered_map<const Val*, kir::Val*> kir_val_map_;
-  std::unordered_map<const Expr*, kir::Expr*> kir_expr_map_;
-
-  // Some stateful information during lowering
-  ThreadPredicateMap thread_pred_map_;
-  PredicateElimination pred_elimination_;
-  ComputeAtMap ca_loop_map_;
-  ComputeAtMap ca_index_map_;
-  ComputeAtMap ca_parallel_map_;
-  TrivialReductionInfo trivial_reduction_info_;
-  HaloInfo halo_info_;
-  ParallelDimensionMap parallel_dimension_map_;
+  std::unordered_map<const Val*, Val*> kir_map_;
 
   Fusion* fusion_ = nullptr;
 };
index af6d6fe..12c5254 100644 (file)
@@ -1,16 +1,13 @@
 #include <torch/csrc/jit/codegen/cuda/lower_alias_memory.h>
 
+#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
+#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
+#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
 #include <torch/csrc/jit/codegen/cuda/lower_utils.h>
 
-#include <sstream>
-#include <unordered_map>
-#include <unordered_set>
-
 namespace torch {
 namespace jit {
 namespace fuser {
@@ -20,41 +17,52 @@ namespace {
 
 //! Get string representation of Allocate size for symbolic comparison
 //!
-class SymbolicSizePrinter : private kir::IrVisitor {
+class SymbolicSizePrinter final : private OptOutConstDispatch {
  public:
-  static std::string printSize(const kir::Allocate* allocate) {
+  static std::string print_size(const kir::Allocate* alloc) {
     SymbolicSizePrinter printer;
-    allocate->size()->accept(&printer);
+    printer.handle(alloc->size());
     return printer.os_.str();
   }
 
  private:
-  void visit(const kir::Int* node) final {
-    if (auto def = node->definition()) {
-      def->accept(this);
-    } else if (node->isConst()) {
-      os_ << *node->value();
-    } else {
-      os_ << "ki" << node->id();
-    }
+  void handle(const Val* v) final {
+    OptOutConstDispatch::handle(v);
   }
 
-  void visit(const kir::NamedScalar* named_scalar) final {
-    os_ << "@" << named_scalar->name();
+  void handle(const Expr* e) final {
+    OptOutConstDispatch::handle(e);
+  }
+
+  void handle(const kir::Int* node) final {
+    if (auto def = FusionGuard::getCurFusion()->origin(node)) {
+      os_ << "( ";
+      handle(def);
+      os_ << " )";
+      return;
+    } else if (node->isSymbolic()) {
+      os_ << "i" << node->name();
+    } else {
+      os_ << *node->value();
+    }
   }
 
-  void visit(const kir::UnaryOp* unary_op) final {
-    os_ << unary_op->operation() << "(";
-    unary_op->accept(this);
-    os_ << ")";
+  void handle(const kir::NamedScalar* node) final {
+    os_ << node->name();
   }
 
-  void visit(const kir::BinaryOp* binary_op) final {
-    os_ << binary_op->operation() << "(";
-    binary_op->lhs()->accept(this);
-    os_ << ",";
-    binary_op->rhs()->accept(this);
-    os_ << ")";
+  void handle(const kir::BinaryOp* node) final {
+    if (auto inline_bop = inline_op_str(node->getBinaryOpType())) {
+      handle(node->lhs());
+      os_ << " " << inline_bop.value() << " ";
+      handle(node->rhs());
+    } else {
+      os_ << node->getBinaryOpType() << "(";
+      handle(node->lhs());
+      os_ << ", ";
+      handle(node->rhs());
+      os_ << ")";
+    }
   }
 
  private:
@@ -63,12 +71,13 @@ class SymbolicSizePrinter : private kir::IrVisitor {
 
 //! Reuse Allocation nodes via pointer aliasing
 //!
-class AllocateReuseModifier {
-  // Alias local memory if it exceeds this threshold
-  static constexpr size_t kRegisterSizeThreshold = 1;
-
+class AllocateReuseModifier final : private OptOutDispatch {
  public:
-  void modify(const std::vector<kir::Expr*>& exprs) {
+  explicit AllocateReuseModifier(Fusion* fusion, size_t register_size_threshold)
+      : eval_evaluator_(fusion),
+        register_size_threshold_(register_size_threshold) {}
+
+  void modify(const std::vector<Expr*>& exprs) {
     // Find candidate TensorViews and collect analysis information
     for (auto expr : exprs) {
       handle(expr);
@@ -76,96 +85,102 @@ class AllocateReuseModifier {
 
     // Iterate over candidates to find match
     for (auto tv : candidate_alias_tv_) {
-      const auto def = tv->definition();
-      TORCH_INTERNAL_ASSERT(def != nullptr);
+      TORCH_INTERNAL_ASSERT(
+          map_tv_to_origin_expr_.find(tv) != map_tv_to_origin_expr_.end());
 
-      const auto alloc_it = map_tv_to_allocations_.find(tv->name());
-      TORCH_INTERNAL_ASSERT(alloc_it != map_tv_to_allocations_.end());
-      const auto output_alloc = alloc_it->second;
+      const auto& expr = map_tv_to_origin_expr_[tv];
+      const auto output = expr->output(0)->as<TensorView>();
 
-      const auto input_alloc = findCompatibleInputAllocate(
-          tv->dtype(), SymbolicSizePrinter::printSize(output_alloc), def);
+      TORCH_INTERNAL_ASSERT(
+          map_tv_to_allocations_.find(output->name()) !=
+          map_tv_to_allocations_.end());
 
+      auto output_alloc = map_tv_to_allocations_[output->name()];
+
+      auto input_alloc = findCompatibleInputAllocate(
+          SymbolicSizePrinter::print_size(output_alloc), expr);
       if (input_alloc != nullptr) {
+        // std::cout << "Alias Match\t" << output->getMemoryType() << std::endl;
         output_alloc->setAlias(input_alloc);
       }
     }
   }
 
  private:
-  // Do we have a true pointwise op?
-  // (ie. a TV op, excluding direct assignments and reductions)
-  static bool isPointwiseTvOp(const kir::Expr* expr) {
-    if (ir_utils::isTVOp(expr)) {
-      if (auto unary_op = dynamic_cast<const kir::UnaryOp*>(expr)) {
-        return unary_op->operation() != UnaryOpType::Set;
-      } else {
-        return expr->isA<kir::BinaryOp>() || expr->isA<kir::TernaryOp>();
-      }
-    }
+  // Check if we are a Pointwise TensorView op.
+  bool isPwiseTVOp(const Expr* expr) {
+    // Ignore set operations
+    if (expr->outputs().size() == 1 && ir_utils::isTV(expr->output(0)) &&
+        ((expr->getExprType().value() == ExprType::UnaryOp &&
+          expr->as<UnaryOp>()->getUnaryOpType() != UnaryOpType::Set) ||
+         expr->getExprType().value() == ExprType::BinaryOp ||
+         expr->getExprType().value() == ExprType::TernaryOp))
+      return true;
     return false;
   }
 
   // Find an Input Allocate that is compatible with the Output Allocate
-  const kir::Allocate* findCompatibleInputAllocate(
-      const DataType output_dtype,
+  kir::Allocate* findCompatibleInputAllocate(
       const std::string& output_size_str,
-      const kir::Expr* expr) {
+      Expr* expr) {
     // Stop searching if current op is not point-wise
-    if (!isPointwiseTvOp(expr)) {
+    if (!isPwiseTVOp(expr)) {
       return nullptr;
     }
 
-    const kir::TensorView* first_tv_input = nullptr;
-    for (const auto input : expr->inputs()) {
-      if (auto input_tv = dynamic_cast<const kir::TensorView*>(input)) {
-        if (first_tv_input == nullptr) {
-          first_tv_input = input_tv;
-        }
+    const auto& expr_inputs_iter =
+        ir_utils::filterByType<TensorView>(expr->inputs());
 
-        // input_alloc == nullptr implies that input_tv is a kernel input
-        const auto input_alloc = map_tv_to_allocations_[input_tv->name()];
-        if (input_alloc != nullptr) {
-          if (candidate_alias_tv_.find(input_tv) != candidate_alias_tv_.end() &&
-              output_size_str == SymbolicSizePrinter::printSize(input_alloc) &&
-              output_dtype == input_tv->dtype() &&
-              map_tv_to_last_usage_[input_tv] <= map_expr_to_pos_[expr]) {
-            return input_alloc;
-          }
+    std::vector<TensorView*> expr_inputs(
+        expr_inputs_iter.begin(), expr_inputs_iter.end());
+
+    for (const auto input : expr_inputs) {
+      auto input_alloc = map_tv_to_allocations_[input->name()];
+
+      // input_allocation == nullptr implies that input_tv is a fusion input.
+      if (input_alloc != nullptr) {
+        if (candidate_alias_tv_.find(input) != candidate_alias_tv_.end() &&
+            output_size_str == SymbolicSizePrinter::print_size(input_alloc) &&
+            map_tv_to_last_usage_[input] <= map_expr_to_pos_[expr]) {
+          return input_alloc;
         }
       }
     }
 
     // Assume the first argument contains the primary variable
     // Follow path along point-wise operations
-    if (first_tv_input != nullptr &&
-        map_tv_to_last_usage_[first_tv_input] <= map_expr_to_pos_[expr]) {
-      if (const auto def = first_tv_input->definition()) {
-        return findCompatibleInputAllocate(output_dtype, output_size_str, def);
+    if (!expr_inputs.empty()) {
+      auto first_input_argument_tv = expr_inputs.front()->getOrigin();
+      if (first_input_argument_tv != nullptr) {
+        return findCompatibleInputAllocate(
+            output_size_str, first_input_argument_tv);
       }
     }
-
     return nullptr;
   }
 
-  void handle(kir::Expr* expr) {
-    const size_t expr_index = map_expr_to_pos_.size();
+  void handle(Expr* expr) final {
+    size_t expr_index = map_expr_to_pos_.size();
     map_expr_to_pos_[expr] = expr_index;
 
     if (ir_utils::isTVOp(expr)) {
-      const auto output_tv = expr->outputs()[0]->as<kir::TensorView>();
+      const auto output = expr->output(0)->as<TensorView>();
+      map_tv_to_origin_expr_[output] = expr;
 
-      const auto alloc_it = map_tv_to_allocations_.find(output_tv->name());
-      if (alloc_it != map_tv_to_allocations_.end()) {
-        const bool smem_valid = (output_tv->memoryType() == MemoryType::Shared);
+      bool has_allocation = map_tv_to_allocations_.find(output->name()) !=
+          map_tv_to_allocations_.end();
+
+      if (has_allocation) {
+        bool smem_valid = output->getMemoryType() == MemoryType::Shared;
 
         bool local_valid = false;
-        if (output_tv->memoryType() == MemoryType::Local) {
-          const auto allocation = alloc_it->second;
-          const auto register_size =
-              expr_evaluator_.evaluate(allocation->size());
-          if (register_size.has_value()) {
-            local_valid = size_t(*register_size) > kRegisterSizeThreshold;
+        if (output->getMemoryType() == MemoryType::Local) {
+          auto allocation = map_tv_to_allocations_[output->name()];
+          auto inferred_register_size =
+              eval_evaluator_.inferValue(allocation->size());
+          if (inferred_register_size.has_value()) {
+            local_valid = inferred_register_size.value() >
+                static_cast<int64_t>(register_size_threshold_);
           }
         }
 
@@ -173,36 +188,34 @@ class AllocateReuseModifier {
         // its allocation size must exceed the threshold
         // OR be in shared memory
         if (smem_valid || local_valid) {
-          candidate_alias_tv_.insert(output_tv);
+          candidate_alias_tv_.insert(output);
         }
       }
 
-      for (auto input_tv :
-           ir_utils::filterByType<kir::TensorView>(expr->inputs())) {
-        map_tv_to_last_usage_[input_tv] = expr_index;
+      const auto& expr_inputs =
+          ir_utils::filterByType<TensorView>(expr->inputs());
+      for (const auto input : expr_inputs) {
+        map_tv_to_last_usage_[input] = expr_index;
       }
-    } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
-      handle(ite);
-    } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
-      handle(for_loop);
-    } else if (auto allocate = dynamic_cast<kir::Allocate*>(expr)) {
-      handle(allocate);
+    } else {
+      OptOutDispatch::handle(expr);
     }
   }
 
-  void handle(kir::Allocate* allocate) {
-    if (auto tv = dynamic_cast<const kir::TensorView*>(allocate->buffer())) {
-      map_tv_to_allocations_[tv->name()] = allocate;
+  void handle(kir::Allocate* a) final {
+    if (a->buffer()->getValType().value() == ValType::KirTensorView) {
+      auto tv = a->buffer()->as<kir::TensorView>()->fuserTv();
+      map_tv_to_allocations_[tv->name()] = a;
     }
   }
 
-  void handle(const kir::ForLoop* for_loop) {
-    for (auto expr : for_loop->body().exprs()) {
+  void handle(kir::ForLoop* fl) final {
+    for (auto expr : fl->body().exprs()) {
       handle(expr);
     }
   }
 
-  void handle(const kir::IfThenElse* ite) {
+  void handle(kir::IfThenElse* ite) final {
     for (auto expr : ite->thenBody().exprs()) {
       handle(expr);
     }
@@ -213,29 +226,39 @@ class AllocateReuseModifier {
 
  private:
   // Expression Evaluator to infer size of register allocation
-  kir::ExpressionEvaluator expr_evaluator_;
+  StatefulExpressionEvaluator eval_evaluator_;
+
+  // Alias local memory if it exceeds this threshold
+  const size_t register_size_threshold_;
 
   // Map expression to unique position
-  // TODO: elaborate - position relative to what?
-  std::unordered_map<const kir::Expr*, size_t> map_expr_to_pos_;
+  std::unordered_map<Expr*, size_t> map_expr_to_pos_;
+
+  // Map TensorView to origin expression
+  std::unordered_map<const TensorView*, Expr*> map_tv_to_origin_expr_;
 
   // Map TensorView to last usage expression position
-  std::unordered_map<const kir::TensorView*, size_t> map_tv_to_last_usage_;
+  std::unordered_map<const TensorView*, size_t> map_tv_to_last_usage_;
 
   // Map TensorView name to Allocate node
-  std::unordered_map<StmtNameType, kir::Allocate*> map_tv_to_allocations_;
+  std::unordered_map<unsigned int, kir::Allocate*> map_tv_to_allocations_;
 
   // Track candidate TensorViews whose Allocate nodes
   // could potentially alias another Allocate node
-  std::unordered_set<const kir::TensorView*> candidate_alias_tv_;
+  std::unordered_set<const TensorView*> candidate_alias_tv_;
 };
 
 } // namespace
 
-std::vector<kir::Expr*> reuseMemoryAllocations(
-    const std::vector<kir::Expr*>& exprs) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::reuseMemoryAllocations");
-  AllocateReuseModifier arm;
+std::vector<Expr*> reuseMemoryAllocations(
+    Fusion* fusion,
+    const std::vector<Expr*>& exprs) {
+  FUSER_PERF_SCOPE("reuseMemoryAllocations");
+  FusionGuard fg(fusion);
+
+  // Alias local memory if it exceeds this threshold
+  const size_t register_size_threshold = 1;
+  AllocateReuseModifier arm(fusion, register_size_threshold);
   arm.modify(exprs);
   return exprs;
 }
index dfe75db..128fa39 100644 (file)
@@ -28,8 +28,9 @@ namespace cuda {
 //!          is not used after this op:
 //! then alias output Allocate to input Allocate.
 //!
-std::vector<kir::Expr*> reuseMemoryAllocations(
-    const std::vector<kir::Expr*>& exprs);
+std::vector<Expr*> reuseMemoryAllocations(
+    Fusion* fusion,
+    const std::vector<Expr*>& exprs);
 
 } // namespace cuda
 } // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp
deleted file mode 100644 (file)
index d53ba8f..0000000
+++ /dev/null
@@ -1,591 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
-
-#include <unordered_set>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-class AllocationInserter : public kir::MutableIrVisitor {
- private:
-  struct AllocationInformation {
-    // The for loop that the allocation must be placed in, nullptr if not within
-    // a loop
-    kir::ForLoop* for_loop = nullptr;
-
-    // The expression that this allocation must be placed before
-    kir::Expr* place_before = nullptr;
-
-    // The allocation position relative to buffer
-    size_t alloc_pos = 0;
-
-    // The buffer this allocation is for
-    kir::TensorView* buffer = nullptr;
-
-    // The allocation expression
-    kir::Allocate* alloc_expr = nullptr;
-
-    // Initialization
-    kir::Expr* init_expr = nullptr;
-  };
-
-  // Find allocation point
-  void findAllocationPosition(AllocationInformation& info, kir::Expr* expr) {
-    size_t alloc_pos = 0;
-    kir::ForLoop* for_loop = nullptr;
-    auto fuser_tv = info.buffer->fuserTv();
-    size_t fl_idx_next = 0;
-
-    for (auto fl : for_loops) {
-      if (alloc_pos == fuser_tv->getComputeAtPosition()) {
-        break;
-      }
-
-      if (fuser_tv->axis(alloc_pos)->isReduction()) {
-        const auto outputs =
-            FusionGuard::getCurFusion()->getTerminatingOutputs();
-        TORCH_INTERNAL_ASSERT(
-            std::find(outputs.begin(), outputs.end(), fuser_tv) !=
-                outputs.end(),
-            "Invalid computeAt of T",
-            fuser_tv->name(),
-            ". A reducation axis is detected within computeAt axes even though it is not an output tensor.");
-        break;
-      }
-
-      auto fl_id = fl->iter_domain();
-
-      if (fl_id->parallelType() == ParallelType::Unroll) {
-        break;
-      }
-
-      auto local_id = gpu_lower->lowerValue(fuser_tv->axis(alloc_pos))
-                          ->as<kir::IterDomain>();
-
-      if (gpu_lower->caLoopMap().areMapped(local_id, fl_id)) {
-        alloc_pos++;
-      }
-
-      for_loop = fl;
-      ++fl_idx_next;
-    }
-
-    info.alloc_pos = alloc_pos;
-    info.for_loop = for_loop;
-
-    if (info.for_loop == nullptr) {
-      info.place_before = for_loops.size() > 0 ? for_loops[0] : expr;
-    } else {
-      if (info.for_loop == for_loops.back()) {
-        // Inline allocation, place before expr
-        info.place_before = expr;
-      } else {
-        // Place allocation after the last computeAt axis
-        // TODO: may be more efficient to place before the first non-computeAt
-        // axis
-        info.place_before = for_loops.at(fl_idx_next);
-      }
-    }
-  }
-
-  // Create initialization expression if init_val is non-null.
-  void createInitExpr(AllocationInformation& info, kir::Val* init_val) {
-    if (init_val == nullptr) {
-      info.init_expr = nullptr;
-      return;
-    }
-
-    auto fuser_tv = info.buffer->fuserTv();
-
-    std::vector<kir::IterDomain*> init_dims;
-    for (size_t axis_i = info.alloc_pos; axis_i < fuser_tv->nDims(); axis_i++) {
-      if (info.buffer->fuserTv()->axis(axis_i)->isReduction() ||
-          info.buffer->fuserTv()->axis(axis_i)->isBroadcast()) {
-        continue;
-      }
-      auto concrete_id =
-          gpu_lower
-              ->lowerValue(gpu_lower->caParallelMap().getConcreteMappedID(
-                  fuser_tv->axis(axis_i)))
-              ->as<kir::IterDomain>();
-      init_dims.push_back(concrete_id);
-    }
-    kir::Expr* init_expr = ir_builder.create<kir::UnaryOp>(
-        UnaryOpType::Set, info.buffer, init_val);
-    for (auto init_loop_it = init_dims.rbegin();
-         init_loop_it != init_dims.rend();
-         ++init_loop_it) {
-      auto id = *init_loop_it;
-      kir::ForLoop* new_loop = nullptr;
-      auto extent_with_halo = gpu_lower->haloInfo().getExtent(id);
-      if (extent_with_halo) {
-        new_loop = ir_builder.create<kir::ForLoop>(
-            id,
-            ir_builder.create<kir::Int>(c10::nullopt),
-            nullptr,
-            extent_with_halo,
-            nullptr,
-            false,
-            nullptr);
-      } else {
-        new_loop = ir_builder.create<kir::ForLoop>(id);
-      }
-      new_loop->body().push_back(init_expr);
-      init_expr = new_loop;
-    }
-    info.init_expr = init_expr;
-  }
-
-  std::vector<kir::Val*> getGlobalAllocationSizes(AllocationInformation& info) {
-    const auto& domain = info.buffer->domain();
-    const auto& maybe_rfactor_domain =
-        domain->hasRFactor() ? domain->rfactorDomain() : domain->rootDomain();
-
-    std::vector<kir::Val*> alloc_dims;
-
-    for (const auto id : maybe_rfactor_domain) {
-      if (id->isReduction() ||
-          id->iterType() == IterType::BroadcastWithoutStride) {
-        continue;
-      }
-      auto extent = id->extent();
-      // Use halo-extended extent if found
-      auto halo_extent = gpu_lower->haloInfo().getRootAxisInfo(id);
-      if (halo_extent.hasHalo()) {
-        extent = ir_builder.addExpr(extent, halo_extent.width());
-      }
-      alloc_dims.push_back(extent);
-    }
-
-    return alloc_dims;
-  }
-
-  // Get allocation extents of root axes with halo
-  //
-  // Allocation can be done with leaf IDs with halo as well, but
-  // allocation size could be larger than necessary.
-  //
-  // For example, suppose the shift offset of an axis is 1. When it is
-  // split by N, the halo size of the inner output is N+1. When the
-  // allocation only has the inner split output, the allocation size
-  // would be N+1. Suppose that ID is further split by M, the output
-  // extents would be N/M and M+1. The allocation size based on the
-  // leaves would be N/M*(M+1) or N+N/M, which is larger than N+1.
-  //
-  // This function tries to propagate back halo informatin to root
-  // axes to avoid inflating allocations. It fails when merged domains
-  // are split and only one of the split outputs is used for
-  // allocations since in such a case we can't un-merge and properly
-  // determine the extents of the merge inputs. Currently, that
-  // results in an exception, but it may be more reasonable to simply
-  // fall back to the leaf-based allocation.
-  //
-  // See the FusionShiftDoubleSplit test for an example case.
-  std::vector<kir::Val*> getNonGlobalAllocExprWithHalo(
-      TensorView* tv,
-      const std::vector<IterDomain*>& alloc_domains) {
-    std::vector<Val*> start_vals;
-    std::transform(
-        alloc_domains.begin(),
-        alloc_domains.end(),
-        std::back_inserter(start_vals),
-        [](IterDomain* dom) { return dom->as<Val>(); });
-
-    // Get all exprs involved in generating the allocation IDs
-    auto exprs = ExprSort::getExprs(tv->fusion(), start_vals);
-
-    // Get the halo extent if found
-    auto getExtent = [this](IterDomain* id) {
-      auto extent = gpu_lower->haloInfo().getExtent(id);
-      if (extent == nullptr) {
-        extent = gpu_lower->lowerValue(id->extent());
-      }
-      return extent;
-    };
-
-    std::unordered_map<IterDomain*, kir::Val*> known_extents;
-
-    // IterDomains that are allocated fully. For example, if an ID is
-    // split and only one of them is used for allocation, that's not
-    // considered full. Only full domains can be unmerged, which is
-    // needed to propagate back the halo information to root domains.
-    std::unordered_set<IterDomain*> full_domains;
-
-    for (auto alloc_domain : alloc_domains) {
-      known_extents.insert({alloc_domain, getExtent(alloc_domain)});
-      full_domains.insert(alloc_domain);
-    }
-
-    for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) {
-      auto expr = *it;
-      if (auto merge = dynamic_cast<Merge*>(expr)) {
-        auto out_it = known_extents.find(merge->out());
-        // If nothing is know about the out id, no propagation can be
-        // done. Note that's not necessarily an error.
-        if (out_it == known_extents.end()) {
-          continue;
-        }
-        // Similarly, if the extent of the out id is not full extent,
-        // we can't un-merge it.
-        if (full_domains.find(merge->out()) == full_domains.end()) {
-          continue;
-        }
-        // Since the extent of the out id is full, the extent of each
-        // of the input axes is also full
-        known_extents.insert({merge->inner(), getExtent(merge->inner())});
-        full_domains.insert(merge->inner());
-        known_extents.insert({merge->outer(), getExtent(merge->outer())});
-        full_domains.insert(merge->outer());
-        known_extents.erase(out_it);
-      } else if (auto split = dynamic_cast<Split*>(expr)) {
-        auto inner = split->inner();
-        const auto inner_it = known_extents.find(inner);
-        auto outer = split->outer();
-        const auto outer_it = known_extents.find(outer);
-        if (inner_it != known_extents.end() &&
-            outer_it != known_extents.end()) {
-          if (full_domains.find(inner) != full_domains.end() &&
-              full_domains.find(outer) != full_domains.end()) {
-            known_extents.insert({split->in(), getExtent(split->in())});
-            full_domains.insert(split->in());
-          } else {
-            known_extents.insert(
-                {split->in(),
-                 ir_builder.mulExpr(outer_it->second, inner_it->second)});
-          }
-          known_extents.erase(inner_it);
-          known_extents.erase(outer_it);
-        } else if (inner_it != known_extents.end()) {
-          known_extents.insert({split->in(), inner_it->second});
-          known_extents.erase(inner_it);
-        } else if (outer_it != known_extents.end()) {
-          known_extents.insert({split->in(), outer_it->second});
-          known_extents.erase(outer_it);
-        }
-      } else {
-        TORCH_INTERNAL_ASSERT(false, "Unexpected expr: ", expr);
-      }
-    }
-
-    std::vector<kir::Val*> alloc_dims;
-
-    for (auto root_axis : tv->getRootDomain()) {
-      auto it = known_extents.find(root_axis);
-      if (it == known_extents.end()) {
-        continue;
-      }
-      alloc_dims.push_back(it->second);
-      known_extents.erase(it);
-    }
-
-    // known_extents should have only mappings for root axes, so
-    // if anything remains in the map, it's an error
-    if (!known_extents.empty()) {
-      std::stringstream ss;
-      for (auto kv : known_extents) {
-        ss << kv.first << " ";
-      }
-      TORCH_INTERNAL_ASSERT(
-          false, "Non-root axes found for TV", tv->name(), ": ", ss.str());
-    }
-
-    return alloc_dims;
-  }
-
-  std::vector<kir::Val*> getNonGlobalAllocExpr(AllocationInformation& info) {
-    auto fuser_tv = info.buffer->fuserTv();
-    const auto memory_type = info.buffer->memoryType();
-    TORCH_INTERNAL_ASSERT(
-        memory_type != MemoryType::Global,
-        "Invalid memory type: ",
-        memory_type);
-
-    std::vector<kir::Val*> alloc_dims;
-
-    bool has_halo = false;
-    std::vector<IterDomain*> alloc_domains;
-
-    for (size_t axis_i = 0; axis_i < fuser_tv->nDims(); axis_i++) {
-      const auto local_id =
-          gpu_lower->lowerValue(fuser_tv->axis(axis_i))->as<kir::IterDomain>();
-
-      if (
-          // If we're reducing this dimension, don't use it in the allocation
-          // computation
-          local_id->isReduction() ||
-          // If this is a broadcast dimension, don't use it in the allocation
-          // computation
-          local_id->isBroadcast()) {
-        continue;
-      }
-
-      auto concrete_id =
-          gpu_lower
-              ->lowerValue(gpu_lower->caParallelMap().getConcreteMappedID(
-                  fuser_tv->axis(axis_i)))
-              ->as<kir::IterDomain>();
-      const bool is_block_dim =
-          isParallelTypeBlockDim(concrete_id->parallelType());
-      const bool is_thread_dim =
-          isParallelTypeThreadDim(concrete_id->parallelType());
-      const bool is_thread = isParallelTypeThread(concrete_id->parallelType());
-
-      if (axis_i < info.alloc_pos) {
-        // Even when the axis is outside the allocation position, if the
-        // tensor is shared with respect to the axis, the buffer size
-        // needs to be expanded for the axis. Sharing occurs in two
-        // cases: 1) the tensor is on shared memory with the axis
-        // parallelized by TIDs, and 2) the tensor is on global memory
-        // with the axis parallelized by TIDs or BIDs.
-        if (!((memory_type == MemoryType::Shared && is_thread_dim) ||
-              (memory_type == MemoryType::Global && is_thread))) {
-          continue;
-        }
-        alloc_domains.push_back(fuser_tv->axis(axis_i));
-      } else {
-        if (
-            // If shared memory, don't use any IDs bound to a grid dimension
-            (memory_type == MemoryType::Shared && is_block_dim) ||
-            // If local memory, don't use any IDs bound to a grid or block
-            // dimension
-            (memory_type == MemoryType::Local && is_thread)) {
-          continue;
-        }
-        alloc_domains.push_back(fuser_tv->axis(axis_i));
-      }
-
-      auto extent = concrete_id->extent();
-
-      if (gpu_lower->haloInfo().getExtent(fuser_tv->axis(axis_i)) != nullptr) {
-        has_halo = true;
-      }
-
-      alloc_dims.push_back(extent);
-    }
-
-    // When an axis with halo extension is detected, propagate back
-    // the halo extents from leaf IDs to root IDs
-    if (has_halo) {
-      return getNonGlobalAllocExprWithHalo(fuser_tv, alloc_domains);
-    }
-
-    return alloc_dims;
-  }
-
-  void createAllocExpr(AllocationInformation& info, bool is_output) {
-    if (is_output) {
-      info.alloc_expr = nullptr;
-      return;
-    }
-
-    std::vector<kir::Val*> alloc_dims;
-    const MemoryType memory_type = info.buffer->memoryType();
-
-    if (memory_type == MemoryType::Global) {
-      alloc_dims = getGlobalAllocationSizes(info);
-    } else {
-      alloc_dims = getNonGlobalAllocExpr(info);
-    }
-
-    if (alloc_dims.size() == 0 &&
-        info.buffer->domain()->noReductions().size() != 0) {
-      alloc_dims.push_back(ir_builder.create<kir::Int>(1));
-    }
-
-    // Create the allocation node
-    info.alloc_expr = ir_builder.create<kir::Allocate>(
-        info.buffer, info.buffer->memoryType(), alloc_dims);
-  }
-
-  void handle(kir::Expr* expr) {
-    if (!ir_utils::isTVOp(expr) || expr->isA<kir::Allocate>()) {
-      expr->accept(this);
-      return;
-    }
-
-    // // Found where the allocation needs to be inserted
-
-    for (auto out : expr->outputs()) {
-      if (!out->isA<kir::TensorView>()) {
-        continue;
-      }
-
-      auto out_tv = out->as<kir::TensorView>();
-      auto default_val =
-          gpu_lower->predicateElimination().getInitValue(out_tv->fuserTv());
-
-      kir::Val* init = nullptr;
-      if (expr->isA<kir::ReductionOp>() && out_tv->fuserTv()->hasReduction()) {
-        TORCH_INTERNAL_ASSERT(
-            default_val == nullptr,
-            "Reduction should not have a default initialization value for predicate elimination.");
-        init = expr->as<kir::ReductionOp>()->init();
-      } else if (expr->isA<kir::WelfordOp>()) {
-        TORCH_INTERNAL_ASSERT(
-            default_val == nullptr,
-            "Welford should not have a default initialization value for predicate elimination.");
-        const auto welford = expr->as<kir::WelfordOp>();
-        if (out->id() == welford->outVar()->id()) {
-          init = welford->initVar() == nullptr
-              ? ir_builder.create<kir::Double>(0)
-              : welford->initVar();
-        } else if (out->id() == welford->outAvg()->id()) {
-          init = welford->initAvg() == nullptr
-              ? ir_builder.create<kir::Double>(0)
-              : welford->initAvg();
-        } else {
-          TORCH_INTERNAL_ASSERT(
-              out->id() == welford->outN()->id(), "Unreachable");
-          init = welford->initN();
-        }
-      } else if (default_val != nullptr) {
-        init = default_val;
-      }
-
-      const bool is_output = gpu_lower->kernel()->isOutput(out);
-
-      // Don't need to alloc outputs, and if we don't need to initialize we're
-      // done.
-      if (is_output && init == nullptr) {
-        continue;
-      }
-
-      AllocationInformation allocation;
-      allocation.buffer = out_tv;
-      findAllocationPosition(allocation, expr);
-      createAllocExpr(allocation, is_output);
-      createInitExpr(allocation, init);
-
-      allocs.push_back(allocation);
-    }
-  }
-
-  void visit(kir::ForLoop* fl) final {
-    for_loops.push_back(fl);
-    // Modifying in place, make a copy of the vector
-    const std::vector<kir::Expr*> exprs = fl->body().exprs();
-    for (auto expr : exprs) {
-      handle(expr);
-    }
-    for_loops.pop_back();
-  }
-
-  void visit(kir::IfThenElse*) final {
-    TORCH_INTERNAL_ASSERT(
-        false,
-        "Pass does not support conditional statements, ",
-        "this pass should be run before any conditionals are placed in code.");
-  }
-
-  AllocationInserter(std::vector<kir::Expr*> _loop_nests)
-      : loop_nests_(std::move(_loop_nests)),
-        gpu_lower(GpuLower::current()),
-        ir_builder(gpu_lower->kernel()) {
-    // Compute all allocations
-    const std::vector<kir::Expr*> exprs = loop_nests_;
-    for (auto expr : exprs) {
-      handle(expr);
-    }
-
-    // First, place allocations of dynamic smem tensors at the very
-    // beginning of the expr list. Traverse backward as they should be
-    // placed in topological order.
-    for (auto it = allocs.rbegin(); it != allocs.rend(); ++it) {
-      const auto& alloc = *it;
-      if (alloc.alloc_expr == nullptr) {
-        continue;
-      }
-      // Dynamic smem exprs need to be at the begining of the kernel outside for
-      // loops
-      if (alloc.buffer->memoryType() == MemoryType::Shared &&
-          !kir::ExpressionEvaluator::isConst(alloc.alloc_expr->size())) {
-        loop_nests_.insert(loop_nests_.begin(), alloc.alloc_expr);
-      }
-    }
-
-    // Place the remaining allocations.
-    for (const auto& alloc : allocs) {
-      if (alloc.alloc_expr == nullptr) {
-        continue;
-      }
-      if (alloc.buffer->memoryType() == MemoryType::Shared &&
-          !kir::ExpressionEvaluator::isConst(alloc.alloc_expr->size())) {
-        continue;
-      }
-      if (alloc.for_loop == nullptr) {
-        auto place_before_it = std::find(
-            loop_nests_.begin(), loop_nests_.end(), alloc.place_before);
-        TORCH_INTERNAL_ASSERT(
-            place_before_it != loop_nests_.end(),
-            "Could not figure out where to place allocation. ",
-            "Use of the buffer, ",
-            toString(alloc.buffer),
-            ", could not be found.",
-            toString(alloc.place_before));
-        loop_nests_.insert(place_before_it, alloc.alloc_expr);
-      } else {
-        alloc.for_loop->body().insert_before(
-            alloc.place_before, alloc.alloc_expr);
-      }
-    }
-
-    // Now that allocations are in place, place the initializations
-    for (const auto& alloc : allocs) {
-      if (alloc.init_expr == nullptr) {
-        continue;
-      }
-      if (alloc.for_loop == nullptr) {
-        auto place_before_it = std::find(
-            loop_nests_.begin(), loop_nests_.end(), alloc.place_before);
-        // Don't need a check here as if the allocation placement succeeded
-        // this will too
-        loop_nests_.insert(place_before_it, alloc.init_expr);
-      } else {
-        alloc.for_loop->body().insert_before(
-            alloc.place_before, alloc.init_expr);
-      }
-    }
-  }
-
- private:
-  std::deque<AllocationInformation> allocs;
-
-  std::vector<kir::ForLoop*> for_loops;
-
-  std::vector<kir::Expr*> loop_nests_;
-
-  GpuLower* gpu_lower;
-
-  kir::IrBuilder ir_builder;
-
- public:
-  static std::vector<kir::Expr*> insert(
-      const std::vector<kir::Expr*>& loop_nests) {
-    AllocationInserter inserter(loop_nests);
-    return inserter.loop_nests_;
-  }
-};
-
-} // namespace
-
-std::vector<kir::Expr*> insertAllocations(
-    const std::vector<kir::Expr*>& exprs) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::insertAllocations");
-  return AllocationInserter::insert(exprs);
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.h b/torch/csrc/jit/codegen/cuda/lower_allocation.h
deleted file mode 100644 (file)
index d3d2c02..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Insert buffer allocations
-std::vector<kir::Expr*> insertAllocations(const std::vector<kir::Expr*>& exprs);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp b/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp
deleted file mode 100644 (file)
index 427aa9d..0000000
+++ /dev/null
@@ -1,1327 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_expr_sort.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-
-#include <deque>
-#include <list>
-#include <sstream>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-// TODO: Review const model, and objects
-//  ExprSegmentationSorter
-//    Responsible for going through DAG and proposing things we could try to
-//    merge together, calls "supportedMerge" on these proposed groups to see
-//    if they should be merged together, then merges them if so.
-//  ExprGroup
-//    A group of exprs that are grouped together based on their loop nest
-//    structures.
-//  ExprGroupConnections
-//    Holds vals and what they connect. In other words it's a val that is an
-//    output of a ExprSegmentationSorter "from" and an input of
-//    ExprSegmentationSorter "to". There's nothing preventing from a val being
-//    between groups twice.
-//    TODO: make sure there's nothing wrong with grouping of nodes that
-//    have the same value input twice. i.e. (B = A*A)
-
-// Selecting segments to propose is based on the theorem 4.2 in the paper which
-// makes sure when segment the segmented graph will be a DAG (assumes Fusion is
-// already a DAG). The segmentation code relies on assumptions of DAG-ness
-// during segmentation, meaning proposed merging of groups must maintain the DAG
-// property of the graph.
-//
-// Julien Herrmann, Yusuf Ã–zkaya, Bora Uçar, Kamer Kaya, Umit Catalyurek.
-// Multilevel Algorithms for Acyclic Partitioning of Directed Acyclic Graphs.
-// SIAM Journal on Scientific Computing, Society for Industrial and Applied
-// Mathematics, 2019, 41 (4), pp.A2117-A2145. ff10.1137/18M1176865ff.
-// ffhal02306566f
-
-class ExprGroup;
-struct ExprGroupConnections;
-class ExprSegmentationSorter;
-
-// Debug printing disabled due to clang tidy, see below for definitions
-// std::ostream& operator<<(std::ostream& os, const ExprGroup* group);
-
-// Wrapper for values, these are edges between expr groups. Multiple edges can
-// exist between expr groups, and the same Val can show up more than once in
-// multiple edges.
-struct ExprGroupConnections {
-  ExprGroupConnections(
-      ExprGroup* group_from,
-      ExprGroup* group_to,
-      Val* producer_val,
-      Val* consumer_val)
-      : from(group_from),
-        to(group_to),
-        producer_val_(producer_val),
-        consumer_val_(consumer_val) {}
-  // Producer group from which the edge starts
-  ExprGroup* from;
-
-  // Consumer group from which the edge ends
-  ExprGroup* to;
-
-  // The value from the producer group connecting the groups
-  // This value helps us resolve the compute at position of expr groups
-
-  Val* producer_val_;
-
-  // The value that the producer val gets used to create on this edge
-  // This value helps us resolve the produce at position of expr groups
-  Val* consumer_val_;
-};
-
-struct ExprSortPayload : public PolymorphicBase {
-  // Need to track compute at domains as well as produce at domains. Produce at
-  // domains will be matched to producers compute at domains. Track the active
-  // domains that will be matched from inner most dim to outer most.
-  std::vector<IterDomain*> ca_domains_;
-  std::vector<IterDomain*> pa_domains_;
-
-  // Maximum path distance from an input expr group required for
-  // Theorem 4.2
-  int level = -1;
-
-  // Traversal marker, marks if this group has been visited by current pass
-  bool visited = false;
-
-  // Marks if this group is already selected to merge with another group, marks
-  // which group to merge with
-  ExprGroup* merge_with = nullptr;
-
-  // Marks if this group is already selected to merge with another group
-  bool merged = false;
-};
-
-// Groups together expressions which create a expr group
-class ExprGroup {
- public:
-  ExprGroup() : payload_(std::make_unique<ExprSortPayload>()) {}
-
-  ExprGroup(Expr* expr) : payload_(std::make_unique<ExprSortPayload>()) {
-    exprs_.push_back(expr);
-  }
-
-  ExprGroup(const ExprGroup& other)
-      : payload_(new ExprSortPayload(*(other.payload_))) {}
-
-  ExprGroup& operator=(const ExprGroup& other) {
-    *payload_ = *other.payload_;
-    exprs_ = other.exprs_;
-    return *this;
-  }
-
-  // Clears the traversal information in the payload
-  void clearTraversalInfo();
-
-  // Returns all neighbors, producers and consumers
-  std::vector<ExprGroup*> getNeighbors();
-
-  // Return neighbors of this proven to be safe nodes to merge with in regards
-  // to maining an acyclic graph. This looks at, neighbors  if merged, neighbors
-  // level, and merged neighbors of neighbors level. If fallback_mode_enabled
-  // will return the inverse set of ExprGroups that are proven to be safe
-  // merges.
-  std::vector<ExprGroup*> getMergeCandidates(
-      bool fallback_mode_enabled = false);
-
-  std::unique_ptr<ExprSortPayload>& payload() {
-    return payload_;
-  }
-
-  const auto& producerEdges() const {
-    return producer_edges_;
-  }
-
-  void addProducerEdge(ExprGroupConnections* edge) {
-    addEdge(producer_edges_, edge);
-  }
-
-  void removeProducerEdge(ExprGroupConnections* edge) {
-    removeEdge(producer_edges_, edge);
-  }
-
-  void clearProducerEdges() {
-    producer_edges_.clear();
-  }
-
-  const auto& consumerEdges() const {
-    return consumer_edges_;
-  }
-
-  void addConsumerEdge(ExprGroupConnections* edge) {
-    addEdge(consumer_edges_, edge);
-  }
-
-  void removeConsumerEdge(ExprGroupConnections* edge) {
-    removeEdge(consumer_edges_, edge);
-  }
-
-  void clearConsumerEdges() {
-    consumer_edges_.clear();
-  }
-
-  auto& exprs() {
-    return exprs_;
-  }
-
-  const auto& exprs() const {
-    return exprs_;
-  }
-
- private:
-  static void addEdge(
-      std::vector<ExprGroupConnections*>& edges,
-      ExprGroupConnections* edge_to_add) {
-    edges.push_back(edge_to_add);
-  }
-
-  static void removeEdge(
-      std::vector<ExprGroupConnections*>& edges,
-      ExprGroupConnections* edge_to_remove) {
-    auto it = std::find(edges.begin(), edges.end(), edge_to_remove);
-    TORCH_INTERNAL_ASSERT(it != edges.end(), "Could not find edge to remove.");
-    edges.erase(it);
-  }
-
- private:
-  // "Ancestor nodes", towards inputs of segmentedDAG
-  std::vector<ExprGroupConnections*> producer_edges_;
-
-  // "Descendent nodes", towards outputs of segmentedDAG
-  std::vector<ExprGroupConnections*> consumer_edges_;
-
-  // Exprs that make up the group
-  std::vector<Expr*> exprs_;
-
-  // Stateful traversal information
-  std::unique_ptr<ExprSortPayload> payload_;
-};
-
-// This class sorts expressions guarantees two things, 1) Tensors are produced
-// before they're consumed 2) If the production of two tensors are supposed to
-// share a for loop, they're in an order where they can. (1) is pretty standard
-// of ordering a DAG. (2) is where things get a bit complicated and why we do
-// this sorting through segmentation. Consider a section of a DAG: T4 = T3 + T2.
-// Where T2 and T3 are not inputs to the fusion, all tensors are 3D, and we want
-// the production of T3 to share the inner most loop of T4 and we want the
-// production of T2 to share the middle loop with T4. i.e. we're looking for
-// For(i:I){
-//   For(j: J){
-//     For(k: K){
-//       T2[i, j, k] = ...
-//     }
-//     For(k: K){
-//       T3[i, j, k] = ...
-//       T4[i, j, k] = T2[i, j, k] + T3[i, j, k]
-//     }
-//   }
-// }
-// The only valid ordering of expressions is producing T2, then T3, then T4. If
-// we swapped T3 and T2, then T3 and T4 couldn't share their inner most loop,
-// because T2 has its own inner most loop. If we swapped either tensor with T4,
-// then we'd try to be using T2 or T3 without producing them (back to gaurantee
-// 1).
-class ExprSegmentationSorter {
- public:
-  ExprSegmentationSorter(Fusion* fusion) : complete_fusion_(fusion) {}
-
-  void sort();
-
-  std::string toString(int verbosity = 0) const;
-
-  //! Returns a flattened list of sorted exprs
-  std::vector<Expr*> getExprs() const;
-
- private:
-  // Allocate an empty expr group and return it
-  ExprGroup* makeEmptyGroup();
-
-  // Allocate an expr group with the provided expr and return it
-  ExprGroup* makeEmptyGroup(Expr*);
-
-  // Returns if sg1 and sg2 should be merged together, is called if they can
-  // based on the current status of the DAG.
-  bool supportedMerge(ExprGroup* sg1, ExprGroup* sg2);
-
-  // Returns true if the graph will remain an acyclic graph after merging sg1
-  // and sg2
-  bool testStillDag(ExprGroup* sg1, ExprGroup* sg2);
-
-  // Merges two ExprGroups and returns the new ExprGroup
-  ExprGroup* makeMergedNode(ExprGroup* sg1, ExprGroup* sg2);
-
-  // This is called once no more groups can be merged together. This will lower
-  // the compute at position of a segment group if the last dimension of the
-  // segment group doesn't map to any of the dimensions of its neighbors.
-  bool interIterUpdate();
-
-  // Reset the ExprSortPayload of the groups so we can traverse and identify
-  // merge candidates.
-  void resetTraversal();
-
-  // Reset the set levels of each group. This is what's used to identify which
-  // nodes can be merged together.
-  void resetLevels();
-
-  // Go through groups that are marked as to merge and merge them.
-  void mergeNodes();
-
-  // Disconnect the edges connecting group to the rest of the graph, and return
-  // all the edges that were disconnected
-  std::unordered_set<ExprGroupConnections*> disconnectGroup(ExprGroup* group);
-
- private:
-  // Track how many groups we have from iteration to iteration so we can track
-  // when we've stopped merging nodes.
-  size_t n_groups_ = 0;
-
-  // Lifetime of the graph view of the fusion and segmentation. Use list to not
-  // invalidate any entries on insertion/deletion.
-  std::list<std::unique_ptr<ExprGroupConnections>> edges_;
-  std::list<std::unique_ptr<ExprGroup>> groups_;
-
-  std::deque<ExprGroup*> to_visit_;
-
-  std::unordered_set<ExprGroup*> to_merge_;
-
-  // Maintain my own fusion the state of which is not always the same as the
-  // original provided fusion.
-  Fusion* complete_fusion_;
-
-  // We use a theorem out of a paper mentioned in other comments. This theorem
-  // is good at identifying multiple expr groups to merge during a single
-  // iteration without producing a cyclic graph from an acyclic graph. This
-  // theorem is not guaranteed to find all possible nodes that can be merged
-  // together. We need to be able to group all disjoint groups of exprs or
-  // we fail to generate code. Therefore, if we can't find anything to make
-  // forward progress based on the theorem we fallback to manually looking if we
-  // can segmenet all combinations we haven't previously looked at.
-  bool fallback_mode_enabled_ = false;
-};
-
-// // Debug printing, disabled due to clang-tidy see above for declarations.
-// std::ostream& operator<<(std::ostream& os, ExprGroup* group) {
-//   os << "Group Start{\n  ca, pa ("
-//      << group->payload()->ca_domains_.size() << ", "
-//      << group->payload()->pa_domains_.size() << ")";
-//   os << " ca_ids {";
-//   for (size_t i = 0; i < group->payload()->ca_domains_.size(); i++) {
-//     os << group->payload()->ca_domains_[i];
-//     if (i + 1 != group->payload()->ca_domains_.size())
-//       os << ", ";
-//   }
-//   os << "} pa_ids {";
-//   for (size_t i = 0; i < group->payload()->pa_domains_.size(); i++) {
-//     os << group->payload()->pa_domains_[i];
-//     if (i + 1 != group->payload()->pa_domains_.size())
-//       os << ", ";
-//   }
-//   os << "}";
-//   os << "\nExprs {\n";
-//   for(auto expr : group->exprs()){
-//     os << expr;
-//   }
-//    os << "}Group End\n";
-//   return os;
-// }
-
-std::vector<ExprGroup*> ExprGroup::getNeighbors() {
-  std::vector<ExprGroup*> neighbors;
-  for (auto inp : producer_edges_) {
-    neighbors.push_back(inp->from);
-  }
-  for (auto out : consumerEdges()) {
-    neighbors.push_back(out->to);
-  }
-  return neighbors;
-}
-
-std::vector<ExprGroup*> ExprGroup::getMergeCandidates(
-    bool fallback_mode_enabled) {
-  std::vector<ExprGroup*> neighbors = getNeighbors();
-
-  // Don't look for candidates if already merged
-  if (payload()->merged) {
-    return {};
-  }
-
-  // Can this node be merged with another? Check if neighbors are merged, if
-  // so and merged neighbor is within 1 level or node merged with neighbor is
-  // within 1 level, can't merge this node with anything else.
-  bool can_merge_this = true;
-  bool neighbor_merged = false;
-  for (auto neighbor : neighbors) {
-    if (!neighbor->payload()->merged) {
-      continue;
-    }
-    neighbor_merged = true;
-    if (std::abs(neighbor->payload()->level - payload()->level) <= 1) {
-      can_merge_this = false;
-    }
-    if (std::abs(
-            neighbor->payload()->merge_with->payload()->level -
-            payload()->level) <= 1) {
-      can_merge_this = false;
-    }
-  }
-
-  // If something prevents us from merging this node, and we're not in fallback
-  // mode, return empty set.
-  if (!can_merge_this && !fallback_mode_enabled) {
-    return {};
-  }
-
-  // If fallback mode already detected a merge somewhere, we shouldn't still be
-  // traversing.
-  if (fallback_mode_enabled) {
-    TORCH_INTERNAL_ASSERT(
-        !neighbor_merged,
-        "Shouldn't still be traversing in fallback mode if a merge was found.");
-  }
-
-  std::vector<bool> can_merge(true, neighbors.size());
-
-  // Find neighbors with a level that is only 1 differant than this groups level
-  for (size_t i = 0; i < neighbors.size(); i++) {
-    if (std::abs(neighbors[i]->payload()->level - payload()->level) > 1) {
-      can_merge[i] = false;
-    }
-  }
-
-  // Check neighbor of neighbors we're considering, if any of them are merged
-  // with another node, make sure the resulting edge wouldn't have a level
-  // difference of 1
-  for (size_t i = 0; i < neighbors.size(); i++) {
-    if (!can_merge[i]) {
-      continue;
-    }
-
-    for (auto neighbor_neighbor : neighbors[i]->getNeighbors()) {
-      // Don't check self
-      if (neighbor_neighbor == neighbors[i]) {
-        continue;
-      }
-      if (neighbor_neighbor->payload()->merged) {
-        // check neighbor_neighbor level
-        if (std::abs(neighbor_neighbor->payload()->level - payload()->level) <=
-            1) {
-          can_merge[i] = false;
-        }
-        if (std::abs(
-                neighbor_neighbor->payload()->level -
-                neighbors[i]->payload()->level) <= 1) {
-          can_merge[i] = false;
-        }
-
-        // check neighbor_neighber->merged->level
-        if (std::abs(
-                neighbor_neighbor->payload()->merge_with->payload()->level -
-                payload()->level) <= 1) {
-          can_merge[i] = false;
-        }
-        if (std::abs(
-                neighbor_neighbor->payload()->merge_with->payload()->level -
-                neighbors[i]->payload()->level) <= 1) {
-          can_merge[i] = false;
-        }
-      }
-    }
-  }
-
-  std::vector<ExprGroup*> merge_candidates;
-  for (size_t i = 0; i < neighbors.size(); i++) {
-    if ((can_merge[i] && !fallback_mode_enabled) ||
-        (!can_merge[i] && fallback_mode_enabled)) {
-      merge_candidates.push_back(neighbors[i]);
-    }
-  }
-  return merge_candidates;
-}
-
-void ExprGroup::clearTraversalInfo() {
-  payload()->level = -1;
-  payload()->visited = false;
-  payload()->merge_with = nullptr;
-  payload()->merged = false;
-}
-
-void ExprSegmentationSorter::resetTraversal() {
-  for (auto& group : groups_) {
-    // Start traversal at input groups
-    if (group->producerEdges().empty()) {
-      to_visit_.push_back(group.get());
-    }
-    group->clearTraversalInfo();
-  }
-}
-
-// Level is maximum distance from inputs. It's the metric used to select what
-// nodes can be merged while maintaining a DAG
-void ExprSegmentationSorter::resetLevels() {
-  std::vector<ExprGroup*> next_to_visit;
-
-  while (!to_visit_.empty()) {
-    auto visit = to_visit_.front();
-    to_visit_.pop_front();
-
-    // All inputs processed?
-    bool ready = true;
-    if (!visit->producerEdges().empty()) {
-      ready = std::all_of(
-          visit->producerEdges().begin(),
-          visit->producerEdges().end(),
-          [&](ExprGroupConnections* dep) {
-            return dep->from->payload()->visited;
-          });
-    }
-
-    if (!ready) {
-      // In case traversal doesn't complete because there's an error in the
-      // DAG topology.
-      next_to_visit.push_back(visit);
-      continue;
-    }
-
-    visit->payload()->visited = true;
-
-    to_visit_.insert(
-        to_visit_.end(), next_to_visit.begin(), next_to_visit.end());
-    next_to_visit.clear();
-
-    for (auto out : visit->consumerEdges()) {
-      to_visit_.push_back(out->to);
-    }
-
-    visit->payload()->level = 0;
-    for (auto inp : visit->producerEdges()) {
-      visit->payload()->level =
-          std::max(visit->payload()->level, inp->from->payload()->level + 1);
-    }
-  }
-  TORCH_INTERNAL_ASSERT(next_to_visit.empty(), "Error in graph, is not a DAG.");
-}
-
-ExprGroup* ExprSegmentationSorter::makeEmptyGroup() {
-  groups_.push_back(std::make_unique<ExprGroup>());
-  return groups_.back().get();
-}
-
-ExprGroup* ExprSegmentationSorter::makeEmptyGroup(Expr* expr) {
-  auto group = makeEmptyGroup();
-  group->exprs().push_back(expr);
-  if (ir_utils::isTVOp(expr)) {
-    auto out_tv = expr->outputs()[0]->as<TensorView>();
-    // Grab all id's that are shared with other tensors.
-    for (size_t tv_i = 0; tv_i < out_tv->getComputeAtPosition(); tv_i++) {
-      group->payload()->ca_domains_.push_back(out_tv->axis(tv_i));
-    }
-    for (size_t tv_i = 0; tv_i < out_tv->getMaxProducerPosition(); tv_i++) {
-      group->payload()->pa_domains_.push_back(out_tv->axis(tv_i));
-    }
-  }
-  return group;
-}
-
-// Debug function that prints the current state of the sorter.
-std::string ExprSegmentationSorter::toString(int verbosity) const {
-  std::stringstream ss;
-  ss << "{\n";
-  for (auto& group : groups_) {
-    ss << "  " << group.get() << "\n";
-
-    if (verbosity > 1) {
-      if (group->producerEdges().size() > 0) {
-        ss << "Produced by groups with edges: { \n";
-        for (auto producer_edge : group->producerEdges()) {
-          ss << producer_edge->producer_val_ << " -> "
-             << producer_edge->consumer_val_ << "\n";
-        }
-        ss << "    }"
-           << "\n";
-      }
-    }
-
-    if (verbosity > 1) {
-      if (group->consumerEdges().size() > 0) {
-        ss << "Consumed by groups with edges: { \n";
-        for (auto consumer_edge : group->consumerEdges()) {
-          ss << consumer_edge->producer_val_ << " -> "
-             << consumer_edge->consumer_val_ << "\n";
-        }
-        ss << "    }"
-           << "\n";
-      }
-    }
-  }
-  ss << "}\n";
-  return ss.str();
-}
-
-namespace {
-
-// Concat's edges of sg1 and sg2, but removes any edges from/to sg1/sg2
-std::vector<ExprGroupConnections*> getMergedEdges(
-    const ExprGroup* sg1,
-    const std::vector<ExprGroupConnections*>& edges1,
-    const ExprGroup* sg2,
-    const std::vector<ExprGroupConnections*>& edges2) {
-  TORCH_INTERNAL_ASSERT(
-      sg1 != nullptr && sg2 != nullptr,
-      "This function doesn't handle trivial.");
-
-  auto merged_edges = edges1;
-  merged_edges.insert(merged_edges.end(), edges2.begin(), edges2.end());
-
-  // Remove intra edges
-  merged_edges.erase(
-      std::remove_if(
-          merged_edges.begin(),
-          merged_edges.end(),
-          [&sg1, &sg2](ExprGroupConnections* se) {
-            return (se->to == sg1 && se->from == sg2) ||
-                (se->to == sg2 && se->from == sg1);
-          }),
-      merged_edges.end());
-
-  return merged_edges;
-}
-
-// Concat's producer edges of sg1 and sg2, but removes any edges from/to sg1/sg2
-std::vector<ExprGroupConnections*> getMergedProducerEdges(
-    const ExprGroup* sg1,
-    const ExprGroup* sg2) {
-  return getMergedEdges(sg1, sg1->producerEdges(), sg2, sg2->producerEdges());
-}
-
-// Concat's consumer edges of sg1 and sg2, but removes any edges from/to sg1/sg2
-std::vector<ExprGroupConnections*> getMergedConsumerEdges(
-    const ExprGroup* sg1,
-    const ExprGroup* sg2) {
-  return getMergedEdges(sg1, sg1->consumerEdges(), sg2, sg2->consumerEdges());
-}
-
-// Assuming sg1 and sg2 are connected, figure out which is the consumer
-ExprGroup* getProducer(ExprGroup* sg1, ExprGroup* sg2) {
-  for (auto producer_edge : sg1->producerEdges()) {
-    if (producer_edge->from == sg2) {
-      return sg2;
-    }
-  }
-
-  for (auto consumer_edge : sg1->consumerEdges()) {
-    if (consumer_edge->to == sg2) {
-      return sg1;
-    }
-  }
-
-  return nullptr;
-}
-
-// Go through all expressions and compute a local ordering of loops. Since
-// overloading comparison operators for iter domains doesn't make a lot of
-// sense, we instead fake having a < operator by considering that every
-// expressions output domain must be relatively ordered correctly. So we use all
-// of the expressions in a group to get a "local" ordering of the output IDs in
-// the group. We can't rely on any single expression because it may or may not
-// have all loops in the group. We also can't break ties without all
-// expressions.
-//
-// For example two expressions may have domains: [I0], [I1] Yet we
-// won't know the ordering unless we see a domain with: [I0, I1]. This happened
-// in advancedIndexing9 test when merging T5 with the group containing T10
-// (cache of T5, which is post broadcasted output) and T6(pre broadcasted
-// output).
-// T5 had the domain [0, 1, 2, 3, 4] produce at 3
-// T6 had the domain [0, 3, 4] compute at 3
-// Merging [0, 1, 2] and [0, 3, 4] resulted in the domain [0, 3, 4, 1, 2]
-//
-// If ID's are not in filter, we don't care about their ordering and ignore
-// them. This is because we're really focused on loops we will have to merge
-// across groups.If the domain is not in a produce at position in the producer
-// edges, or a compute at position in the consumer edges, the expressions we
-// look at may not have a unique ordering.
-std::vector<IterDomain*> getLocalDomainOrdering(
-    const std::vector<Expr*>& exprs,
-    const ComputeAtMap& map,
-    const std::unordered_set<IterDomain*> filter) {
-  if (exprs.empty()) {
-    return std::vector<IterDomain*>();
-  }
-
-  std::vector<std::vector<IterDomain*>> domains;
-
-  for (auto expr : exprs) {
-    if (!ir_utils::isTVOp(expr)) {
-      continue;
-    }
-
-    auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs());
-    for (auto tv_input : tv_inputs) {
-      std::vector<IterDomain*> domain(
-          tv_input->domain()->domain().begin(),
-          tv_input->domain()->domain().begin() +
-              std::max(
-                  tv_input->getComputeAtPosition(),
-                  tv_input->getMaxProducerPosition()));
-
-      domain.erase(
-          std::remove_if(
-              domain.begin(),
-              domain.end(),
-              [&filter, &map](IterDomain* id) {
-                return filter.find(map.getConcreteMappedID(id)) == filter.end();
-              }),
-          domain.end());
-
-      domains.emplace_back(domain);
-    }
-  }
-
-  if (domains.size() == 1) {
-    return domains[0];
-  }
-
-  std::vector<IterDomain*> merged_domains;
-
-  // For each domain, keep an iterator to the current iter domain we're
-  // checking, and an iterator for the end of the domain.
-  typedef std::pair<
-      std::vector<IterDomain*>::const_iterator,
-      std::vector<IterDomain*>::const_iterator>
-      iter_pair_t;
-
-  std::vector<iter_pair_t> iterators(domains.size());
-  for (auto i : c10::irange(domains.size())) {
-    iterators[i] = std::make_pair(domains[i].begin(), domains[i].end());
-  }
-
-  auto empty = [](iter_pair_t& iter_pair) {
-    return iter_pair.first == iter_pair.second;
-  };
-
-  size_t candidate_i = 0;
-  size_t iterations_since_merge = 0;
-  IterDomain* last_id_checked = nullptr;
-
-  while (std::any_of(
-      iterators.begin(), iterators.end(), [](iter_pair_t iter_pair) {
-        return iter_pair.first != iter_pair.second;
-      })) {
-    TORCH_INTERNAL_ASSERT(
-        iterations_since_merge <= iterators.size(),
-        "Infinite loop detected in lower_expr_sort:mergeDomains.");
-    iterations_since_merge++;
-
-    if (candidate_i == iterators.size()) {
-      candidate_i = 0;
-    }
-    if (empty(iterators[candidate_i])) {
-      candidate_i++;
-      continue;
-    }
-
-    auto iter_dom_candidate = *iterators[candidate_i].first;
-    if (iter_dom_candidate == last_id_checked) {
-      candidate_i++;
-      continue;
-    }
-    last_id_checked = iter_dom_candidate;
-
-    bool candidate_is_next = true;
-
-    // Make sure this iter domain is in all first positions of all iter
-    // lists that contain it, otherwise it shouldn't be the next iter domain.
-    for (auto iterator : iterators) {
-      if (empty(iterator)) {
-        continue;
-      }
-      if (!map.areMapped(iter_dom_candidate, *iterator.first)) {
-        if (std::any_of(
-                iterator.first + 1,
-                iterator.second,
-                [&map, iter_dom_candidate](IterDomain* id) {
-                  return map.areMapped(iter_dom_candidate, id);
-                })) {
-          candidate_is_next = false;
-          break;
-        }
-      }
-    }
-
-    if (!candidate_is_next) {
-      candidate_i++;
-      continue;
-    }
-
-    merged_domains.emplace_back(map.getConcreteMappedID(iter_dom_candidate));
-
-    for (auto match_i : c10::irange(iterators.size())) {
-      if (empty(iterators[match_i])) {
-        continue;
-      }
-      if (map.areMapped(iter_dom_candidate, *iterators[match_i].first)) {
-        iterators[match_i] = std::make_pair(
-            iterators[match_i].first + 1, iterators[match_i].second);
-      }
-    }
-    iterations_since_merge = 0;
-  }
-
-  return merged_domains;
-}
-} // namespace
-
-// Disconect group from neighbors, and return edges that were disconnected
-std::unordered_set<ExprGroupConnections*> ExprSegmentationSorter::
-    disconnectGroup(ExprGroup* group) {
-  std::unordered_set<ExprGroupConnections*> removed_edges(
-      group->producerEdges().begin(), group->producerEdges().end());
-
-  for (auto edge : group->producerEdges()) {
-    edge->from->removeConsumerEdge(edge);
-  }
-
-  for (auto edge : group->consumerEdges()) {
-    edge->to->removeProducerEdge(edge);
-  }
-
-  group->clearProducerEdges();
-  group->clearConsumerEdges();
-
-  return removed_edges;
-}
-
-// TODO: This function may be sub optimial. If we find that an iteration domain
-// matches later in the other domain, we will hold all other iteration domains
-// until that one matches. There may be cases where duplicating that iteration
-// domain, and moving on could be more efficient.
-ExprGroup* ExprSegmentationSorter::makeMergedNode(
-    ExprGroup* sg1,
-    ExprGroup* sg2) {
-  // Keep Expr's sorted in topological order.
-  const auto producer = getProducer(sg1, sg2);
-  const auto consumer = sg1 == producer ? sg2 : sg1;
-
-  // Make the new joined node
-  auto joined_groups = makeEmptyGroup();
-
-  TORCH_INTERNAL_ASSERT(
-      producer != nullptr,
-      "Tried to merge expr's together that aren't neighbors.");
-
-  joined_groups->exprs() = producer->exprs();
-  joined_groups->exprs().insert(
-      joined_groups->exprs().end(),
-      consumer->exprs().begin(),
-      consumer->exprs().end());
-
-  auto producer_edges = getMergedProducerEdges(sg1, sg2);
-  // Connect joined group to resulting neighbors
-  for (auto& edge : producer_edges) {
-    auto from = edge->from;
-    auto producer_val = edge->producer_val_;
-    auto consumer_val = edge->consumer_val_;
-
-    edges_.push_back(std::make_unique<ExprGroupConnections>(
-        from, joined_groups, producer_val, consumer_val));
-
-    joined_groups->addProducerEdge(edges_.back().get());
-    from->addConsumerEdge(edges_.back().get());
-  }
-
-  auto consumer_edges = getMergedConsumerEdges(sg1, sg2);
-
-  for (auto& edge : consumer_edges) {
-    auto to = edge->to;
-    auto producer_val = edge->producer_val_;
-    auto consumer_val = edge->consumer_val_;
-
-    edges_.push_back(std::make_unique<ExprGroupConnections>(
-        joined_groups, to, producer_val, consumer_val));
-    joined_groups->addConsumerEdge(edges_.back().get());
-    edge->to->addProducerEdge(edges_.back().get());
-  }
-
-  // Merge the compute at domain of all edges going out from the newly joined
-  // group. The val's we're looking for are from our consumer edges, but we want
-  // to grab the producer val as that's the one we generate.
-  std::unordered_set<IterDomain*> ca_ids;
-  for (auto consumer_group_edge : joined_groups->consumerEdges()) {
-    auto producer_of_consumer_edge = consumer_group_edge->producer_val_;
-    if (producer_of_consumer_edge->isA<TensorView>()) {
-      auto tv = producer_of_consumer_edge->as<TensorView>();
-      for (size_t tv_i = 0; tv_i < tv->getComputeAtPosition(); tv_i++) {
-        ca_ids.emplace(GpuLower::current()->caLoopMap().getConcreteMappedID(
-            tv->axis(tv_i)));
-      }
-    }
-  }
-
-  // Merge the produce at domain of all edges coming into the newly joined
-  // group. The val's we're looking for are from our producer edges, but we want
-  // to grab the consumer val as that's the one we generate.
-  std::unordered_set<IterDomain*> pa_ids;
-  for (auto producer_group_edge : joined_groups->producerEdges()) {
-    auto consumer_of_producer_edge = producer_group_edge->consumer_val_;
-    if (consumer_of_producer_edge->isA<TensorView>()) {
-      auto tv = consumer_of_producer_edge->as<TensorView>();
-      for (size_t tv_i = 0; tv_i < tv->getMaxProducerPosition(); tv_i++) {
-        pa_ids.emplace(GpuLower::current()->caLoopMap().getConcreteMappedID(
-            tv->axis(tv_i)));
-      }
-    }
-  }
-
-  auto all_ca_pa_ids = ca_ids;
-  all_ca_pa_ids.insert(pa_ids.begin(), pa_ids.end());
-
-  auto ordered_ids = getLocalDomainOrdering(
-      joined_groups->exprs(), GpuLower::current()->caLoopMap(), all_ca_pa_ids);
-
-  for (auto id : ordered_ids) {
-    if (ca_ids.count(id)) {
-      joined_groups->payload()->ca_domains_.emplace_back(id);
-    }
-    if (pa_ids.count(id)) {
-      joined_groups->payload()->pa_domains_.emplace_back(id);
-    }
-  }
-
-  return joined_groups;
-}
-
-bool canReducePA(ExprGroup* group) {
-  if (group->payload()->pa_domains_.empty()) {
-    return false;
-  }
-
-  IterDomain* group_pa_last_id = group->payload()->pa_domains_.back();
-
-  // Look through producer edges to see if we can reduce our produce at domain
-  for (auto producer_edge : group->producerEdges()) {
-    auto producer_val = producer_edge->producer_val_;
-    auto consumer_val = producer_edge->consumer_val_;
-
-    // If producer isn't a tensor view it can't be mapped into a producer dim of
-    // this group
-    if (!(consumer_val->isA<TensorView>() && producer_val->isA<TensorView>())) {
-      continue;
-    }
-
-    // If the compute at domains of the producer group is empty, it can't map to
-    // the produce at domains of this group
-    auto producer_group = producer_edge->from;
-    if (producer_group->payload()->ca_domains_.empty()) {
-      continue;
-    }
-
-    auto producer_tv = producer_val->as<TensorView>();
-    auto consumer_tv = consumer_val->as<TensorView>();
-
-    // If this consumer_tv doesn't map to the last producer domain of this group
-    // it can't decide if it can be reduced
-    bool has_matching_pa = false;
-    for (size_t i = 0; i < consumer_tv->getMaxProducerPosition(); i++) {
-      if (GpuLower::current()->caLoopMap().areMapped(
-              consumer_tv->axis(i), group_pa_last_id)) {
-        has_matching_pa = true;
-        break;
-      }
-    }
-
-    if (!has_matching_pa) {
-      continue;
-    }
-
-    // If any compute at positions of producers directly map to the last produce
-    // at position it can't be lowered.
-    for (int producer_pos_i = producer_tv->getComputeAtPosition();
-         producer_pos_i > 0;
-         producer_pos_i--) {
-      if (GpuLower::current()->caLoopMap().areMapped(
-              producer_tv->axis(producer_pos_i - 1), group_pa_last_id)) {
-        return false;
-      }
-    }
-  }
-
-  return true;
-}
-
-// Update in between attempts to segment. This is called once no more groups
-// can be merged together. Typically we will want to remove compute at groups
-// that have finished being grouped together. However if no groups have been
-// merged after we've done this, we may need to stop as we could have multiple
-// disjoint groups that won't be merged.
-bool ExprSegmentationSorter::interIterUpdate() {
-  // Go through groups and lower either pa or ca domain return if anything was
-  // lowered
-  bool lowered_a_domain = false;
-  for (auto& group : groups_) {
-    if (canReducePA(group.get())) {
-      group->payload()->pa_domains_.pop_back();
-      lowered_a_domain = true;
-    }
-  }
-
-  // If we couldn't lower compute at domain any further, and we haven't merged
-  // any new groups after fallback_mode_enabled_ has been turned on, make sure
-  // we've finished successfully
-  if (!lowered_a_domain && n_groups_ == groups_.size()) {
-    // Make sure none of the groups are still connected, as that would mean we
-    // should have been able to merge them.
-    bool successfully_finished = std::all_of(
-        groups_.begin(), groups_.end(), [](std::unique_ptr<ExprGroup>& sg) {
-          return sg->producerEdges().empty() && sg->consumerEdges().empty();
-        });
-    if (successfully_finished) {
-      return false;
-    }
-    // If we didn't finish and we tried the fallback, throw.
-    TORCH_INTERNAL_ASSERT(
-        !fallback_mode_enabled_,
-        "Couldn't succcessfully sort out the fusion expressions. ",
-        "There are remaining connections of the heirarchical segmentation which should have been ",
-        "flattened to a single ordered group, or disjoint ordered groups.");
-    // We didn't finish, but we haven't tried the fallback, try again with that.
-    fallback_mode_enabled_ = true;
-  }
-
-  n_groups_ = groups_.size();
-  // Not done, continue.
-  return true;
-}
-
-void ExprSegmentationSorter::mergeNodes() {
-  std::unordered_set<ExprGroup*> clean_up_groups;
-  std::unordered_set<ExprGroupConnections*> clean_up_edges;
-
-  while (!to_merge_.empty()) {
-    auto group1 = *to_merge_.begin();
-    auto group2 = group1->payload()->merge_with;
-    to_merge_.erase(group1);
-    to_merge_.erase(group2);
-    clean_up_groups.emplace(group1);
-    clean_up_groups.emplace(group2);
-    makeMergedNode(group1, group2);
-  }
-
-  for (auto group : clean_up_groups) {
-    auto disconnected_edges = disconnectGroup(group);
-    clean_up_edges.insert(disconnected_edges.begin(), disconnected_edges.end());
-  }
-
-  edges_.remove_if([&](std::unique_ptr<ExprGroupConnections>& edge) {
-    return clean_up_edges.find(edge.get()) != clean_up_edges.end();
-  });
-
-  groups_.remove_if([&](std::unique_ptr<ExprGroup>& group) {
-    return clean_up_groups.find(group.get()) != clean_up_groups.end();
-  });
-}
-
-// Two expression groups can be merged together if there's a value produced by
-// producer group, consumed by consumer group, where the compute at position
-// maps to the inner most compute at domain of the producer group and maps to
-// the inner most produce at domain of the consumer. If this value doesn't exist
-// we can't be certain these domains share the "next" inner most loop.
-//
-// We're looking for this because we're starting at the inner most loops of all
-// expressions, and looking for neighboring expressions that share inner loops.
-// Once we've found all the inner most loops that expressions share, we merge
-// them together, then look at the next inner most loop of the group and figure
-// out which other groups share this next inner most loop.
-bool ExprSegmentationSorter::supportedMerge(ExprGroup* sg1, ExprGroup* sg2) {
-  auto producer_group = getProducer(sg1, sg2);
-  auto consumer_group = sg1 == producer_group ? sg2 : sg1;
-
-  if (producer_group->payload()->ca_domains_.size() <
-      producer_group->payload()->pa_domains_.size()) {
-    return false;
-  }
-
-  if (consumer_group->payload()->pa_domains_.size() <
-      consumer_group->payload()->ca_domains_.size()) {
-    return false;
-  }
-
-  const auto& producer_ca_domain = producer_group->payload()->ca_domains_;
-  const auto& consumer_pa_domain = consumer_group->payload()->pa_domains_;
-
-  if (producer_ca_domain.empty() && consumer_pa_domain.empty()) {
-    return true;
-  }
-
-  if (producer_ca_domain.empty() || consumer_pa_domain.empty()) {
-    return false;
-  }
-
-  const auto& loop_map = GpuLower::current()->caLoopMap();
-
-  for (auto edge : producer_group->consumerEdges()) {
-    if (edge->to != consumer_group) {
-      continue;
-    }
-    auto producer_val = edge->producer_val_;
-    auto consumer_val = edge->consumer_val_;
-
-    if (!producer_val->isA<TensorView>()) {
-      continue;
-    }
-
-    TORCH_INTERNAL_ASSERT(
-        consumer_val->isA<TensorView>(),
-        "Mismatched tensorview to non-tensorview in expression sorting. ",
-        producer_val,
-        " is consumed by ",
-        consumer_val);
-
-    auto producer_tv = producer_val->as<TensorView>();
-
-    auto compute_at_pos = producer_tv->getComputeAtPosition();
-    auto compute_at_dim = compute_at_pos > 0
-        ? producer_tv->axis((int)producer_tv->getComputeAtPosition() - 1)
-        : nullptr;
-
-    if (compute_at_dim == nullptr) {
-      continue;
-    }
-
-    if (!loop_map.areMapped(compute_at_dim, producer_ca_domain.back())) {
-      continue;
-    }
-
-    if (loop_map.areMapped(compute_at_dim, consumer_pa_domain.back())) {
-      return true;
-    }
-  }
-  return false;
-}
-
-bool ExprSegmentationSorter::testStillDag(ExprGroup* sg1, ExprGroup* sg2) {
-  std::deque<ExprGroup*> to_visit;
-  std::unordered_set<ExprGroup*> visited;
-  // Add consumers of sg1 if not sg2
-  for (auto sg1_consumer_edge : sg1->consumerEdges()) {
-    if (sg1_consumer_edge->to != sg2) {
-      to_visit.emplace_back(sg1_consumer_edge->to);
-    }
-  }
-
-  // Add consumers of sg2 if not sg1
-  for (auto sg2_consumer_edge : sg2->consumerEdges()) {
-    if (sg2_consumer_edge->to != sg1) {
-      to_visit.emplace_back(sg2_consumer_edge->to);
-    }
-  }
-
-  while (to_visit.size() > 0) {
-    auto group = to_visit.front();
-    // Arrived back at one of the original groups, merging these two groups
-    // would generate a cycle
-    if (group == sg1 || group == sg2) {
-      return false;
-    }
-    to_visit.pop_front();
-    if (visited.find(group) != visited.end()) {
-      continue;
-    }
-    visited.emplace(group);
-    for (auto consumer_edge : group->consumerEdges()) {
-      to_visit.emplace_back(consumer_edge->to);
-    }
-  }
-
-  // No cycles found, we're good.
-  return true;
-}
-
-void ExprSegmentationSorter::sort() {
-  // Need this for initialization of the DAG that is processed
-  std::unordered_map<Expr*, ExprGroup*> expr2group;
-
-  // Initialize DAG, convert each expr to a segment group
-  for (auto expr : complete_fusion_->exprs()) {
-    auto group = makeEmptyGroup(expr);
-    expr2group.insert(std::make_pair(expr, group));
-  }
-
-  // Create edges between the Exprs. Mark inputs and outputs of the fusion.
-  for (auto expr : complete_fusion_->exprs()) {
-    auto expr_group = expr2group.at(expr);
-    auto out = expr->outputs()[0];
-    for (auto inp : expr->inputs()) {
-      if (inp->isFusionInput()) {
-        continue;
-      }
-
-      // Could be something like a constant scalar, definition is nullptr, but
-      // isn't an "input" to the fusion. At least not one provided by an
-      // external source.
-      if (inp->definition() == nullptr) {
-        continue;
-      }
-
-      auto inp_def_group = expr2group.at(inp->definition());
-      edges_.push_back(std::make_unique<ExprGroupConnections>(
-          inp_def_group, expr_group, inp, out));
-      expr_group->addProducerEdge(edges_.back().get());
-      inp_def_group->addConsumerEdge(edges_.back().get());
-    }
-  }
-  bool inter_iter_update = true;
-  while (inter_iter_update) {
-    // If we didn't do any update, stop traversal, we're done.
-    if (!fallback_mode_enabled_) {
-      // Merge expressions in sorted order
-      bool merged_nodes = true;
-      while (merged_nodes) {
-        // Reset stateful traversal details in ExprGroups
-        resetTraversal();
-        resetLevels();
-
-        for (auto& group : groups_) {
-          if (group->payload()->merged) {
-            continue;
-          }
-          auto candidates = group->getMergeCandidates(fallback_mode_enabled_);
-          if (candidates.empty()) {
-            continue;
-          }
-
-          auto candidate_it = candidates.begin();
-          while (candidate_it != candidates.end() &&
-                 !supportedMerge(group.get(), *candidate_it)) {
-            candidate_it++;
-          }
-          if (candidate_it == candidates.end()) {
-            continue;
-          }
-
-          to_merge_.emplace(group.get());
-          to_merge_.emplace(*candidate_it);
-
-          group->payload()->merged = true;
-          group->payload()->merge_with = *candidate_it;
-
-          (*candidate_it)->payload()->merged = true;
-          (*candidate_it)->payload()->merge_with = group.get();
-        }
-
-        if (to_merge_.empty()) {
-          merged_nodes = false;
-        }
-
-        mergeNodes();
-
-        // Move compute at axes left
-        inter_iter_update = interIterUpdate();
-      }
-    } else {
-      // fallback_mode_enabled = true
-      // Reset stateful traversal details in ExprGroups as we'll exclude merge
-      // options that were already ruled out and therefore need traversal and
-      // levels reset.
-      resetTraversal();
-      resetLevels();
-
-      for (auto& group : groups_) {
-        if (group->payload()->merged) {
-          continue;
-        }
-        // Get merge candidates that weren't proven safe to merge with default
-        // algorithm.
-        auto candidates = group->getMergeCandidates(fallback_mode_enabled_);
-        if (candidates.empty()) {
-          continue;
-        }
-
-        auto candidate_it = candidates.begin();
-
-        while (candidate_it != candidates.end()) {
-          while (candidate_it != candidates.end() &&
-                 !supportedMerge(group.get(), *candidate_it)) {
-            candidate_it++;
-          }
-
-          if (candidate_it == candidates.end()) {
-            break;
-          }
-
-          if (testStillDag(group.get(), *candidate_it)) {
-            // Mark in same style as default algorithm for convenience even
-            // though we will only merge once with the fallback
-            to_merge_.emplace(group.get());
-            to_merge_.emplace(*candidate_it);
-
-            group->payload()->merged = true;
-            group->payload()->merge_with = *candidate_it;
-
-            (*candidate_it)->payload()->merged = true;
-            (*candidate_it)->payload()->merge_with = group.get();
-            break;
-          }
-
-          candidate_it++;
-        }
-
-        if (to_merge_.size() > 0) {
-          break;
-        }
-      }
-
-      // If we can merge something, merge it, disable fallback, and bail
-      if (to_merge_.size() > 0) {
-        mergeNodes();
-      }
-
-      // Move compute at axes left
-      // If fallback didn't work, interIterUpdate will catch that we failed.
-      inter_iter_update = interIterUpdate();
-      fallback_mode_enabled_ = false;
-    }
-  }
-}
-
-std::vector<Expr*> ExprSegmentationSorter::getExprs() const {
-  std::vector<Expr*> exprs;
-  for (auto& group : groups_) {
-    exprs.insert(exprs.end(), group->exprs().begin(), group->exprs().end());
-  }
-  return exprs;
-}
-
-} // namespace
-
-std::vector<Expr*> reorderExprsForComputeAt() {
-  auto fusion = FusionGuard::getCurFusion();
-  TORCH_INTERNAL_ASSERT(fusion != nullptr);
-  ExprSegmentationSorter sorter(fusion);
-  sorter.sort();
-  auto sorted_exprs = sorter.getExprs();
-  TORCH_INTERNAL_ASSERT(
-      sorted_exprs.size() > 0,
-      "Error during expression sorting, no expressions produced.");
-  return sorted_exprs;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_expr_sort.h b/torch/csrc/jit/codegen/cuda/lower_expr_sort.h
deleted file mode 100644 (file)
index 4b44541..0000000
+++ /dev/null
@@ -1,15 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-std::vector<Expr*> reorderExprsForComputeAt();
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
index dcc9c28..546ced2 100644 (file)
@@ -16,261 +16,251 @@ namespace cuda {
 
 IndexLowering::IndexLowering() : ir_builder_(GpuLower::current()->kernel()) {}
 
-kir::Val* IndexLowering::lowerSrcIndex(kir::Val* val, kir::Val* dst) const {
-  if (auto tv = dynamic_cast<kir::TensorView*>(val)) {
-    TORCH_INTERNAL_ASSERT(dst->isA<kir::TensorView>());
+Val* IndexLowering::lowerOperand(Val* op, Val* out) const {
+  if (ir_utils::isTV(op)) {
     return Index::getProducerIndex(
-        tv->fuserTv(),
-        dst->as<kir::TensorView>()->fuserTv(),
-        scope_utils::getLoops(active_scope_expr_));
+        ir_utils::asTV(op),
+        ir_utils::asTV(out),
+        scope_utils::getLoops(active_scope_expr));
   } else {
-    return val;
+    return GpuLower::lowerValue(op);
   }
 }
 
-kir::Val* IndexLowering::lowerDstIndex(kir::Val* dst) const {
-  if (auto tv = dynamic_cast<kir::TensorView*>(dst)) {
+Val* IndexLowering::lowerOutput(Expr* expr) const {
+  TORCH_CHECK(expr->outputs().size() == 1);
+  const auto out = expr->output(0);
+  if (ir_utils::isTVOp(expr)) {
     return Index::getConsumerIndex(
-        tv->fuserTv(), scope_utils::getLoops(active_scope_expr_));
+        ir_utils::asTV(out), scope_utils::getLoops(active_scope_expr));
   } else {
-    return dst;
+    return GpuLower::lowerValue(out);
   }
 }
 
-void IndexLowering::pushBack(kir::Expr* expr) {
-  if (active_scope_ == nullptr) {
-    lowered_exprs_.push_back(expr);
+void IndexLowering::pushBack(Expr* expr) {
+  if (active_scope == nullptr) {
+    lowered_exprs.push_back(expr);
   } else {
-    active_scope_->push_back(expr);
+    active_scope->push_back(expr);
   }
 }
 
-void IndexLowering::visit(const kir::IfThenElse* ite) {
-  const auto prev_scope_expr = active_scope_expr_;
-  const auto prev_scope = active_scope_;
+void IndexLowering::handle(kir::IfThenElse* ite) {
+  Expr* prev_scope_expr = active_scope_expr;
+  kir::Scope* prev_scope = active_scope;
 
-  // TODO(kir): try to avoid recreating new nodes and leaving old ones around
-  auto new_ite = ir_builder_.create<kir::IfThenElse>(ite->predicate());
+  auto new_ite =
+      ir_builder_.create<kir::IfThenElse>(ite->cond(), prev_scope_expr);
   pushBack(new_ite);
-
-  active_scope_expr_ = new_ite;
-  active_scope_ = &new_ite->thenBody();
+  active_scope_expr = new_ite;
+  active_scope = &new_ite->thenBody();
 
   for (auto expr : ite->thenBody().exprs()) {
-    expr->accept(this);
+    OptInDispatch::handle(expr);
   }
 
-  active_scope_ = &new_ite->elseBody();
+  active_scope = &new_ite->elseBody();
 
   for (auto expr : ite->elseBody().exprs()) {
-    expr->accept(this);
+    OptInDispatch::handle(expr);
   }
 
-  active_scope_ = prev_scope;
-  active_scope_expr_ = prev_scope_expr;
+  active_scope = prev_scope;
+  active_scope_expr = prev_scope_expr;
 }
 
-void IndexLowering::visit(const kir::ForLoop* for_loop) {
-  const auto prev_scope_expr = active_scope_expr_;
-  const auto prev_scope = active_scope_;
+void IndexLowering::handle(kir::ForLoop* fl) {
+  Expr* prev_scope_expr = active_scope_expr;
+  kir::Scope* prev_scope = active_scope;
 
-  auto new_for_loop = ir_builder_.create<kir::ForLoop>(for_loop);
-  pushBack(new_for_loop);
+  auto newFl = ir_builder_.create<kir::ForLoop>(
+      fl->index(), fl->iter_domain(), prev_scope_expr);
+  pushBack(newFl);
 
-  active_scope_expr_ = new_for_loop;
-  active_scope_ = &new_for_loop->body();
+  active_scope_expr = newFl;
+  active_scope = &newFl->body();
 
-  for (auto expr : for_loop->body().exprs()) {
-    expr->accept(this);
+  for (auto expr : fl->body().exprs()) {
+    OptInDispatch::handle(expr);
   }
 
-  active_scope_ = prev_scope;
-  active_scope_expr_ = prev_scope_expr;
+  active_scope = prev_scope;
+  active_scope_expr = prev_scope_expr;
 }
 
-void IndexLowering::visit(const kir::UnaryOp* uop) {
-  const auto in = lowerSrcIndex(uop->in(), uop->out());
-  const auto out = lowerDstIndex(uop->out());
-  pushBack(ir_builder_.create<kir::UnaryOp>(uop->operation(), out, in));
+void IndexLowering::handle(UnaryOp* uop) {
+  if (ir_utils::isTVOp(uop)) {
+    const auto in = lowerOperand(uop->in(), uop->out());
+    const auto out = lowerOutput(uop);
+    pushBack(ir_builder_.create<kir::UnaryOp>(uop->getUnaryOpType(), out, in));
+  } else {
+    // This will automatically lower the expression defining the value
+    pushBack(GpuLower::lowerValue(uop->out())->getOrigin());
+  }
 }
 
-void IndexLowering::visit(const kir::BinaryOp* bop) {
-  const auto lhs = lowerSrcIndex(bop->lhs(), bop->out());
-  const auto rhs = lowerSrcIndex(bop->rhs(), bop->out());
-  const auto out = lowerDstIndex(bop->out());
-  pushBack(ir_builder_.create<kir::BinaryOp>(bop->operation(), out, lhs, rhs));
+void IndexLowering::handle(BinaryOp* bop) {
+  if (ir_utils::isTVOp(bop)) {
+    const auto lhs = lowerOperand(bop->lhs(), bop->out());
+    const auto rhs = lowerOperand(bop->rhs(), bop->out());
+    const auto out = lowerOutput(bop);
+    pushBack(ir_builder_.create<kir::BinaryOp>(
+        bop->getBinaryOpType(), out, lhs, rhs));
+  } else {
+    // This will automatically lower the expression defining the value
+    pushBack(GpuLower::lowerValue(bop->out())->getOrigin());
+  }
 }
 
-void IndexLowering::visit(const kir::TernaryOp* top) {
-  const auto in1 = lowerSrcIndex(top->in1(), top->out());
-  const auto in2 = lowerSrcIndex(top->in2(), top->out());
-  const auto in3 = lowerSrcIndex(top->in3(), top->out());
-  const auto out = lowerDstIndex(top->out());
-  pushBack(
-      ir_builder_.create<kir::TernaryOp>(top->operation(), out, in1, in2, in3));
+void IndexLowering::handle(TernaryOp* top) {
+  if (ir_utils::isTVOp(top)) {
+    const auto in1 = lowerOperand(top->in1(), top->out());
+    const auto in2 = lowerOperand(top->in2(), top->out());
+    const auto in3 = lowerOperand(top->in3(), top->out());
+    const auto out = lowerOutput(top);
+    pushBack(ir_builder_.create<kir::TernaryOp>(
+        top->getTernaryOpType(), out, in1, in2, in3));
+  } else {
+    // This will automatically lower the expression defining the value
+    pushBack(GpuLower::lowerValue(top->out())->getOrigin());
+  }
 }
 
 namespace {
 
-void allocateGridReductionFlag(
-    kir::TensorView* out_tv,
-    kir::Expr* current_scope_expr) {
+void allocateGridReductionFlag(TensorView* out_tv, Expr* current_scope_expr) {
   kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-
-  const auto flag_name = kir::GridReduction::getPredicateFlagName(out_tv);
-  const auto flag_var = ir_builder.create<kir::Allocate>(
+  auto flag_name = kir::GridReduction::getPredicateFlagName(out_tv);
+  auto flag_var = ir_builder.create<kir::Allocate>(
       ir_builder.create<kir::NamedScalar>(flag_name, DataType::Bool),
       MemoryType::Local,
       ir_builder.create<kir::Int>(1));
-
   // When enclosed by IfThenElse, place the variable outside of the
   // IfThenElse. This IfThenElse is assumed to be the prediate for
   // this grid reduction expression.
-  if (current_scope_expr->isA<kir::IfThenElse>()) {
+  if (current_scope_expr->getExprType() == ExprType::IfThenElse) {
     scope_utils::insertBefore(
-        current_scope_expr->parentScope(), current_scope_expr, flag_var);
+        scope_utils::getParent(current_scope_expr),
+        current_scope_expr,
+        flag_var);
   } else {
-    TORCH_INTERNAL_ASSERT(current_scope_expr->isA<kir::ForLoop>());
-    current_scope_expr->as<kir::ForLoop>()->body().push_back(flag_var);
+    scope_utils::pushBack(current_scope_expr, flag_var);
   }
 }
 
 } // namespace
 
-void IndexLowering::visit(const kir::ReductionOp* rop) {
-  TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(rop));
+void IndexLowering::handle(ReductionOp* rop) {
+  TORCH_INTERNAL_ASSERT(
+      ir_utils::isTVOp(rop),
+      "Cannot have a reduction operation on something other than a tensor view, but received ",
+      rop);
 
-  const auto out_tv = rop->out()->as<kir::TensorView>();
-  const auto out_domain = out_tv->domain();
+  auto out_tv = ir_utils::asTV(rop->out());
 
-  const bool is_block_reduce = out_domain->hasBlockReduction();
-  const bool is_grid_reduce = out_domain->hasGridReduction();
+  const bool is_block_reduce = out_tv->hasBlockReduction();
+  const bool is_grid_reduce = out_tv->hasGridReduction();
 
   // If we do a grid reduction we can't have a reduction axis that is not bound
   // to a grid or block dim ()
   if (is_grid_reduce) {
     TORCH_INTERNAL_ASSERT(
         std::none_of(
-            out_domain->domain().begin(),
-            out_domain->domain().end(),
-            [](kir::IterDomain* id) {
-              return !id->isThread() && id->isReduction() &&
-                  !id->extent()->isOneInt();
+            out_tv->domain()->domain().begin(),
+            out_tv->domain()->domain().end(),
+            [](IterDomain* id) {
+              return !id->isThread() && id->isReduction();
             }),
-        "Found a reduction stage that has both a non-parallelized ",
-        "reduction and a grid reduction.  This is not supported, ",
-        "please use rfactor to do the serialized reduction first, ",
-        "then the grid reduction.");
+        "Found a reduction stage that has both a non-parallelized reduction and a grid reduction.",
+        " This is not supported, please use rfactor to do the serialized reduction first, then the grid reduction.");
   }
+  const auto loops = scope_utils::getLoops(active_scope_expr);
 
-  const auto out = lowerDstIndex(rop->out());
-  const auto in = lowerSrcIndex(rop->in(), rop->out());
+  kir::TensorIndex* out = Index::getConsumerIndex(out_tv, loops);
+  kir::TensorIndex* in = Index::getProducerIndex(
+      ir_utils::asTV(rop->in()), ir_utils::asTV(rop->out()), loops);
 
   kir::ReductionOp* block_reduction_op = nullptr;
-
   if (is_block_reduce) {
+    auto pred =
+        PredicateCompute::getInlinePredicate(rop, loops, nullptr, false);
+
     block_reduction_op = ir_builder_.create<kir::ReductionOp>(
-        rop->operation(), rop->init(), out, in);
-    if (rop->predicate()) {
-      block_reduction_op->setPredicate(rop->predicate());
-    }
-    if (rop->writePredicate()) {
-      block_reduction_op->setWritePredicate(rop->writePredicate());
-    }
+        rop->getReductionOpType(),
+        GpuLower::lowerValue(rop->init()),
+        out,
+        in,
+        pred);
     pushBack(block_reduction_op);
   }
 
   if (is_grid_reduce) {
     // First, declare a boolean flag variable storing the return value
-    // of the gridReduce() helper
-    allocateGridReductionFlag(out_tv, active_scope_expr_);
+    // of gridReduce.
+    allocateGridReductionFlag(out_tv, active_scope_expr);
 
-    auto buffer_ids = out_domain->domain();
+    std::vector<IterDomain*> buffer_ids(out_tv->domain()->domain());
     buffer_ids.erase(
         std::remove_if(
             buffer_ids.begin(),
             buffer_ids.end(),
-            [](kir::IterDomain* id) {
-              return id->isReduction() && !id->isBlockDim();
+            [](IterDomain* id) {
+              return id->isReduction() & !id->isBlockDim();
             }),
         buffer_ids.end());
 
-    kir::Val* buffer_size = buffer_ids.empty() ? ir_builder_.create<kir::Int>(1)
-                                               : buffer_ids[0]->extent();
-
+    Val* buffer_size =
+        buffer_ids.empty() ? new Int(1) : buffer_ids[0]->rawExtent();
     for (const auto i : c10::irange(1, buffer_ids.size())) {
-      buffer_size = ir_builder_.mulExpr(buffer_size, buffer_ids[i]->extent());
+      buffer_size = mul(buffer_size, buffer_ids[i]->rawExtent());
     }
 
-    auto sync_ids = out_domain->domain();
+    std::vector<IterDomain*> sync_ids(out_tv->domain()->domain());
     sync_ids.erase(
         std::remove_if(
             sync_ids.begin(),
             sync_ids.end(),
-            [](kir::IterDomain* id) {
+            [](IterDomain* id) {
               return id->isReduction() || !id->isBlockDim();
             }),
         sync_ids.end());
 
-    kir::Val* sync_size = sync_ids.empty() ? ir_builder_.create<kir::Int>(1)
-                                           : sync_ids[0]->extent();
-
+    Val* sync_size = sync_ids.empty() ? new Int(1) : sync_ids[0]->rawExtent();
     for (const auto i : c10::irange(1, sync_ids.size())) {
-      sync_size = ir_builder_.mulExpr(sync_size, sync_ids[i]->extent());
+      sync_size = mul(sync_size, sync_ids[i]->rawExtent());
     }
 
-    const auto zero = ir_builder_.create<kir::Int>(0);
+    IterDomain* buffer_id = new IterDomain(new Int(0), buffer_size);
+    TensorView* reduce_buffer_tv = new TensorView(
+        new TensorDomain({buffer_id}),
+        out->getDataType().value(),
+        MemoryType::Global);
 
-    const std::vector<kir::IterDomain*> new_buffer_ids = {
-        ir_builder_.create<kir::IterDomain>(zero, buffer_size)};
-    const auto buffer_domain =
-        ir_builder_.create<kir::TensorDomain>(new_buffer_ids);
-    const auto reduce_buffer_tv = ir_builder_.create<kir::TensorView>(
-        out->dtype(), buffer_domain, MemoryType::Global);
-
-    const std::vector<kir::IterDomain*> new_sync_ids = {
-        ir_builder_.create<kir::IterDomain>(zero, sync_size)};
-    const auto sync_domain =
-        ir_builder_.create<kir::TensorDomain>(new_sync_ids);
-    const auto reduce_sync_tv = ir_builder_.create<kir::TensorView>(
-        DataType::Int, sync_domain, MemoryType::Global);
+    IterDomain* sync_id = new IterDomain(new Int(0), sync_size);
+    TensorView* reduce_sync_tv = new TensorView(
+        new TensorDomain({sync_id}), DataType::Int, MemoryType::Global);
 
     const auto reduce_buffer = ir_builder_.create<kir::Allocate>(
-        reduce_buffer_tv, reduce_buffer_tv->memoryType());
-
+        GpuLower::lowerValue(reduce_buffer_tv),
+        reduce_sync_tv->getMemoryType());
     const auto sync_buffer = ir_builder_.create<kir::Allocate>(
-        reduce_sync_tv, reduce_sync_tv->memoryType(), nullptr, true);
+        GpuLower::lowerValue(reduce_sync_tv),
+        reduce_sync_tv->getMemoryType(),
+        nullptr,
+        true);
 
-    const auto grid_reduction_op = (block_reduction_op == nullptr)
+    const auto grid_reduction_op = block_reduction_op == nullptr
         ? ir_builder_.create<kir::ReductionOp>(
-              rop->operation(), rop->init(), out, in)
+              rop->getReductionOpType(),
+              GpuLower::lowerValue(rop->init()),
+              out,
+              in)
         : block_reduction_op;
-
-    // The thread predicate for GridReduction needs to be set
-    // separately from the main predicate. Do not combine them like
-    // other expressions.
-    const auto& thread_pred =
-        GpuLower::current()->threadPredMap().at(out_tv->fuserTv()).pred;
-    auto grid_reduction = ir_builder_.create<kir::GridReduction>(
-        grid_reduction_op, reduce_buffer, sync_buffer);
-    grid_reduction->setThreadPredicate(thread_pred);
-
-    if (rop->predicate()) {
-      // If preceded by a blockReduce, all thread blocks should have
-      // valid inputs to gridReduce. In fact, using the original
-      // predicate does not work when the write predicate of the
-      // blockReduce is different from the read predicate.
-      if (is_block_reduce) {
-        grid_reduction->setPredicate(
-            ir_builder_.create<kir::Predicate>(ir_builder_.trueVal()));
-      } else {
-        grid_reduction->setPredicate(rop->predicate());
-      }
-    }
-
-    if (rop->writePredicate()) {
-      grid_reduction->setWritePredicate(rop->writePredicate());
-    }
+    auto pred =
+        PredicateCompute::getInlinePredicate(rop, loops, nullptr, false);
+    const auto grid_reduction = ir_builder_.create<kir::GridReduction>(
+        grid_reduction_op, reduce_buffer, sync_buffer, pred);
 
     pushBack(reduce_buffer);
     pushBack(sync_buffer);
@@ -278,186 +268,41 @@ void IndexLowering::visit(const kir::ReductionOp* rop) {
   }
 
   if (!is_block_reduce && !is_grid_reduce) {
-    // TODO(kir): this breaks our "SSA" form
-    pushBack(ir_builder_.create<kir::BinaryOp>(rop->operation(), out, out, in));
-  }
-}
-
-namespace {
-
-template <typename T>
-kir::Allocate* allocGlobalBuffer(
-    kir::IrBuilder& ir_builder,
-    const kir::TensorDomain* td,
-    T id_filter,
-    DataType dtype,
-    bool zero_init = false) {
-  auto buffer_ids = td->domain();
-  buffer_ids.erase(
-      std::remove_if(buffer_ids.begin(), buffer_ids.end(), id_filter),
-      buffer_ids.end());
-
-  kir::Val* buffer_size = buffer_ids.empty() ? ir_builder.create<kir::Int>(1)
-                                             : buffer_ids[0]->extent();
-  for (size_t i = 1; i < buffer_ids.size(); i++) {
-    buffer_size = ir_builder.mulExpr(buffer_size, buffer_ids[i]->extent());
+    pushBack(ir_builder_.create<kir::BinaryOp>(
+        rop->getReductionOpType(), out, out, in));
   }
-  const auto zero = ir_builder.create<kir::Int>(0);
-  const std::vector<kir::IterDomain*> new_buffer_ids = {
-      ir_builder.create<kir::IterDomain>(zero, buffer_size)};
-  const auto buffer_domain =
-      ir_builder.create<kir::TensorDomain>(new_buffer_ids);
-  const auto buffer_tv = ir_builder.create<kir::TensorView>(
-      dtype, buffer_domain, MemoryType::Global);
-  return ir_builder.create<kir::Allocate>(
-      buffer_tv, buffer_tv->memoryType(), nullptr, zero_init);
 }
 
-} // namespace
-
-void IndexLowering::visit(const kir::WelfordOp* wop) {
-  TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(wop));
-
-  const auto out_tv = wop->outAvg()->as<kir::TensorView>();
-  const auto out_domain = out_tv->domain();
+void IndexLowering::handle(BroadcastOp* bop) {
+  TORCH_INTERNAL_ASSERT(
+      ir_utils::isTVOp(bop),
+      "Cannot have a broadcast operation on something other than a tensor view, but received ",
+      bop);
 
-  const bool is_block_reduce = out_domain->hasBlockReduction();
-  const bool is_grid_reduce = out_domain->hasGridReduction();
+  auto loops = scope_utils::getLoops(active_scope_expr);
 
-  // If we do a grid reduction we can't have a reduction axis that is not bound
-  // to a grid or block dim ()
-  if (is_grid_reduce) {
-    TORCH_INTERNAL_ASSERT(
-        std::none_of(
-            out_domain->domain().begin(),
-            out_domain->domain().end(),
-            [](kir::IterDomain* id) {
-              return !id->isThread() && id->isReduction();
-            }),
-        "Found a reduction stage that has both a non-parallelized ",
-        "reduction and a grid reduction.  This is not supported, ",
-        "please use rfactor to do the serialized reduction first, ",
-        "then the grid reduction.");
-  }
-
-  // lower IO tensors
-  const auto in_var =
-      wop->inVar() ? lowerSrcIndex(wop->inVar(), wop->outAvg()) : nullptr;
-  const auto in_avg = lowerSrcIndex(wop->inAvg(), wop->outAvg());
-  auto in_N = wop->inN();
-
-  // in Rfactor-ed case, the input N is actually a TV
-  if (!in_N->isScalar()) {
-    in_N = lowerSrcIndex(in_N, wop->outN());
-  }
-
-  auto out_avg = lowerDstIndex(wop->outAvg());
-  auto out_var = lowerDstIndex(wop->outVar());
-  auto out_N = lowerDstIndex(wop->outN());
-
-  kir::WelfordOp* welford_op = ir_builder_.create<kir::WelfordOp>(
-      out_var,
-      out_avg,
-      out_N,
-      wop->initVar(),
-      wop->initAvg(),
-      wop->initN(),
-      in_var,
-      in_avg,
-      in_N);
-
-  kir::WelfordOp* block_welford_op = nullptr;
-
-  if (is_block_reduce) {
-    block_welford_op = welford_op;
-    if (wop->predicate()) {
-      block_welford_op->setPredicate(wop->predicate());
-    }
-    if (wop->writePredicate()) {
-      block_welford_op->setWritePredicate(wop->writePredicate());
-    }
-    pushBack(block_welford_op);
-  }
-
-  if (is_grid_reduce) {
-    // Allocate T_pred
-    allocateGridReductionFlag(out_tv, active_scope_expr_);
-
-    // Buffer allocation
-    auto buffer_filter = [](const kir::IterDomain* id) {
-      return id->isReduction() && !id->isBlockDim();
-    };
-    const auto out_var_buffer = allocGlobalBuffer(
-        ir_builder_, out_domain, buffer_filter, out_var->dtype());
-    const auto out_avg_buffer = allocGlobalBuffer(
-        ir_builder_, out_domain, buffer_filter, out_avg->dtype());
-    const auto out_N_buffer = allocGlobalBuffer(
-        ir_builder_, out_domain, buffer_filter, out_N->dtype());
-    const auto sync_buffer = allocGlobalBuffer(
-        ir_builder_, out_domain, buffer_filter, DataType::Int, true);
-
-    // Grid Welford instantiation
-    const auto grid_welford_op =
-        (block_welford_op == nullptr) ? welford_op : block_welford_op;
-
-    // The thread predicate for GridReduction needs to be set
-    // separately from the main predicate. Do not combine them like
-    // other expressions.
-    const auto& thread_pred =
-        GpuLower::current()->threadPredMap().at(out_tv->fuserTv()).pred;
-
-    auto grid_welford = ir_builder_.create<kir::GridWelford>(
-        grid_welford_op,
-        out_var_buffer,
-        out_avg_buffer,
-        out_N_buffer,
-        sync_buffer);
-
-    grid_welford->setThreadPredicate(thread_pred);
-
-    if (wop->predicate()) {
-      grid_welford->setPredicate(wop->predicate());
-    }
-
-    pushBack(out_var_buffer);
-    pushBack(out_avg_buffer);
-    pushBack(out_N_buffer);
-    pushBack(sync_buffer);
-    pushBack(grid_welford);
-  }
-
-  if (!is_block_reduce && !is_grid_reduce) {
-    pushBack(welford_op);
-  }
-}
-
-void IndexLowering::visit(const kir::BroadcastOp* bop) {
-  TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(bop));
-
-  const auto out = lowerDstIndex(bop->out());
-  const auto in = lowerSrcIndex(bop->in(), bop->out());
-  auto indexed_expr = ir_builder_.create<kir::BroadcastOp>(out, in);
-
-  if (bop->predicate()) {
-    indexed_expr->setPredicate(bop->predicate());
-  }
+  kir::TensorIndex* out =
+      Index::getConsumerIndex(ir_utils::asTV(bop->out()), loops);
 
-  pushBack(indexed_expr);
+  Val* in = bop->in();
+  if (ir_utils::isTV(in))
+    in = Index::getProducerIndex(
+        ir_utils::asTV(in), ir_utils::asTV(bop->out()), loops);
+  pushBack(ir_builder_.create<kir::BroadcastOp>(out, in));
 }
 
-void IndexLowering::visit(const kir::Allocate* allocate) {
-  // TODO(kir): remove the need for const_cast
-  pushBack(const_cast<kir::Allocate*>(allocate)); // NOLINT
+void IndexLowering::handle(kir::Allocate* allocate) {
+  pushBack(allocate);
 }
 
-void IndexLowering::visit(const kir::Sync* sync) {
-  // TODO(kir): remove the need for const_cast
-  pushBack(const_cast<kir::Sync*>(sync)); // NOLINT
+void IndexLowering::handle(kir::Sync* sync) {
+  pushBack(sync);
 }
 
-void IndexLowering::generate(const std::vector<kir::Expr*>& exprs) {
-  for (auto expr : exprs) {
-    expr->accept(this);
+void IndexLowering::generate(const std::vector<Expr*>& exprs) {
+  // Run through loop nests and further lower the expressions
+  for (auto* expr : exprs) {
+    OptInDispatch::handle(expr);
   }
 }
 
index d6139e9..7d7e861 100644 (file)
@@ -2,10 +2,10 @@
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
 
 #include <vector>
 
@@ -14,39 +14,47 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-class TORCH_CUDA_CU_API IndexLowering : private kir::IrVisitor {
+class TORCH_CUDA_CU_API IndexLowering : public OptInDispatch {
  public:
-  static std::vector<kir::Expr*> getIndexedExprs(
-      std::vector<kir::Expr*> incoming_exprs) {
-    FUSER_PERF_SCOPE("GpuLower::Lower::IndexLowering::getIndexedExprs");
+  static std::vector<Expr*> getIndexedExprs(
+      Fusion* fusion,
+      std::vector<Expr*> incoming_exprs) {
+    FUSER_PERF_SCOPE("IndexLowering::getIndexedExprs");
+    FusionGuard fg(fusion);
     IndexLowering il;
     il.generate(incoming_exprs);
-    return il.lowered_exprs_;
+    return il.lowered_exprs;
   }
 
  private:
   IndexLowering();
 
-  void pushBack(kir::Expr*);
+  // Wrap pushBack, if active_scope is null we want it to go
+  // straight to lower_exprs
+  void pushBack(Expr*);
 
-  void visit(const kir::ForLoop*) final;
-  void visit(const kir::IfThenElse*) final;
-  void visit(const kir::UnaryOp*) final;
-  void visit(const kir::BinaryOp*) final;
-  void visit(const kir::TernaryOp*) final;
-  void visit(const kir::ReductionOp*) final;
-  void visit(const kir::WelfordOp*) final;
-  void visit(const kir::BroadcastOp*) final;
-  void visit(const kir::Allocate*) final;
-  void visit(const kir::Sync*) final;
+  // Open the for loop.
+  void handle(kir::ForLoop*) final;
 
-  void generate(const std::vector<kir::Expr*>& exprs);
+  // Open the for loop.
+  void handle(kir::IfThenElse*) final;
 
-  kir::Val* lowerSrcIndex(kir::Val* val, kir::Val* dst) const;
-  kir::Val* lowerDstIndex(kir::Val* dst) const;
+  // Remake operations with TensorIndex
+  void handle(UnaryOp*) final;
+  void handle(BinaryOp*) final;
+  void handle(TernaryOp*) final;
+  void handle(ReductionOp*) final;
+  void handle(BroadcastOp*) final;
+  void handle(kir::Allocate*) final;
+  void handle(kir::Sync*) final;
+
+  void generate(const std::vector<Expr*>& exprs);
+
+  Val* lowerOperand(Val* op, Val* out) const;
+  Val* lowerOutput(Expr* expr) const;
 
  private:
-  std::vector<kir::Expr*> lowered_exprs_;
+  std::vector<Expr*> lowered_exprs;
 
   // This is a slight work around as scope has a couple definitions, we have the
   // Scope that's in ForLoop/IfThenElse which is really just a wrapper around
@@ -54,8 +62,8 @@ class TORCH_CUDA_CU_API IndexLowering : private kir::IrVisitor {
   // to be able to carry both around because when we push back to a scope it
   // could be either the body or else body of the IfThenElse. However, we want
   // to understand the nesting of IfThenElse/ForLoop nodes.
-  kir::Scope* active_scope_ = nullptr;
-  kir::Expr* active_scope_expr_ = nullptr;
+  kir::Scope* active_scope = nullptr;
+  Expr* active_scope_expr = nullptr;
 
   kir::IrBuilder ir_builder_;
 };
index 6c9a3c8..4326b83 100644 (file)
@@ -1,12 +1,12 @@
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+#include <torch/csrc/jit/codegen/cuda/lower_insert_syncs.h>
+
+#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
+#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_insert_syncs.h>
-
-#include <unordered_set>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
 
 namespace torch {
 namespace jit {
@@ -15,109 +15,74 @@ namespace cuda {
 
 namespace {
 
-//! Scan through Kernel IR for-loops to insert Sync nodes to avoid
-//! Write-After-Read (WAR) race condition.
-//!
-//! Example:
-//!   for () {
-//!     smem_buf[threadIdx.x] = x;
-//!     __syncthreads();
-//!     buf[threadId.x] = smem_buf[threadIdx.x + 1];
-//!  }
+//! Scan through Kernel IR to insert Sync nodes to avoid
+//! Write-After-Read (WAR) race condition
 //!
-//! In this case, additional syncthreads is needed at the end of the
-//! loop body to avoid a hazard with smem_buf.
-class LocalSyncInserter {
-  using TvSet = std::unordered_set<const kir::TensorView*>;
-
+class LocalSyncInserter final : private OptOutDispatch {
  public:
-  //! Write-After-Read race conditions are only found within for-loops.
-  //! Sync nodes are inserted directly into the for-loops.
-  //! The expressions are modified in-place and exprs is const.
-  static void insertSyncs(const std::vector<kir::Expr*>& exprs) {
+  // Write-After-Read race conditions are only found within for-loops.
+  // Sync nodes are inserted directly into the for-loops.
+  // The expressions are modified in-place and exprs is const.
+  static void InsertSyncs(const std::vector<Expr*>& exprs) {
+    LocalSyncInserter sync_inserter;
     for (auto expr : exprs) {
-      if (auto fl = dynamic_cast<kir::ForLoop*>(expr)) {
-        LocalSyncInserter sync_inserter(fl);
-      }
-    }
-  }
-
- private:
-  //! Insert Sync nodes at the end of a given for-loop when a WAR
-  //! hazard may happen.
-  LocalSyncInserter(kir::ForLoop* fl) {
-    for (auto expr : fl->body().exprs()) {
-      handle(expr);
-    }
-
-    // No need to insert sync when the loop is not actually generated
-    if (fl->iter_domain()->isThread() || fl->iter_domain()->isBroadcast()) {
-      return;
-    }
-
-    // Determine if any smem TV is written to at beginning of the for-loop
-    // and whether that smem TV is read from at the end of the for-loop
-    // Insert new SyncThreads at end of for-loop to prevent WAR race condition
-    //
-    // TODO: replace __syncthreads with __threadfence for alias ops
-    //
-    if (detectIntersection(initial_, final_) &&
-        !fl->body().exprs().back()->isA<kir::Sync>() && !is_last_op_sync_) {
-      kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-      fl->body().push_back(ir_builder.create<kir::Sync>(true));
-      initial_sync_ = true;
-      is_last_op_sync_ = true;
-      final_.clear();
+      sync_inserter.handle(expr);
     }
   }
 
-  const auto& initial() const {
+  const std::unordered_set<const TensorView*>& initial() const {
     return initial_;
   }
 
-  const auto& final() const {
+  const std::unordered_set<const TensorView*>& final() const {
     return final_;
   }
 
-  const auto& all_smem_inputs() const {
+  const std::unordered_set<const TensorView*>& all_smem_inputs() const {
     return all_smem_inputs_;
   }
 
-  const auto& all_smem_outputs() const {
+  const std::unordered_set<const TensorView*>& all_smem_outputs() const {
     return all_smem_outputs_;
   }
 
-  void handle(kir::Expr* expr) {
-    if (ir_utils::isTVOp(expr)) {
-      is_last_op_sync_ = false;
+  const std::unordered_set<unsigned int>& all_aliased_allocations() const {
+    return all_alias_allocations_;
+  }
+
+ private:
+  explicit LocalSyncInserter(
+      const std::unordered_set<unsigned int>* parent_alias_allocations =
+          nullptr) {
+    if (parent_alias_allocations != nullptr) {
+      all_alias_allocations_.insert(
+          parent_alias_allocations->begin(), parent_alias_allocations->end());
+    }
+  }
 
+  void handle(Expr* expr) final {
+    if (ir_utils::isTVOp(expr)) {
       // For this SyncInserter
-      if (initial_sync_) {
-        addInputSmemTvs(expr, final_);
-      } else {
-        addInputSmemTvs(expr, final_);
-        addOutputSmemTvs(expr, initial_);
-      }
+      (!initial_sync_) ? hasOutputSmemExpr(expr, initial_)
+                       : hasInputSmemExpr(expr, final_);
 
       // For parent SyncInserter
-      addOutputSmemTvs(expr, all_smem_outputs_);
-      addInputSmemTvs(expr, all_smem_inputs_);
-    } else if (auto sync = dynamic_cast<kir::Sync*>(expr)) {
-      handle(sync);
-    } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
-      handle(ite);
-    } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
-      handle(for_loop);
+      hasOutputSmemExpr(expr, all_smem_outputs_);
+      hasInputSmemExpr(expr, all_smem_inputs_);
+    } else {
+      OptOutDispatch::handle(expr);
     }
   }
 
-  void handle(kir::Sync* sync) {
-    is_last_op_sync_ = true;
-    initial_sync_ = true;
-    final_.clear();
+  void handle(kir::Allocate* a) final {
+    if (a->buffer()->getValType().value() == ValType::KirTensorView &&
+        a->alias() != nullptr && a->getMemoryType() == MemoryType::Shared) {
+      auto tv = a->buffer()->as<kir::TensorView>()->fuserTv();
+      all_alias_allocations_.insert(tv->name());
+    }
   }
 
-  void handle(kir::IfThenElse* ite) {
+  void handle(kir::IfThenElse* ite) final {
     for (auto expr : ite->thenBody().exprs()) {
       handle(expr);
     }
@@ -126,59 +91,103 @@ class LocalSyncInserter {
     }
   }
 
-  void handle(kir::ForLoop* fl) {
-    LocalSyncInserter child_sync_inserter(fl);
-
-    const auto& child_inputs = child_sync_inserter.all_smem_inputs();
-    const auto& child_outputs = child_sync_inserter.all_smem_outputs();
-    const bool maybe_skipped = !fl->start()->isZeroInt() &&
-        !isParallelTypeThread(fl->iter_domain()->parallelType());
-
-    // Default - Track all smem inputs / outputs
-    all_smem_inputs_.insert(child_inputs.begin(), child_inputs.end());
-    all_smem_outputs_.insert(child_outputs.begin(), child_outputs.end());
-
-    // Propagate the last_op_sync flag from the child loop. If the
-    // child is deterministically executed at least once, just set the
-    // flag with the child flag. Otherwise, conservatively set the
-    // flag, i.e., if the current flag is true and the child flag is
-    // also true, we can say the last op is still sync.
-    if (!maybe_skipped) {
-      is_last_op_sync_ = child_sync_inserter.is_last_op_sync_;
-    } else {
-      is_last_op_sync_ =
-          is_last_op_sync_ && child_sync_inserter.is_last_op_sync_;
-    }
-
-    // When the child is not guaranteed to have sync.
-    if (!child_sync_inserter.initial_sync_) {
-      // If no sync is yet found, add the child outputs to
-      // initial.
-      if (!initial_sync_) {
-        initial_.insert(child_outputs.begin(), child_outputs.end());
-      }
-      // Add the child inputs to final even when inital_sync is false,
-      // which only means sync may not be found yet.
-      final_.insert(child_inputs.begin(), child_inputs.end());
-    } else {
-      // Similar to the above case, but here, the child is guaranteed
-      // to have sync, so we only need to look at initial and final.
-      if (!initial_sync_) {
-        initial_.insert(
-            child_sync_inserter.initial().begin(),
-            child_sync_inserter.initial().end());
-      }
-      if (!maybe_skipped) {
+  void handle(kir::ForLoop* fl) final {
+    // Track if last op in body is sync in nested for-loop
+    bool is_last_op_sync_ = false;
+    for (auto expr : fl->body().exprs()) {
+      is_last_op_sync_ = false;
+      if (expr->getExprType().value() == ExprType::Sync) {
         initial_sync_ = true;
         final_.clear();
+      } else if (expr->getExprType().value() == ExprType::ForLoop) {
+        // Recursively handle nested for-loop
+        LocalSyncInserter child_sync_inserter(&all_alias_allocations_);
+        child_sync_inserter.handle(expr);
+        const auto& child_inputs = child_sync_inserter.all_smem_inputs();
+        const auto& child_outputs = child_sync_inserter.all_smem_outputs();
+        const auto& child_alias_allocations =
+            child_sync_inserter.all_aliased_allocations();
+
+        // Default - Track all smem inputs / outputs
+        all_smem_inputs_.insert(child_inputs.begin(), child_inputs.end());
+        all_smem_outputs_.insert(child_outputs.begin(), child_outputs.end());
+        all_alias_allocations_.insert(
+            child_alias_allocations.begin(), child_alias_allocations.end());
+
+        if (!initial_sync_) {
+          // Parent - None
+          if (!child_sync_inserter.initial_sync_) {
+            // Child - None
+            // Append All Child Outputs to Parent Initial
+            initial_.insert(child_outputs.begin(), child_outputs.end());
+          } else if (child_sync_inserter.has_war_hazard_sync_) {
+            // Child - WAR race
+            // Parent first sync
+            // Inherit Child Initial / Clear Parent Final
+            initial_sync_ = true;
+            is_last_op_sync_ = true;
+            initial_.insert(
+                child_sync_inserter.initial().begin(),
+                child_sync_inserter.initial().end());
+            final_.clear();
+          } else {
+            // Child - 1+
+            // Parent first sync
+            // Inherit Child Initial + Final
+            initial_sync_ = true;
+            initial_.insert(
+                child_sync_inserter.initial().begin(),
+                child_sync_inserter.initial().end());
+            final_.insert(
+                child_sync_inserter.final().begin(),
+                child_sync_inserter.final().end());
+          }
+        } else {
+          // Parent - 1+
+          if (!child_sync_inserter.initial_sync_) {
+            // Child - None
+            // Append All Child to Parent Last
+            final_.insert(child_inputs.begin(), child_inputs.end());
+          } else if (child_sync_inserter.has_war_hazard_sync_) {
+            // Child - WAR race
+            // Clear Parent Last / Discard Child Initial
+            is_last_op_sync_ = true;
+            final_.clear();
+          } else {
+            // Child - 1+
+            // Inherit Child Final / Discard Child Initial
+            final_.insert(
+                child_sync_inserter.final().begin(),
+                child_sync_inserter.final().end());
+          }
+        }
+      } else {
+        handle(expr);
+      }
+    }
+
+    // This level of the nested for-loop may not exist in the kernel.
+    // However, subsequent levels can exist, so we handle the body of the
+    // for-loop first.
+    if (!fl->iter_domain()->isThread() && !fl->iter_domain()->isBroadcast()) {
+      // Determine if any smem TV is written to at beginning of the for-loop
+      // and whether that smem TV is read from at the end of the for-loop
+      // Insert new SyncThreads at end of for-loop to prevent WAR race condition
+      // TODO: replace __syncthreads with __threadfence for alias ops
+      if (detect_intersection(initial_, final_) &&
+          fl->body().exprs().back()->getExprType().value() != ExprType::Sync &&
+          !is_last_op_sync_) {
+        // std::cout << "WAR race detected; Add Sync" << std::endl;
+        has_war_hazard_sync_ = true;
+        kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+        fl->body().push_back(ir_builder.create<kir::Sync>(true));
       }
-      final_.insert(
-          child_sync_inserter.final().begin(),
-          child_sync_inserter.final().end());
     }
   }
 
-  static bool detectIntersection(const TvSet& left, const TvSet& right) {
+  bool detect_intersection(
+      std::unordered_set<const TensorView*>& left,
+      std::unordered_set<const TensorView*>& right) {
     for (auto item : left) {
       if (right.find(item) != right.end()) {
         return true;
@@ -187,20 +196,26 @@ class LocalSyncInserter {
     return false;
   }
 
-  static void addOutputSmemTvs(const kir::Expr* expr, TvSet& set) {
+  void hasOutputSmemExpr(
+      Expr* expr,
+      std::unordered_set<const TensorView*>& set) {
     for (auto out : expr->outputs()) {
-      if (auto tv = dynamic_cast<kir::TensorView*>(out)) {
-        if (tv->memoryType() == MemoryType::Shared) {
+      if (ir_utils::isTV(out)) {
+        auto tv = out->as<TensorView>();
+        if (tv->getMemoryType() == MemoryType::Shared) {
           set.insert(tv);
         }
       }
     }
   }
 
-  static void addInputSmemTvs(const kir::Expr* expr, TvSet& set) {
-    for (auto in : expr->inputs()) {
-      if (auto tv = dynamic_cast<kir::TensorView*>(in)) {
-        if (tv->memoryType() == MemoryType::Shared) {
+  void hasInputSmemExpr(
+      Expr* expr,
+      std::unordered_set<const TensorView*>& set) {
+    for (auto inp : expr->inputs()) {
+      if (ir_utils::isTV(inp)) {
+        auto tv = inp->as<TensorView>();
+        if (tv->getMemoryType() == MemoryType::Shared) {
           set.insert(tv);
         }
       }
@@ -208,373 +223,41 @@ class LocalSyncInserter {
   }
 
  private:
+  // Track TensorViews for Allocate nodes that alias another memory location
+  std::unordered_set<unsigned int> all_alias_allocations_;
+
   // Track Shared Memory Inputs (Reads) for parent for-loop
-  TvSet all_smem_inputs_;
+  std::unordered_set<const TensorView*> all_smem_inputs_;
 
   // Track Shared Memory Outputs (Writes) for parent for-loop
-  TvSet all_smem_outputs_;
+  std::unordered_set<const TensorView*> all_smem_outputs_;
 
   // Shared Memory Writes at beginning of the for-loop
   // before first SyncThreads
-  TvSet initial_;
+  std::unordered_set<const TensorView*> initial_;
 
   // Shared Memory Reads at end of the for-loop
   // Cleared after each SyncThreads
-  TvSet final_;
+  std::unordered_set<const TensorView*> final_;
 
-  // Track first sync deterministically found in for-loop. Even when a
-  // child loop has a sync, if it may not be executed due to non-zero
-  // start value, this flag remains false.
+  // Track first sync found in for-loop
   bool initial_sync_ = false;
 
-  // Track if last op is sync
-  bool is_last_op_sync_ = false;
-};
-
-class ExprFlattener : private kir::IrVisitor {
- private:
-  void handle(kir::Expr* expr) {
-    if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
-      expr->accept(this);
-    } else {
-      exprs_.push_back(expr);
-    }
-  }
-
-  void visit(const kir::ForLoop* fl) final {
-    for (auto expr : fl->body().exprs()) {
-      handle(expr);
-    }
-  }
-
-  void visit(const kir::IfThenElse* ite) final {
-    for (auto expr : ite->thenBody().exprs()) {
-      handle(expr);
-    }
-    for (auto expr : ite->elseBody().exprs()) {
-      handle(expr);
-    }
-  }
-
- private:
-  std::vector<kir::Expr*> exprs_;
-
- public:
-  //! Flattens scopes extracting out a single ordered list of exprs.
-  static std::vector<kir::Expr*> flatten(
-      const std::vector<kir::Expr*>& loop_nests) {
-    ExprFlattener flattener;
-    for (auto expr : loop_nests) {
-      flattener.handle(expr);
-    }
-    return flattener.exprs_;
-  }
-};
-
-class ValidatePlacementAfterWrites : private kir::IrVisitor {
- public:
-  //! Validate no expr in writes found under loop
-  static void validate(
-      kir::ForLoop* loop,
-      const std::unordered_set<kir::Expr*>& writes) {
-    ValidatePlacementAfterWrites validator(writes);
-    validator.handle(loop);
-  }
-
- private:
-  ValidatePlacementAfterWrites(const std::unordered_set<kir::Expr*>& writes)
-      : writes_(writes) {}
-
-  void handle(kir::Expr* expr) {
-    if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
-      expr->accept(this);
-    } else {
-      TORCH_INTERNAL_ASSERT(
-          writes_.find(expr) == writes_.end(),
-          "Block sync must be placed after ",
-          kir::toString(expr));
-    }
-  }
-
-  void visit(const kir::ForLoop* fl) final {
-    for (auto expr : fl->body().exprs()) {
-      handle(expr);
-    }
-  }
-
-  void visit(const kir::IfThenElse* ite) final {
-    for (auto expr : ite->thenBody().exprs()) {
-      handle(expr);
-    }
-    for (auto expr : ite->elseBody().exprs()) {
-      handle(expr);
-    }
-  }
-
- private:
-  const std::unordered_set<kir::Expr*>& writes_;
-};
-
-class ReadAfterWriteSyncs : public kir::MutableIrVisitor {
- private:
-  //! Traverse up the loop stack from loops_it and if a halo loop is
-  //! found, place a given sync expr before the outer-most halo loop.
-  bool insertBeforeHaloLoop(
-      std::vector<kir::ForLoop*>::iterator loops_it,
-      kir::Sync* sync_expr,
-      const std::unordered_set<kir::Expr*>& writes) {
-    std::vector<kir::ForLoop*>::iterator halo_loop_it;
-    bool halo_loop_found = false;
-
-    while (true) {
-      if ((*loops_it)->iter_domain()->isThreadDim() &&
-          (*loops_it)->iter_domain()->extent() != (*loops_it)->stop()) {
-        halo_loop_found = true;
-        halo_loop_it = loops_it;
-      }
-
-      if (loops_it == for_loops_.begin()) {
-        break;
-      }
-      --loops_it;
-    }
-
-    // No halo loop found. Do not place the sync expr here. Return
-    // false to indicate nothing is done.
-    if (!halo_loop_found) {
-      return false;
-    }
-
-    auto halo_loop = *halo_loop_it;
-
-    // Make sure there's no write to the smem buffer inside the halo
-    // loop. syncthreads is moved before the halo loop, so having
-    // writes inside the loop invalidates the consistency.
-    ValidatePlacementAfterWrites::validate(halo_loop, writes);
-
-    if (halo_loop_it == for_loops_.begin()) {
-      // place in global scope
-      auto place_before_it =
-          std::find(loop_nests_.begin(), loop_nests_.end(), halo_loop);
-      TORCH_INTERNAL_ASSERT(place_before_it != loop_nests_.end());
-      loop_nests_.insert(place_before_it, sync_expr);
-    } else {
-      auto place_in = *(halo_loop_it - 1);
-      place_in->body().insert_before(halo_loop, sync_expr);
-    }
-
-    return true;
-  }
-
-  void handle(kir::Expr* expr) {
-    if (!ir_utils::isTVOp(expr) || expr->isA<kir::Allocate>()) {
-      expr->accept(this);
-      return;
-    }
-
-    if (sync_after_.size() > 0 && sync_after_.front() == expr) {
-      sync_after_.pop_front();
-      auto last_writes = last_writes_.front();
-      last_writes_.pop_front();
-      // Found that a sync is needed
-      TORCH_INTERNAL_ASSERT(expr->outputs()[0]->isA<kir::TensorView>());
-      auto out_tv = expr->outputs()[0]->as<kir::TensorView>();
-
-      // Find where a sync needs to be inserted
-      // This is very similar to how allocations are placed, simply place sync
-      // after the expression instead of placing like allocation where it goes
-      // before.
-      // TODO: This may be a common operation, could be worth making a utility
-      // out of or saving state for tensor view ID -> for loop
-      // TODO: Explicitly test the 3 cases below
-
-      kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-      auto sync_expr = ir_builder.create<kir::Sync>();
-      if (out_tv->fuserTv()->getComputeAtPosition() == 0) {
-        // Sync should be placed at global scope, after its outer most loop if
-        // it has one.
-        kir::Expr* place_after = for_loops_.size() > 0 ? for_loops_[0] : expr;
-        // Find location in loop_nests_
-        auto place_after_it =
-            std::find(loop_nests_.begin(), loop_nests_.end(), place_after);
-        TORCH_INTERNAL_ASSERT(
-            place_after_it != loop_nests_.end(),
-            "Could not figure out where to place synchronization. ",
-            "Tried to place after, ",
-            toString(place_after),
-            ", but could not find this expression at the global scope.");
-        loop_nests_.insert(place_after_it + 1, sync_expr);
-      } else {
-        // Find the last loop in computeAt of out_tv, this is the loop where we
-        // would place an allocation for out_tv
-        auto fuser_tv = out_tv->fuserTv();
-        auto lowered_local_id =
-            GpuLower::current()
-                ->lowerValue(fuser_tv->axis(
-                    (int)out_tv->fuserTv()->getComputeAtPosition() - 1))
-                ->as<kir::IterDomain>();
-
-        auto loops_it = std::find_if(
-            for_loops_.begin(),
-            for_loops_.end(),
-            [&lowered_local_id](const auto& loop) {
-              return GpuLower::current()->caLoopMap().areMapped(
-                         loop->iter_domain(), lowered_local_id) ||
-                  loop->iter_domain()->parallelType() == ParallelType::Unroll;
-            });
-
-        TORCH_INTERNAL_ASSERT(loops_it != for_loops_.end());
-
-        // block sync must be placed before halo-extended loops
-        if (insertBeforeHaloLoop(loops_it, sync_expr, last_writes)) {
-          return;
-        }
-
-        auto place_in = *loops_it;
-        kir::Expr* place_after = nullptr;
-
-        if (loops_it + 1 == for_loops_.end()) {
-          // Inline allocation, place after expr
-          place_after = expr;
-        } else {
-          // Place allocation after the last computeAt axis
-          // TODO: may be more efficient to place after the first non-computeAt
-          // axis
-          place_after = *(loops_it + 1);
-        }
-
-        place_in->body().insert_after(place_after, sync_expr);
-      }
-    }
-  }
-
-  void visit(kir::ForLoop* fl) final {
-    for_loops_.push_back(fl);
-    // Modifying in place, make a copy of the vector
-    const std::vector<kir::Expr*> exprs = fl->body().exprs();
-    for (auto expr : exprs) {
-      handle(expr);
-    }
-    for_loops_.pop_back();
-  }
-
-  void visit(kir::IfThenElse*) final {
-    TORCH_INTERNAL_ASSERT(
-        false,
-        "Pass does not support conditional statements, ",
-        "this pass should be run before any conditionals are placed in code.");
-  }
-
-  // Clear the modify status for all shared memory buffers
-  static void cleanSharedMemory(
-      std::unordered_map<kir::Val*, kir::Expr*>& smem) {
-    smem.clear();
-  }
-
-  // Return a set of expressions that modify shared-memory
-  // tensors. Expressions are excluded when syncthreads are already
-  // placed.
-  std::unordered_set<kir::Expr*> isModifiedSharedMemory(
-      const std::unordered_map<kir::Val*, kir::Expr*>& smem,
-      const std::vector<kir::Val*>& tvs) const {
-    std::unordered_set<kir::Expr*> last_writes;
-    for (auto tv : tvs) {
-      auto it = smem.find(tv);
-      if (it != smem.end()) {
-        last_writes.insert(it->second);
-      }
-    }
-    return last_writes;
-  }
-
-  ReadAfterWriteSyncs(std::vector<kir::Expr*> _loop_nests)
-      : loop_nests_(std::move(_loop_nests)) {
-    // Fusion shared_memory values
-    // Tracks if shared memory is modified
-    std::unordered_map<kir::Val*, kir::Expr*> smem;
-
-    // Flatten all the expressions
-    auto flattened_exprs = ExprFlattener::flatten(loop_nests_);
-
-    kir::Expr* prev_tv_expr = nullptr;
-    for (auto expr : flattened_exprs) {
-      if (!ir_utils::isTVOp(expr) || expr->isA<kir::Allocate>()) {
-        continue;
-      }
-
-      auto last_writes = isModifiedSharedMemory(smem, expr->inputs());
-      if (!last_writes.empty()) {
-        TORCH_INTERNAL_ASSERT(
-            prev_tv_expr != nullptr,
-            "Can't require sync on inputs, however, detected it's needed.");
-        sync_after_.push_back(prev_tv_expr);
-        last_writes_.push_back(last_writes);
-        cleanSharedMemory(smem);
-      }
-
-      for (auto out : expr->outputs()) {
-        if (out->isA<kir::TensorView>()) {
-          if (out->as<kir::TensorView>()->memoryType() == MemoryType::Shared) {
-            smem[out] = expr;
-          }
-        }
-      }
-
-      prev_tv_expr = expr;
-    }
-
-    // Insert read after write syncs
-    const std::vector<kir::Expr*> exprs = loop_nests_;
-    for (auto expr : exprs) {
-      handle(expr);
-    }
-
-    TORCH_INTERNAL_ASSERT(
-        sync_after_.empty(), "Didn't place all required syncs.");
-  }
-
- private:
-  //! Keep track of expressions that must be followed by syncthreads
-  std::deque<kir::Expr*> sync_after_;
-
-  //! Keep track of write expressions that must be placed before
-  //! syncthreads.
-  //!
-  //! syncthreads is placed after for each expression of
-  //! sync_after_. However, if it's inside a loop with halo, it must
-  //! be placed before that. last_writes_ keeps track of expressions
-  //! modifying the smem buffer each syncthreads is used for so that
-  //! it is not placed before those write expressions.
-  std::deque<std::unordered_set<kir::Expr*>> last_writes_;
-
-  //! Keep track of for loops while inserting syncthreads
-  std::vector<kir::ForLoop*> for_loops_;
-
-  //! Loop-nests where syncthreads are inserted
-  std::vector<kir::Expr*> loop_nests_;
-
- public:
-  static std::vector<kir::Expr*> insert(
-      const std::vector<kir::Expr*>& loop_nests) {
-    ReadAfterWriteSyncs inserter(loop_nests);
-    return inserter.loop_nests_;
-  }
+  // Track sync was inserted for war hazard
+  bool has_war_hazard_sync_ = false;
 };
 
 } // namespace
 
-std::vector<kir::Expr*> insertRawThreadSynchronization(
-    const std::vector<kir::Expr*>& exprs) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::insertRawThreadSynchronization");
-  return ReadAfterWriteSyncs::insert(exprs);
-}
-
-std::vector<kir::Expr*> insertWarThreadSynchronization(
-    const std::vector<kir::Expr*>& exprs) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::insertWarThreadSynchronization");
-  LocalSyncInserter::insertSyncs(exprs);
+std::vector<Expr*> insertThreadSynchronization(
+    Fusion* fusion,
+    const std::vector<Expr*>& exprs) {
+  FUSER_PERF_SCOPE("insertThreadSynchronization");
+  FusionGuard fg(fusion);
+  LocalSyncInserter::InsertSyncs(exprs);
   return exprs;
 }
+
 } // namespace cuda
 } // namespace fuser
 } // namespace jit
index 7a95434..82fab23 100644 (file)
@@ -2,8 +2,8 @@
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
 
 #include <vector>
 
@@ -13,7 +13,6 @@ namespace fuser {
 namespace cuda {
 
 //! Insert sync at end of for-loops to prevent write-after-read race condition.
-//!
 //! WAR race condition occurs when the next iteration of the loop overwrites
 //! shared memory value before a previous operation has finished reading it.
 //!
@@ -44,12 +43,9 @@ namespace cuda {
 //! If Child - End and Parent has zero remaining operations, then
 //! Parent inherits Child End.
 //!
-std::vector<kir::Expr*> insertWarThreadSynchronization(
-    const std::vector<kir::Expr*>& exprs);
-
-//! Insert syncs between writing to shared memory and then reading it.
-std::vector<kir::Expr*> insertRawThreadSynchronization(
-    const std::vector<kir::Expr*>& exprs);
+std::vector<Expr*> insertThreadSynchronization(
+    Fusion* fusion,
+    const std::vector<Expr*>& exprs);
 
 } // namespace cuda
 } // namespace fuser
index 316a531..4e90d54 100644 (file)
@@ -5,14 +5,11 @@
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
 #include <torch/csrc/jit/codegen/cuda/ir_utils.h>
 #include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
 #include <torch/csrc/jit/codegen/cuda/lower_utils.h>
 #include <torch/csrc/jit/codegen/cuda/transform_replay.h>
 
 #include <algorithm>
-#include <deque>
 #include <numeric>
 
 namespace torch {
@@ -20,91 +17,229 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-std::vector<kir::Expr*> LoopNestGenerator::loweredExprs(
-    const std::vector<Expr*>& exprs) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::LoopNestGenerator::loweredExprs");
-  TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr);
-  LoopNestGenerator generator(exprs);
-  return generator.lowered_exprs_;
-}
-
-LoopNestGenerator::LoopNestGenerator(const std::vector<Expr*>& exprs) {
+LoopNestGenerator::LoopNestGenerator(
+    Fusion* fusion,
+    ThreadPredicateMap& thread_predicates,
+    const std::vector<Expr*>& exprs)
+    : fusion_(fusion),
+      thread_predicates_(thread_predicates),
+      ir_builder_(GpuLower::current()->kernel()) {
   generate(exprs);
 }
 
-namespace {
+// Create, place, and return the allocation for tv
+Expr* LoopNestGenerator::pushAlloc(TensorView* tv) {
+  TORCH_INTERNAL_ASSERT(
+      !(FusionGuard::getCurFusion()->hasInput(tv) ||
+        FusionGuard::getCurFusion()->hasOutput(tv)),
+      "Tried to allocate an input or output tensor.");
+
+  const auto alloc_point = loop_utils::getAllocPoint(tv, for_loops);
+  const auto alloc_loop = alloc_point.first;
+  const auto alloc_pos = alloc_point.second;
+
+  // Grab the dimensions the allocation will be based on to compute a size
+  std::vector<Val*> alloc_dims;
+  for (size_t i = alloc_pos; i < tv->nDims(); i++) {
+    IterDomain* compute_at_dim = tv->getComputeAtAxis(i).first;
+    IterDomain* local_dim = tv->axis(i);
+    if (
+        // If shared memory, don't use any IDs bound to a grid dimension
+        (tv->memory_type_ == MemoryType::Shared &&
+         compute_at_dim->isBlockDim()) ||
+        // If local memory, don't use any IDs bound to a grid or block dimension
+        (tv->memory_type_ == MemoryType::Local && compute_at_dim->isThread()) ||
+        // If we're reducing this dimension, don't use it in the allocation
+        // computation
+        local_dim->isReduction() ||
+        // If this is a broadcast dimension, don't use it in the allocation
+        // computation
+        local_dim->isBroadcast()) {
+      continue;
+    }
+    alloc_dims.push_back(compute_at_dim->rawExtent());
+  }
 
-kir::ForLoop* openForHelper(kir::ForLoop* scope, IterDomain* id) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-  const auto kir_id = gpu_lower->lowerValue(id)->as<kir::IterDomain>();
-  auto extent_with_halo = gpu_lower->haloInfo().getExtent(kir_id);
-  kir::ForLoop* new_scope = nullptr;
-  if (extent_with_halo) {
-    // When an axis is extended with halo, unrolling and vectorization
-    // are assumed to not be used for now.
-    TORCH_INTERNAL_ASSERT(
-        id->getParallelType() != ParallelType::Unroll &&
-        !isParallelTypeVectorize(id->getParallelType()));
-    // Use the extent that's extended by halo
-    new_scope = ir_builder.create<kir::ForLoop>(
-        kir_id,
-        id->isBroadcast() ? ir_builder.zeroVal()
-                          : ir_builder.create<kir::Int>(c10::nullopt),
-        nullptr,
-        extent_with_halo,
-        nullptr,
-        false,
-        nullptr);
+  // Multiply all the dimensions we're going to use for the allocation together
+  // to get the total size
+  Val* size = nullptr;
+  if (alloc_dims.size() == 0) {
+    size = ir_builder_.create<kir::Int>(1);
   } else {
-    new_scope = ir_builder.create<kir::ForLoop>(kir_id);
+    size = GpuLower::lowerValue(alloc_dims[0]);
+    for (const auto i : c10::irange(1, alloc_dims.size())) {
+      size = ir_builder_.mulExpr(size, GpuLower::lowerValue(alloc_dims[i]));
+    }
   }
-  if (scope != nullptr) {
-    scope->body().insert(0, new_scope);
+
+  // Create the allocation node
+  const auto lowered_tv = ir_builder_.create<kir::TensorView>(tv);
+  const auto alloc = ir_builder_.create<kir::Allocate>(
+      lowered_tv, lowered_tv->memoryType(), size);
+
+  // Track Dynamic Shared Memory Allocation Nodes
+  if (tv->getMemoryType() == MemoryType::Shared) {
+    if (!size->isConstScalar()) {
+      dynamic_smem_.push_front(alloc);
+      return nullptr;
+    }
   }
-  return new_scope;
-}
 
-} // namespace
+  // Place the allocation
+  if (alloc_loop != nullptr) {
+    alloc_loop->body().insert(for_loop_allocations_[alloc_loop], alloc);
+    ++for_loop_allocations_[alloc_loop];
+  } else {
+    lowered_exprs.insert(lowered_exprs.begin(), alloc);
+  }
 
-void LoopNestGenerator::openFor(IterDomain* iter_domain) {
-  if (for_loops_.size() > 0) {
-    const auto new_scope = openForHelper(for_loops_.back(), iter_domain);
-    // for_loop_allocations_.insert({new_scope, 0});
-    for_loops_.push_back(new_scope);
+  return alloc;
+}
+
+void LoopNestGenerator::openFor(std::pair<IterDomain*, TensorView*> id_pair) {
+  compute_at_scope.push_back(id_pair);
+  IterDomain* id = id_pair.first;
+  if (for_loops.size() > 0) {
+    kir::ForLoop* new_scope = scope_utils::openFor(for_loops.back(), id);
+    for_loop_allocations_.insert({new_scope, 0});
+    for_loops.push_back(new_scope);
   } else {
-    for_loops_.push_back(openForHelper(nullptr, iter_domain));
-    lowered_exprs_.insert(lowered_exprs_.begin(), for_loops_.back());
+    for_loops.push_back(scope_utils::openFor(nullptr, id));
+    lowered_exprs.push_back(for_loops.back());
   }
 }
 
-void LoopNestGenerator::closeFor() {
-  TORCH_INTERNAL_ASSERT(!for_loops_.empty());
-  for_loops_.pop_back();
+void LoopNestGenerator::popFor() {
+  TORCH_INTERNAL_ASSERT(
+      !for_loops.empty() && !compute_at_scope.empty(),
+      "Can't pop for loop, scope is empty.");
+  for_loops.pop_back();
+  compute_at_scope.pop_back();
 }
 
-void LoopNestGenerator::pushFront(kir::Expr* expr) {
-  if (for_loops_.size() == 0) {
-    lowered_exprs_.insert(lowered_exprs_.begin(), expr);
+void LoopNestGenerator::pushBack(Expr* expr) {
+  if (for_loops.size() == 0) {
+    lowered_exprs.push_back(expr);
   } else {
-    for_loops_.back()->body().insert(0, expr);
+    scope_utils::pushBack(for_loops.back(), expr);
   }
 }
 
-void LoopNestGenerator::handle(Expr* expr) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
+// Update for loop structure based on this TensorView, if there's an allocation
+// stmt, send it in so we can make sure that we insert this initialization after
+// it
+void LoopNestGenerator::initReduction(
+    TensorView* tv,
+    Val* init_val,
+    Expr* alloc_expr) {
+  auto alloc_point = loop_utils::getAllocPoint(tv, for_loops);
+  auto alloc_loop = alloc_point.first;
+  auto alloc_pos = alloc_point.second;
+
+  // Grab the IDs that will be involved in the initialization, ignore reduction
+  // dimensions. Everything else will be iterated over to cover the entire
+  // buffer. Index compute will ignore [block, grid]Dims depending on buffer
+  // memory location
+  std::vector<kir::IterDomain*> ids;
+  for (size_t i = alloc_pos; i < tv->nDims(); i++) {
+    IterDomain* dim = tv->getComputeAtAxis(i).first;
+    if (dim->isReduction())
+      continue;
+    ids.push_back(GpuLower::lowerValue(dim)->as<kir::IterDomain>());
+  }
+
+  // Unsafe clone, as we want an exact replica of tv so we can create a UnaryOp
+  // to set the buffer to the init_val.
+  auto clone = tv->unsafeClone();
+  thread_predicates_.duplicate(clone, tv);
+  // The initilization stmt that will be located inside the loop nest (if there
+  // is one)
+  auto init_stmt = new UnaryOp(UnaryOpType::Set, clone, init_val);
+
+  // Init a pointer that will become the entirety of the initialization
+  Expr* init_loop_nest = nullptr;
+
+  // The for loop that we will place the initialization within (alloc_pos - 1),
+  // if one exists. Once we're done this inner_fl will be the inner most loop
+  // containing the init_stmt
+  kir::ForLoop* inner_fl = nullptr;
+  if (alloc_pos >= 1)
+    inner_fl = for_loops[alloc_pos - 1];
+
+  // Work through the iter domains that we need to initialize on, outside to
+  // inside, to construct the loop nest for the initialization.
+  for (auto id : ids) {
+    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+    kir::ForLoop* new_fl;
+
+    if (id->isThread()) {
+      // If based on a thread, make sure we get the named Int right
+      std::stringstream ss;
+      ss << id->getParallelType();
+      new_fl = ir_builder_.create<kir::ForLoop>(
+          ir_builder_.create<kir::NamedScalar>(ss.str(), DataType::Int),
+          id,
+          inner_fl);
+    } else {
+      // Otherwise it's just a new int-
+      new_fl = ir_builder_.create<kir::ForLoop>(
+          ir_builder_.create<kir::Int>(c10::nullopt), id, inner_fl);
+    }
+    for_loop_allocations_.insert({new_fl, 0});
+
+    if (init_loop_nest == nullptr) {
+      // If this is our first generated loop, then it will be our outer most
+      // loop nest
+      init_loop_nest = new_fl;
+    } else {
+      // Otherwise place it inside the last generated loop
+      inner_fl->body().push_back(new_fl);
+    }
+    // Increment the inner most for loop
+    inner_fl = new_fl;
+  }
+
+  if (init_loop_nest == nullptr) {
+    // If no loops were generated, than our init_stmt is all we need
+    init_loop_nest = init_stmt;
+  } else {
+    // If there were for loops generated, place the init_stmt in the inner most
+    // for loop.
+    inner_fl->body().push_back(init_stmt);
+  }
+
+  // If we don't have an alloc_loop defined it means it needs to go in
+  // lowered_exprs. Make sure to place after the allocation of what we're
+  // initializing if there is one.
+  if (alloc_loop == nullptr) {
+    if (alloc_expr != nullptr) {
+      auto it =
+          std::find(lowered_exprs.begin(), lowered_exprs.end(), alloc_expr);
+      TORCH_INTERNAL_ASSERT(
+          it != lowered_exprs.end(),
+          "Could not figure out where to initialize the buffer for ",
+          tv);
+      lowered_exprs.insert(it + 1, init_loop_nest);
+    } else {
+      lowered_exprs.insert(lowered_exprs.begin(), init_loop_nest);
+    }
+  } else {
+    if (alloc_expr != nullptr) {
+      // If there is an allocation for this TensorView
+      // place this loop nest after it
+      alloc_loop->body().insert_after(alloc_expr, init_loop_nest);
+      ++for_loop_allocations_[alloc_loop];
+    } else {
+      // Otherwise we're allocating a global value
+      alloc_loop->body().insert(0, init_loop_nest);
+    }
+  }
+}
 
+void LoopNestGenerator::handle(Expr* expr) {
   // Check if it's a tensor view expression we need to place in the loop nest
   // structure
   if (!ir_utils::isTVOp(expr)) {
-    // Close all the loops, scalar operations cannot be inside for loops based
-    // on expr sorting.
-    while (!for_loops_.empty()) {
-      closeFor();
-    }
-    pushFront(gpu_lower->lowerExpr(expr));
-
     for (auto out : expr->outputs()) {
       TORCH_INTERNAL_ASSERT(
           out->getValType().value() == ValType::Scalar,
@@ -113,131 +248,557 @@ void LoopNestGenerator::handle(Expr* expr) {
           " cannot lower ",
           out->getValType().value());
 
-      pushFront(ir_builder.create<kir::Allocate>(
-          gpu_lower->lowerValue(out),
+      pushBack(ir_builder_.create<kir::Allocate>(
+          GpuLower::lowerValue(out),
           MemoryType::Local,
-          ir_builder.create<kir::Int>(1)));
+          ir_builder_.create<kir::Int>(1)));
     }
+    pushBack(expr);
     return;
   }
 
-  TensorView* out_tv = expr->output(0)->as<TensorView>();
+  //  0) Apply SyncThreads if any shared memory inputs are modified
+  bool shared_memory_sync = false;
+  for (auto in : expr->inputs()) {
+    shared_memory_sync |= isModifiedSharedMemory(in);
+  }
+  if (shared_memory_sync) {
+    TORCH_INTERNAL_ASSERT(!for_loops.empty(), "Attempted to add SyncThreads");
+    // push Sync to the back of the last for loop
+    scope_utils::pushBack(for_loops.back(), ir_builder_.create<kir::Sync>());
+    cleanSharedMemory();
+  }
+
+  TensorView* out = expr->output(0)->as<TensorView>();
 
   // Figure out what the entire loop structure should look like.
-  std::deque<IterDomain*> loop_structure;
-
-  // Fill the entire loop structure by Looking at each axis
-  // individually in out's domain
-  for (size_t out_i = 0; out_i < out_tv->nDims(); out_i++) {
-    // Note: It is not safe to skip trivial reduction axes as they could be
-    // inlined with other tensor views. This happens in
-    // NVFuserTest.FusionBNRepro_CUDA as of this commit on norm_hack_2_rebased
-    // branch
-
-    // Look up the concrete ID in the parallel map, not in the loop
-    // map, which also maps non-CA axes.
-    auto concrete_id =
-        gpu_lower->caParallelMap().getConcreteMappedID(out_tv->axis(out_i));
-    loop_structure.push_back(concrete_id);
-  }
-
-  auto loop_structure_it = loop_structure.begin();
-  auto for_loop_it = for_loops_.begin();
-  auto last_for_loop_matched = for_loops_.begin();
-
-  // Match the loop structure with the current for-loops. Reuse
-  // matching loops and close unmatched ones.
-  while (loop_structure_it != loop_structure.end() &&
-         for_loop_it != for_loops_.end()) {
-    auto lowered_out_id =
-        gpu_lower->lowerValue(*loop_structure_it)->as<kir::IterDomain>();
-    // Similar to the above, the parallel map is used rather than the
-    // loop map. Again, non-CA axes should not share loops, so the
-    // parallel map should be used.
-    if (gpu_lower->caParallelMap().areMapped(
-            lowered_out_id, (*for_loop_it)->iter_domain())) {
-      loop_structure_it++;
-      last_for_loop_matched = ++for_loop_it;
+  std::deque<std::pair<IterDomain*, TensorView*>> loop_structure;
+
+  // As we go through iteration domains track the previous view
+  TensorView* last_ca_view = nullptr;
+  // Check where in the previous view our last axis was in that view
+  int64_t last_ca_view_ind = 0;
+
+  // Look at each axis individually in out's domain
+  for (int64_t out_i = 0; out_i < (int64_t)out->getThisComputeAtAxis();
+       out_i++) {
+    // Grab the axis information
+    auto ca_point = out->getComputeAtAxis(out_i);
+    auto ca_view = ca_point.second;
+    auto ca_id = ca_point.first;
+
+    // Figure out if there are axes in the compute at tensor view that aren't
+    // in out, make sure to also open them. Check where to start looking for
+    // them in the compute at view.
+    size_t start = 0;
+    if (last_ca_view == nullptr) {
+      // Start at the begining, we haven't processed any axes yet.
+      start = 0;
+    } else if (last_ca_view == ca_view) {
+      // This view is the same as the last axis, so start where we left off.
+      start = last_ca_view_ind + 1;
     } else {
-      ++for_loop_it;
+      // This is a new view, figure out where we are in it, and start from there
+      for (start = 0; start < ca_view->nDims(); start++) {
+        if (loop_structure.back().first ==
+            ca_view->getComputeAtAxis(start).first) {
+          break;
+        }
+      }
+      start++;
+    }
+
+    // Go from start, and open all loops in the computeAt view until we hit the
+    // one associated with out->getComputeAtAxis(out_i)
+    for (size_t ca_i = start; ca_i < ca_view->nDims(); ca_i++) {
+      // Note that ca_view->getComputeAtAxis(ca_i) is equivalent to
+      // std::pair(ca_view->axis(ca_i), ca_view)
+      loop_structure.push_back(ca_view->getComputeAtAxis(ca_i));
+
+      // Update the last view processed
+      last_ca_view_ind = ca_i;
+      last_ca_view = ca_view;
+      if (ca_view->getComputeAtAxis(ca_i).first == ca_id) {
+        break;
+      }
+    }
+
+    // Shouldn't ever hit this, but make sure we hit the break above, meaning we
+    // added all necessary axes from the compute at view.
+    TORCH_INTERNAL_ASSERT(
+        ca_view->getComputeAtAxis(last_ca_view_ind).first == ca_id);
+  }
+
+  // We're up to the compute at point in loop_structure, grab the remaining
+  // axes.
+  for (int64_t out_i = (int64_t)out->getThisComputeAtAxis();
+       out_i < (int64_t)out->nDims();
+       out_i++) {
+    // It's actually local, but getComputeAtAxis returns a std::pair, axis
+    // doesn't
+    loop_structure.push_back(out->getComputeAtAxis(out_i));
+  }
+
+  // At this point loop_structure contains our overal target loop nest structure
+  // Lets get a copy of the loop structure, and figure out which loops we need
+  // to open.
+  decltype(loop_structure) loops_to_open(loop_structure);
+  // Pop out loops already opened
+  for (const auto& existing_loop : for_loops) {
+    if (loops_to_open.empty()) {
+      // Nothing to open
+      break;
+    }
+    if (GpuLower::lowerValue(loops_to_open.front().first)
+            ->as<kir::IterDomain>() == existing_loop->iter_domain()) {
+      loops_to_open.pop_front();
+    }
+  }
+
+  // At this point for_loops + loops_to_open contains our overal target loop
+  // nest structure. Open loops in "loops_to_open".
+  while (!loops_to_open.empty()) {
+    openFor(loops_to_open.front());
+    loops_to_open.pop_front();
+  }
+
+  Expr* alloc_expr = nullptr;
+  // Place the allocation for out
+  if (!FusionGuard::getCurFusion()->hasInput(out) &&
+      !FusionGuard::getCurFusion()->hasOutput(out)) {
+    alloc_expr = pushAlloc(out);
+  }
+
+  //  If this is a reduction, initialize the output (open for loops to inner
+  //  most, predicate, initialize, place next after allocation if exists, close
+  //  to computeAt)
+  if (out->hasReduction()) {
+    initReduction(out, expr->as<ReductionOp>()->init(), alloc_expr);
+  }
+
+  //  Place the expression
+  pushBack(expr);
+
+  // If output is a shared memory buffer, set modified status
+  modifySharedMemory(out);
+
+  // Reduce the loop nest structure back to computeAt
+  if (out->getThisComputeAtAxis() == 0) {
+    while (!for_loops.empty()) {
+      popFor();
+    }
+  } else {
+    auto ca_axis = out->getThisComputeAtAxis() - 1;
+    while (for_loops.size() > 0 &&
+           for_loops.back()->iter_domain() !=
+               GpuLower::lowerValue(out->getComputeAtAxis(ca_axis).first)
+                   ->as<kir::IterDomain>()) {
+      popFor();
+    }
+  }
+}
+
+namespace {
+
+TensorView* findOutputTensor(Expr* expr) {
+  TORCH_INTERNAL_ASSERT(
+      expr->outputs().size() <= 1, "Unexpected number of outputs");
+  if (expr->outputs().size() != 1) {
+    return nullptr;
+  }
+  auto out = expr->output(0);
+  if (out->getValType() != ValType::TensorView) {
+    return nullptr;
+  }
+  return out->as<TensorView>();
+}
+
+void findTargetTensor(Expr* expr, TensorView*& target, unsigned& score) {
+  TORCH_INTERNAL_ASSERT(expr->outputs().size() <= 1);
+
+  TensorView* out_tv = findOutputTensor(expr);
+  if (out_tv == nullptr) {
+    target = nullptr;
+    score = 0;
+    return;
+  }
+
+  if (!out_tv->hasComputeAt()) {
+    target = out_tv;
+    // No computeAt, so this should come last.
+    score = std::numeric_limits<unsigned>::max();
+    return;
+  }
+
+  auto axis = out_tv->getRelativeComputeAtAxis();
+  target = out_tv->getComputeAtView();
+  while (target->hasComputeAt()) {
+    if (target->getThisComputeAtAxis() < axis) {
+      break;
     }
+    axis = target->getComputeAtRelPos(axis);
+    target = target->getComputeAtView();
   }
 
-  auto n_loops_to_close =
-      std::distance(last_for_loop_matched, for_loops_.end());
+  score = axis;
+}
 
+// Type definitions for brevity
+using ExprListT = std::vector<Expr*>;
+using TargetGroupMapT = std::unordered_map<TensorView*, ExprListT>;
+using ExprTargetMapT = std::unordered_map<Expr*, TensorView*>;
+using ScoreT = unsigned;
+using ExprScoreMapT = std::unordered_map<const Expr*, ScoreT>;
+
+void sanityCheck(
+    const ExprListT& exprs,
+    const ExprListT& reordered_exprs,
+    const ExprScoreMapT& scores,
+    const ExprTargetMapT& target_map,
+    const TargetGroupMapT& computed_at_exprs) {
+  const auto num_exprs = exprs.size();
+  TORCH_INTERNAL_ASSERT(scores.size() == num_exprs);
   TORCH_INTERNAL_ASSERT(
-      n_loops_to_close >= 0 &&
-          n_loops_to_close <= (std::ptrdiff_t)for_loops_.size(),
-      "Tried to close an invalid number of loops: ",
-      n_loops_to_close);
-
-  if (max_close < n_loops_to_close && max_close > 0) {
-    // Figure out where the last for loop matches from out_tv, go until the
-    // max_close loop marked from previous tv's producer domain. Make sure
-    // none of these domains are actually present in current out_tv. If these
-    // loops map to current out_tv, it should be responsible for deciding if
-    // they stay or go, this could result from an invalid compute at topology
-    // on the DAG or bad expression sorting.
-    auto for_loops_it = for_loops_.end() - n_loops_to_close;
-    auto for_loops_it_end = for_loops_.end() - max_close;
-
-    for (; for_loops_it != for_loops_it_end; for_loops_it++) {
+      reordered_exprs.size() + target_map.size() == num_exprs);
+  int num_computed_exprs = std::accumulate(
+      computed_at_exprs.begin(),
+      computed_at_exprs.end(),
+      0,
+      [](int acc, const std::pair<TensorView*, ExprListT>& p) {
+        return acc + p.second.size();
+      });
+  TORCH_INTERNAL_ASSERT(num_computed_exprs == (int)target_map.size());
+}
+
+// Arrange exprs into loop-nest groups. Loop-nest groups are
+// disjoint grouping of expressions based on the expression
+// where each expression is computed at.
+void groupExpressions(
+    Expr* expr,
+    ExprListT& reordered_exprs,
+    ExprTargetMapT& target_map,
+    TargetGroupMapT& computed_at_exprs,
+    ExprScoreMapT& scores) {
+  TensorView* target_tensor = nullptr;
+  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+  ScoreT score;
+  findTargetTensor(expr, target_tensor, score);
+  scores.emplace(expr, score);
+  if (target_tensor == nullptr) {
+    reordered_exprs.push_back(expr);
+  } else {
+    target_map.emplace(expr, target_tensor);
+    if (computed_at_exprs.find(target_tensor) == computed_at_exprs.end()) {
+      computed_at_exprs.emplace(target_tensor, TargetGroupMapT::mapped_type());
+    }
+    auto& exprs = computed_at_exprs[target_tensor];
+    exprs.push_back(expr);
+  }
+}
+
+// Sort each loop-nest group based on axis (i.e., score)
+void sortGroup(ExprListT& exprs, ExprScoreMapT& scores) {
+  std::stable_sort(
+      exprs.begin(),
+      exprs.end(),
+      [&scores](const Expr* expr1, const Expr* expr2) {
+        return scores[expr1] < scores[expr2];
+      });
+}
+
+// If an expression is missing from expr_status, search for all ancestors
+// that are necessary for the expression
+void mapMissingInputsToAncestors(
+    const TensorView* tv,
+    const std::unordered_map<const Expr*, bool>& expr_status,
+    std::vector<const TensorView*>& ancestors) {
+  const Expr* expr = tv->getOrigin();
+  const auto& expr_inputs = ir_utils::filterByType<TensorView>(expr->inputs());
+  for (auto input : expr_inputs) {
+    const Expr* input_origin = input->getOrigin();
+    if (input_origin != nullptr) {
+      if (expr_status.find(input_origin) == expr_status.end()) {
+        mapMissingInputsToAncestors(input, expr_status, ancestors);
+      } else {
+        ancestors.push_back(input);
+      }
+    }
+  }
+}
+
+// For each expression, find all TensorView inputs.
+// If an input TensorView is missing from expr_status,
+// find that input's ancestors that are present in expr_status.
+std::unordered_map<const Expr*, std::vector<const TensorView*>> findExprTvInputs(
+    const std::unordered_map<const Expr*, bool>& expr_status) {
+  std::unordered_map<const Expr*, std::vector<const TensorView*>>
+      map_expr_to_tv_inputs;
+
+  // Iterate over all exprs and filter missing expr
+  for (auto item : expr_status) {
+    const auto expr = item.first;
+    const auto& expr_inputs =
+        ir_utils::filterByType<TensorView>(expr->inputs());
+
+    map_expr_to_tv_inputs.insert({expr, std::vector<const TensorView*>()});
+    auto& tv_inputs = map_expr_to_tv_inputs[expr];
+
+    for (auto input : expr_inputs) {
+      const Expr* input_origin = input->getOrigin();
+      bool missing_input = input_origin != nullptr &&
+          expr_status.find(input_origin) == expr_status.end();
+
+      if (missing_input) {
+        // Map missing input to ancestor that is present in exprs_status
+        std::vector<const TensorView*> ancestors;
+        mapMissingInputsToAncestors(input, expr_status, ancestors);
+        tv_inputs.insert(tv_inputs.begin(), ancestors.begin(), ancestors.end());
+      } else {
+        tv_inputs.push_back(input);
+      }
+    }
+  }
+  return map_expr_to_tv_inputs;
+}
+
+// Reorder expressions that are computed at the same position in a
+// breadth-first order.
+void reorderSegmentBreadthFirst(
+    ExprListT::iterator seg_begin,
+    ExprListT::const_iterator seg_end) {
+  // mapping of each expression to a bool flag indicating if it's
+  // already been visited
+  std::unordered_map<const Expr*, bool> expr_status;
+  for (auto it = seg_begin; it != seg_end; ++it) {
+    expr_status.insert({*it, false});
+  }
+
+  // Holds all input TVs necessary for every expression.
+  const auto map_expr_to_tv_inputs = findExprTvInputs(expr_status);
+
+  while (seg_begin != seg_end) {
+    std::vector<const Expr*> visited_exprs;
+    for (auto it = seg_begin; it != seg_end; ++it) {
+      const auto expr = *it;
+      const auto& expr_inputs = map_expr_to_tv_inputs.at(expr);
+
+      // if all input expressions are visited
+      // then expr can be visited
+      const bool ready_to_visit = std::all_of(
+          expr_inputs.begin(),
+          expr_inputs.end(),
+          [&expr_status](const TensorView* input) {
+            const Expr* input_origin = input->getOrigin();
+            return input_origin == nullptr ||
+                (expr_status.find(input_origin) != expr_status.end() &&
+                 expr_status.at(input_origin));
+          });
+      if (ready_to_visit) {
+        std::iter_swap(seg_begin, it);
+        TORCH_INTERNAL_ASSERT(*seg_begin == expr);
+        ++seg_begin;
+        visited_exprs.push_back(expr);
+      }
+    }
+    for (const auto& visited_expr : visited_exprs) {
+      expr_status.at(visited_expr) = true;
+    }
+  }
+}
+
+// Reorder expressions in a group in a breadth-first order. Reordering
+// is done within a subset of expressions that have the same score
+// (i.e., computeAt position). For each subset,
+// reorderSegmentBreadthFirst is called.
+void reorderGroupBreadthFirst(ExprListT& exprs, const ExprScoreMapT& scores) {
+  auto seg_begin = exprs.begin();
+  auto seg_end = exprs.begin();
+  ScoreT seg_score = scores.at(*seg_begin);
+  while (seg_end != exprs.end()) {
+    const auto expr = *seg_end;
+    const auto cur_score = scores.at(expr);
+    if (seg_score == cur_score) {
+      // advance further
+      ++seg_end;
+      continue;
+    } else if (seg_score < cur_score) {
+      // segment ended
+      reorderSegmentBreadthFirst(seg_begin, seg_end);
+      seg_begin = seg_end;
+      seg_score = cur_score;
+    } else {
+      // exprs list is assumed to be sorted in the order of scores, so
+      // this should never be reachable
       TORCH_INTERNAL_ASSERT(
-          std::none_of(
-              loop_structure_it,
-              loop_structure.end(),
-              [&gpu_lower, &for_loops_it](IterDomain* loop_structure_id) {
-                // Check loop structure doesn't map for_loops in for loop map
-                auto id0 = (*for_loops_it)->iter_domain();
-                auto id1 = gpu_lower->lowerValue(loop_structure_id)
-                               ->as<kir::IterDomain>();
-                return gpu_lower->caLoopMap().areMapped(id0, id1);
-              }),
-          "Invalid loop found to close.");
+          false, "Unexpected expression: ", expr, ", score: ", cur_score);
     }
+  }
+  reorderSegmentBreadthFirst(seg_begin, seg_end);
+}
 
-    n_loops_to_close = std::min(n_loops_to_close, max_close);
+void mergeNonRootGroupsIntoRootGroups(
+    TargetGroupMapT& computed_at_exprs,
+    ExprTargetMapT& target_map) {
+  for (auto it = computed_at_exprs.begin(); it != computed_at_exprs.end();) {
+    TensorView* target = it->first;
+    if (target->hasComputeAt()) {
+      Expr* target_expr = target->getOrigin();
+      TensorView* target_of_target = target_map.at(target_expr);
+      auto& target_group = computed_at_exprs.at(target_of_target);
+      auto pos =
+          std::find(target_group.begin(), target_group.end(), target_expr);
+      TORCH_INTERNAL_ASSERT(pos != target_group.end());
+      target_group.insert(pos, it->second.begin(), it->second.end());
+      // Update the target map
+      for (auto& inserted_expr : it->second) {
+        TORCH_INTERNAL_ASSERT(target_map.at(inserted_expr) == target);
+        target_map.at(inserted_expr) = target_of_target;
+      }
+      it = computed_at_exprs.erase(it);
+    } else {
+      ++it;
+    }
   }
+}
 
-  for (int64_t i_loop_close = 0; i_loop_close < n_loops_to_close;
-       i_loop_close++) {
-    closeFor();
+// Merge root loop-nests into reordered_exprs
+void mergeGroupsIntoSortedList(
+    TargetGroupMapT& computed_at_exprs,
+    ExprListT& reordered_exprs) {
+  while (computed_at_exprs.size() > 0) {
+    // Find the root loop-nest that has no dependency with the other
+    // loop-nests
+    TensorView* cur_target = computed_at_exprs.begin()->first;
+    for (auto& group : computed_at_exprs) {
+      auto target = group.first;
+      if (cur_target == target)
+        continue;
+      if (DependencyCheck::isDependencyOf(target, cur_target)) {
+        cur_target = target;
+      }
+    }
+    // cur_target can be visited
+    reordered_exprs.insert(
+        reordered_exprs.end(),
+        computed_at_exprs.at(cur_target).begin(),
+        computed_at_exprs.at(cur_target).end());
+    computed_at_exprs.erase(cur_target);
   }
+}
 
-  // Open the remaining needed loops
-  for (; loop_structure_it != loop_structure.end(); ++loop_structure_it) {
-    openFor(*loop_structure_it);
+// Reorder exprs so that LoopNestGenerator::handle(Expr*) can generate
+// correct loop nests. Vector exprs is assumed to be topologically
+// sorted, but that is not sufficient as tensors computed at
+// outer loops need to be located earlier.
+void reorderExprsForComputeAt(std::vector<Expr*>& exprs) {
+  ExprListT reordered_exprs;
+
+  // expr -> target
+  ExprTargetMapT target_map;
+
+  // target -> [computed at expressions]
+  TargetGroupMapT computed_at_exprs;
+
+  // score of each expression that is calculated based on the
+  // computeAt axis. A lower score of an expression means it should be
+  // placed earlier in the expression list. This is a requirement for
+  // the loop-nest generation of this class to work.
+  ExprScoreMapT scores;
+
+  // 1. Group expressions by target tensors. Non-grouped expressions
+  // are copied into reordered_exprs.
+  for (auto& expr : exprs) {
+    groupExpressions(
+        expr, reordered_exprs, target_map, computed_at_exprs, scores);
   }
 
-  if (out_tv->getMaxProducerPosition() == 0) {
-    max_close = -1;
-  } else {
-    auto produce_at_id = loop_structure[out_tv->getMaxProducerPosition() - 1];
-    auto max_close_loop = std::find_if(
-        for_loops_.begin(),
-        for_loops_.end(),
-        [&produce_at_id, &gpu_lower](kir::ForLoop* fl) {
-          auto produce_at_lowered_it =
-              gpu_lower->lowerValue(produce_at_id)->as<kir::IterDomain>();
-          return gpu_lower->caParallelMap().areMapped(
-              produce_at_lowered_it, fl->iter_domain());
-        });
+  sanityCheck(exprs, reordered_exprs, scores, target_map, computed_at_exprs);
 
-    max_close = std::distance(max_close_loop, for_loops_.end());
-    max_close = max_close > 0 ? max_close - 1 : max_close;
+  // If no computeAt found, no need to reorder.
+  if (computed_at_exprs.size() == 0) {
+    return;
   }
-  pushFront(gpu_lower->lowerExpr(expr));
+
+  // 2. Sort each loop-nest group based on axis (i.e., score)
+  for (auto& group : computed_at_exprs) {
+    sortGroup(group.second, scores);
+
+    // Reorder expressions in a breadth-first order
+    reorderGroupBreadthFirst(group.second, scores);
+  }
+
+  // 3. Merge non-root loop-nests into root loop-nests
+  mergeNonRootGroupsIntoRootGroups(computed_at_exprs, target_map);
+
+  // At this point, only root loop-nests (i.e., no computeAt'ed)
+  // should exist.
+  for (auto& group : computed_at_exprs) {
+    // Guarantee only root loop-nests exist.
+    TensorView* target = group.first;
+    TORCH_INTERNAL_ASSERT(!target->hasComputeAt());
+  }
+
+  sanityCheck(exprs, reordered_exprs, scores, target_map, computed_at_exprs);
+
+  mergeGroupsIntoSortedList(computed_at_exprs, reordered_exprs);
+
+  // Reordering completed. Reordered exprs exist in reordered_exprs.
+
+  TORCH_INTERNAL_ASSERT(exprs.size() == reordered_exprs.size());
+  exprs = std::move(reordered_exprs);
 }
 
-// Generate the loop nest structure and place it in lowered_exprs_
+} // namespace
+
+// Generate the loop nest structure and place it in lowered_exprs
 void LoopNestGenerator::generate(const std::vector<Expr*>& exprs) {
-  TORCH_INTERNAL_ASSERT(lowered_exprs_.empty());
+  FusionGuard fg(fusion_);
+
+  // Identify all shared memory TensorViews
+  // Insert into shared_memory map <tv, modify status>
+  for (auto v : fusion_->vals()) {
+    if (v->getValType().value() == ValType::TensorView) {
+      if (v->as<TensorView>()->getMemoryType() == MemoryType::Shared) {
+        smem_.insert({v, false});
+      }
+    }
+  }
+
+  // Initialize members of the class
+  lowered_exprs = std::vector<Expr*>();
+
+  auto reordered = exprs;
+  reorderExprsForComputeAt(reordered);
+
+  for (auto* expr : reordered) {
+    handle(expr);
+  }
+
+  // Insert Dynamic Shared Memory at beginning of kernel
+  for (auto smem_alloc : dynamic_smem_) {
+    lowered_exprs.insert(lowered_exprs.begin(), smem_alloc);
+  }
+}
+
+void LoopNestGenerator::cleanSharedMemory() {
+  for (auto& item : smem_) {
+    item.second = false;
+  }
+}
+
+void LoopNestGenerator::modifySharedMemory(Val* key) {
+  auto it = smem_.find(key);
+  if (it != smem_.end()) {
+    it->second = true;
+  }
+}
 
-  // Process the carefully ordered expressions
-  for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) {
-    handle(*it);
+bool LoopNestGenerator::isModifiedSharedMemory(Val* key) const {
+  auto it = smem_.find(key);
+  if (it != smem_.end()) {
+    return it->second;
   }
+  return false;
 }
 
 } // namespace cuda
index 2786141..db7859c 100644 (file)
@@ -1,12 +1,10 @@
-
 #pragma once
-
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
-#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
 #include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
 
@@ -15,51 +13,107 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-//! Loop nest generator pass will get IR that looks something like:
-//! T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...* for( i :
-//! I0o{ceil(I0/4)} ) { and will generate the loop nest structure for these
-//! exprs like:
-//!
-//! for( i : I0o{ceil(I0/4)} ) {
-//!   for( j : I1o{ceil(I1/128)} ) {
-//!     for( k : I0i{4} )
-//!       for( l : I1i{128} )
-//!         T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
-//!
-//! It does not generate predicates, but it will generate allocations, and loop
-//! nests to initialize reduction buffers.
-class TORCH_CUDA_CU_API LoopNestGenerator {
+/*
+ * Loop nest generator pass will get IR that looks something like:
+ * T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...* for( i :
+ * I0o{ceil(I0/4)} ) { and will generate the loop nest structure for these exprs
+ * like:
+ *
+ * for( i : I0o{ceil(I0/4)} ) {
+ *   for( j : I1o{ceil(I1/128)} ) {
+ *     for( k : I0i{4} )
+ *       for( l : I1i{128} )
+ *         T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
+ *
+ * It does not generate predicates, but it will generate allocations, and loop
+ * nests to initialize reduction buffers.
+ *
+ */
+class TORCH_CUDA_CU_API LoopNestGenerator : public OptOutDispatch {
  public:
-  static std::vector<kir::Expr*> loweredExprs(const std::vector<Expr*>& exprs);
+  static std::vector<Expr*> loweredExprs(
+      Fusion* fusion,
+      ThreadPredicateMap& thread_predicates,
+      const std::vector<Expr*>& exprs) {
+    FUSER_PERF_SCOPE("LoopNestGenerator::loweredExprs");
+    LoopNestGenerator generator(fusion, thread_predicates, exprs);
+    return generator.lowered_exprs;
+  }
 
  private:
-  LoopNestGenerator(const std::vector<Expr*>& exprs);
+  LoopNestGenerator(
+      Fusion* fusion,
+      ThreadPredicateMap& thread_predicates,
+      const std::vector<Expr*>& exprs);
+
+  // Create the allocation for tv, place it inside the loop associated with
+  // alloc_id, return the node
+  Expr* pushAlloc(TensorView*);
+
+  // Fusion shared_memory values
+  // Tracks if shared memory is modified
+  std::unordered_map<Val*, bool> smem_;
+
+  // Track dynamic shared memory buffers
+  // Insert allocation at the beginning of the kernel
+  std::deque<kir::Allocate*> dynamic_smem_;
+
+  // Clear the modify status for all shared memory buffers
+  void cleanSharedMemory();
+
+  // Toggle modify status for this shared memory buffer
+  void modifySharedMemory(Val* key);
+
+  // Return the status of the shared memory buffer
+  // False if TensorView is not shared memory buffer
+  bool isModifiedSharedMemory(Val* key) const;
 
   // Open a new inner most for loop, track which TV it was constructed from
   // according to the computeAt chain.
-  void openFor(IterDomain*);
+  void openFor(std::pair<IterDomain*, TensorView*>);
 
   // Close the inner most for loop
-  void closeFor();
+  void popFor();
 
-  // Appends an expression to the current scope
-  void pushFront(kir::Expr* expr);
+  // Wrap pushBack in lower_utils if active_scope is null we want it to go
+  // straight to lower_exprs
+  void pushBack(Expr*);
 
-  void handle(Expr* expr);
+  // Initialize a buffer to init_val. If this buffer is in smem or registers,
+  // pass in its allocation statement so we can make sure that we insert this
+  // initialization after the allocation.
+  void initReduction(TensorView* tv, Val* init_val, Expr* alloc_expr = nullptr);
 
-  // Run the pass and accumulate output in lowered_exprs_
+  // Check if expr is a TV op and handle accordingly.
+  void handle(Expr*) final;
+
+  // Run the pass and accumulate output in lowered_exprs
   void generate(const std::vector<Expr*>& exprs);
 
  private:
+  // Track number of allocations in each for loop. It is used to insert
+  // allocations in the correct order, which is necessary for memory aliasing
+  std::unordered_map<kir::ForLoop*, size_t> for_loop_allocations_;
+
   // Lowered exprs to return
-  std::vector<kir::Expr*> lowered_exprs_;
+  std::vector<Expr*> lowered_exprs;
+
+  // Fusion pointer for convenience
+  Fusion* fusion_;
 
   // Keep all for loops conveniently to make unrolling easier, basically just a
   // stack of the active for_loops
-  std::vector<kir::ForLoop*> for_loops_;
+  std::vector<kir::ForLoop*> for_loops;
+
+  // Track the active computeAt scope, and what view we're "computeAt-ing" into
+  std::vector<std::pair<IterDomain*, TensorView*>> compute_at_scope;
+
+  // Predicates from ThreadPredicates that we will extend to reduction buffer
+  // initialization
+  ThreadPredicateMap& thread_predicates_;
 
-  // How many loops can the next iteration close
-  std::ptrdiff_t max_close = -1;
+  // Kernel IR builder
+  kir::IrBuilder ir_builder_;
 };
 
 } // namespace cuda
diff --git a/torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp b/torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp
deleted file mode 100644 (file)
index 449ea1f..0000000
+++ /dev/null
@@ -1,133 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
-
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-class MagicZeroInserter : public kir::MutableIrVisitor {
- public:
-  static std::vector<kir::Expr*> insert(const std::vector<kir::Expr*>& exprs) {
-    MagicZeroInserter inserter(exprs);
-    return inserter.loop_nests_;
-  }
-
- private:
-  struct InsertionInfo {
-    kir::Scope* scope = nullptr;
-    kir::ForLoop* fl = nullptr;
-  };
-
-  MagicZeroInserter(const std::vector<kir::Expr*>& exprs)
-      : loop_nests_(exprs), ir_builder(GpuLower::current()->kernel()) {
-    loop_nests_.insert(
-        loop_nests_.begin(), ir_builder.create<kir::InitMagicZero>());
-    for (auto expr : exprs) {
-      handle(expr);
-    }
-    insertAll();
-  }
-
-  void handle(kir::Expr* expr) {
-    if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
-      handle(ite);
-    } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
-      handle(for_loop);
-    }
-  }
-
-  void handle(kir::IfThenElse* ite) {
-    scope_nest_.push_back(&ite->thenBody());
-    for (auto expr : ite->thenBody().exprs()) {
-      handle(expr);
-    }
-    scope_nest_.pop_back();
-    scope_nest_.push_back(&ite->elseBody());
-    for (auto expr : ite->elseBody().exprs()) {
-      handle(expr);
-    }
-    scope_nest_.pop_back();
-  }
-
-  void handle(kir::ForLoop* fl) {
-    if (fl->isUnrollable()) {
-      kir::Scope* scope = nullptr;
-      if (!scope_nest_.empty()) {
-        scope = scope_nest_.back();
-      }
-      insertion_list_.push_back({scope, fl});
-    } else {
-      scope_nest_.push_back(&fl->body());
-      for (auto expr : fl->body().exprs()) {
-        handle(expr);
-      }
-      scope_nest_.pop_back();
-    }
-  }
-
-  void insertAll() {
-    for (const auto& info : insertion_list_) {
-      auto fl = info.fl;
-      auto scope = info.scope;
-      if (scope == nullptr) {
-        // place in global scope
-        auto loop_it = std::find(loop_nests_.begin(), loop_nests_.end(), fl);
-        TORCH_INTERNAL_ASSERT(loop_it != loop_nests_.end());
-        // Place after the loop
-        loop_it++;
-        loop_nests_.insert(loop_it, ir_builder.create<kir::UpdateMagicZero>());
-      } else {
-        scope->insert_after(fl, ir_builder.create<kir::UpdateMagicZero>());
-      }
-    }
-  }
-
-  //! Keep track for loop structure
-  std::vector<kir::Scope*> scope_nest_;
-
-  // Keep a copy of the expressions provided
-  std::vector<kir::Expr*> loop_nests_;
-
-  kir::IrBuilder ir_builder;
-
-  std::vector<InsertionInfo> insertion_list_;
-};
-
-} // namespace
-
-std::vector<kir::Expr*> insertMagicZero(const std::vector<kir::Expr*>& exprs) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::insertMagicZero");
-  // Check if magic zero was even used, if not we don't have to define it or
-  // update it.
-  bool has_magic_zero = false;
-  const auto gpu_lower = GpuLower::current();
-  auto kernel = gpu_lower->kernel();
-  for (auto& val : kernel->irNodes()) {
-    if (val->isA<kir::NamedScalar>()) {
-      auto named_scalar = val->as<kir::NamedScalar>();
-      if (named_scalar->dtype() == DataType::Int &&
-          named_scalar->name() == "nvfuser_zero") {
-        has_magic_zero = true;
-        break;
-      }
-    }
-  }
-
-  if (!has_magic_zero) {
-    return exprs;
-  }
-
-  return MagicZeroInserter::insert(exprs);
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_magic_zero.h b/torch/csrc/jit/codegen/cuda/lower_magic_zero.h
deleted file mode 100644 (file)
index 1ccf466..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Insert magic zero definition at the begining of the kernel. Insert magic
-//! zero update after every (outer most) loop nest with a compile time extent.
-//!
-//! This will make sure nvrtc does not aggressively save predicate and indices.
-std::vector<kir::Expr*> insertMagicZero(const std::vector<kir::Expr*>& exprs);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp
deleted file mode 100644 (file)
index 2404c68..0000000
+++ /dev/null
@@ -1,608 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h>
-
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-class MisalignedVectorizationModifier {
- public:
-  void process(const std::vector<kir::Expr*>& exprs) {
-    FUSER_PERF_SCOPE(
-        "GpuLower::Lower::MisalignedVectorizationModifier::process");
-    // Run through loop nests
-    // Find for-loops with misaligned vectorization domains
-    for (auto* expr : exprs) {
-      handle(expr);
-    }
-  }
-
-  const std::unordered_map<kir::Expr*, kir::Expr*>& replacementMap() const {
-    return expr_replacement_map_;
-  }
-
- private:
-  void handle(kir::Expr* expr) {
-    if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
-      handle(for_loop);
-    } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
-      handle(ite);
-    }
-  }
-
-  void handle(kir::ForLoop* fl) {
-    for_loops_structure_.push_back(fl);
-
-    // Make copy of exprs because we replace them inplace in fl
-    const auto exprs_copy = fl->body().exprs();
-
-    if (containsAnyDirectChildMisalignedVectorize(fl)) {
-      auto new_fl = handleMisalignedVectorize(for_loops_structure_, fl);
-      expr_replacement_map_.insert({fl, new_fl});
-    } else {
-      for (auto expr : exprs_copy) {
-        handle(expr);
-      }
-    }
-
-    for_loops_structure_.pop_back();
-  }
-
-  void handle(kir::IfThenElse* ite) {
-    for (auto expr : ite->thenBody().exprs()) {
-      handle(expr);
-    }
-    for (auto expr : ite->elseBody().exprs()) {
-      handle(expr);
-    }
-  }
-
-  struct ReferenceTensors {
-    // Input TensorView to Vectorize Set operation
-    kir::TensorView* in_tv = nullptr;
-    // Output TensorView to Vectorize Set operation
-    kir::TensorView* out_tv = nullptr;
-    // TensorView in global memory
-    kir::TensorView* global_tv = nullptr;
-    // TensorView with vectorize IterDomain and not in global memory
-    kir::TensorView* vec_tv = nullptr;
-  };
-
-  ReferenceTensors getReferenceTensors(kir::Expr* vectorized_expr) {
-    TORCH_INTERNAL_ASSERT(vectorized_expr != nullptr);
-    TORCH_INTERNAL_ASSERT(
-        vectorized_expr->outputs().front()->isA<kir::TensorView>());
-    TORCH_INTERNAL_ASSERT(
-        vectorized_expr->inputs().front()->isA<kir::TensorView>());
-
-    auto in_tv = vectorized_expr->inputs().front()->as<kir::TensorView>();
-    auto out_tv = vectorized_expr->outputs().front()->as<kir::TensorView>();
-
-    const bool global_vectorize_write_op =
-        (out_tv->memoryType() == MemoryType::Global &&
-         in_tv->memoryType() == MemoryType::Local);
-    const bool global_vectorize_read_op =
-        (out_tv->memoryType() == MemoryType::Local &&
-         in_tv->memoryType() == MemoryType::Global);
-    TORCH_INTERNAL_ASSERT(
-        global_vectorize_write_op || global_vectorize_read_op,
-        "Unsupported vectorize memory configuration detected.");
-
-    // TensorView on global memory. This is the tensor that may have
-    // a non-aligned base address.
-    auto global_tv =
-        (out_tv->memoryType() == MemoryType::Global) ? out_tv : in_tv;
-
-    // TensorView with the misaligned vec iterDomain. It is the consumer
-    // of vectorized load or the producer of vectorized store. It is
-    // assumed that when the output TV is not on global memory, this
-    // expression is a vectorized load, so the output TV is vec_tv.
-    auto vec_tv = (out_tv->memoryType() != MemoryType::Global) ? out_tv : in_tv;
-
-    return {in_tv, out_tv, global_tv, vec_tv};
-  }
-
-  struct VectorizeData {
-    kir::Val* vector_size = nullptr;
-    kir::Val* shift = nullptr;
-    kir::Val* extent = nullptr;
-    kir::Val* remainder = nullptr;
-    kir::Val* extent_minus_remainder = nullptr;
-    kir::Val* last_root_domain_index = nullptr;
-    kir::Val* last_root_domain_index_shift = nullptr;
-  };
-
-  // Create constants for handling misaligned addresses
-  VectorizeData createVectorizeConstants(
-      const std::vector<kir::ForLoop*>& for_loop_structure,
-      const ReferenceTensors& tensors,
-      kir::IfThenElse* parent_scope_ite) {
-    kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-
-    // Generate vectorize index
-    auto indices = (tensors.out_tv->memoryType() == MemoryType::Global)
-        ? Index::getConsumerStridedIndices(
-              tensors.out_tv->fuserTv(), for_loop_structure)
-        : Index::getProducerStridedIndices(
-              tensors.in_tv->fuserTv(),
-              tensors.out_tv->fuserTv(),
-              for_loop_structure);
-
-    // >>>>>>>>>>>>>
-    // Number of elements in vectorize access
-    auto vector_size =
-        tensors.vec_tv->domain()->domain().back()->extent()->as<kir::Int>();
-
-    // Size of memory type for the elements
-    kir::Int* data_size_in_bytes =
-        ir_builder.create<kir::Int>(dataTypeSize(tensors.vec_tv->dtype()));
-
-    // The number of bytes in the vectorize access
-    auto vector_size_in_bytes =
-        ir_builder.mulExpr(vector_size, data_size_in_bytes);
-
-    auto index = ir_builder.create<kir::TensorIndex>(
-        tensors.global_tv->fuserTv(), indices);
-    auto address = createNamedScalarFromValue(
-        parent_scope_ite->thenBody(), index, "address", true);
-
-    // offset_size = (address % vector_size_bytes) / data_type_size_bytes
-    // shift_init = vector_size - offset_size
-    auto a = ir_builder.modExpr(address, vector_size_in_bytes);
-    auto b = ir_builder.divExpr(a, data_size_in_bytes);
-    auto c = ir_builder.subExpr(vector_size, b);
-    auto shift_init = createNamedScalarFromValue(
-        parent_scope_ite->thenBody(), c, "shift_val");
-
-    // shift = (shift_init == vector_size) ? 0 : shift_init
-    // The number of elements until the first aligned address
-    auto shift_pred = ir_builder.eqExpr(shift_init, vector_size);
-    auto shift_val =
-        ir_builder.whereExpr(shift_pred, ir_builder.zeroVal(), shift_init);
-
-    // >>>>>>>>>>>>>
-    auto shift = createNamedScalarFromValue(
-        parent_scope_ite->thenBody(), shift_val, "shift");
-
-    // >>>>>>>>>>>>>
-    // Get full extent for the inner-most, merged root domain
-    auto extent = getVectorizeExtent(tensors.in_tv, tensors.out_tv);
-
-    // remainder = (extent - shift) % vector_size
-    // The number of elements remaining not accessed by vectorized operations
-    auto remaining_extent = ir_builder.subExpr(extent, shift);
-    auto remainder_val = ir_builder.modExpr(remaining_extent, vector_size);
-    auto remainder = createNamedScalarFromValue(
-        parent_scope_ite->thenBody(), remainder_val, "remainder");
-
-    // (extent - remainder) is the upper-bound for the vectorize section
-    auto extent_remainder_val = ir_builder.subExpr(extent, remainder);
-
-    // >>>>>>>>>>>>>
-    auto extent_minus_remainder = createNamedScalarFromValue(
-        parent_scope_ite->thenBody(),
-        extent_remainder_val,
-        "extent_minus_remainder");
-
-    // >>>>>>>>>>>>>
-    auto last_root_domain_index = createNamedScalarFromValue(
-        parent_scope_ite->thenBody(), indices.back(), "last_root_domain_index");
-
-    // >>>>>>>>>>>>>
-    auto last_root_domain_index_shift =
-        ir_builder.addExpr(last_root_domain_index, shift);
-
-    return {
-        vector_size,
-        shift,
-        extent,
-        remainder,
-        extent_minus_remainder,
-        last_root_domain_index,
-        last_root_domain_index_shift};
-  }
-
-  // Vectorized : [shift - (extent-remainder))
-  // From the first to the last aligned address
-  kir::IfThenElse* createVectorizeSection(
-      const std::vector<kir::ForLoop*>& child_loops,
-      const VectorizeData& params) {
-    kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-
-    auto vectorized_child_loops =
-        cloneForLoops(child_loops, params.vector_size, true, params.shift);
-
-    // Vectorize Range: [shift - (extent-remainder))
-    // (last_root_domain_index + shift) < (extent - remainder)
-    kir::Val* vectorize_cond = ir_builder.ltExpr(
-        params.last_root_domain_index_shift, params.extent_minus_remainder);
-
-    kir::Predicate* vectorize_pred =
-        ir_builder.create<kir::Predicate>(vectorize_cond->as<kir::Bool>());
-    kir::IfThenElse* vectorize_ite =
-        ir_builder.create<kir::IfThenElse>(vectorize_pred);
-
-    for (auto cloned_loop : vectorized_child_loops) {
-      vectorize_ite->thenBody().push_back(cloned_loop);
-    }
-
-    return vectorize_ite;
-  }
-
-  // Initial : [0 - shift)
-  // From the initial address until the first aligned address
-  kir::IfThenElse* createInitialSection(
-      const std::vector<kir::ForLoop*>& child_loops,
-      const VectorizeData& params) {
-    kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-
-    auto pre_child_loops =
-        cloneForLoops(child_loops, params.shift, false, nullptr);
-
-    // Initial Range: [0 - shift)
-    // last_root_domain_index == 0
-    kir::Val* initial_cond =
-        ir_builder.eqExpr(params.last_root_domain_index, ir_builder.zeroVal());
-
-    kir::Predicate* initial_pred =
-        ir_builder.create<kir::Predicate>(initial_cond->as<kir::Bool>());
-    kir::IfThenElse* initial_ite =
-        ir_builder.create<kir::IfThenElse>(initial_pred);
-
-    for (auto cloned_loop : pre_child_loops) {
-      initial_ite->thenBody().push_back(cloned_loop);
-    }
-
-    return initial_ite;
-  }
-
-  // Remainder : [(extent-remainder) - extent)
-  // From the last aligned address until the end of the extent
-  kir::IfThenElse* createRemainderSection(
-      const std::vector<kir::ForLoop*>& child_loops,
-      const VectorizeData& params) {
-    kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-
-    auto post_child_loops =
-        cloneForLoops(child_loops, params.remainder, false, params.shift);
-
-    // Remainder Range: [(extent-remainder) - extent)
-    // (extent - remainder) <= last_root_domain_index + shift < extent
-    kir::Val* lower_bound = ir_builder.geExpr(
-        params.last_root_domain_index_shift, params.extent_minus_remainder);
-    kir::Val* upper_bound =
-        ir_builder.ltExpr(params.last_root_domain_index_shift, params.extent);
-    kir::Val* remainder_cond = ir_builder.andExpr(lower_bound, upper_bound);
-
-    kir::Predicate* remainder_pred =
-        ir_builder.create<kir::Predicate>(remainder_cond->as<kir::Bool>());
-    kir::IfThenElse* remainder_ite =
-        ir_builder.create<kir::IfThenElse>(remainder_pred);
-
-    for (auto cloned_loop : post_child_loops) {
-      remainder_ite->thenBody().push_back(cloned_loop);
-    }
-
-    return remainder_ite;
-  }
-
-  kir::ForLoop* handleMisalignedVectorize(
-      std::vector<kir::ForLoop*> for_loop_structure,
-      const kir::ForLoop* parent_for_loop) {
-    kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-
-    auto child_loops = findChildForLoops(parent_for_loop);
-
-    // Assumption: All vectorize operations have the same shift
-    auto vectorized_expr =
-        findFirstVectorizedSetOp(for_loop_structure, child_loops);
-    TORCH_INTERNAL_ASSERT(vectorized_expr != nullptr);
-
-    auto reference_tensors = getReferenceTensors(vectorized_expr);
-
-    // The parent_for_loop contains allocate, read, compute, write operations
-    const auto new_parent_for_loop =
-        ir_builder.create<kir::ForLoop>(parent_for_loop);
-
-    // Transfer all expressions except for-loops to new parent for-loop
-    // All expressions are placed at the beginning of the new for-loop
-    moveExprsExceptForLoops(parent_for_loop, new_parent_for_loop);
-
-    // Get the predicate for all but the last root domain
-    auto pred_except_last_root_domain = ir_builder.create<kir::Predicate>(
-        PredicateType::Misaligned, vectorized_expr, ir_builder.trueVal());
-    kir::IfThenElse* pred_ite =
-        ir_builder.create<kir::IfThenElse>(pred_except_last_root_domain);
-    new_parent_for_loop->body().push_back(pred_ite);
-
-    auto constants = createVectorizeConstants(
-        for_loop_structure, reference_tensors, pred_ite);
-
-    // The last root domain is divided into three sections.
-    // | Initial - N/A Shift | Vectorize - Shift | Remainder - Shift |
-
-    // Vectorized set operation with vectorize shift
-    auto vectorize_ite = createVectorizeSection(child_loops, constants);
-    pred_ite->thenBody().push_back(vectorize_ite);
-
-    // Standard set operation without vectorize shift
-    auto initial_ite = createInitialSection(child_loops, constants);
-    pred_ite->thenBody().push_back(initial_ite);
-
-    // Standard set operation with vectorize shift
-    auto remainder_ite = createRemainderSection(child_loops, constants);
-    pred_ite->thenBody().push_back(remainder_ite);
-
-    return new_parent_for_loop;
-  }
-
-  // Determine that the expression is UnaryOpType::Set AND
-  // the output TensorView domain is vectorized
-  bool isVectorizeSetOp(kir::ForLoop* fl, kir::Expr* expr) {
-    if (fl->iter_domain()->parallelType() !=
-        ParallelType::MisalignedVectorize) {
-      return false;
-    }
-
-    if (expr->isA<kir::UnaryOp>()) {
-      auto unaryOp = expr->as<kir::UnaryOp>();
-      if (unaryOp->out()->isA<kir::TensorView>()) {
-        auto out_tv = unaryOp->out()->as<kir::TensorView>();
-        return unaryOp->operation() == UnaryOpType::Set &&
-            out_tv->domain()->hasVectorize();
-      }
-    }
-    return false;
-  }
-
-  // Clone each for loop
-  // stop value - for (index = start; index < stop; index += step)
-  // vectorize flag - Do not generate for loop header
-  // shift value - Add shift to global indices generated within for loop
-  std::vector<kir::ForLoop*> cloneForLoops(
-      const std::vector<kir::ForLoop*>& for_loops,
-      kir::Val* stop,
-      bool vectorize,
-      kir::Val* vectorize_shift) {
-    kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-    std::vector<kir::ForLoop*> cloned_for_loops;
-
-    for (auto fl : for_loops) {
-      auto first_expr = fl->body().exprs().front();
-      bool has_vectorize_op = isVectorizeSetOp(fl, first_expr);
-
-      // If the for loop contains a vectorize Set operation, then
-      // it should only contain a single expression
-      TORCH_INTERNAL_ASSERT(
-          !has_vectorize_op || fl->body().exprs().size() == 1);
-
-      const auto new_loop = ir_builder.create<kir::ForLoop>(
-          fl->iter_domain(),
-          fl->index(),
-          ir_builder.zeroVal(),
-          stop,
-          ir_builder.oneVal(),
-          vectorize && has_vectorize_op,
-          vectorize_shift);
-
-      for (auto expr : fl->body().exprs()) {
-        new_loop->body().push_back(expr);
-      }
-
-      cloned_for_loops.push_back(new_loop);
-    }
-    return cloned_for_loops;
-  }
-
-  // Add all expressions except for loops to new parent for loop
-  void moveExprsExceptForLoops(
-      const kir::ForLoop* for_loop,
-      kir::ForLoop* new_loop) {
-    std::vector<kir::ForLoop*> loops;
-    for (auto expr : for_loop->body().exprs()) {
-      if (!expr->isA<kir::ForLoop>()) {
-        new_loop->body().push_back(expr);
-      }
-    }
-  }
-
-  // Find any child for loops inside parent for loop
-  std::vector<kir::ForLoop*> findChildForLoops(const kir::ForLoop* for_loop) {
-    std::vector<kir::ForLoop*> loops;
-    for (auto expr : for_loop->body().exprs()) {
-      if (auto nested_for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
-        loops.push_back(nested_for_loop);
-      }
-    }
-    return loops;
-  }
-
-  // Find the first vectorize set - either read or write
-  // Add child For-Loop to for_loop_structure
-  // Enable vectorize flag in child For-Loop
-  kir::Expr* findFirstVectorizedSetOp(
-      std::vector<kir::ForLoop*>& for_loop_structure,
-      const std::vector<kir::ForLoop*>& for_loops) {
-    for (auto fl : for_loops) {
-      auto first_expr = fl->body().exprs().front();
-      bool has_vectorize_op = isVectorizeSetOp(fl, first_expr);
-      if (has_vectorize_op) {
-        for_loop_structure.push_back(fl);
-        return first_expr;
-      }
-    }
-    return nullptr;
-  }
-
-  // Get full extent for the inner-most, merged root domain
-  kir::Val* getVectorizeExtent(
-      kir::TensorView* producer_tv,
-      kir::TensorView* consumer_tv) {
-    const auto gpu_lower = GpuLower::current();
-    kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-    auto consumer_fuser_tv = consumer_tv->fuserTv();
-    auto producer_fuser_tv = producer_tv->fuserTv();
-
-    auto p2c =
-        PairwiseRootDomainMap(producer_fuser_tv, consumer_fuser_tv)
-            .mapProducerToConsumer(
-                producer_fuser_tv->domain(), consumer_fuser_tv->domain());
-
-    auto consumer_root_right_of_ca_domains = IterVisitor::getInputsTo(
-        {consumer_fuser_tv->domain()->domain().begin() +
-             consumer_fuser_tv->getComputeAtPosition(),
-         consumer_fuser_tv->domain()->domain().end()});
-    auto producer_root_right_of_ca_domains = IterVisitor::getInputsTo(
-        {producer_fuser_tv->domain()->domain().begin() +
-             producer_fuser_tv->getComputeAtPosition(),
-         producer_fuser_tv->domain()->domain().end()});
-
-    const auto& consumer_contig = consumer_fuser_tv->domain()->contiguity();
-    const auto& producer_contig = producer_fuser_tv->domain()->contiguity();
-
-    // No rfactor should exist in the producer TVs
-    TORCH_INTERNAL_ASSERT(
-        !producer_tv->domain()->hasRFactor(),
-        "Invalid producer tensor: ",
-        producer_fuser_tv);
-    auto producer_root_domain = producer_fuser_tv->getRootDomain();
-
-    // Calculate extent of merged root domains
-    kir::Val* extent = nullptr;
-    auto consumer_root_idx = int(consumer_fuser_tv->getRootDomain().size()) - 1;
-    for (int i = int(producer_root_domain.size()) - 1; i >= 0; --i) {
-      auto producer_root_id = producer_root_domain.at(i);
-
-      TORCH_INTERNAL_ASSERT(
-          !gpu_lower->trivialReductionInfo().isDerived(producer_root_id),
-          "No trivial reduciton axis should exist: ",
-          producer_root_id);
-
-      // If the producer ID is reduction or broadcast, it should be safe
-      // to ignore.
-      if (producer_root_id->isReduction()) {
-        continue;
-      } else if (producer_root_id->isBroadcast()) {
-        --consumer_root_idx;
-        continue;
-      }
-
-      // There must be a matching consumer root ID as the producer ID is
-      // not reduction and the expression between them is UnaryOpType::Set.
-      auto it = p2c.find(producer_root_id);
-      TORCH_INTERNAL_ASSERT(
-          it != p2c.end(), "No matching consumer root ID found");
-      auto consumer_root_id = it->second;
-
-      // Don't extend the vectorization domain beyond the CA position
-      if (std::find(
-              consumer_root_right_of_ca_domains.begin(),
-              consumer_root_right_of_ca_domains.end(),
-              consumer_root_id) == consumer_root_right_of_ca_domains.end() ||
-          std::find(
-              producer_root_right_of_ca_domains.begin(),
-              producer_root_right_of_ca_domains.end(),
-              producer_root_id) == producer_root_right_of_ca_domains.end()) {
-        break;
-      }
-
-      // We now know it's safe to extend the vectorization domain to these
-      // axes. It shouldn't matter whether producer or consumer is used.
-      auto consumer_extent = gpu_lower->lowerValue(consumer_root_id->extent());
-      if (extent == nullptr) {
-        extent = consumer_extent;
-      } else {
-        extent = ir_builder.mulExpr(extent, consumer_extent);
-      }
-
-      // If it's not contiguous, extending the vectorization domain
-      // further is not possible
-      if (!(producer_contig.at(i) && consumer_contig.at(consumer_root_idx))) {
-        break;
-      }
-
-      --consumer_root_idx;
-    }
-
-    TORCH_INTERNAL_ASSERT(extent != nullptr);
-
-    return extent;
-  }
-
-  kir::Val* createNamedScalarFromValue(
-      kir::Scope& body,
-      kir::Val* val,
-      const std::string& name,
-      bool address = false) {
-    kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-    auto namedScalar = (address) ? ir_builder.addressExprNamedScalar(name, val)
-                                 : ir_builder.setExprNamedScalar(name, val);
-    TORCH_INTERNAL_ASSERT(namedScalar->definition() != nullptr);
-
-    auto alloc = ir_builder.create<kir::Allocate>(
-        namedScalar, MemoryType::Local, ir_builder.oneVal());
-    body.push_back(alloc);
-    body.push_back(namedScalar->definition());
-    return namedScalar;
-  }
-
- private:
-  // We will track which loops in the incoming IR will be replaced and by what
-  std::unordered_map<kir::Expr*, kir::Expr*> expr_replacement_map_;
-
-  // A depth-first ordering of nested for loops
-  // It is used for indexing and predicate generation
-  std::vector<kir::ForLoop*> for_loops_structure_;
-};
-
-} // namespace
-
-std::vector<kir::Expr*> processMisalignedVectorization(
-    Fusion* fusion,
-    const std::vector<kir::Expr*>& exprs) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::processMisalignedVectorization");
-
-  MisalignedVectorizationModifier mvm;
-  mvm.process(exprs);
-
-  std::vector<kir::Expr*> mutated_exprs;
-  mutated_exprs.reserve(exprs.size());
-  for (auto expr : exprs) {
-    mutated_exprs.push_back(
-        ir_utils::applyReplacements(mvm.replacementMap(), expr));
-  }
-
-  return mutated_exprs;
-}
-
-bool containsAnyDirectChildMisalignedVectorize(const kir::ForLoop* fl) {
-  for (auto expr : fl->body().exprs()) {
-    if (expr->isA<kir::ForLoop>()) {
-      auto child_fl = expr->as<kir::ForLoop>();
-      if (child_fl->iter_domain()->parallelType() ==
-          ParallelType::MisalignedVectorize) {
-        return true;
-      }
-    }
-  }
-  return false;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h
deleted file mode 100644 (file)
index db28adb..0000000
+++ /dev/null
@@ -1,118 +0,0 @@
-#pragma once
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Transform for-loop structure to handle misaligned addresses
-//!
-//! Sections of misaligned addresses are handled sequentially
-//! while aligned addresses use vectorized memory accesses.
-//!
-//! ---------------------------------------------------------------------------
-//! Before Misaligned Vectorization:
-//!
-//! Inputs: T0
-//! Outputs: T3
-//!
-//! for(...) {
-//!   T1[vector_size];
-//!   for( i : vector_size ) {
-//!     T1[i] = T0[...]
-//!   }
-//!
-//!   T2[vector_size];
-//!   for( i : vector_size ) {
-//!     T2[i] = unaryOp(T1[i])
-//!   }
-//!
-//!   for( i : vector_size ) {
-//!     T3[...] = T2[i]
-//!   }
-//! }
-//!
-//! ---------------------------------------------------------------------------
-//! After Misaligned Vectorization:
-//!
-//! Inputs: T0
-//! Outputs: T3
-//!
-//! for(...) {
-//!   T1[vector_size];
-//!   T2[vector_size];
-//!
-//!   if (inline_predicate_except_last_root_domain) {
-//!     index_except_last_root_domain = ...
-//!     address = (int64_t) &T1[index_except_last_root_domain]
-//!
-//!     offset_size = (address % vector_size_bytes) / data_type_size_bytes
-//!     shift_init = vector_size - offset_size
-//!     shift = (shift_init == vector_size) ? 0 : shift_init
-//!
-//!     // size of the last root domain
-//!     extent = ...
-//!     remainder = (extent - shift) % vector_size
-//!
-//!     last_root_domain_index = ...
-//!
-//!     // Vectorize Section
-//!     if ( (last_root_domain_index + shift) < (extent - remainder) ) {
-//!       T1[0] = vectorize_load( T0[index + shift] );
-//!
-//!       for( i : vector_size ) {
-//!         T2[i] = unaryOp(T1[i])
-//!       }
-//!
-//!       T3[index + shift] = vectorize_store( T2[0] );
-//!     }
-//!
-//!     // Initial Section
-//!     if ( last_root_domain_index == 0 ) {
-//!       for( i : shift ) {
-//!         T1[i] = T0[...]
-//!       }
-//!
-//!       for( i : shift ) {
-//!         T2[i] = unaryOp(T1[i])
-//!       }
-//!
-//!       for( i : shift ) {
-//!         T3[...] = T2[i]
-//!       }
-//!     }
-//!
-//!     // Remainder Section
-//!     if ( (last_root_domain_index + shift) >= (extent - remainder) &&
-//!          (last_root_domain_index + shift) < extent) {
-//!
-//!       for( i : remainder ) {
-//!         T1[i] = T0[index + shift]
-//!       }
-//!
-//!       for( i : remainder ) {
-//!         T2[i] = unaryOp(T1[i])
-//!       }
-//!
-//!       for( i : remainder ) {
-//!         T3[index + shift] = T2[i]
-//!       }
-//!     }
-//!   }
-//! }
-//!
-std::vector<kir::Expr*> processMisalignedVectorization(
-    Fusion* fusion,
-    const std::vector<kir::Expr*>& exprs);
-
-bool containsAnyDirectChildMisalignedVectorize(const kir::ForLoop* fl);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp
deleted file mode 100644 (file)
index ce95093..0000000
+++ /dev/null
@@ -1,596 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
-
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
-#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-class ConditionalFromPredicateModifier {
- public:
-  ConditionalFromPredicateModifier(const std::vector<kir::Expr*>& exprs) {
-    FUSER_PERF_SCOPE(
-        "GpuLower::Lower::ConditionalFromPredicateModifier::process");
-    for (auto* expr : exprs) {
-      handle(expr);
-    }
-  }
-
-  const std::unordered_map<kir::Expr*, kir::Expr*>& replacementMap() const {
-    return expr_replacement_map_;
-  }
-
- private:
-  void handle(kir::Expr* expr) {
-    if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
-      handle(for_loop);
-    } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
-      handle(ite);
-    } else if (expr != nullptr && expr->predicate() != nullptr) {
-      // Replace expr predicate with bool conditional
-      auto conditional = generateConditional(expr->predicate());
-      TORCH_INTERNAL_ASSERT(conditional != nullptr);
-      expr->predicate()->setValue(conditional);
-      TORCH_INTERNAL_ASSERT(expr->predicate()->value() != nullptr);
-      setWritePredicate(expr, conditional);
-    }
-  }
-
-  void setWritePredicate(kir::Expr* expr, kir::Bool* read_cond) {
-    if (expr->writePredicate() != nullptr) {
-      auto write_cond = generateConditional(expr->writePredicate());
-      if (write_cond) {
-        expr->writePredicate()->setValue(write_cond);
-      } else {
-        // If generateConditional returns null, it means no specific
-        // predicate needs to be used.
-        expr->setWritePredicate(nullptr);
-      }
-    }
-  }
-
-  void handle(kir::ForLoop* fl) {
-    for_loops_structure_.push_back(fl);
-
-    const auto exprs_copy = fl->body().exprs();
-    for (auto expr : exprs_copy) {
-      handle(expr);
-    }
-
-    for_loops_structure_.pop_back();
-  }
-
-  void handle(kir::IfThenElse* ite) {
-    TORCH_INTERNAL_ASSERT(ite->predicate() != nullptr);
-
-    // If ite already has Bool conditional, handle internal expressions
-    // Otherwise, generate conditional and update predicate
-    if (ite->predicate()->hasValue()) {
-      const auto then_exprs_copy = ite->thenBody().exprs();
-      for (auto expr : then_exprs_copy) {
-        handle(expr);
-      }
-
-      const auto else_exprs_copy = ite->elseBody().exprs();
-      for (auto expr : else_exprs_copy) {
-        handle(expr);
-      }
-    } else {
-      auto conditional = generateConditional(ite->predicate());
-      TORCH_INTERNAL_ASSERT(conditional != nullptr);
-      TORCH_INTERNAL_ASSERT(conditional->isA<kir::Bool>());
-
-      // Update bool conditional in-place
-      ite->predicate()->setValue(conditional);
-      handle(ite);
-      TORCH_INTERNAL_ASSERT(ite->predicate()->value() != nullptr);
-    }
-  }
-
-  // Generate conditional according to PredicateType
-  kir::Bool* generateConditional(kir::Predicate* pred) {
-    switch (pred->predicate_type()) {
-      case PredicateType::Inline:
-      case PredicateType::ReductionWrite:
-      case PredicateType::Misaligned: {
-        return PredicateCompute::getInlinePredicate(
-            pred->expr(),
-            for_loops_structure_,
-            pred->thread_pred(),
-            pred->predicate_type());
-      }
-      case PredicateType::Vectorize: {
-        std::vector<kir::ForLoop*> outer_loops;
-        kir::ForLoop* vectorized_loop = nullptr;
-        for (auto loop : for_loops_structure_) {
-          if (loop->iter_domain()->parallelType() == ParallelType::Vectorize) {
-            vectorized_loop = loop;
-            break;
-          } else {
-            outer_loops.emplace_back(loop);
-          }
-        }
-        TORCH_INTERNAL_ASSERT(
-            vectorized_loop != nullptr, "Should be unreachable.");
-        return UnswitchPredicate::get(outer_loops, vectorized_loop);
-      }
-      case PredicateType::Unswitch: {
-        return UnswitchPredicate::get(
-            for_loops_structure_, pred->unrolled_loop());
-      }
-      case PredicateType::Shift: {
-        kir::TensorView* out_tv = ir_utils::getTVOutput(pred->expr());
-        TORCH_INTERNAL_ASSERT(
-            out_tv != nullptr, "Missing kir::TensorView output");
-        return ShiftPredicateInserter::getPredicate(
-            pred->expr(),
-            for_loops_structure_,
-            out_tv,
-            pred->thread_pred(),
-            true);
-      }
-      case PredicateType::Padding: {
-        kir::TensorView* out_tv = ir_utils::getTVOutput(pred->expr());
-        TORCH_INTERNAL_ASSERT(
-            out_tv != nullptr, "Missing kir::TensorView output");
-        return ShiftPredicateInserter::getPredicate(
-            pred->expr(),
-            for_loops_structure_,
-            out_tv,
-            pred->thread_pred(),
-            false);
-      }
-      case PredicateType::Manual: {
-        return pred->value();
-      }
-      default:
-        break;
-    }
-    return nullptr;
-  }
-
- private:
-  // We will track which loops in the incoming IR will be replaced and by what
-  std::unordered_map<kir::Expr*, kir::Expr*> expr_replacement_map_;
-
-  // A depth-first ordering of nested for loops
-  // It is used for indexing and predicate generation
-  std::vector<kir::ForLoop*> for_loops_structure_;
-};
-
-} // namespace
-
-std::vector<kir::Expr*> generateConditionalFromPredicate(
-    Fusion* fusion,
-    const std::vector<kir::Expr*>& exprs) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::generateConditionalFromPredicate");
-
-  ConditionalFromPredicateModifier p2cm(exprs);
-
-  std::vector<kir::Expr*> mutated_exprs;
-  mutated_exprs.reserve(exprs.size());
-  for (auto expr : exprs) {
-    mutated_exprs.push_back(
-        ir_utils::applyReplacements(p2cm.replacementMap(), expr));
-  }
-
-  return mutated_exprs;
-}
-
-namespace {
-
-class PredicateAnalyzer : public OptOutDispatch {
- public:
-  //! Checks if a predicate is needed to avoid out-of-bound accesses.
-  //!
-  //! Due to the way we allocate local-memory tensors, there should
-  //! never be out-of-bound accesses with consumer tensors when allocated on
-  //! local memory. However, accessing producer tensors still may
-  //! result in out-of-bound as they are replayed as consumers.
-  static bool needsPredicate(TensorView* producer, TensorView* consumer) {
-    // Both tensors must be on local memory. Global tensors must be
-    // predicated as allocation is done based on root domains. Smem
-    // and local tensors are allocated based on leaf domains, however,
-    // smem tensors are parallelized, which is highly likely, the size
-    // of the parallelized axis is the actual size of the axis, not
-    // the number of threads. Since the number of threads can be
-    // larger than the axis size, it's not safe to skip predication
-    if (!(producer->getMemoryType() == MemoryType::Local &&
-          consumer->getMemoryType() == MemoryType::Local)) {
-      return true;
-    }
-
-    auto pairwise_map = PairwiseRootDomainMap(producer, consumer);
-    auto c2p =
-        BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map)
-            .getReplay();
-
-    PredicateAnalyzer analyzer(c2p);
-
-    for (auto id : consumer->domain()->domain()) {
-      if (analyzer.needsPredicate(id)) {
-        return true;
-      }
-    }
-
-    return false;
-  }
-
- private:
-  PredicateAnalyzer(const std::unordered_map<IterDomain*, IterDomain*>& c2p_map)
-      : c2p_map_(c2p_map) {}
-
-  // Returns true if no out-of-bound accesses could occur with a
-  // producer
-  bool needsPredicate(IterDomain* consumer_id) {
-    needs_predicate_ = false;
-    handle(consumer_id);
-    return needs_predicate_;
-  }
-
-  using OptOutDispatch::handle;
-
-  void handle(IterDomain* consumer_id) override {
-    // The traversal should have ended if needs_predicate_ was true
-    TORCH_INTERNAL_ASSERT(!needs_predicate_);
-
-    // If consumer_id is not going to be materialized as a loop (e.g.,
-    // broadcast), no need to predicate
-    const auto gpu_lower = GpuLower::current();
-    if (consumer_id->isBroadcast() ||
-        gpu_lower->trivialReductionInfo().isDerived(consumer_id)) {
-      return;
-    }
-
-    // If the producer has a matching domain, it should not cause
-    // out-of-bound accesses
-    if (c2p_map_.find(consumer_id) != c2p_map_.end()) {
-      return;
-    }
-
-    // If no definition exists, stop traversing
-    if (consumer_id->definition() == nullptr) {
-      return;
-    }
-
-    handle(consumer_id->definition());
-  }
-
-  // If it splits the input axis evenly, proceeds to check the input
-  // axis. Otherwise, we can't skip predication as it might cause
-  // out-bound accesses with the producer tensor
-  void handle(Split* split) override {
-    auto factor = split->factor()->getInt();
-    if (!factor.has_value()) {
-      needs_predicate_ = true;
-      return;
-    }
-
-    ExpressionEvaluator ee(split->fusion());
-    const auto in_extent = ee.evaluate(split->in()->extent());
-
-    if (!in_extent.has_value() || ((in_extent.value() % factor.value()) != 0)) {
-      needs_predicate_ = true;
-      return;
-    }
-
-    handle(split->in());
-  }
-
-  void handle(Merge* merge) override {
-    handle(merge->inner());
-    if (needs_predicate_) {
-      return;
-    }
-    handle(merge->outer());
-  }
-
- private:
-  //! BestEffort map from consumer IDs to producer IDs
-  const std::unordered_map<IterDomain*, IterDomain*>& c2p_map_;
-  bool needs_predicate_ = false;
-};
-
-} // namespace
-
-bool PredicateElimination::needsPredicate(Expr* expr) const {
-  if (!ir_utils::isTVOp(expr)) {
-    return false;
-  }
-
-  std::vector<std::function<bool(Expr*)>> filters;
-
-  // Always predicate integer division and related ops as we don't
-  // know what values are in the out-of-bound region and they may
-  // cause exceptions
-  filters.push_back([](Expr* expr) {
-    auto dt = expr->outputs()[0]->getDataType().value();
-    return (
-        (dt == DataType::Int || dt == DataType::Int32) &&
-        expr->isA<BinaryOp>() &&
-        (expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::Div ||
-         expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::Mod ||
-         expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::Remainder ||
-         expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::CeilDiv));
-  });
-
-  // Skip if MisalignedVectorize is involved for now. This could be
-  // relaxed.
-  filters.push_back([](Expr* expr) {
-    std::vector<const std::vector<Val*>*> inputs_and_outputs = {
-        &(expr->inputs()), &(expr->outputs())};
-    for (const auto& inputs_or_outputs : inputs_and_outputs) {
-      for (auto tv : ir_utils::filterByType<TensorView>(*inputs_or_outputs)) {
-        if (std::any_of(
-                tv->domain()->domain().begin(),
-                tv->domain()->domain().end(),
-                [](IterDomain* axis) {
-                  return axis->getParallelType() ==
-                      ParallelType::MisalignedVectorize;
-                })) {
-          return true;
-        }
-      }
-    }
-    return false;
-  });
-
-  // Shift is not supported yet.
-  filters.push_back([](Expr* expr) {
-    auto& halo_info = GpuLower::current()->haloInfo();
-    auto input_tvs = ir_utils::filterByType<TensorView>(expr->inputs());
-    return halo_info.needsShiftPredicate(expr) ||
-        std::any_of(input_tvs.begin(), input_tvs.end(), [&](auto input_tv) {
-             return input_tv->definition() != nullptr &&
-                 halo_info.needsShiftPredicate(input_tv->definition());
-           });
-  });
-
-  // Predicates the expression if any producer-consumer pair of the
-  // expression needs to be predicated
-  filters.push_back([](Expr* expr) {
-    for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) {
-      for (auto input : ir_utils::filterByType<TensorView>(expr->inputs())) {
-        if (PredicateAnalyzer::needsPredicate(input, output)) {
-          return true;
-        }
-      }
-    }
-    return false;
-  });
-
-  // Predicates Welford ops
-  filters.push_back([](Expr* expr) { return expr->isA<WelfordOp>(); });
-
-  // If this is a reduction, and if we omit the predicate for the
-  // input, the input may have a garbabe value, which must not be used
-  // for this reduction. However, if the input is also an output of
-  // another reduction with the same binary op, which is a common
-  // pattern with rfactor, the input should be safe to use with no
-  // predication.
-  filters.push_back([this](Expr* expr) {
-    if (expr->isA<ReductionOp>()) {
-      auto input = expr->inputs()[0]->as<TensorView>();
-      auto input_def = input->definition();
-      // When input_def is null, input must be an input to the fusion,
-      // so that must be allocated on global memory. Since we don't omit
-      // predication for expressions involving global memory, this
-      // should never occur.
-      TORCH_INTERNAL_ASSERT(
-          input_def != nullptr, "Inconsistent input found: ", input);
-
-      if (non_predicated_exprs_.find(input_def) !=
-              non_predicated_exprs_.end() &&
-          !(input_def->isA<ReductionOp>() &&
-            (expr->as<ReductionOp>()->getReductionOpType() ==
-             input_def->as<ReductionOp>()->getReductionOpType()))) {
-        return true;
-      }
-    }
-    return false;
-  });
-
-  // If any of the filters returns true, predicate must be used.
-  return std::any_of(filters.begin(), filters.end(), [expr](auto filter) {
-    return filter(expr);
-  });
-}
-
-void PredicateElimination::handle(Expr* expr) {
-  if (!ir_utils::isTVOp(expr)) {
-    return;
-  }
-
-  if (needsPredicate(expr)) {
-    return;
-  }
-
-  non_predicated_exprs_.insert(expr);
-
-  // Ensure all inputs have some values set at the out-of-bound
-  // regions
-  for (auto input : ir_utils::filterByType<TensorView>(expr->inputs())) {
-    auto input_def = input->definition();
-    // When input_def is null, input must be an input to the fusion,
-    // so that must be allocated on global memory. Since we don't omit
-    // predication for expressions involving global memory, this
-    // should never occur.
-    std::stringstream ss;
-    ss << input;
-    TORCH_INTERNAL_ASSERT(
-        input_def != nullptr, "Inconsistent input found: ", ss.str());
-
-    // If input is an output of reduction, it should be fully
-    // initialied as it's allocated on local memory.
-    if (input_def->isA<ReductionOp>() || input_def->isA<WelfordOp>()) {
-      continue;
-    }
-
-    // If this expr is reduction, always initilize the input with the
-    // default value. NOTE: This can be done more
-    // intelligently. A garbage value can only cause a problem when
-    // it's reduced with non-garbage values, so if the non-reduction
-    // axes do not have any garbage, it should be just fine without
-    // explicit initialization. However, initialization cost should be
-    // cheap, so that further optimization should not make a large
-    // difference.
-    if (expr->isA<ReductionOp>()) {
-      setReductionInitValue(input, expr->as<ReductionOp>()->init());
-      continue;
-    }
-
-    // If an input does not need a predicate either, then it should
-    // have some value, so no need to set a default value
-    if (non_predicated_exprs_.find(input_def) != non_predicated_exprs_.end()) {
-      continue;
-    }
-
-    // Make sure input is initialized
-    setDefaultInitValue(input);
-  }
-}
-
-bool PredicateElimination::setDefaultInitValue(TensorView* tv) {
-  auto it = init_value_map_.find(tv);
-  // If there's already a mapping for tv, it should be mapped to a
-  // zero val or a reduction init. Either case, no need to modify
-  // the existing mapping.
-  if (it == init_value_map_.end()) {
-    init_value_map_.insert({tv, nullptr});
-  }
-  return true;
-}
-
-bool PredicateElimination::setReductionInitValue(
-    TensorView* tv,
-    Val* reduction_init) {
-  auto it = init_value_map_.find(tv);
-  if (it == init_value_map_.end()) {
-    init_value_map_.insert({tv, reduction_init});
-    return true;
-  }
-
-  auto existing_val = it->second;
-  if (existing_val == nullptr) {
-    // If the existing mapping returns nullptr, it means that a
-    // default init was set before. Overwrite with the reduction
-    // init val.
-    init_value_map_[tv] = reduction_init;
-    return true;
-  } else if (existing_val->sameAs(reduction_init)) {
-    return true;
-  } else {
-    TORCH_INTERNAL_ASSERT(
-        false,
-        "Incosistent setting of initialization value for t",
-        tv->name(),
-        ". Prev: ",
-        existing_val,
-        ", New: ",
-        reduction_init);
-    return false;
-  }
-}
-
-bool PredicateElimination::canOmitPredicate(const Expr* expr) const {
-  TORCH_INTERNAL_ASSERT(expr != nullptr);
-  const auto out_tv = ir_utils::getTVOutput(expr);
-  TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Not a tensor expression");
-  // No need to predicate local tensors to which a scalar is assigned
-  if (out_tv->getMemoryType() == MemoryType::Local) {
-    if (auto uop = dynamic_cast<const UnaryOp*>(expr)) {
-      if (uop->getUnaryOpType() == UnaryOpType::Set && uop->in()->isScalar()) {
-        return true;
-      }
-    }
-  }
-  if (non_predicated_exprs_.find(expr) != non_predicated_exprs_.end()) {
-    return true;
-  }
-
-  return false;
-}
-
-bool PredicateElimination::canOmitPredicate(const kir::Expr* kir_expr) const {
-  TORCH_INTERNAL_ASSERT(kir_expr != nullptr);
-  const auto out_tv = ir_utils::getTVOutput(kir_expr);
-  TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Not a tensor expression");
-  // No need to predicate local tensors to which a scalar is assigned
-  if (out_tv->memoryType() == MemoryType::Local) {
-    if (auto uop = dynamic_cast<const kir::UnaryOp*>(kir_expr)) {
-      if (uop->operation() == UnaryOpType::Set && uop->in()->isScalar()) {
-        return true;
-      }
-    }
-  }
-  const auto fuser_tv = out_tv->fuserTv();
-  if (fuser_tv == nullptr) {
-    return false;
-  }
-  return canOmitPredicate(fuser_tv->definition());
-}
-
-kir::Val* PredicateElimination::getInitValue(TensorView* tv) const {
-  auto it = init_value_map_.find(tv);
-  if (it == init_value_map_.end()) {
-    return nullptr;
-  }
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-  auto init_val = it->second;
-  if (init_val == nullptr) {
-    // No reduction restriction. Just use zero
-    return ir_builder.zeroVal();
-  } else {
-    return gpu_lower->lowerValue(init_val);
-  }
-}
-
-void PredicateElimination::build(Fusion* fusion) {
-  traverseFrom(fusion, fusion->outputs());
-}
-
-std::string PredicateElimination::toString() const {
-  std::stringstream ss;
-  ss << "Tensors that do not need predication:";
-  for (auto expr : non_predicated_exprs_) {
-    for (auto out : expr->outputs()) {
-      TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
-      ss << " T" << out->name();
-    }
-  }
-  ss << "\n";
-  ss << "Init values:";
-  for (auto kv : init_value_map_) {
-    ss << " T" << kv.first->name() << "->";
-    if (kv.second == nullptr) {
-      ss << "<default(0)>";
-    } else {
-      ss << kv.second;
-    }
-  }
-  ss << "\n";
-  return ss.str();
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.h b/torch/csrc/jit/codegen/cuda/lower_predicate.h
deleted file mode 100644 (file)
index de70640..0000000
+++ /dev/null
@@ -1,63 +0,0 @@
-#pragma once
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Update predicates with valid bool conditionals
-//!
-std::vector<kir::Expr*> generateConditionalFromPredicate(
-    Fusion* fusion,
-    const std::vector<kir::Expr*>& exprs);
-
-class TORCH_CUDA_CU_API PredicateElimination : public IterVisitor {
- public:
-  void build(Fusion* fusion);
-
-  //! True if expr does not need a predicate
-  //!
-  //! \param expr Tensor expression
-  bool canOmitPredicate(const Expr* expr) const;
-
-  //! True if expr does not need a predicate
-  //!
-  //! \param expr KIR tensor expr
-  bool canOmitPredicate(const kir::Expr* expr) const;
-
-  //! Value to initialize out-of-bound regions
-  kir::Val* getInitValue(TensorView* tv) const;
-
-  //! Dump to string for debugging
-  std::string toString() const;
-
- private:
-  using IterVisitor::handle;
-
-  void handle(Expr* expr) override;
-
-  //! Set a value to initialize out-of-bound regions
-  bool setDefaultInitValue(TensorView* tv);
-  //! Set a value to initialize out-of-bound regions of reduction tensors
-  bool setReductionInitValue(TensorView* tv, Val* reduction_init);
-
-  //! Check if expr needs to be predicated
-  bool needsPredicate(Expr* expr) const;
-
- private:
-  //! Expressions that are found to be safe without predicates
-  std::unordered_set<const Expr*> non_predicated_exprs_;
-  //! Tensors and their initialization values
-  std::unordered_map<TensorView*, Val*> init_value_map_;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp
deleted file mode 100644 (file)
index 1c494d5..0000000
+++ /dev/null
@@ -1,1040 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-
-#include <functional>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-// utility function
-kir::Bool* makeAndExpr(kir::Val* lhs, kir::Val* rhs) {
-  TORCH_INTERNAL_ASSERT(!(lhs == nullptr && rhs == nullptr));
-  if (lhs == nullptr) {
-    return rhs->as<kir::Bool>();
-  } else if (rhs == nullptr) {
-    return lhs->as<kir::Bool>();
-  } else {
-    kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-    return ir_builder.andExpr(lhs, rhs)->as<kir::Bool>();
-  }
-}
-
-kir::Int* makeAddExpr(kir::Int* lhs, kir::Int::ScalarType rhs) {
-  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-  if (rhs == 0) {
-    return lhs;
-  } else if (lhs == nullptr) {
-    return ir_builder.create<kir::Int>(rhs);
-  } else if (lhs->isConst()) {
-    return ir_builder.create<kir::Int>(lhs->value().value() + rhs);
-  } else if (rhs > 0) {
-    return ir_builder.addExpr(lhs, ir_builder.create<kir::Int>(rhs))
-        ->as<kir::Int>();
-  } else {
-    return ir_builder.subExpr(lhs, ir_builder.create<kir::Int>(-rhs))
-        ->as<kir::Int>();
-  }
-}
-
-kir::Int* makeAddExpr(kir::Int* lhs, kir::Int* rhs) {
-  if (rhs == nullptr) {
-    return lhs;
-  } else if (lhs == nullptr) {
-    return rhs;
-  } else if (lhs->isConst()) {
-    return makeAddExpr(rhs, lhs->value().value());
-  } else if (rhs->isConst()) {
-    return makeAddExpr(lhs, rhs->value().value());
-  } else {
-    kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-    return ir_builder.addExpr(lhs, rhs)->as<kir::Int>();
-  }
-}
-
-kir::Val* makeAddExpr(kir::Val* lhs, kir::Val* rhs) {
-  TORCH_INTERNAL_ASSERT(lhs != nullptr || rhs != nullptr);
-  if (lhs == nullptr || lhs->isZeroInt()) {
-    return rhs;
-  } else if (rhs == nullptr || rhs->isZeroInt()) {
-    return lhs;
-  }
-  auto lhs_int = dynamic_cast<kir::Int*>(lhs);
-  auto rhs_int = dynamic_cast<kir::Int*>(rhs);
-  if (lhs_int != nullptr && rhs_int != nullptr) {
-    return makeAddExpr(lhs_int, rhs_int);
-  } else {
-    kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-    return ir_builder.addExpr(lhs, rhs);
-  }
-}
-
-} // namespace
-
-void ShiftPredicateInserter::insert(
-    kir::Expr* expr,
-    const std::vector<kir::ForLoop*>& loops,
-    kir::Bool* thread_pred) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  kir::TensorView* out_tv = ir_utils::getTVOutput(expr);
-  TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing kir::TensorView output");
-
-  TensorView* out_fuser_tv = out_tv->fuserTv();
-  const bool needs_shift_predicate =
-      gpu_lower->haloInfo().needsShiftPredicate(out_fuser_tv->definition());
-  if (!needs_shift_predicate) {
-    return;
-  }
-
-  // The conditional branches to create:
-  //
-  // if (shift_pred) {
-  //   consumer = producer;
-  // } else {
-  //   if (padding_pred) {
-  //     consumer = 0;
-  //   }
-  // }
-
-  kir::Predicate* shift_pred = ir_builder.create<kir::Predicate>(
-      PredicateType::Shift, expr, thread_pred);
-
-  // If the expr involves a thread-block barrier, set the predicate of
-  // the expre with shift_pred. Since the expr is not shift, the
-  // padding should be safe to omit. In fact, padding is probably not
-  // necessary for all non-shift exprs (see #877)
-  if (ir_utils::hasBlockSync(expr, gpu_lower->threadPredMap())) {
-    expr->setPredicate(shift_pred);
-    return;
-  }
-
-  auto shift_ite = ir_builder.create<kir::IfThenElse>(shift_pred);
-
-  auto& scope = loops.back()->body();
-
-  // Insert the if statement
-  scope.insert_before(expr, shift_ite);
-
-  // Remove the expr from the list
-  scope.erase(expr);
-
-  // Place the expr inside the if statement
-  shift_ite->thenBody().push_back(expr);
-
-  // Padding by zero
-  kir::Predicate* padding_pred = ir_builder.create<kir::Predicate>(
-      PredicateType::Padding, expr, thread_pred);
-  auto bounds_ite = ir_builder.create<kir::IfThenElse>(padding_pred);
-  const int pad_value = 0;
-  auto pad_expr = ir_builder.create<kir::UnaryOp>(
-      UnaryOpType::Set, out_tv, ir_builder.create<kir::Int>(pad_value));
-  bounds_ite->thenBody().push_back(pad_expr);
-  // Insert the else block
-  shift_ite->elseBody().push_back(bounds_ite);
-}
-
-namespace {
-
-kir::Val* getShiftProducerIndex(
-    size_t consumer_root_axis,
-    kir::Val* consumer_index,
-    ShiftOp* shift_expr) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  const int shift_offset =
-      (shift_expr != nullptr) ? shift_expr->offset(consumer_root_axis) : 0;
-
-  if (shift_offset == 0) {
-    return consumer_index;
-  } else if (shift_offset > 0) {
-    return ir_builder.subExpr(
-        consumer_index, ir_builder.create<kir::Int>(shift_offset));
-  } else {
-    return ir_builder.addExpr(
-        consumer_index, ir_builder.create<kir::Int>(-shift_offset));
-  }
-}
-
-// Create a producer index by adjusting the corresponding consumer
-// index.
-kir::Val* getGatherProducerIndex(
-    size_t consumer_root_axis,
-    kir::Val* consumer_index,
-    GatherOp* gather_expr,
-    const std::vector<kir::Val*>& indices) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  if (gather_expr == nullptr ||
-      consumer_root_axis >= gather_expr->windowShape().size() ||
-      gather_expr->windowShape()[consumer_root_axis]->isOneInt()) {
-    return consumer_index;
-  }
-
-  // Relative to the consumer index, the producer index needs to
-  // account for:
-  // - window access
-  // - padding at offset 0
-  // This adjustment is basically the same as
-  // getProducerIndexWithGather in index_compute.cpp.
-  // TODO: Refactor shift/gather indexing and predication
-  const auto window_axis = gather_expr->gatherAxis(consumer_root_axis);
-  TORCH_INTERNAL_ASSERT(window_axis < (int)indices.size());
-  auto window_idx = indices[window_axis];
-  auto pad_size = gather_expr->padWidth()[consumer_root_axis][0];
-  auto producer_index = ir_builder.subExpr(
-      ir_builder.addExpr(consumer_index, window_idx),
-      ir_builder.create<kir::Int>(pad_size));
-  return producer_index;
-}
-
-} // namespace
-
-kir::Bool* ShiftPredicateInserter::getPredicate(
-    const kir::Expr* expr,
-    const std::vector<kir::ForLoop*>& loops,
-    kir::TensorView* out_tv,
-    kir::Bool* thread_pred,
-    bool isShiftPredicate) {
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  TensorView* out_fuser_tv = out_tv->fuserTv();
-
-  const bool needs_shift_predicate =
-      gpu_lower->haloInfo().needsShiftPredicate(out_fuser_tv->definition());
-  TORCH_INTERNAL_ASSERT(needs_shift_predicate);
-
-  const auto& root_domain = out_fuser_tv->getRootDomain();
-
-  auto shift_expr = dynamic_cast<ShiftOp*>(out_fuser_tv->definition());
-  auto gather_expr = dynamic_cast<GatherOp*>(out_fuser_tv->definition());
-
-  // Creates indices at the root domain.
-  // Set contiguity of all axes false as separate indices are needed for each
-  // root axis.
-  // Note: separate indices should be needed only for axes that
-  // require shift predication, so other axes could use the actual
-  // contiguity information. See a TODO item of issue #877.
-  const auto pred_contiguity = std::vector<bool>(root_domain.size(), false);
-  auto pred_indices =
-      Index::getConsumerRootPredIndices(out_tv, loops, pred_contiguity);
-  const auto& indices = pred_indices.first;
-  const bool buffer_init = pred_indices.second;
-
-  // No predication is needed when the expr is to initialize reduction
-  // buffer on local memory
-  if (out_tv->memoryType() == MemoryType::Local && buffer_init) {
-    return ir_builder.trueVal();
-  }
-
-  TORCH_INTERNAL_ASSERT(indices.size() == root_domain.size());
-
-  kir::Bool* predicate = nullptr;
-
-  for (size_t i = 0; i < root_domain.size(); ++i) {
-    auto root_id = root_domain[i];
-
-    if (root_id->isBroadcast() || (buffer_init && root_id->isReduction()) ||
-        gpu_lower->trivialReductionInfo().isDerived(root_id)) {
-      continue;
-    }
-
-    const auto halo_info = gpu_lower->haloInfo().getRootAxisInfo(root_id);
-
-    if (isShiftPredicate) {
-      // Below, "left" and "right" halo mean halo at offset zero and
-      // axis extent, respectively.
-      //
-      // The consumer axis looks like this:
-      //
-      // [0, left halo)[0, extent)[0, right halo)
-      //              ^         ^
-      //        left limit   right limit
-      //
-      // Accesses outside of the left and right limits are filled by
-      // zero. As illustrated above, left limit = left halo, and right
-      // limit = left halo + extent.
-
-      kir::Val* left_limit = halo_info.width(0);
-      kir::Val* right_limit = makeAddExpr(
-          out_tv->domain()->rootDomain()[i]->extent(), halo_info.width(0));
-
-      kir::Val* consumer_index = indices[i];
-      kir::Val* producer_index = nullptr;
-
-      if (shift_expr != nullptr) {
-        producer_index = getShiftProducerIndex(i, consumer_index, shift_expr);
-      } else if (gather_expr != nullptr) {
-        producer_index =
-            getGatherProducerIndex(i, consumer_index, gather_expr, indices);
-      } else {
-        producer_index = indices[i];
-      }
-
-      // If the defining expr is ShiftOp and its offset is positive,
-      // consumer access at 0 to the offset corresponds to
-      // out-of-bound producer access unless the producer has halo as
-      // well. For now, always add predication assuming no halo on the
-      // producer. This should be reivisted for performance
-      // optimization (#877).
-      if (shift_expr && shift_expr->offset(i) > 0) {
-        predicate = makeAndExpr(
-            predicate, ir_builder.geExpr(producer_index, left_limit));
-      } else if (gather_expr) {
-        // Since it's unknown if producer_index < consumer_index, we need
-        // to predicate using both of the producer and consumer
-        // indices. This would be the case if dynamic shift offset is
-        // used, which is not yet supported. This can be a performance
-        // problem, but in a common case where the input tensor is
-        // cached at SMEM, it should be possible to remove the
-        // predicate for this expression entirely.
-        predicate = makeAndExpr(
-            predicate, ir_builder.geExpr(consumer_index, left_limit));
-        if (consumer_index != producer_index) {
-          predicate = makeAndExpr(
-              predicate, ir_builder.geExpr(producer_index, left_limit));
-        }
-      } else if (!left_limit->isZeroInt()) {
-        predicate = makeAndExpr(
-            predicate, ir_builder.geExpr(consumer_index, left_limit));
-      }
-
-      // If the shift offset is negative, the maximum index is extent -
-      // abs(shift_offset). Instead of subtracting shift_offset from
-      // extent, which can result in wrap around, add the absolute value
-      // of the shift offset to the index
-      if (shift_expr && shift_expr->offset(i) < 0) {
-        predicate = makeAndExpr(
-            predicate, ir_builder.ltExpr(producer_index, right_limit));
-      } else if (gather_expr) {
-        predicate = makeAndExpr(
-            predicate, ir_builder.ltExpr(consumer_index, right_limit));
-        if (consumer_index != producer_index) {
-          predicate = makeAndExpr(
-              predicate, ir_builder.ltExpr(producer_index, right_limit));
-        }
-      } else {
-        predicate = makeAndExpr(
-            predicate, ir_builder.ltExpr(consumer_index, right_limit));
-      }
-    } else {
-      auto padding_max_offset = makeAddExpr(
-          out_tv->domain()->rootDomain()[i]->extent(), halo_info.width());
-
-      predicate = makeAndExpr(
-          predicate, ir_builder.ltExpr(indices[i], padding_max_offset));
-    }
-  }
-
-  if (thread_pred->isConst()) {
-    if (!thread_pred->value().value()) {
-      predicate = ir_builder.create<kir::Bool>(false);
-    }
-  } else {
-    predicate = makeAndExpr(predicate, thread_pred);
-  }
-
-  return predicate;
-}
-
-AxisHaloInfo::AxisHaloInfo() {
-  auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-  setWidth(0, ir_builder.zeroVal());
-  setWidth(1, ir_builder.zeroVal());
-}
-
-kir::Int* AxisHaloInfo::width() const {
-  auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-  return makeAddExpr(width(0), width(1));
-}
-
-kir::Int* AxisHaloInfo::width(int pos) const {
-  TORCH_INTERNAL_ASSERT(pos >= 0 && pos < 2);
-  TORCH_INTERNAL_ASSERT(widths_[pos] != nullptr);
-  return widths_[pos];
-}
-
-void AxisHaloInfo::setWidth(int pos, kir::Int* width) {
-  TORCH_INTERNAL_ASSERT(pos >= 0 && pos < 2);
-  widths_[pos] = width;
-}
-
-void AxisHaloInfo::merge(int pos, kir::Int* other) {
-  auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-  auto cur = width(pos);
-  kir::Int* new_width = nullptr;
-  if (cur->isConst() && other->isConst()) {
-    new_width = ir_builder.create<kir::Int>(
-        std::max(cur->value().value(), other->value().value()));
-  } else if (cur->isZeroInt()) {
-    new_width = other;
-  } else if (other->isZeroInt()) {
-    new_width = cur;
-  } else {
-    new_width = ir_builder.maxExpr(width(pos), other)->as<kir::Int>();
-  }
-  setWidth(pos, new_width);
-}
-
-void AxisHaloInfo::merge(const AxisHaloInfo& other) {
-  for (size_t i = 0; i < widths_.size(); ++i) {
-    merge(i, other.width(i));
-  }
-}
-
-bool AxisHaloInfo::hasHalo() const {
-  return std::any_of(
-      widths_.begin(), widths_.end(), [](auto w) { return !w->isZeroInt(); });
-}
-
-std::string AxisHaloInfo::toString() const {
-  std::stringstream ss;
-  ss << "<" << kir::toString(width(0)) << ", " << kir::toString(width(1))
-     << ">";
-  return ss.str();
-}
-
-const AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) const {
-  TORCH_INTERNAL_ASSERT(
-      id->definition() == nullptr || id->isRFactorProduct(),
-      "Invalid IterDomain: ",
-      id);
-  auto it = root_axis_map_.find(id);
-  TORCH_INTERNAL_ASSERT(
-      it != root_axis_map_.end(), "Halo root axis info not found for ", id);
-  return it->second;
-}
-
-AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) {
-  return const_cast<AxisHaloInfo&>(
-      const_cast<const HaloInfo*>(this)->getRootAxisInfo(id));
-}
-
-const AxisHaloInfo& HaloInfo::getRootAxisInfo(kir::IterDomain* id) const {
-  TORCH_INTERNAL_ASSERT(
-      id->definition() == nullptr || id->isRFactorProduct(),
-      "Invalid IterDomain: ",
-      id);
-  auto it = kir_root_axis_map_.find(id);
-  TORCH_INTERNAL_ASSERT(
-      it != kir_root_axis_map_.end(), "Halo root axis info not found for ", id);
-  return it->second;
-}
-
-AxisHaloInfo& HaloInfo::getRootAxisInfo(kir::IterDomain* id) {
-  return const_cast<AxisHaloInfo&>(
-      const_cast<const HaloInfo*>(this)->getRootAxisInfo(id));
-}
-
-void HaloInfo::setRootAxisInfo(
-    IterDomain* id,
-    const AxisHaloInfo& root_axis_info) {
-  TORCH_INTERNAL_ASSERT(
-      id->definition() == nullptr || id->isRFactorProduct(),
-      "Invalid IterDomain: ",
-      id);
-  root_axis_map_[id] = root_axis_info;
-  kir_root_axis_map_
-      [GpuLower::current()->lowerValue(id)->as<kir::IterDomain>()] =
-          root_axis_info;
-  return;
-}
-
-void HaloInfo::build(Fusion* fusion) {
-  const auto vals = fusion->usedMathVals();
-  auto tvs = ir_utils::filterByType<TensorView>(vals);
-
-  // Initialize all root axis info
-  for (auto tv : tvs) {
-    for (auto root_axis : tv->getRootDomain()) {
-      setRootAxisInfo(root_axis, AxisHaloInfo());
-    }
-    // Just adds a placeholder to make it not fail. Reduction and
-    // rfactor support is not yet in place.
-    if (tv->hasRFactor()) {
-      for (auto rf_root_axis : tv->getRFactorDomain()) {
-        setRootAxisInfo(rf_root_axis, AxisHaloInfo());
-      }
-    }
-  }
-
-  // Propagate backward halo information of root axes from fusion
-  // outputs to inputs
-  auto exprs = fusion->exprs();
-  for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) {
-    auto expr = *it;
-    if (!expr->outputs()[0]->isA<TensorView>()) {
-      continue;
-    }
-
-    propagateRootAxisInfo(expr);
-  }
-
-  // Propagates halo information from root axes down to leaf axes
-  for (auto tv : tvs) {
-    build(tv->domain());
-  }
-
-  // Note that validation requires consumer halo info
-  for (auto tv : tvs) {
-    validate(tv);
-  }
-}
-
-void HaloInfo::propagateRootAxisInfo(Expr* expr) {
-  for (auto output : expr->outputs()) {
-    auto out_tv = dynamic_cast<TensorView*>(output);
-    if (out_tv == nullptr) {
-      continue;
-    }
-    for (auto input : expr->inputs()) {
-      auto in_tv = dynamic_cast<TensorView*>(input);
-      if (in_tv == nullptr) {
-        continue;
-      }
-      propagateRootAxisInfo(in_tv, out_tv, expr);
-    }
-  }
-}
-
-void HaloInfo::propagateRootAxisInfo(
-    TensorView* producer,
-    TensorView* consumer,
-    Expr* expr) {
-  // Do not add halo to input tensors
-  if (producer->isFusionInput()) {
-    return;
-  }
-
-  auto c2p = PairwiseRootDomainMap(producer, consumer)
-                 .mapConsumerToProducer(consumer->domain(), producer->domain());
-
-  const auto& c_root = consumer->getRootDomain();
-
-  auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  for (size_t i = 0; i < c_root.size(); ++i) {
-    auto c_id = c_root[i];
-    auto it = c2p.find(c_id);
-    if (it == c2p.end()) {
-      // nothing to propagate
-      continue;
-    }
-
-    // propagate root-axis halo info from c_id to p_id
-
-    auto p_id = it->second;
-
-    auto p_info = getRootAxisInfo(p_id);
-    const auto c_info = getRootAxisInfo(c_id);
-
-    // If the root axes are broadcast, no halo should be associated
-    // with them.
-    if (c_id->isBroadcast()) {
-      TORCH_INTERNAL_ASSERT(!c_info.hasHalo());
-      p_info.merge(c_info);
-      setRootAxisInfo(p_id, p_info);
-      continue;
-    }
-
-    // If the defining expression is shift, adjust the producer halo
-    // width based on the shift offset. If the shift offset is
-    // positive, create halo at offset zero of the producer axis so
-    // that the consumer can safely access the producer. If the offset
-    // is negative, halo is created at the other end of the axis.
-    // If the expr is not shift, just merge the consumer halo info
-    // to the producer halo info so that the producer halo can be the
-    // maximum of all its consumers.
-    if (auto shift_op = dynamic_cast<ShiftOp*>(expr)) {
-      const auto offset = shift_op->offset(i);
-      if (offset == 0) {
-        p_info.merge(c_info);
-      } else {
-        int pos = (offset > 0) ? 0 : 1;
-        p_info.merge(pos, makeAddExpr(c_info.width(pos), std::abs(offset)));
-      }
-    } else if (auto gather_op = dynamic_cast<GatherOp*>(expr)) {
-      const auto window_dim =
-          gpu_lower->lowerValue(gather_op->windowShape()[i]);
-      if (window_dim->isOneInt()) {
-        p_info.merge(c_info);
-        continue;
-      }
-      const auto& pad_dim = gather_op->padWidth()[i];
-      const auto pad_dim0 = gpu_lower->lowerValue(pad_dim[0])->as<kir::Int>();
-      p_info.merge(0, makeAddExpr(c_info.width(0), pad_dim0));
-      // The right-side halo is propagated as:
-      //   consumer_right_halo + (window_dim - 1 - left_padding)
-      p_info.merge(
-          1,
-          ir_builder
-              .subExpr(
-                  makeAddExpr(c_info.width(1), window_dim),
-                  makeAddExpr(pad_dim0, 1))
-              ->as<kir::Int>());
-    } else {
-      p_info.merge(c_info);
-    }
-    setRootAxisInfo(p_id, p_info);
-  }
-}
-
-// Propagate extent information from root axes to descendants
-void HaloInfo::build(TensorDomain* td) {
-  auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  for (auto root_axis : td->getRootDomain()) {
-    const auto& halo_info = getRootAxisInfo(root_axis);
-    auto halo_width = halo_info.width();
-
-    // There should be no existing mapping. Note that at one point it
-    // wasn't the case as root axes were reused when creating
-    // reference tensors.
-    // TODO: This is not the case actually. Root domains are reused
-    // when creating some TensorDomains, so a single IterDomain can
-    // show up multiple times. That itself should be fixed, but for
-    // now disable this assertion.
-    TORCH_INTERNAL_ASSERT(
-        halo_width_map_.find(root_axis) == halo_width_map_.end(),
-        "Invalid domain: ",
-        root_axis,
-        " of ",
-        td->getRootDomain());
-
-    if (!halo_info.hasHalo()) {
-      halo_width_map_.insert({root_axis, ir_builder.zeroVal()});
-      continue;
-    }
-
-    auto expanded_extent = ir_builder.addExpr(
-        gpu_lower->lowerValue(root_axis->extent()), halo_width);
-    kir_extent_map_.insert(
-        {gpu_lower->lowerValue(root_axis)->as<kir::IterDomain>(),
-         expanded_extent});
-    halo_width_map_.insert({root_axis, halo_width});
-  }
-
-  auto exprs = ExprSort::getExprs(
-      td->fusion(),
-      std::vector<Val*>(td->domain().begin(), td->domain().end()));
-
-  // Track IDs that are generated by merging halo-extended IDs
-  std::unordered_set<IterDomain*> merged_shifted_ids;
-
-  // Propagate halo information by traversing IterDomain
-  // expressions. We populate extent_map_ and
-  // halo_width_map_.
-  // - extent_map_ maps to Expr* representing the
-  // extent of each axis including its halo. If no mapping exists for
-  // a particular axis in extent_map_, it means the axis does not have
-  // halo.
-  // - halo_width_map_ just maps to the integer size of the halo,
-  // which is used for extent comparison (e.g., extentLessEqual).
-  //
-  // - When expr is split: if the halo width of the input axis is
-  // zero, both the split outputs get zero halo in halo_width_map_. No
-  // mapping is added for extent_map_. Otherwise, the halo is
-  // propagated only to the inner output, so the inner output gets the
-  // same halo width and its mapping is created in extent_map_.
-  //
-  // One major assumption here is that splitting an axis that is
-  // an output of merging halo-extended axes is not allowed. This is
-  // because it is unclear how to split the halo part of the merged
-  // axis. This is unlikely to be a real limitation in practice.
-  //
-  // - When expr is merge: if either of the inputs has halo, a mapping
-  // for the output is created in extent_map_. No mapping is created
-  // for halo_width_map_ (see the comment on HaloInfo::halo_width_map_
-  // in lower_shift.h). If both of them don't have halo, just adds a
-  // new mapping of the output to zero in halo_width_map_. Also adds
-  // it to a set (merged_shifted_ids) to track which axes are merge
-  // outputs of halo-extended axes.
-
-  for (auto expr : exprs) {
-    if (auto split = dynamic_cast<Split*>(expr)) {
-      // Merge-then-split of halo-extended IDs is not allowed
-      TORCH_INTERNAL_ASSERT(
-          merged_shifted_ids.find(split->in()) == merged_shifted_ids.end(),
-          "Splitting IterDomain that is a merged domain of halo-extended domains is not allowed");
-
-      auto in_id = split->in();
-
-      // There must be always a mapping for the input axis of a split
-      // expr. The only exception is when the input axis is an output
-      // of merge, but that's excluded by the assertion above.
-      const auto& halo_width_it = halo_width_map_.find(in_id);
-      TORCH_INTERNAL_ASSERT(halo_width_it != halo_width_map_.end());
-
-      const auto halo_width = halo_width_it->second;
-
-      if (halo_width->isZeroInt()) {
-        halo_width_map_.insert({split->outer(), halo_width});
-        halo_width_map_.insert({split->inner(), halo_width});
-        continue;
-      }
-
-      // propagate to inner domain
-      auto out_id = split->inner();
-
-      auto expanded_extent = ir_builder.addExpr(
-          gpu_lower->lowerValue(out_id->extent()), halo_width);
-      kir_extent_map_.insert(
-          {gpu_lower->lowerValue(out_id)->as<kir::IterDomain>(),
-           expanded_extent});
-
-      halo_width_map_.insert({split->outer(), ir_builder.zeroVal()});
-      halo_width_map_.insert({split->inner(), halo_width});
-    } else if (auto merge = dynamic_cast<Merge*>(expr)) {
-      // If either of the two inputs has halo extension, propagate it
-      // to the merged output ID
-      auto inner_extent = getExtent(merge->inner());
-      auto outer_extent = getExtent(merge->outer());
-      if (inner_extent != nullptr || outer_extent != nullptr) {
-        if (inner_extent == nullptr) {
-          inner_extent = gpu_lower->lowerValue(merge->inner()->extent());
-        }
-        if (outer_extent == nullptr) {
-          outer_extent = gpu_lower->lowerValue(merge->outer()->extent());
-        }
-        auto expanded_extent = ir_builder.mulExpr(outer_extent, inner_extent);
-        kir_extent_map_.insert(
-            {gpu_lower->lowerValue(merge->out())->as<kir::IterDomain>(),
-             expanded_extent});
-        // Splitting the output of this merge is not allowed, so
-        // remember it
-        merged_shifted_ids.insert(merge->out());
-        // Note that halo_width_map_ is not updated
-      } else {
-        halo_width_map_.insert({merge->out(), ir_builder.zeroVal()});
-      }
-    } else {
-      TORCH_INTERNAL_ASSERT(false, "Unsupported expr: ", expr);
-    }
-  }
-}
-
-//! Restriction 1: When allocation is outside of a shifted
-//! axis, the shifted axis must be guaranteed to have a smaller extent
-//! than the concrete axis. For now, shifted axes always mean expanded
-//! allocations when the axis is located inside the allocation
-//! point. This restriction is validated at the allocation lowering
-//! pass.
-//!
-//! Restriction 2: If an expanded axis is parallelized, its memory
-//! must be accessible by all other threads. More specifically:
-//! - TIDx: It must be on shared memory. May want to consider
-//! utilizing the shuffle instructions as well.
-//! - BIDx: Not supported. If on global memory, Cooperative Launch
-//! may be used to support it, however, it's unclear in what
-//! situations block-level parallelization should be used.
-//!
-//! Other types of parallelization should be supported except for
-//! vectorization. Vectorization should be eventually supported but
-//! needs further work.
-void HaloInfo::validate(TensorView* tv) const {
-  const auto& par_map = GpuLower::current()->caParallelMap();
-  const auto& loop_map = GpuLower::current()->caLoopMap();
-  const auto mem_type = tv->getMemoryType();
-
-  for (auto axis : tv->domain()->domain()) {
-    auto concrete_id = par_map.getConcreteMappedID(axis);
-
-    // The extent is assumed to be the same
-    TORCH_INTERNAL_ASSERT(
-        extentEqual(axis, concrete_id),
-        "Axis does not have the same exact size with its concrete ID due to halo extension.",
-        " Tensor: T",
-        tv->name(),
-        ", Axis: ",
-        axis,
-        ", concrete ID: ",
-        concrete_id);
-
-    auto halo_extent = getExtent(axis);
-
-    // If no halo extent is associated with this axis, it means the
-    // axis is not extended.
-    if (halo_extent == nullptr) {
-      continue;
-    }
-
-    // Enforce restrictions on parallelization and memory type
-    const auto ptype = concrete_id->getParallelType();
-
-    if (ptype == ParallelType::Serial) {
-      continue;
-    }
-
-    // Only threading parallelism is considered for now
-    TORCH_CHECK(
-        isParallelTypeThread(ptype), "Unsupported parallel type: ", ptype);
-
-    bool shared_mem_needed = false;
-    for (auto use : tv->uses()) {
-      if (!ir_utils::isTVOp(use)) {
-        continue;
-      }
-      if (use->isA<ShiftOp>() || use->isA<GatherOp>()) {
-        shared_mem_needed = true;
-        break;
-      }
-      auto consumer = use->outputs()[0]->as<TensorView>();
-      // Find the corresponding axis in the consumer
-      auto it = std::find_if(
-          consumer->domain()->domain().begin(),
-          consumer->domain()->domain().end(),
-          [&](IterDomain* consumer_axis) {
-            return loop_map.areMapped(axis, consumer_axis);
-          });
-      if (it == consumer->domain()->domain().end()) {
-        continue;
-      }
-      if (!extentEqual(axis, *it)) {
-        shared_mem_needed = true;
-        break;
-      }
-    }
-
-    if (!shared_mem_needed) {
-      continue;
-    }
-
-    if (isParallelTypeThreadDim(ptype)) {
-      // If all the consumers have the same extent and none of the
-      // expressions is shift, any memory should be fine. Otherwise, it
-      // must be accessible by all threads involved in the
-      // parallelization.
-      TORCH_CHECK(
-          mem_type == MemoryType::Shared,
-          "TV",
-          tv->name(),
-          " must be allocated on shared memory as its halo-extended axis is parallelized by ",
-          ptype);
-
-    } else if (isParallelTypeBlockDim(ptype)) {
-      TORCH_CHECK(
-          false,
-          "Block-based parallelization of a halo-extended axis is not supported: ",
-          axis);
-    }
-  }
-  return;
-}
-
-kir::Val* HaloInfo::getExtent(IterDomain* id) const {
-  auto kir_id = GpuLower::current()->lowerValue(id)->as<kir::IterDomain>();
-  return getExtent(kir_id);
-}
-
-kir::Val* HaloInfo::getExtent(kir::IterDomain* id) const {
-  auto it = kir_extent_map_.find(id);
-  if (it != kir_extent_map_.end()) {
-    return it->second;
-  } else {
-    return nullptr;
-  }
-}
-
-kir::Int* HaloInfo::getHaloWidth(IterDomain* id) const {
-  auto it = halo_width_map_.find(id);
-  TORCH_INTERNAL_ASSERT(it != halo_width_map_.end());
-  return it->second;
-}
-
-bool HaloInfo::hasHaloWidth(IterDomain* id) const {
-  return halo_width_map_.find(id) != halo_width_map_.end();
-}
-
-namespace {
-
-//! Prove if the comparison operator, cmp, is true with the extents of
-//! id1 and id2, including their halo. The comparison is done
-//! conservatively, meaning false negative is possible.
-//!
-//! It is assumed that id1 and id2 are mapped with the CA Loop map, so
-//! what is checked here is only about halo
-//! sizes using HaloInfo::halo_width_map_. Since it does not have
-//! mappings for merged axes, each axis of merge inputs are
-//! individually compared, and only when both of the input axes
-//! return true, the merge output axis returns true.
-template <typename Cmp>
-bool extentCompare(
-    const HaloInfo& halo_map,
-    IterDomain* id1,
-    IterDomain* id2,
-    Cmp cmp) {
-  auto gpu_lower = GpuLower::current();
-  TORCH_INTERNAL_ASSERT(
-      gpu_lower->caLoopMap().areMapped(id1, id2), "Invalid axes to compare");
-
-  // It's invalid to compare two axes and when only either of them has
-  // halo.
-
-  if (halo_map.hasHaloWidth(id1)) {
-    TORCH_INTERNAL_ASSERT(
-        halo_map.hasHaloWidth(id2), "Invalid comparison: ", id1, " and ", id2);
-    // Both axes have halo. We assume the axes themselves have equal
-    // extents, excluding halo, as they are mapped with the CA
-    // map. So, we just need to compare the halo width of each axis.
-    return cmp(halo_map.getHaloWidth(id1), halo_map.getHaloWidth(id2));
-  } else {
-    TORCH_INTERNAL_ASSERT(!halo_map.hasHaloWidth(id2));
-    // Both don't have halo. The only case this can happen must be
-    // both axes are the output of a merge expression, so each merge
-    // input is recursively compared, and returns true only when both
-    // inputs return.
-    if (auto merge1 = dynamic_cast<Merge*>(id1->definition())) {
-      auto merge2 = dynamic_cast<Merge*>(id2->definition());
-      TORCH_INTERNAL_ASSERT(
-          merge2 != nullptr, "Invalid comparison: ", id1, " and ", id2);
-      auto inner_le =
-          extentCompare(halo_map, merge1->inner(), merge2->inner(), cmp);
-      auto outer_le =
-          extentCompare(halo_map, merge1->outer(), merge2->outer(), cmp);
-      return inner_le && outer_le;
-    } else {
-      // This is not considered. Should never reach here.
-      TORCH_INTERNAL_ASSERT(false, "Invalid comparison: ", id1, " and ", id2);
-    }
-  }
-}
-
-} // namespace
-
-bool HaloInfo::extentLessEqual(IterDomain* id1, IterDomain* id2) const {
-  auto cmp = [](kir::Int* x, kir::Int* y) {
-    if (x == y) {
-      return true;
-    }
-    auto xv = x->value();
-    auto yv = y->value();
-    return xv.has_value() && yv.has_value() && xv.value() <= yv.value();
-  };
-  return extentCompare(*this, id1, id2, cmp);
-}
-
-bool HaloInfo::extentEqual(IterDomain* id1, IterDomain* id2) const {
-  // Returns true only when x and y are proven to be the same. The
-  // analysis is not comprehensive and can prove in rather trivial
-  // cases only. Specifically:
-  //   - x and y are the same pointers
-  //   - Both have static values and they are the same
-  //   - Both are defined by the same expression and the inputs are
-  //     proven to be equal
-  std::function<bool(kir::Int*, kir::Int*)> cmp = [&](kir::Int* x,
-                                                      kir::Int* y) {
-    if (x == y) {
-      return true;
-    }
-
-    auto xv = x->value();
-    auto yv = y->value();
-    if (xv.has_value() && yv.has_value() && xv.value() == yv.value()) {
-      return true;
-    }
-
-    // Check if both are defined by an expression of the same type. If
-    // so, recursively check the input operands.
-    auto x_def = x->definition();
-    auto y_def = y->definition();
-    if (x_def && y_def &&
-        ((x_def->isA<kir::UnaryOp>() && y_def->isA<kir::UnaryOp>() &&
-          x_def->as<kir::UnaryOp>()->operation() ==
-              y_def->as<kir::UnaryOp>()->operation()) ||
-         (x_def->isA<kir::BinaryOp>() && y_def->isA<kir::BinaryOp>() &&
-          x_def->as<kir::BinaryOp>()->operation() ==
-              y_def->as<kir::BinaryOp>()->operation()))) {
-      for (size_t i = 0; i < x_def->inputs().size(); ++i) {
-        auto x_input = dynamic_cast<kir::Int*>(x_def->inputs()[i]);
-        auto y_input = dynamic_cast<kir::Int*>(y_def->inputs()[i]);
-        // Both must be kir::Int
-        TORCH_INTERNAL_ASSERT(x_input && y_input);
-        if (!cmp(x_input, y_input)) {
-          return false;
-        }
-      }
-      return true;
-    }
-
-    return false;
-  };
-  return extentCompare(*this, id1, id2, cmp);
-}
-
-std::string HaloInfo::toString() const {
-  std::stringstream ss;
-
-  ss << "HaloInfo:\n";
-
-  if (root_axis_map_.empty()) {
-    return ss.str();
-  }
-
-  Fusion* fusion = root_axis_map_.begin()->first->fusion();
-
-  auto used_vals = DependencyCheck::getAllValsBetween(
-      {fusion->inputs().begin(), fusion->inputs().end()}, fusion->outputs());
-
-  for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
-    const auto& root = tv->getRootDomain();
-    ss << "TV" << tv->name() << " root domain: ";
-    for (auto axis : root) {
-      ss << axis << " -> " << getRootAxisInfo(axis).toString() << ", ";
-    }
-    ss << "\n";
-  }
-
-  return ss.str();
-}
-
-bool HaloInfo::needsShiftPredicate(Expr* expr) const {
-  auto consumer_td = ir_utils::getTVOutput(expr)->domain();
-  auto shift_expr = dynamic_cast<ShiftOp*>(expr);
-  auto gather_expr = dynamic_cast<GatherOp*>(expr);
-  for (size_t i = 0; i < consumer_td->getRootDomain().size(); ++i) {
-    auto consumer_id = consumer_td->getRootDomain()[i];
-    const auto consumer_halo_info = getRootAxisInfo(consumer_id);
-    if (consumer_halo_info.hasHalo() ||
-        (shift_expr != nullptr && shift_expr->offset(i) != 0 &&
-         !consumer_id->isBroadcast()) ||
-        (gather_expr != nullptr && !gather_expr->windowShape()[i]->isOneInt() &&
-         !consumer_id->isBroadcast())) {
-      return true;
-    }
-  }
-  return false;
-}
-
-bool HaloInfo::needsShiftPredicate(kir::Expr* expr) const {
-  const auto out_tv = expr->outputs()[0]->as<kir::TensorView>();
-  auto fuser_expr = out_tv->fuserTv()->definition();
-  TORCH_INTERNAL_ASSERT(fuser_expr != nullptr);
-  return needsShiftPredicate(fuser_expr);
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.h b/torch/csrc/jit/codegen/cuda/lower_shift.h
deleted file mode 100644 (file)
index bcda899..0000000
+++ /dev/null
@@ -1,207 +0,0 @@
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Auxiliary class to represent information about halo of an axis
-class AxisHaloInfo {
- public:
-  AxisHaloInfo();
-
-  //! Width of halo.
-  //!
-  //! pos is either 0 or 1. The width of halo at offset zero is set
-  //! when pos is 0.
-  kir::Int* width(int pos) const;
-
-  //! Sum of the widths of both widths
-  kir::Int* width() const;
-
-  const auto& widths() const {
-    return widths_;
-  }
-
-  //! Set the halo width of either side.
-  //! pos is either 0 or 1. The width of halo at offset zero is set
-  //! when pos is 0.
-  void setWidth(int pos, kir::Int* width);
-
-  //! Extend the halo width to account for another axis.
-  void merge(int pos, kir::Int* other);
-
-  //! Extend the halo width to account for another axis.
-  void merge(const AxisHaloInfo& other);
-
-  //! True when halo may be attached
-  bool hasHalo() const;
-
-  std::string toString() const;
-
- private:
-  //! Sizes of the halo regions of two sides. Both values are zero for
-  //! axes with no halo. When an axis has halo at offset zero,
-  //! widths_[0] is non-zero and designates the size of the
-  //! halo. Similarly, non-zero widths_[1] means the axis has halo at
-  //! the other end of the axis.
-  std::array<kir::Int*, 2> widths_ = {nullptr, nullptr};
-};
-
-//! Helper class for lowering tensors with halo. Only valid at the
-//! lowering time.
-class HaloInfo {
- public:
-  //! Scan a fusion and collect all information for lowering
-  void build(Fusion* fusion);
-
-  //! Build mappings of extent information of a TensorDomain
-  void build(TensorDomain* td);
-
-  //! Set initial AxisHaloInfo of a root axis
-  //!
-  //! This is only for root or rfactor axes. It is an error to query
-  //! with other axes.
-  void setRootAxisInfo(IterDomain* id, const AxisHaloInfo& root_axis_info);
-
-  //! Returns the registed AxisHaloInfo of a root axis.
-  //!
-  //! This is only for root axes. It is an error to query with
-  //! non-root axes.
-  const AxisHaloInfo& getRootAxisInfo(IterDomain* id) const;
-  AxisHaloInfo& getRootAxisInfo(IterDomain* id);
-  //! KIR version
-  const AxisHaloInfo& getRootAxisInfo(kir::IterDomain* id) const;
-  AxisHaloInfo& getRootAxisInfo(kir::IterDomain* id);
-
-  //! Query if an axis has a halo width.
-  //!
-  //! See the comment at halo_width_map_.
-  bool hasHaloWidth(IterDomain* id) const;
-
-  //! Return the halo width of an axis.
-  //!
-  //! It's an error if queried for an axis with no halo width
-  //! information.
-  kir::Int* getHaloWidth(IterDomain* id) const;
-
-  //! Returns an extent if id is extended for halo. Nullptr is
-  //! returned otherwise.
-  kir::Val* getExtent(IterDomain* id) const;
-  kir::Val* getExtent(kir::IterDomain* id) const;
-
-  // True when the extent of id1 is guaranteed to be lesser than or
-  // equal to id2. False when it *may* not.
-  bool extentLessEqual(IterDomain* id1, IterDomain* id2) const;
-  // True when the extent of id1 is guaranteed to be equal to
-  // id2. False when it *may* not.
-  bool extentEqual(IterDomain* id1, IterDomain* id2) const;
-
-  //! Check if expr must be predicated based on boundary conditions
-  //! directly or indirectly induced by shift expressions.
-  //!
-  //! When yes, the expression needs two predications: one for
-  //! interior and another for padding. Predicate insertion is done in
-  //! the ShiftPredicateInserter class below.
-  bool needsShiftPredicate(Expr* expr) const;
-  bool needsShiftPredicate(kir::Expr* expr) const;
-
-  std::string toString() const;
-
- private:
-  //! Propagate root axis information from outputs to inputs of an
-  //! expression
-  void propagateRootAxisInfo(Expr* expr);
-
-  //! Propagate root axis information from consumer to producer
-  void propagateRootAxisInfo(
-      TensorView* producer,
-      TensorView* consumer,
-      Expr* expr);
-
-  //! Validate shift usage
-  void validate(TensorView* td) const;
-
- private:
-  //! Halo information of root axes
-  std::unordered_map<IterDomain*, AxisHaloInfo> root_axis_map_;
-  //! KIR version
-  std::unordered_map<kir::IterDomain*, AxisHaloInfo> kir_root_axis_map_;
-
-  //! Halo-extended extents. No mapping for axes without halo extension
-  std::unordered_map<kir::IterDomain*, kir::Val*> kir_extent_map_;
-
-  //! The halo width of an axis.
-  //!
-  //! The mapped value is a sum of two widths of both sizes of an
-  //! axis. For root axes, it is equivalent to AxisHaloInfo.widths_[0]
-  //! + AxisHaloInfo.widths_[1] (or AxisHaloInfo.width()). For
-  //! example, when a root axis is extended by 1 for both sides, it'd
-  //! be mapped to 2. For axes with no halo, they are mapped to zero.
-  //!
-  //! When an axis is split, its halo is only propagated to the inner
-  //! output axis, so the value of this map for the inner output is
-  //! the same as the input of split, while the outer output is mapped
-  //! to zero.
-  //!
-  //! When an axis is merged, no mapping is created for its
-  //! output at this point primarly because it isn't clear what the
-  //! "halo width" for a merged axis should mean. Perhaps, a merged
-  //! axis of (N+a)*(M+b), where N and M correspond to the original
-  //! extens of two axes, and a and b correspond to their halo widths,
-  //! it might make sense to set the halo width of this merged axis as
-  //! (N+a)*(M+b)-N*M. Currently, however, this isn't necessary, so no
-  //! particular mapping is created for merged axes.
-  //!
-  //! This is currently used only for conservatively comparing the
-  //! overall extents of axes. See HaloInfo::extentLessEqual and
-  //! HaloInfo::extentEqual.
-  //!
-  //! Example: Suppose a root axis has {0, 1} of
-  //! AxisHaloInfo.widths_. The root axis is mapped to 1. When it is
-  //! split, say, by 4, the output axes, [N / 4] and [4], where N is
-  //! the extent of the root axis, the outer axis is mapped to 0,
-  //! whereas the inner axis is mapped to 1. Further, suppose the
-  //! inner axis is merged with another axis of extent M, we know that
-  //! the extent of the resulting output axis is 5*M, but we don't
-  //! create its mapping.
-  std::unordered_map<IterDomain*, kir::Int*> halo_width_map_;
-};
-
-class ShiftPredicateInserter {
- public:
-  //! Works mostly the same way as
-  //! PredicateCompute::getInlinePredicate but does the insertion of
-  //! the generated predicate. The branch structure is different from
-  //! the usual predicated expression, so the insertion is also done
-  //! here.
-  static void insert(
-      kir::Expr* expr,
-      const std::vector<kir::ForLoop*>& loops,
-      kir::Bool* thread_pred);
-
-  //! Returns predicates for the interior and overall domains of a
-  //! tensor.
-  //!
-  //! The isShiftPredicate flag toggles between the predicate for shifted
-  //! accesses and padding.
-  static kir::Bool* getPredicate(
-      const kir::Expr* expr,
-      const std::vector<kir::ForLoop*>& loops,
-      kir::TensorView* out_tv,
-      kir::Bool* thread_pred,
-      bool isShiftPredicate);
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
index d8a6c7f..6645158 100644 (file)
@@ -16,69 +16,60 @@ namespace cuda {
 
 namespace {
 
-kir::Val* getPredicatePerParallelType(
+Val* getPredicatePerParallelType(
     ParallelType pt,
-    const ThreadPredicateMap::SourceMap& source_map) {
+    const ThreadPredicateMap::SourceMapType& source_map) {
   kir::IrBuilder ir_builder(GpuLower::current()->kernel());
 
   if (pt == ParallelType::BIDx || pt == ParallelType::BIDy ||
       pt == ParallelType::BIDz) {
     auto source = source_map.at(pt);
     TORCH_INTERNAL_ASSERT(!source.empty(), "No predicate source found");
-    kir::Val* pred = nullptr;
-    for (auto src : source) {
-      if (pred == nullptr) {
-        auto flag_name = kir::GridReduction::getPredicateFlagName(src);
-        pred = ir_builder.create<kir::NamedScalar>(flag_name, DataType::Bool);
-      } else {
-        auto flag_name = kir::GridReduction::getPredicateFlagName(src);
-        pred = ir_builder.andExpr(
-            pred,
-            ir_builder.create<kir::NamedScalar>(flag_name, DataType::Bool));
-      }
-    }
-    return pred;
+    TORCH_INTERNAL_ASSERT(source.size() == 1, "Multiple sources detected");
+    auto src = *source.begin();
+    auto flag_name = kir::GridReduction::getPredicateFlagName(src);
+    return ir_builder.create<kir::NamedScalar>(flag_name, DataType::Bool);
   } else {
     return ir_builder.eqExpr(
         kir::NamedScalar::getParallelIndex(pt), ir_builder.create<kir::Int>(0));
   }
 }
 
-kir::Bool* getPredicateFromParallelTypes(
-    const ParallelTypeBitmap& bits,
-    const ThreadPredicateMap::SourceMap& source_map) {
+kir::Bool* getPredicate(
+    const ir_utils::ParallelTypeBitmap& bits,
+    const ThreadPredicateMap::SourceMapType& source_map) {
   kir::IrBuilder ir_builder(GpuLower::current()->kernel());
 
   if (bits.none()) {
-    return ir_builder.trueVal();
+    return ir_builder.create<kir::Bool>(true);
   }
 
-  kir::Bool* pred = nullptr;
+  Val* pred = nullptr;
 
   for (const auto& pt_bool : bits.getMap()) {
     if (pt_bool.second) {
-      const auto tp = getPredicatePerParallelType(pt_bool.first, source_map);
-      if (pred == nullptr) {
-        pred = ir_builder.create<kir::Bool>(c10::nullopt);
-        ir_builder.create<kir::UnaryOp>(UnaryOpType::Set, pred, tp);
-      } else {
-        pred = ir_builder.andExpr(pred, tp)->as<kir::Bool>();
-      }
+      auto tp = getPredicatePerParallelType(pt_bool.first, source_map);
+      pred = (pred == nullptr) ? tp : ir_builder.andExpr(pred, tp);
     }
   }
 
+  // Should never be hit.
   TORCH_INTERNAL_ASSERT(pred != nullptr);
 
-  return pred;
+  TORCH_INTERNAL_ASSERT(
+      pred->getDataType().value() == DataType::Bool,
+      "Tried to return a predicate that is not a bool val.");
+
+  return pred->as<kir::Bool>();
 }
 
 void mergeSourceMap(
-    ThreadPredicateMap::SourceMap& dst,
-    const ThreadPredicateMap::SourceMap& src) {
+    ThreadPredicateMap::SourceMapType& dst,
+    const ThreadPredicateMap::SourceMapType& src) {
   for (const auto& kv : src) {
     const auto& src_key = kv.first;
     const auto& src_value = kv.second;
-    auto& dst_set = dst[src_key];
+    std::unordered_set<const TensorView*>& dst_set = dst[src_key];
     for (const auto& src_tensor : src_value) {
       dst_set.insert(src_tensor);
     }
@@ -86,9 +77,9 @@ void mergeSourceMap(
 }
 
 void addToSouceMap(
-    ThreadPredicateMap::SourceMap& dst,
+    ThreadPredicateMap::SourceMapType& dst,
     const TensorView* tv,
-    const ParallelTypeBitmap& reducton_pred) {
+    const ir_utils::ParallelTypeBitmap& reducton_pred) {
   for (const auto& kv : reducton_pred.getMap()) {
     if (kv.second) {
       ParallelType ptype = kv.first;
@@ -98,8 +89,8 @@ void addToSouceMap(
 }
 
 void maskSouceMap(
-    ThreadPredicateMap::SourceMap& src_map,
-    const ParallelTypeBitmap& mask) {
+    ThreadPredicateMap::SourceMapType& src_map,
+    const ir_utils::ParallelTypeBitmap& mask) {
   for (const auto& kv : mask.getMap()) {
     if (!kv.second) {
       ParallelType ptype = kv.first;
@@ -110,68 +101,54 @@ void maskSouceMap(
 
 // A bit of a hack for now for GEMM tiling so we don't fetch tiles multiple
 // times. It's safe to do, there may simply be a better place to do it.
-ParallelTypeBitmap avoidRedundantWritesToSmem(
-    const TensorView* out_tv,
-    const ParallelTypeBitmap& pred) {
-  const auto& ca_map = GpuLower::current()->caParallelMap();
-  auto new_pred = pred;
+void avoidRedundantWritesToSmem(
+    TensorView* out_tv,
+    ir_utils::ParallelTypeBitmap& pred) {
   if (out_tv->getMemoryType() == MemoryType::Shared) {
     for (const auto i : c10::irange(out_tv->nDims())) {
-      auto id = ca_map.getConcreteMappedID(out_tv->axis(i));
+      auto id = out_tv->getComputeAtAxis(i).first;
       if (out_tv->axis(i)->isBroadcast() && id->isThreadDim()) {
-        new_pred.set(id->getParallelType(), true);
+        pred.set(id->getParallelType(), true);
       }
     }
   }
-  return new_pred;
 }
 
 } // namespace
 
 // Update the reduction_deps bitset based on provided Expr
-void ThreadPredicateMap::updateBitSet(const Expr* expr) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::ThreadPredicateMap::updateBitSet");
+void ThreadPredicateMap::updateBitSet(Expr* expr) {
+  FUSER_PERF_SCOPE("ThreadPredicateMap::updateBitSet");
 
   // Which predicates were set for the inputs
-  ParallelTypeBitmap input_preds;
+  ir_utils::ParallelTypeBitmap input_preds;
 
   // Which dims are reductions in inputs
-  ParallelTypeBitmap input_reductions;
+  ir_utils::ParallelTypeBitmap input_reductions;
 
   // Which dims are bcast in inputs
-  ParallelTypeBitmap input_bcasts;
+  ir_utils::ParallelTypeBitmap input_bcasts;
 
-  SourceMap src_map;
+  SourceMapType src_map;
 
   // Run through inputs and update bitsets
   for (const auto* inp : expr->inputs()) {
     if (!ir_utils::isTV(inp))
       continue;
 
-    auto tv_inp = inp->as<TensorView>();
-
-    // Change for welford Op, we want the users of all outputs of welfordOp
-    //  to use a single predicate name.
-    if (auto tv_def = tv_inp->definition()) {
-      if (auto wop = dynamic_cast<WelfordOp*>(tv_def)) {
-        tv_inp = wop->out()->as<TensorView>();
-      }
-    }
-
+    auto tv_inp = ir_utils::asConstTV(inp);
     TORCH_INTERNAL_ASSERT(
         thread_predicates_.find(tv_inp) != thread_predicates_.end(),
         "Thread predicate map was not initialized, couldn't find ",
         inp);
 
-    const auto& pred_and_src = at(tv_inp);
-
-    input_preds |= pred_and_src.pred;
+    input_preds |= at(tv_inp).first;
 
-    mergeSourceMap(src_map, pred_and_src.source_map);
+    mergeSourceMap(src_map, at(tv_inp).second);
 
-    ParallelTypeBitmap id_reductions;
-    ParallelTypeBitmap id_bcasts;
-    ParallelTypeBitmap id_ptypes;
+    ir_utils::ParallelTypeBitmap id_reductions;
+    ir_utils::ParallelTypeBitmap id_bcasts;
+    ir_utils::ParallelTypeBitmap id_ptypes;
 
     for (auto id : tv_inp->domain()->domain()) {
       if (id->isThread()) {
@@ -184,7 +161,7 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {
     }
 
     // Validate the combination of ptypes, reductions, bcasts
-    for (const auto i : c10::irange(ParallelTypeBitmap::num_p_type)) {
+    for (const auto i : c10::irange(ir_utils::ParallelTypeBitmap::num_p_type)) {
       if (input_reductions[i]) {
         if (id_ptypes[i]) {
           TORCH_INTERNAL_ASSERT(
@@ -214,33 +191,40 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {
   auto output_preds = input_preds | input_reductions;
 
   // Figure out which dims bcast wants to reset
-  const auto bcast_reset_mask = ~(output_preds & input_bcasts);
+  auto bcast_reset_map = output_preds & input_bcasts;
 
-  // Get rid of any reductions which are bcasted
-  output_preds &= bcast_reset_mask;
+  // Flip it to make a bit mask
+  bcast_reset_map = ~bcast_reset_map;
 
+  // Get rid of any reductions which are bcasted
+  output_preds &= bcast_reset_map;
   // Similarly, drop non-relevant source tensors
-  maskSouceMap(src_map, bcast_reset_mask);
+  maskSouceMap(src_map, bcast_reset_map);
 
   // Run through outputs and set bitset predicates
   for (auto* out : expr->outputs()) {
-    if (auto tv = dynamic_cast<const TensorView*>(out)) {
-      TORCH_INTERNAL_ASSERT(find(tv) == end());
-      insert(tv, avoidRedundantWritesToSmem(tv, output_preds), src_map);
-    }
+    if (!ir_utils::isTV(out))
+      continue;
+    TORCH_INTERNAL_ASSERT(find(ir_utils::asConstTV(out)) == end());
+    auto pred_for_this_out = output_preds;
+    avoidRedundantWritesToSmem(ir_utils::asTV(out), pred_for_this_out);
+    insert(ir_utils::asConstTV(out), pred_for_this_out, src_map);
   }
 }
 
-void ThreadPredicateMap::build(Fusion* fusion) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::ThreadPredicateMap");
-
+// TODO(kir): revisit this - can we build it from the kernel IR?
+ThreadPredicateMap::ThreadPredicateMap(Fusion* _fusion) : fusion_(_fusion) {
+  FUSER_PERF_SCOPE("ThreadPredicateMap");
   // Initialize mapping for input tensors
-  for (auto inp : fusion->inputs()) {
-    if (auto tv = dynamic_cast<const TensorView*>(inp)) {
-      insert(tv, ParallelTypeBitmap(), SourceMap());
+  for (auto inp : fusion_->inputs()) {
+    if (ir_utils::isTV(inp)) {
+      insert(
+          ir_utils::asConstTV(inp),
+          ir_utils::ParallelTypeBitmap(),
+          SourceMapType());
     }
   }
-  for (auto expr : fusion->exprs()) {
+  for (auto expr : fusion_->exprs(true)) {
     updateBitSet(expr);
   }
 }
@@ -254,94 +238,46 @@ ThreadPredicateMap::const_iterator ThreadPredicateMap::end() const {
   return thread_predicates_.end();
 }
 
-const ThreadPredicateMap::PredAndSource& ThreadPredicateMap::at(
+const ThreadPredicateMap::MapType::mapped_type& ThreadPredicateMap::at(
     const TensorView* tv) const {
   return thread_predicates_.at(tv);
 }
 
-ThreadPredicateMap::PredAndSource& ThreadPredicateMap::at(
+ThreadPredicateMap::MapType::mapped_type& ThreadPredicateMap::at(
     const TensorView* tv) {
   return thread_predicates_.at(tv);
 }
 
-void ThreadPredicateMap::insert(
-    const TensorView* tv,
-    const ParallelTypeBitmap& pred,
-    const SourceMap& src_map) {
-  insert(tv, {pred, src_map});
+ThreadPredicateMap::MapType::mapped_type& ThreadPredicateMap::operator[](
+    const TensorView* tv) {
+  return thread_predicates_[tv];
 }
 
 void ThreadPredicateMap::insert(
     const TensorView* tv,
-    const PredAndSource& pred_and_src) {
-  thread_predicates_.insert({tv, pred_and_src});
+    const ir_utils::ParallelTypeBitmap& pred,
+    const SourceMapType& src_map) {
+  insert(tv, std::make_pair(pred, src_map));
 }
 
-kir::Bool* ThreadPredicateMap::getPredicate(const TensorView* tv) const {
-  // No thread predicate is needed when tv is an output of a
-  // parallel broadcast expression.
-  if (auto bop = dynamic_cast<BroadcastOp*>(tv->definition())) {
-    if (getParallelBroadcastDomains(tv).any()) {
-      return kir::IrBuilder(GpuLower::current()->kernel()).trueVal();
-    }
-  }
-  TORCH_INTERNAL_ASSERT(find(tv) != end(), "Couldn't find ", tv);
-  const auto& pred_and_src = at(tv);
-  return getPredicateFromParallelTypes(
-      pred_and_src.pred, pred_and_src.source_map);
+void ThreadPredicateMap::insert(
+    const TensorView* tv,
+    const std::pair<ir_utils::ParallelTypeBitmap, SourceMapType>&
+        pred_and_src) {
+  thread_predicates_.insert(std::make_pair(tv, pred_and_src));
 }
 
-ParallelTypeBitmap ThreadPredicateMap::getParallelBroadcastDomains(
-    const TensorView* tv) const {
-  // If no pred is found for tv, no predicate is necessary
-  if (find(tv) == end()) {
-    return ParallelTypeBitmap();
-  }
-
-  ParallelTypeBitmap parallel_broadcast;
-
-  const auto& iter_domains = tv->domain()->domain();
-
-  // If the output is on shared memory, assume that all subsequent
-  // reads from all threads in its CTA can be done with no parallel
-  // broadcast. Only one thread will write to shared memory followed
-  // by a proper _syncthreads.
-  const bool output_smem = tv->getMemoryType() == MemoryType::Shared;
-
-  for (auto id : iter_domains) {
-    if (!id->isBroadcast()) {
-      continue;
-    }
-    if (id->isBlockDim() || (!output_smem && id->isThreadDim())) {
-      parallel_broadcast.set(id->getParallelType(), true);
-    }
+void ThreadPredicateMap::duplicate(
+    const TensorView* copy,
+    const TensorView* origin) {
+  if (find(origin) != end()) {
+    insert(copy, at(origin).first, at(origin).second);
   }
-
-  return parallel_broadcast & at(tv).pred;
 }
 
-void ThreadPredicateMap::print() const {
-  std::cout << "\nThreadPredicateMap\n";
-  std::cout << "--------------------------------\n";
-  for (const auto& kv : thread_predicates_) {
-    std::cout << "T" << kv.first->name() << " {";
-    // ParallelTypeBitmap
-    for (auto ptkv : kv.second.pred.getMap()) {
-      if (ptkv.second) {
-        std::cout << " " << ptkv.first;
-      }
-    }
-    std::cout << " }\n";
-    // SourceMap
-    for (const auto& pkv : kv.second.source_map) {
-      std::cout << "  " << pkv.first << " : [";
-      for (auto tv : pkv.second) {
-        std::cout << " T" << tv->name();
-      }
-      std::cout << " ]\n";
-    }
-  }
-  std::cout << "--------------------------------\n\n";
+kir::Bool* ThreadPredicateMap::getExpr(const TensorView* out_tv) const {
+  TORCH_INTERNAL_ASSERT(find(out_tv) != end(), "Couldn't find ", out_tv);
+  return getPredicate(at(out_tv).first, at(out_tv).second);
 }
 
 } // namespace cuda
index c5ccef2..cf13b06 100644 (file)
@@ -1,24 +1,19 @@
-
 #pragma once
-
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h>
 
-#include <unordered_map>
-#include <unordered_set>
-#include <utility>
+#include <bitset>
 
 namespace torch {
 namespace jit {
 namespace fuser {
 namespace cuda {
 
-//! Maps TensorViews to a { ParallelTypeBitmap, SourceMap } pair
+//! Maps TensorViews to std::pair<ir_utils::ParallelTypeBitmap, SourceMapType>>
 //!
-//! Map from TensorView to bit set represnting <BIDx, BIDy, BIDz, TIDx, TIDy,
+//! Map from tensorview to bit set represnting <BIDx, BIDy, BIDz, TIDx, TIDy,
 //! TIDz> If any dependency of TV had a parallelized reduction, we will track
 //! it here. This will be used for predicate generation to prevent
 //! parallelization on that axis. This is important if we have a reduction on
@@ -29,52 +24,40 @@ namespace cuda {
 //!
 class TORCH_CUDA_CU_API ThreadPredicateMap {
  public:
-  using SourceMap = std::unordered_map<
+  using SourceMapType = std::unordered_map<
       ParallelType,
       std::unordered_set<const TensorView*>,
       TypeHash>;
-
-  struct PredAndSource {
-    ParallelTypeBitmap pred;
-    SourceMap source_map;
-  };
-
-  using MapType = std::unordered_map<const TensorView*, PredAndSource>;
-
+  using MapType = std::unordered_map<
+      const TensorView*,
+      std::pair<ir_utils::ParallelTypeBitmap, SourceMapType>>;
   using const_iterator = MapType::const_iterator;
 
-  void build(Fusion* fusion);
+  explicit ThreadPredicateMap(Fusion* _fusion);
 
-  // TODO(kir): these methods are only used by getParallelBroadcastDomains() ?
   const_iterator find(const TensorView* tv) const;
   const_iterator end() const;
-  const PredAndSource& at(const TensorView* tv) const;
-  PredAndSource& at(const TensorView* tv);
-
-  // Returns a Bool predicate for a given TensorView.
-  kir::Bool* getPredicate(const TensorView* tv) const;
+  const MapType::mapped_type& at(const TensorView* tv) const;
+  MapType::mapped_type& at(const TensorView* tv);
+  MapType::mapped_type& operator[](const TensorView* tv);
 
-  //! Returns a ParallelTypeBitmap representing which domain needs
-  //! blockBroadcast.
-  //!
-  //! Even when a domain is broadcast and parallelized, it does not need
-  //! blockBroadcast unless it is predicated.
-  ParallelTypeBitmap getParallelBroadcastDomains(const TensorView* tv) const;
+  void duplicate(const TensorView* copy, const TensorView* origin);
 
-  void print() const;
+  // Returns a Bool predicate expression for a given output TensorView.
+  kir::Bool* getExpr(const TensorView* out_tv) const;
 
  private:
   // Update the thread_predicates bitset based on provided Expr
-  void updateBitSet(const Expr*);
+  void updateBitSet(Expr*);
 
   void insert(
       const TensorView* tv,
-      const ParallelTypeBitmap& pred,
-      const SourceMap& src_map);
-
-  void insert(const TensorView* tv, const PredAndSource& pred_and_src);
+      const ir_utils::ParallelTypeBitmap& pred,
+      const SourceMapType& src_map);
+  void insert(const TensorView* tv, const MapType::mapped_type& pred_and_src);
 
  private:
+  Fusion* fusion_ = nullptr;
   MapType thread_predicates_;
 };
 
diff --git a/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp b/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp
deleted file mode 100644 (file)
index 3365178..0000000
+++ /dev/null
@@ -1,139 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-
-#include <unordered_set>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id);
-
-bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {
-  TORCH_INTERNAL_ASSERT(
-      root_id->definition() == nullptr, "Not root IterDomain: ", root_id);
-
-  if (tv->definition() == nullptr) {
-    // This is an input tensor, so no rfactor tensor to traverse.
-    return false;
-  }
-
-  const auto& inputs = tv->definition()->inputs();
-
-  if (inputs.size() != 1 || !inputs[0]->isA<TensorView>() ||
-      (tv->definition()->getExprType() != ExprType::ReductionOp &&
-       tv->definition()->getExprType() != ExprType::WelfordOp)) {
-    // No rfactor producer found
-    return false;
-  }
-
-  auto producer = inputs[0]->as<TensorView>();
-
-  if (!producer->hasRFactor()) {
-    return false;
-  }
-
-  auto c2p = PairwiseRootDomainMap(producer, tv)
-                 .mapConsumerToProducer(tv->domain(), producer->domain());
-
-  auto producer_id_it = c2p.find(root_id);
-  if (producer_id_it == c2p.end()) {
-    // No matching producer is found. Stop traversing.
-    return false;
-  }
-
-  auto producer_root_id = producer_id_it->second;
-
-  return analyzeIfDerivedFromTrivialReduction(producer, producer_root_id);
-}
-
-bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id) {
-  auto id_inputs = InputsOf::output(id->fusion(), id);
-  for (auto root_id : ir_utils::filterByType<IterDomain>(id_inputs)) {
-    if (root_id->isReduction() && root_id->extent()->isOneInt()) {
-      continue;
-    }
-    // If not possible to prove the root ID is trivial, see if the ID
-    // is derived from a rfactor tensor and, if so, continue the
-    // analysis at the rfactor tensor.
-    if (!traverseToRFactorTensor(tv, root_id)) {
-      return false;
-    }
-  }
-  return true;
-}
-
-} // namespace
-
-void TrivialReductionInfo::build(Fusion* fusion, GpuLower* gpu_lower) {
-  auto used_vals = fusion->usedMathVals();
-
-  for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
-    for (auto id : tv->domain()->domain()) {
-      if (analyzeIfDerivedFromTrivialReduction(tv, id)) {
-        // If id is a trivial reduction, all of its ancestor vals are
-        // also trivial reductions.
-        for (auto dep_id : DependencyCheck::getAllValsBetween(
-                 std::unordered_set<Val*>(
-                     tv->getRootDomain().begin(), tv->getRootDomain().end()),
-                 {id})) {
-          domains_.insert(dep_id->as<IterDomain>());
-          domains_derived_from_root_.insert(dep_id->as<IterDomain>());
-        }
-      } else if (id->isReduction() && id->extent()->isOneInt()) {
-        // This happens when a leaf domain is trivial but its root
-        // axes are not. For example, consider a non-trivial domain
-        // split by one. The inner output axis is a trivial domain,
-        // whereas the outer output axis is not. Since the root axis
-        // is not trivial, a for-loop needs to be generated.
-        domains_.insert(id);
-      }
-    }
-  }
-
-  buildKir(fusion, gpu_lower);
-}
-
-void TrivialReductionInfo::buildKir(Fusion* fusion, GpuLower* gpu_lower) {
-  for (auto id : domains_) {
-    auto kir_trivial_id = gpu_lower->lowerValue(id)->as<kir::IterDomain>();
-    kir_domains_.insert(kir_trivial_id);
-  }
-
-  for (auto id : domains_derived_from_root_) {
-    auto kir_trivial_id = gpu_lower->lowerValue(id)->as<kir::IterDomain>();
-    kir_domains_derived_from_root_.insert(kir_trivial_id);
-  }
-}
-
-bool TrivialReductionInfo::isDerived(IterDomain* id) const {
-  return domains_.find(id) != domains_.end();
-}
-
-bool TrivialReductionInfo::isDerivedFromRoot(IterDomain* id) const {
-  return domains_derived_from_root_.find(id) !=
-      domains_derived_from_root_.end();
-}
-
-bool TrivialReductionInfo::isDerived(kir::IterDomain* id) const {
-  return kir_domains_.find(id) != kir_domains_.end();
-}
-
-bool TrivialReductionInfo::isDerivedFromRoot(kir::IterDomain* id) const {
-  return kir_domains_derived_from_root_.find(id) !=
-      kir_domains_derived_from_root_.end();
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h b/torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h
deleted file mode 100644 (file)
index 3f5a94d..0000000
+++ /dev/null
@@ -1,59 +0,0 @@
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <unordered_set>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class GpuLower;
-
-//! Detect almost all IterDomains that are derived from trivial
-//! reductons.
-class TORCH_CUDA_CU_API TrivialReductionInfo {
- public:
-  void build(Fusion* fusion, GpuLower* gpu_lower);
-
-  bool isDerived(IterDomain* id) const;
-  bool isDerivedFromRoot(IterDomain* id) const;
-
-  bool isDerived(kir::IterDomain* id) const;
-  bool isDerivedFromRoot(kir::IterDomain* id) const;
-
- private:
-  //! Convert the sets to KIR sets
-  void buildKir(Fusion* fusion, GpuLower* gpu_lower);
-
- private:
-  //! IterDomains that are derived only from trivial
-  //! reductons. Included domains are not limited to reduction axes as
-  //! rfactor can make reductions to normal axes.
-  //!
-  //! Note that the set should cover almost all cases but there can be
-  //! undetected trivial domains. For example, split by one creates a
-  //! trivial reduction domain, which is detected. However, if it is
-  //! further split, both of the two resulting axes are also trivial,
-  //! however, only the inner axis is recognized as trivial. While this
-  //! is a limitation, it would have very little practical
-  //! implication.
-  std::unordered_set<IterDomain*> domains_;
-  //! Subset of domains_, whose input root axes are all derived from
-  //! trivial reductions. These domains do not need to manifest as
-  //! for-loops.
-  std::unordered_set<IterDomain*> domains_derived_from_root_;
-
-  std::unordered_set<kir::IterDomain*> kir_domains_;
-  std::unordered_set<kir::IterDomain*> kir_domains_derived_from_root_;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
index f610960..57e4ad5 100644 (file)
@@ -4,12 +4,8 @@
 #include <torch/csrc/jit/codegen/cuda/index_compute.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h>
 #include <torch/csrc/jit/codegen/cuda/lower_utils.h>
 #include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
 
@@ -18,124 +14,43 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-namespace {
-
-// Provide a new for loop matching the one provided
-kir::ForLoop* cloneLoopNest(const kir::ForLoop* for_loop) {
-  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-  const auto new_loop = ir_builder.create<kir::ForLoop>(
-      for_loop->iter_domain(),
-      for_loop->index(),
-      for_loop->start(),
-      for_loop->stop(),
-      for_loop->step(),
-      for_loop->vectorize(),
-      for_loop->vectorize_shift());
-  for (auto expr : for_loop->body().exprs()) {
-    if (auto nested_for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
-      expr = cloneLoopNest(nested_for_loop);
+kir::Bool* UnrollPass::getThreadPredicate(TensorView* tv) {
+  // No thread predicate is needed predicate when tv is output of a
+  // parallel broadcast expression.
+  const auto origin = tv->getOrigin();
+  if (origin != nullptr && origin->getExprType() == ExprType::BroadcastOp) {
+    const auto out = origin->as<BroadcastOp>()->out();
+    if (ir_utils::getParallelBroadcastDomains(out, thread_predicates_).any()) {
+      return nullptr;
     }
-    new_loop->body().push_back(expr);
   }
-  return new_loop;
-}
 
-// Returns true if expr is an expression that initializes a reduction
-// buffer.
-bool isReductionInitExpr(const kir::Expr* expr) {
-  // False if its output isn't a TensorView
-  if (!ir_utils::isTVOp(expr)) {
-    return false;
-  }
-  // False if it doesn't have any reduction axis
-  const auto out_tv = expr->outputs()[0]->as<kir::TensorView>();
-  if (!out_tv->domain()->hasReduction()) {
-    return false;
-  }
-  // False if it has have TensorView inputs as initialization should
-  // never use TensorViews
-  const auto tv_filter_inp_view =
-      ir_utils::filterByType<kir::TensorView>(expr->inputs());
-  if (tv_filter_inp_view.begin() != tv_filter_inp_view.end()) {
-    return false;
-  }
-  return true;
+  return thread_predicates_.getExpr(tv);
 }
 
-} // namespace
-
-void UnrollPass::handle(kir::Expr* expr) {
+// Custom dispatch for Expr, want to find out of it's a TV op.
+void UnrollPass::handle(Expr* expr) {
+  // If tv op, predciate it.
   if (ir_utils::isTVOp(expr)) {
-    // If tv op, predicate it
-    const auto out_tv = ir_utils::getTVOutput(expr);
-    const bool should_predicate = !for_loops_.empty() ||
-        out_tv->memoryType() == MemoryType::Global ||
-        out_tv->memoryType() == MemoryType::Shared;
-    if (!should_predicate) {
-      return;
-    }
-
-    kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-    const auto thread_pred = isReductionInitExpr(expr)
-        ? ir_builder.trueVal()
-        : GpuLower::current()->threadPredMap().getPredicate(out_tv->fuserTv());
-
-    // When a predicate needs to account for ShiftOp, it is currently
-    // taken care by its own function.
-    if (GpuLower::current()->haloInfo().needsShiftPredicate(expr)) {
-      ShiftPredicateInserter::insert(expr, for_loops_, thread_pred);
-      return;
-    }
+    TORCH_INTERNAL_ASSERT(for_loops.size() != 0);
 
-    // Reduction may need a separate predicate for writes.
-    if (!isReductionInitExpr(expr) && out_tv->domain()->hasReduction()) {
-      const auto write_pred = ir_builder.create<kir::Predicate>(
-          PredicateType::ReductionWrite, expr, thread_pred);
-      expr->setWritePredicate(write_pred);
-    }
-
-    // For expr calling a device func with block sync, don't create
-    // if-then-else but pass the predicate to the device func
-    if (ir_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) {
-      const auto pred = ir_builder.create<kir::Predicate>(
-          PredicateType::Inline, expr, thread_pred);
-      expr->setPredicate(pred);
-      return;
-    }
-
-    // Vectorized expressions should never use inline predicates
-    kir::Predicate* vectorized_pred = nullptr;
-    if (std::any_of(
-            for_loops_.begin(), for_loops_.end(), [](const kir::ForLoop* fl) {
-              return fl->iter_domain()->parallelType() ==
-                  ParallelType::Vectorize;
-            })) {
-      vectorized_pred =
-          ir_builder.create<kir::Predicate>(PredicateType::Vectorize);
-    }
-
-    const auto pred = vectorized_pred == nullptr
-        ? ir_builder.create<kir::Predicate>(
-              PredicateType::Inline, expr, thread_pred)
-        : vectorized_pred;
-
-    TORCH_INTERNAL_ASSERT(pred != nullptr);
+    auto pred = PredicateCompute::getInlinePredicate(
+        expr, for_loops, getThreadPredicate(ir_utils::getTVOutput(expr)));
 
     // If we need a predicate, put expr inside an if then else
-    non_trivial_pred_found_ = true;
-    kir::IfThenElse* inline_ite = ir_builder.create<kir::IfThenElse>(pred);
-    if (for_loops_.empty()) {
-      // Special handling for top level output expressions that still
-      // need predicates. One motivating example is a reduction op that
-      // reduces to a scalar (issue #491)
-      expr_replacement_map_.insert({expr, inline_ite});
-    } else {
-      for_loops_.back()->body().insert_before(expr, inline_ite);
-      for_loops_.back()->body().erase(expr);
+    if (!(pred->isConst()) || !(pred->isConst() && pred->value().value())) {
+      non_trivial_pred_found = true;
+      kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+      kir::IfThenElse* inline_ite =
+          ir_builder.create<kir::IfThenElse>(pred, for_loops.back());
+      inline_ite->thenBody().push_back(expr);
+      for_loops.back()->body().insert_before(expr, inline_ite);
+      for_loops.back()->body().erase(expr);
     }
-    inline_ite->thenBody().push_back(expr);
-  } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
-    handle(for_loop);
+
+  } else {
+    // If not tv op, dispatch it.
+    OptOutDispatch::handle(expr);
   }
 }
 
@@ -143,150 +58,82 @@ void UnrollPass::handle(kir::Expr* expr) {
 // IR nodes "unroll_pred" or "inline_pred", then generate those later.
 void UnrollPass::handle(kir::ForLoop* fl) {
   // Setup for loop scoping
-  const bool is_unroll =
-      fl->iter_domain()->parallelType() == ParallelType::Unroll ||
-      fl->iter_domain()->parallelType() == ParallelType::Unswitch ||
-      fl->iter_domain()->parallelType() == ParallelType::Vectorize;
-
+  bool is_unroll = ir_utils::isUnrolledFor(fl);
   // If we're not looking for an unroll loop, or didn't find one, process as
   // normal.
-  if (!is_unroll || !look_for_unroll_) {
-    for_loops_.push_back(fl);
+  if (!is_unroll || !look_for_unroll) {
+    for_loops.push_back(fl);
 
+    std::vector<Expr*> exprs_copy = fl->body().exprs();
     // Make copy of exprs because we replace them inplace in fl
-    const auto exprs_copy = fl->body().exprs();
-
-    // Skip Misaligned Vectorization For-Loops here
-    if (!containsAnyDirectChildMisalignedVectorize(fl)) {
-      for (auto expr : exprs_copy) {
-        handle(expr);
-      }
+    for (auto expr : exprs_copy) {
+      handle(expr);
     }
+    for_loops.pop_back();
 
-    for_loops_.pop_back();
     return;
   }
 
-  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-  auto unroll_pred = ir_builder.create<kir::Predicate>(fl);
+  auto unroll_pred = UnrollPredicate::get(for_loops, fl, p2c_root_map);
 
-  kir::IfThenElse* unroll_ite = ir_builder.create<kir::IfThenElse>(unroll_pred);
+  kir::ForLoop* parent_scope = for_loops.empty() ? nullptr : for_loops.back();
+
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+  kir::IfThenElse* unroll_ite =
+      ir_builder.create<kir::IfThenElse>(unroll_pred, parent_scope);
 
   // Get the loop nest for the unrolled path
-  kir::ForLoop* unrolled_loop_nest = cloneLoopNest(fl);
+  kir::ForLoop* unrolled_loop_nest = scope_utils::cloneLoopNest(fl, unroll_ite);
 
   unroll_ite->thenBody().push_back(unrolled_loop_nest);
-  if (fl->iter_domain()->parallelType() == ParallelType::Vectorize) {
-    expr_replacement_map_.insert({fl, unroll_ite});
-    return;
-  }
 
   // Loop nest for inlined path
-  kir::ForLoop* inlined_loop = cloneLoopNest(fl);
+  kir::ForLoop* inlined_loop = scope_utils::cloneLoopNest(fl, unroll_ite);
 
   // Add inline predicates for inlined loop nest
-  look_for_unroll_ = false;
-  non_trivial_pred_found_ = false;
+  look_for_unroll = false;
+  non_trivial_pred_found = false;
   handle(inlined_loop);
-  look_for_unroll_ = true;
-  if (!non_trivial_pred_found_) {
-    expr_replacement_map_.insert({fl, inlined_loop});
+  look_for_unroll = true;
+  if (!non_trivial_pred_found) {
+    inlined_loop->setParentScope(parent_scope);
+    loop_replacement_map.insert({fl, inlined_loop});
   } else {
-    if (!canOmitElseClause(fl)) {
-      unroll_ite->elseBody().push_back(inlined_loop);
-    }
-    expr_replacement_map_.insert({fl, unroll_ite});
-  }
-}
-
-bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) const {
-  kir::ExpressionEvaluator eval;
-  std::vector<kir::ForLoop*> loops({fl});
-
-  const auto& pred_map = GpuLower::current()->threadPredMap();
-
-  while (loops.size() > 0) {
-    auto loop = loops.back();
-    loops.pop_back();
-
-    // If there's any expression that requires barrier
-    // synchronization, the else part can't be omitted
-    for (auto expr : loop->body().exprs()) {
-      if (expr->isA<kir::BroadcastOp>()) {
-        const ParallelTypeBitmap domains = pred_map.getParallelBroadcastDomains(
-            expr->outputs()[0]->as<kir::TensorView>()->fuserTv());
-        if (domains.any()) {
-          return false;
-        }
-      } else if (expr->isA<kir::ReductionOp>() || expr->isA<kir::WelfordOp>()) {
-        auto td = ir_utils::getTVOutput(expr)->domain();
-        if (td->hasBlockReduction() || td->hasGridReduction()) {
-          return false;
-        }
-      }
-    }
-    // If the number of visits of the loop body per thread is one, the
-    // unswitch predicate is sufficient.
-    // When the loop stop is the same as the extent of its IterDomain,
-    // the per-thread visit count is guaranteed to be one at most (see
-    // CudaKernelGenerator::visit(kir::ForLoop*) as well. Also, when a
-    // loop is vectorized (not misaligned), the count must be one at
-    // most. Even if not parallelized nor vectoirzed, it is also
-    // sufficient if the loop stop is in fact one.
-    bool visit_once = false;
-    auto id = loop->iter_domain();
-    if ((id->isThread() && (loop->stop() == id->extent())) ||
-        id->parallelType() == ParallelType::Vectorize) {
-      visit_once = true;
-    }
-    if (!visit_once) {
-      const auto result = eval.evaluate(loop->stop());
-      if (result.has_value() && result.value() == 1) {
-        visit_once = true;
-      }
-    }
-
-    // The visit count is not guaranteed to be one, so the else part
-    // must be created.
-    if (!visit_once) {
-      return false;
-    }
-
-    // The unswitch predicate is sufficient for this loop. Proceed to
-    // nested loops.
-    for (auto nested_loop :
-         ir_utils::filterByType<kir::ForLoop>(loop->body().exprs())) {
-      loops.push_back(nested_loop);
-    }
+    unroll_ite->elseBody().push_back(inlined_loop);
+    loop_replacement_map.insert({fl, unroll_ite});
   }
-
-  return true;
 }
 
 // Generate the loop nest structure and place it in lowered_exprs
-UnrollPass::UnrollPass(const std::vector<kir::Expr*>& exprs) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::UnrollPass::computeMap");
+void UnrollPass::computeMap() {
+  FUSER_PERF_SCOPE("UnrollPass::computeMap");
+
+  FusionGuard fg(fusion_);
 
   // Run through loop nests and further lower the expressions
-  for (auto* expr : exprs) {
-    handle(expr);
+  for (auto* expr : incoming_exprs_) {
+    OptOutDispatch::handle(expr);
   }
 }
 
-std::vector<kir::Expr*> UnrollPass::runPass(
+std::vector<Expr*> UnrollPass::runPass(
     Fusion* fusion,
-    const std::vector<kir::Expr*>& exprs) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::UnrollPass::runPass");
-
-  UnrollPass unroll_pass(exprs);
-
-  std::vector<kir::Expr*> mutated_exprs;
-  mutated_exprs.reserve(exprs.size());
-  for (auto expr : exprs) {
-    mutated_exprs.push_back(
-        ir_utils::applyReplacements(unroll_pass.replacementMap(), expr));
+    const std::vector<Expr*>& exprs,
+    const ThreadPredicateMap& thread_predicates) {
+  FUSER_PERF_SCOPE("UnrollPass::runPass");
+  FusionGuard fg(fusion);
+  UnrollPass up(fusion, exprs, thread_predicates);
+  up.computeMap();
+  std::vector<Expr*> mutated_exprs;
+  for (Expr* expr : exprs) {
+    if (up.loop_replacement_map.find(expr) != up.loop_replacement_map.end()) {
+      mutated_exprs.push_back(up.loop_replacement_map[expr]);
+    } else {
+      if (ir_utils::isScope(expr))
+        scope_utils::replaceExprsInScope(expr, up.loop_replacement_map);
+      mutated_exprs.push_back(expr);
+    }
   }
-
   return mutated_exprs;
 }
 
index fe297a4..0231550 100644 (file)
 #pragma once
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
 
 #include <bitset>
-#include <unordered_map>
 
 namespace torch {
 namespace jit {
 namespace fuser {
 namespace cuda {
 
-//! Unroll pass
-//!
-//! A bit deceptively: UnrollPass adds all predicates, so it needs to be run
-//! even if we don't unroll any loops.
-//!
-//! Unrolling pass will get IR that looks something like:
-//! for( i : I0o{ceil(I0/4)} ) {
-//!   for( j : I1o{ceil(I1/128)} ) {
-//!     for( k : I0i{4} )
-//!       for( l : I1i{128} )
-//!         T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
-//!
-//! And it will return the following:
-//! for( i : I0o{ceil(I0/4)} ) {
-//!   for( j : I1o{ceil(I1/128)} ) {
-//!
-//!     if( i * 4 + 3 < I && j * 128 + 127 < J ){
-//!       for( k : I0i{4} )
-//!         for( l : I1i{128} )
-//!           T0[ ( i * 4 + k ) * J + j * 128 + l ] = ...
-//!     } else {
-//!       for( k : I0i{4} )
-//!         for( l : I1i{128} )
-//!           if( i * 4 + k < I && j * 128 + l < J)
-//!              T0[ ( i * 4 + k ) * J + j * 128 + l ] = ...
-//!     }
-//!
-//!   }
-//! }
-//!
-//! As can be seen it generates two sets of loops for I0i{4} and I1i{128}. The
-//! first set is protected by a predicate that makes sure there's a full
-//! internal tile we can iterate over. This way we remove the predicate nested
-//! in the inner most loop. There's of course a second set of loops, which has a
-//! predicate still in the inner most loop, making sure that we cover edges and
-//! corners.
-//!
-class TORCH_CUDA_CU_API UnrollPass {
- public:
-  // Take the incoming exprs and run loop unrolling, returning the new IR
-  static std::vector<kir::Expr*> runPass(
-      Fusion* fusion,
-      const std::vector<kir::Expr*>& exprs);
+/*
+ * A bit deceptively: UnrollPass adds all predicates, so it needs to be run even
+ * if we don't unroll any loops.
+ *
+ * Unrolling pass will get IR that looks something like:
+ * for( i : I0o{ceil(I0/4)} ) {
+ *   for( j : I1o{ceil(I1/128)} ) {
+ *     for( k : I0i{4} )
+ *       for( l : I1i{128} )
+ *         T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
+ *
+ * And it will return the following:
+ * for( i : I0o{ceil(I0/4)} ) {
+ *   for( j : I1o{ceil(I1/128)} ) {
+ *
+ *     if( i * 4 + 3 < I && j * 128 + 127 < J ){
+ *       for( k : I0i{4} )
+ *         for( l : I1i{128} )
+ *           T0[ ( i * 4 + k ) * J + j * 128 + l ] = ...
+ *     } else {
+ *       for( k : I0i{4} )
+ *         for( l : I1i{128} )
+ *           if( i * 4 + k < I && j * 128 + l < J)
+ *              T0[ ( i * 4 + k ) * J + j * 128 + l ] = ...
+ *     }
+ *
+ *   }
+ * }
+ *
+ * As can be seen it generates two sets of loops for I0i{4} and I1i{128}. The
+ * first set is protected by a predicate that makes sure there's a full internal
+ * tile we can iterate over. This way we remove the predicate nested in the
+ * inner most loop. There's of course a second set of loops, which has a
+ * predicate still in the inner most loop, making sure that we cover edges and
+ * corners.
+ */
 
+class TORCH_CUDA_CU_API UnrollPass : public OptOutDispatch {
  private:
-  // Generate the for Expr replacement map
-  UnrollPass(const std::vector<kir::Expr*>& exprs);
+  // Wrapper to access thread_predicates_ based on an output TV
+  kir::Bool* getThreadPredicate(TensorView*);
 
-  const std::unordered_map<kir::Expr*, kir::Expr*>& replacementMap() const {
-    return expr_replacement_map_;
-  }
+  // We will track which loops in the incomming IR will be replaced and by what
+  std::unordered_map<Expr*, Expr*> loop_replacement_map;
 
-  void handle(kir::ForLoop* fl);
+  // Hold on to a reference to the fusion for convenience
+  Fusion* fusion_;
 
-  void handle(kir::Expr* expr);
+  // Hold on to the incoming exprs, but don't modify them. We don't set the
+  // Expr* to be const as Exprs' are const by virtue of their interface design
+  const std::vector<Expr*>& incoming_exprs_;
 
-  bool canOmitElseClause(kir::ForLoop* fl) const;
+  // Keep all for loops conveniently to make unrolling easier
+  std::vector<kir::ForLoop*> for_loops;
 
- private:
-  // We will track which loops in the incoming IR will be replaced and by what
-  std::unordered_map<kir::Expr*, kir::Expr*> expr_replacement_map_;
+  // Map from TensorView
+  const ThreadPredicateMap& thread_predicates_;
 
-  // Keep all for loops conveniently to make unrolling easier
-  std::vector<kir::ForLoop*> for_loops_;
+  std::unordered_map<IterDomain*, IterDomain*> p2c_root_map;
 
   // keep track if we're within an unrolled loop
-  bool look_for_unroll_ = true;
+  bool look_for_unroll = true;
 
   // As we generate inline predicates check if we actually generated a
   // non-trivial one.
-  bool non_trivial_pred_found_ = false;
+  bool non_trivial_pred_found = false;
+
+  // Custom dispatch for Expr, want to find out of it's a TV op
+  void handle(Expr*) final;
+
+  // Open the for loop.
+  void handle(kir::ForLoop*) final;
+
+  // Constructor
+  UnrollPass(
+      Fusion* _fusion,
+      const std::vector<Expr*>& _incoming_exprs,
+      const ThreadPredicateMap& _thread_predicates)
+      : fusion_(_fusion),
+        incoming_exprs_(_incoming_exprs),
+        thread_predicates_(_thread_predicates) {
+    p2c_root_map = loop_utils::p2cRootMap(_fusion->exprs(true));
+  }
+
+  // Generate the for Expr replacement map
+  void computeMap();
+
+ public:
+  // Take the incoming fusion and exprs and run loop unrolling, returning the
+  // new IR.
+  static std::vector<Expr*> runPass(
+      Fusion* fusion,
+      const std::vector<Expr*>& exprs,
+      const ThreadPredicateMap& thread_predicates);
 };
 
 } // namespace cuda
index 368c713..b7892f3 100644 (file)
 #include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
 #include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
 
 #include <algorithm>
 
-// TODO: refactor this file (one per namespace)
-
 namespace torch {
 namespace jit {
 namespace fuser {
 namespace cuda {
-
 namespace scope_utils {
 
-std::vector<kir::ForLoop*> getLoops(kir::Expr* scope) {
-  std::vector<kir::ForLoop*> loops;
-  while (scope != nullptr) {
-    if (auto loop = dynamic_cast<kir::ForLoop*>(scope)) {
-      loops.push_back(loop);
+// START SCOPE HELPER SYSTEMS
+namespace {
+
+class Loops : private OptInDispatch {
+ private:
+  std::deque<kir::ForLoop*> loops;
+  void handle(kir::ForLoop* fl) final {
+    loops.insert(loops.begin(), fl);
+  }
+
+  void handle(kir::IfThenElse* ite) final {}
+
+  void handle(Expr* expr) final {
+    OptInDispatch::handle(expr);
+  }
+
+ public:
+  static std::vector<kir::ForLoop*> getLoops(Expr* scope) {
+    Loops loops;
+    Expr* it = scope;
+    while (it != nullptr) {
+      loops.handle(it);
+      it = scope_utils::getParent(it);
+    }
+    return std::vector<kir::ForLoop*>(loops.loops.begin(), loops.loops.end());
+  }
+};
+
+class scopePushBack : private OptInDispatch {
+ private:
+  Expr* expr_;
+  void handle(kir::ForLoop* fl) final {
+    fl->body().push_back(expr_);
+  }
+
+  void handle(kir::IfThenElse* ite) final {
+    ite->thenBody().push_back(expr_);
+  }
+
+  void handle(Expr* expr) final {
+    OptInDispatch::handle(expr);
+  }
+
+  scopePushBack(Expr* expr) : expr_(expr) {}
+
+ public:
+  static void push(Expr* scope, Expr* expr) {
+    scopePushBack pb(expr);
+    TORCH_INTERNAL_ASSERT(
+        expr != nullptr && scope != nullptr,
+        "Cannot push back, scope or expr is a nullptr.");
+    pb.handle(scope);
+  }
+};
+
+class scopeInsertBefore : private OptInDispatch {
+ private:
+  Expr* ref_;
+  Expr* expr_;
+  void handle(kir::ForLoop* fl) final {
+    fl->body().insert_before(ref_, expr_);
+  }
+
+  void handle(kir::IfThenElse* ite) final {
+    ite->thenBody().insert_before(ref_, expr_);
+  }
+
+  void handle(Expr* expr) final {
+    OptInDispatch::handle(expr);
+  }
+
+  scopeInsertBefore(Expr* ref, Expr* expr) : ref_(ref), expr_(expr) {}
+
+ public:
+  static void insert(Expr* scope, Expr* ref, Expr* expr) {
+    scopeInsertBefore scb(ref, expr);
+    TORCH_INTERNAL_ASSERT(
+        expr != nullptr && scope != nullptr,
+        "Cannot push back, scope or expr is a nullptr.");
+    scb.handle(scope);
+  }
+};
+
+class ExprInScope : private OptInDispatch {
+ private:
+  Expr* expr_;
+  bool contains_ = false;
+
+  void handle(kir::ForLoop* fl) final {
+    if (fl->body().contains(expr_)) {
+      contains_ = true;
+    }
+  }
+
+  void handle(kir::IfThenElse* ite) final {
+    if (ite->thenBody().contains(expr_)) {
+      contains_ = true;
+    }
+  }
+
+  void handle(Expr* expr) final {
+    OptInDispatch::handle(expr);
+  }
+
+  ExprInScope(Expr* expr) : expr_(expr) {}
+
+ public:
+  static bool find(Expr* scope, Expr* expr) {
+    ExprInScope eis(expr);
+    TORCH_INTERNAL_ASSERT(
+        expr != nullptr && scope != nullptr,
+        "Cannot push back, scope or expr is a nullptr.");
+    eis.handle(scope);
+    return eis.contains_;
+  }
+};
+
+class parentScope : private OptInDispatch {
+ private:
+  Expr* parent_ = nullptr;
+
+  void handle(kir::ForLoop* fl) final {
+    parent_ = fl->parentScope();
+  }
+
+  void handle(kir::IfThenElse* ite) final {
+    parent_ = ite->parentScope();
+  }
+
+  void handle(Expr* expr) final {
+    OptInDispatch::handle(expr);
+  }
+
+ public:
+  static Expr* get(Expr* scope) {
+    parentScope sp;
+    sp.handle(scope);
+    return sp.parent_;
+  }
+};
+
+void assertScope(Expr* expr) {
+  TORCH_INTERNAL_ASSERT(
+      expr->getExprType() == ExprType::ForLoop ||
+          expr->getExprType() == ExprType::IfThenElse,
+      "Assert Scope failed when calling a scope_util function.");
+}
+
+class CloneLoopNest : public OptOutMutator {
+ private:
+  Expr* parent_scope_ = nullptr;
+  Expr* to_clone_ = nullptr;
+
+  Statement* mutate(kir::ForLoop* fl) final {
+    kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+    const auto parent_scope =
+        fl == to_clone_ ? parent_scope_ : fl->parentScope();
+    auto new_loop = ir_builder.create<kir::ForLoop>(
+        fl->index(), fl->iter_domain(), parent_scope);
+    for (Expr* expr : fl->body().exprs()) {
+      new_loop->body().push_back(ir_utils::asExpr(OptOutMutator::mutate(expr)));
+    }
+    return new_loop;
+  }
+
+  CloneLoopNest(Expr* _to_clone, Expr* _parent_scope)
+      : parent_scope_(_parent_scope), to_clone_(_to_clone) {}
+
+ public:
+  static kir::ForLoop* getClone(kir::ForLoop* _to_clone, Expr* _parent_scope) {
+    TORCH_INTERNAL_ASSERT(
+        _to_clone != nullptr,
+        "Tried to clone a scope, but received a nullptr.");
+    CloneLoopNest cln(_to_clone, _parent_scope);
+    return ir_utils::asForLoop(ir_utils::asExpr(cln.mutate(_to_clone)));
+  }
+};
+
+class ReplaceExprsInScope : public OptOutDispatch {
+ public:
+  static void replace(
+      Expr* scope,
+      std::unordered_map<Expr*, Expr*> replacement_map) {
+    ReplaceExprsInScope reis(std::move(replacement_map));
+    reis.handle(scope);
+  }
+
+ private:
+  explicit ReplaceExprsInScope(std::unordered_map<Expr*, Expr*> replacement_map)
+      : replacement_map_(std::move(replacement_map)) {}
+
+  void handleScope(kir::Scope& scope) {
+    for (const auto i : c10::irange(scope.size())) {
+      const auto it = replacement_map_.find(scope[i]);
+      if (it == replacement_map_.end()) {
+        handle(scope[i]);
+        continue;
+      }
+      scope[i] = it->second;
+    }
+  }
+
+  void handle(Expr* expr) final {
+    OptOutDispatch::handle(expr);
+  }
+
+  void handle(kir::ForLoop* fl) final {
+    handleScope(fl->body());
+  }
+
+  void handle(kir::IfThenElse* ite) final {
+    handleScope(ite->thenBody());
+    handleScope(ite->elseBody());
+  }
+
+ private:
+  std::unordered_map<Expr*, Expr*> replacement_map_;
+};
+
+class FirstInnerMostScope : private OptInDispatch {
+ private:
+  Expr* active_scope = nullptr;
+
+  void handle(kir::ForLoop* fl) final {
+    for (auto expr : fl->body().exprs()) {
+      if (ir_utils::isScope(expr)) {
+        active_scope = expr;
+        return;
+      }
     }
-    scope = scope->parentScope();
+    active_scope = nullptr;
   }
-  std::reverse(loops.begin(), loops.end());
-  return loops;
+
+  void handle(kir::IfThenElse* ite) final {
+    for (auto expr : ite->thenBody().exprs()) {
+      if (ir_utils::isScope(expr)) {
+        active_scope = expr;
+        return;
+      }
+    }
+    for (auto expr : ite->elseBody().exprs()) {
+      if (ir_utils::isScope(expr)) {
+        active_scope = expr;
+        return;
+      }
+    }
+    active_scope = nullptr;
+  }
+
+  Expr* getInner(Expr* expr) {
+    OptInDispatch::handle(expr);
+    return active_scope;
+  }
+
+ public:
+  static Expr* get(Expr* scope) {
+    TORCH_INTERNAL_ASSERT(
+        scope != nullptr,
+        "Tried to get inner most scope, but was provided nullptr.");
+
+    FirstInnerMostScope fims;
+    Expr* inner = fims.getInner(scope);
+
+    if (inner == nullptr)
+      return scope;
+
+    while (fims.getInner(inner) != nullptr)
+      inner = fims.getInner(inner);
+    return inner;
+  }
+};
+
+// END SCOPE HELPER SYSTEMS
+} // namespace
+
+// Grab the ForLoop starting from scope working out
+std::vector<kir::ForLoop*> getLoops(Expr* scope) {
+  if (scope == nullptr)
+    return std::vector<kir::ForLoop*>();
+  assertScope(scope);
+  return Loops::getLoops(scope);
 }
 
-void insertBefore(kir::Expr* scope, kir::Expr* ref, kir::Expr* expr) {
-  if (auto ite = dynamic_cast<kir::IfThenElse*>(scope)) {
-    ite->thenBody().insert_before(ref, expr);
-  } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(scope)) {
-    for_loop->body().insert_before(ref, expr);
+// Push back an expr to scope
+void pushBack(Expr* scope, Expr* expr) {
+  TORCH_INTERNAL_ASSERT(
+      scope != nullptr, "Scope is a nullptr, cannot push an expr to it.");
+  assertScope(scope);
+  scopePushBack::push(scope, expr);
+}
+
+// Insert expr in scope before ref
+void insertBefore(Expr* scope, Expr* ref, Expr* expr) {
+  scopeInsertBefore::insert(scope, ref, expr);
+}
+
+bool exprInScope(Expr* scope, Expr* expr) {
+  return ExprInScope::find(scope, expr);
+}
+
+// Return the parent of the active scope
+Expr* getParent(Expr* scope) {
+  TORCH_INTERNAL_ASSERT(
+      scope != nullptr,
+      "Tried to close the active scope, but there isn't one set.");
+  assertScope(scope);
+  return parentScope::get(scope);
+}
+
+// Open a new inner most for loop
+kir::ForLoop* openFor(Expr* scope, IterDomain* id) {
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+  const auto kir_id = GpuLower::lowerValue(id)->as<kir::IterDomain>();
+  kir::ForLoop* new_scope = nullptr;
+  if (id->isThread()) {
+    std::stringstream ss;
+    ss << id->getParallelType();
+    new_scope = ir_builder.create<kir::ForLoop>(
+        ir_builder.create<kir::NamedScalar>(ss.str(), DataType::Int),
+        kir_id,
+        scope);
   } else {
-    TORCH_INTERNAL_ASSERT(false, "Unexpected scope expression");
+    new_scope = ir_builder.create<kir::ForLoop>(
+        ir_builder.create<kir::Int>(c10::nullopt), kir_id, scope);
   }
+  if (scope != nullptr)
+    pushBack(scope, new_scope);
+  return new_scope;
+}
+
+kir::ForLoop* cloneLoopNest(kir::ForLoop* to_clone, Expr* parent_scope) {
+  return CloneLoopNest::getClone(to_clone, parent_scope);
+}
+
+void replaceExprsInScope(
+    Expr* scope,
+    std::unordered_map<Expr*, Expr*> replacement_map) {
+  TORCH_INTERNAL_ASSERT(
+      replacement_map.find(scope) == replacement_map.end(),
+      "Error trying to replace expressions in a scope, scope wants to be replaced entirely.");
+  ReplaceExprsInScope::replace(scope, std::move(replacement_map));
+}
+
+Expr* firstInnerMostScope(Expr* scope) {
+  return FirstInnerMostScope::get(scope);
 }
 
 } // namespace scope_utils
@@ -86,36 +407,31 @@ std::vector<IterDomain*> iterDomainInputsOfOrderedAs(
   return ordered_inputs;
 }
 
+std::vector<Val*> indices(std::vector<kir::ForLoop*> loops) {
+  std::vector<Val*> inds(loops.size());
+  std::transform(
+      loops.begin(), loops.end(), inds.begin(), [](kir::ForLoop* fl) {
+        return fl->index();
+      });
+  return inds;
+}
+
 bool isTV(const Val* val) {
   return val->getValType().value() == ValType::TensorView;
 }
 
 // Check if we're a TensorView op that we can generate code for.
 bool isTVOp(const Expr* expr) {
-  if (std::any_of(
-          expr->outputs().begin(),
-          expr->outputs().end(),
-          [](Val* v) { return isTV(v); }) &&
+  if (expr->outputs().size() == 1 && isTV(expr->output(0)) &&
       (expr->getExprType().value() == ExprType::BinaryOp ||
        expr->getExprType().value() == ExprType::UnaryOp ||
        expr->getExprType().value() == ExprType::TernaryOp ||
        expr->getExprType().value() == ExprType::ReductionOp ||
-       expr->getExprType().value() == ExprType::WelfordOp ||
-       expr->getExprType().value() == ExprType::BroadcastOp ||
-       expr->getExprType().value() == ExprType::TransposeOp ||
-       expr->getExprType().value() == ExprType::ShiftOp ||
-       expr->getExprType().value() == ExprType::GatherOp)) {
+       expr->getExprType().value() == ExprType::BroadcastOp))
     return true;
-  }
   return false;
 }
 
-bool isTVOp(const kir::Expr* expr) {
-  const auto& outputs = expr->outputs();
-  return outputs.size() >= 1 && outputs[0]->isA<kir::TensorView>();
-}
-
-// TODO: why do we assume there's a single TV output?
 TensorView* getTVOutput(const Expr* expr) {
   for (auto out : expr->outputs()) {
     if (out->getValType().value() == ValType::TensorView) {
@@ -125,17 +441,6 @@ TensorView* getTVOutput(const Expr* expr) {
   return nullptr;
 }
 
-kir::TensorView* getTVOutput(const kir::Expr* expr) {
-  for (auto out : expr->outputs()) {
-    if (auto tv = dynamic_cast<kir::TensorView*>(out)) {
-      return tv;
-    } else if (auto ti = dynamic_cast<kir::TensorIndex*>(out)) {
-      return ti->view();
-    }
-  }
-  return nullptr;
-}
-
 bool isScalarOp(const Expr* expr) {
   for (auto out : expr->outputs())
     if (!out->isScalar())
@@ -143,8 +448,15 @@ bool isScalarOp(const Expr* expr) {
   return true;
 }
 
+void ASSERT_EXPR(Statement* stmt) {
+  TORCH_INTERNAL_ASSERT(
+      stmt->isExpr(),
+      "Tried to generate a kernel but hit a non expression during lowering: ",
+      stmt);
+}
+
 Expr* asExpr(Statement* stmt) {
-  TORCH_INTERNAL_ASSERT(stmt->isExpr());
+  ASSERT_EXPR(stmt);
   return stmt->as<Expr>();
 }
 
@@ -153,73 +465,177 @@ TensorView* asTV(Val* val) {
   return val->as<TensorView>();
 }
 
-bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map) {
-  if (!isTVOp(expr)) {
+bool isScope(const Expr* expr) {
+  return expr->getExprType() == ExprType::ForLoop ||
+      expr->getExprType() == ExprType::IfThenElse;
+}
+
+kir::ForLoop* asForLoop(Statement* stmt) {
+  Expr* expr = asExpr(stmt);
+  TORCH_INTERNAL_ASSERT(expr->getExprType() == ExprType::ForLoop);
+  return expr->as<kir::ForLoop>();
+}
+
+const TensorView* asConstTV(const Val* val) {
+  TORCH_INTERNAL_ASSERT(isTV(val));
+  return val->as<TensorView>();
+}
+
+bool isUnrolledFor(const Expr* expr) {
+  if (expr->getExprType() != ExprType::ForLoop) {
     return false;
   }
+  return expr->as<kir::ForLoop>()->iter_domain()->getParallelType() ==
+      ParallelType::Unroll;
+}
 
-  auto tv = getTVOutput(expr);
+const std::unordered_map<ParallelType, int, TypeHash>
+    ParallelTypeBitmap::pt_to_offset_{
+        {ParallelType::BIDx, 0},
+        {ParallelType::BIDy, 1},
+        {ParallelType::BIDz, 2},
+        {ParallelType::TIDx, 3},
+        {ParallelType::TIDy, 4},
+        {ParallelType::TIDz, 5}};
+
+const std::unordered_map<int, ParallelType> ParallelTypeBitmap::offset_to_pt_ =
+    {{0, ParallelType::BIDx},
+     {1, ParallelType::BIDy},
+     {2, ParallelType::BIDz},
+     {3, ParallelType::TIDx},
+     {4, ParallelType::TIDy},
+     {5, ParallelType::TIDz}};
+
+bool ParallelTypeBitmap::get(ParallelType pt) const {
+  if (pt_to_offset_.find(pt) == pt_to_offset_.end()) {
+    TORCH_INTERNAL_ASSERT(false, "Could not recognize parallel type.");
+  }
+  return bitset_[pt_to_offset_.at(pt)];
+}
 
-  if ((expr->isA<ReductionOp>() || expr->isA<WelfordOp>()) &&
-      (tv->hasBlockReduction() || tv->hasGridReduction())) {
-    return true;
-  } else if (expr->isA<BroadcastOp>()) {
-    const ParallelTypeBitmap pt_map =
-        GpuLower::current()->threadPredMap().getParallelBroadcastDomains(tv);
-    return pt_map.hasTID();
+bool ParallelTypeBitmap::set(ParallelType pt, bool new_val) {
+  if (pt_to_offset_.find(pt) == pt_to_offset_.end()) {
+    TORCH_INTERNAL_ASSERT(false, "Could not recognize parallel type.");
   }
+  bool old_val = bitset_[pt_to_offset_.at(pt)];
+  bitset_[pt_to_offset_.at(pt)] = new_val;
+  return old_val;
+}
 
-  return false;
+ParallelTypeBitmap ParallelTypeBitmap::operator&=(
+    const ParallelTypeBitmap& other) {
+  bitset_ &= other.bitset_;
+  return *this;
+}
+
+ParallelTypeBitmap ParallelTypeBitmap::operator|=(
+    const ParallelTypeBitmap& other) {
+  bitset_ |= other.bitset_;
+  return *this;
+}
+
+ParallelTypeBitmap ParallelTypeBitmap::operator^=(
+    const ParallelTypeBitmap& other) {
+  bitset_ ^= other.bitset_;
+  return *this;
 }
 
-bool hasBlockSync(const kir::Expr* expr, const ThreadPredicateMap& pred_map) {
-  if (expr->isA<kir::ReductionOp>() || expr->isA<kir::GridReduction>() ||
-      expr->isA<kir::BroadcastOp>() || expr->isA<kir::WelfordOp>() ||
-      expr->isA<kir::GridWelford>()) {
-    auto fuser_tv = getTVOutput(expr)->fuserTv();
-    auto fuser_expr = fuser_tv->definition();
-    TORCH_INTERNAL_ASSERT(fuser_expr != nullptr);
-    return hasBlockSync(fuser_expr, pred_map);
+ParallelTypeBitmap ParallelTypeBitmap::operator~() const {
+  return ParallelTypeBitmap(~bitset_);
+}
+
+bool ParallelTypeBitmap::none() const {
+  return bitset_.none();
+}
+
+bool ParallelTypeBitmap::any() const {
+  return bitset_.any();
+}
+
+bool ParallelTypeBitmap::all() const {
+  return bitset_.all();
+}
+
+bool ParallelTypeBitmap::operator[](size_t pos) const {
+  TORCH_INTERNAL_ASSERT(
+      pos < num_p_type, "Invalid index to ParallelTypeBitset: ", pos);
+  return bitset_[pos];
+}
+
+std::map<ParallelType, bool> ParallelTypeBitmap::getMap() const {
+  std::map<ParallelType, bool> map;
+  for (const auto& pt_offset : pt_to_offset_) {
+    map.emplace(pt_offset.first, bitset_[pt_offset.second]);
   }
+  return map;
+}
 
-  return false;
+ParallelTypeBitmap operator&(
+    const ParallelTypeBitmap& lhs,
+    const ParallelTypeBitmap& rhs) {
+  auto x = lhs;
+  x &= rhs;
+  return x;
 }
 
-kir::Expr* applyReplacements(
-    const std::unordered_map<kir::Expr*, kir::Expr*>& expr_replacement_map,
-    kir::Expr* expr) {
-  auto handle_scope = [&](kir::Scope& scope) {
-    for (size_t i = 0; i < scope.size(); ++i) {
-      scope[i] = applyReplacements(expr_replacement_map, scope[i]);
-    }
-  };
+ParallelTypeBitmap operator|(
+    const ParallelTypeBitmap& lhs,
+    const ParallelTypeBitmap& rhs) {
+  auto x = lhs;
+  x |= rhs;
+  return x;
+}
 
-  const auto it = expr_replacement_map.find(expr);
-  if (it != expr_replacement_map.end()) {
-    return it->second;
-  } else {
-    if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
-      handle_scope(for_loop->body());
-    } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
-      handle_scope(ite->thenBody());
-      handle_scope(ite->elseBody());
+ParallelTypeBitmap operator^(
+    const ParallelTypeBitmap& lhs,
+    const ParallelTypeBitmap& rhs) {
+  auto x = lhs;
+  x ^= rhs;
+  return x;
+}
+
+ParallelTypeBitmap getParallelBroadcastDomains(
+    const Val* bop_out,
+    const ThreadPredicateMap& preds) {
+  if (bop_out->getValType().value() == ValType::TensorIndex) {
+    bop_out = bop_out->as<kir::TensorIndex>()->view()->fuserTv();
+  }
+  TORCH_INTERNAL_ASSERT(
+      bop_out->getValType().value() == ValType::TensorView,
+      "Out is not tensor view");
+  auto out_tv = bop_out->as<TensorView>();
+  // If no pred is found for out_tv, no predicate is necessary
+  if (preds.find(out_tv) == preds.end()) {
+    return ParallelTypeBitmap();
+  }
+  const ParallelTypeBitmap& out_pred = preds.at(out_tv).first;
+
+  ParallelTypeBitmap parallel_broadcast;
+  const auto& iter_domains = out_tv->domain()->domain();
+  // If the output is on shared memory, assume that all subsequent
+  // reads from all threads in its CTA can be done with no parallel
+  // broadcast. Only one thread will write to shared memory followed
+  // by a proper _syncthreads.
+  const bool output_smem = out_tv->getMemoryType() == MemoryType::Shared;
+  for (auto id : iter_domains) {
+    if (!id->isBroadcast()) {
+      continue;
+    }
+    if (id->isBlockDim() || (!output_smem && id->isThreadDim())) {
+      parallel_broadcast.set(id->getParallelType(), true);
     }
-    return expr;
   }
+
+  return parallel_broadcast & out_pred;
 }
 
 } // namespace ir_utils
 
 namespace loop_utils {
 
-// TODO: Clean this up, Naoya added a mechanism we should be able to reuse.
 std::pair<kir::ForLoop*, int64_t> getAllocPoint(
-    const TensorView* tv,
-    const std::vector<kir::ForLoop*>& loops,
-    const std::unordered_map<IterDomain*, IterDomain*>& id_map,
-    bool use_id_map) {
-  const auto gpu_lower = GpuLower::current();
-
+    TensorView* tv,
+    const std::vector<kir::ForLoop*>& loops) {
   // If in global memory, it can be all the way outside the loops.
   if (tv->getMemoryType() == MemoryType::Global) {
     return {nullptr, 0};
@@ -230,36 +646,32 @@ std::pair<kir::ForLoop*, int64_t> getAllocPoint(
   kir::ForLoop* alloc_loop = nullptr;
 
   auto loops_it = loops.begin();
+
   // Look at each axis individually in out's domain
-  for (const auto tv_i : c10::irange((int64_t)tv->getComputeAtPosition())) {
+  for (const auto tv_i : c10::irange((int64_t)tv->getThisComputeAtAxis())) {
     // Grab the axis ID
 
-    auto local_id = tv->axis(tv_i);
-    if (use_id_map) {
-      auto id_it = id_map.find(local_id);
-      if (id_it != id_map.end()) {
-        local_id = id_it->second;
-      }
-    }
+    auto ca_id = tv->getComputeAtAxis(tv_i).first;
+    auto kir_ca_id = GpuLower::lowerValue(ca_id)->as<kir::IterDomain>();
 
-    if (gpu_lower->trivialReductionInfo().isDerivedFromRoot(local_id)) {
-      continue;
-    }
-
-    auto lowered_local_id =
-        gpu_lower->lowerValue(local_id)->as<kir::IterDomain>();
-    loops_it = std::find_if(
-        loops_it, loops.end(), [&lowered_local_id](const auto& loop) {
-          return GpuLower::current()->caLoopMap().areMapped(
-                     lowered_local_id, loop->iter_domain()) ||
-              loop->iter_domain()->parallelType() == ParallelType::Unroll;
+    loops_it =
+        std::find_if(loops_it, loops.end(), [&kir_ca_id](const auto& loop) {
+          return kir_ca_id == loop->iter_domain() ||
+              loop->iter_domain()->getParallelType() == ParallelType::Unroll;
         });
 
+    if (loops_it == loops.end()) {
+      for (auto loop : loops) {
+        std::cout << kir::toString(loop->iter_domain()) << "  ";
+      }
+      std::cout << std::endl;
+    }
     TORCH_INTERNAL_ASSERT(
         loops_it != loops.end(),
         "Could not find all required axes for indexing when trying to index into ",
         tv);
-    if ((*loops_it)->iter_domain()->parallelType() == ParallelType::Unroll) {
+
+    if (kir_ca_id->getParallelType() == ParallelType::Unroll) {
       return {alloc_loop, tv_i};
     }
 
@@ -267,13 +679,44 @@ std::pair<kir::ForLoop*, int64_t> getAllocPoint(
     ++loops_it;
   }
 
-  return {alloc_loop, (int64_t)tv->getComputeAtPosition()};
+  return {alloc_loop, (int64_t)tv->getThisComputeAtAxis()};
 }
 
-std::pair<kir::ForLoop*, int64_t> getAllocPoint(
-    const TensorView* tv,
-    const std::vector<kir::ForLoop*>& loops) {
-  return getAllocPoint(tv, loops, {}, false);
+std::unordered_map<IterDomain*, IterDomain*> p2cRootMap(
+    const std::vector<Expr*>& exprs) {
+  std::unordered_map<IterDomain*, IterDomain*> p2c_root_map;
+
+  for (auto expr : exprs) {
+    auto out_tv = ir_utils::getTVOutput(expr);
+    for (auto inp : expr->inputs()) {
+      if (inp->getValType().value() != ValType::TensorView) {
+        continue;
+      }
+
+      auto root_p2c = TensorDomain::mapRootPtoC(
+          inp->as<TensorView>()->domain(), out_tv->domain());
+      for (auto entry : root_p2c) {
+        auto p_id = entry.first;
+        auto c_id = entry.second;
+        // Careful we don't allow circular references
+        if (p_id != c_id) {
+          p2c_root_map[p_id] = c_id;
+        }
+      }
+    }
+  }
+
+  return p2c_root_map;
+}
+
+IterDomain* getTermIDInMap(
+    IterDomain* root_id,
+    std::unordered_map<IterDomain*, IterDomain*> p2c_root_map) {
+  auto entry = root_id;
+  while (p2c_root_map.find(entry) != p2c_root_map.end()) {
+    entry = p2c_root_map.at(entry);
+  }
+  return entry;
 }
 
 } // namespace loop_utils
index b8ca98a..1a2c16a 100644 (file)
@@ -1,11 +1,8 @@
-
 #pragma once
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h>
 
 #include <bitset>
 #include <map>
@@ -19,19 +16,39 @@ namespace cuda {
 
 class ThreadPredicateMap;
 
-using IterDomainMap = std::unordered_map<kir::IterDomain*, kir::IterDomain*>;
-
 namespace scope_utils {
 
-//! Returns the list of nesting loops starting at `scope`
-// Primarily used in indexing, maybe could be moved there
-std::vector<kir::ForLoop*> getLoops(kir::Expr* scope);
+// Grab the ForLoop starting from scope working out
+std::vector<kir::ForLoop*> getLoops(Expr* scope);
+
+// Track how far our for loop scope is
+unsigned int computeForDepth(Expr* scope);
+
+// Push back an expr to scope
+void pushBack(Expr* scope, Expr* expr);
+
+// Insert expr in scope before ref
+void insertBefore(Expr* scope, Expr* ref, Expr* expr);
+
+// Returns if expr is in scope, does not check nested scopes
+bool exprInScope(Expr* scope, Expr* expr);
+
+// Return the parent of the active scope
+Expr* getParent(Expr* scope);
+
+// Open a new inner most for loop
+kir::ForLoop* openFor(Expr* scope, IterDomain*);
 
-//! Insert expr in scope before ref
-//!
-//! \warning for kir::IfThenElse we implicitly insert in the "then" branch!
-//!
-void insertBefore(kir::Expr* scope, kir::Expr* ref, kir::Expr* expr);
+// Provide a new for loop matching the one provided, sets parent_scope as
+// parent_scope, but does not insert into parent scope.
+kir::ForLoop* cloneLoopNest(kir::ForLoop* to_clone, Expr* parent_scope);
+
+// Run through a scope and replace expressions inside with replacement_map
+void replaceExprsInScope(
+    Expr* scope,
+    std::unordered_map<Expr*, Expr*> replacement_map);
+
+Expr* firstInnerMostScope(Expr* scope);
 
 } // namespace scope_utils
 
@@ -62,43 +79,76 @@ std::vector<IterDomain*> iterDomainInputsOfOrderedAs(
     const std::vector<IterDomain*>& of,
     const std::vector<IterDomain*>& order);
 
-bool isTV(const Val* const);
+std::vector<Val*> indices(std::vector<kir::ForLoop*>);
 
-TORCH_CUDA_CU_API bool isTVOp(const Expr*);
+bool isTV(const Val* const);
 
-bool isTVOp(const kir::Expr* expr);
+bool isTVOp(const Expr*);
 
 TensorView* getTVOutput(const Expr*);
-kir::TensorView* getTVOutput(const kir::Expr*);
 
 bool isScalarOp(const Expr*);
 
-// TODO(kir): remove
+void ASSERT_EXPR(Statement*);
+
+bool isScope(const Expr*);
+
 Expr* asExpr(Statement*);
 
-// TODO(kir): Remove in favor of ->as<TensorView>()
+// TODO: Remove in favor of ->as<TensorView>()
 TensorView* asTV(Val*);
 
-bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map);
-bool hasBlockSync(const kir::Expr* expr, const ThreadPredicateMap& pred_map);
-
-// expr_replacement_map maps an expression to its replacement.
-//
-// The applyReplacement function serves two purposes.
-//
-// 1. If expr is found in expr_replacement_map, return the value for expr key.
-// Otherwise, return the original expression.
-//
-// 2. If a replacement is not found and the expression is a ForLoop or an
-// IfThenElse, it modifies the expressions in its scope by running the
-// handle_scope function
-//
-// The handle_scope function iterates over the expressions in the scope.
-// For each expression, it updates the expression the value returned by
-// applyReplacement.
-kir::Expr* applyReplacements(
-    const std::unordered_map<kir::Expr*, kir::Expr*>& expr_replacement_map,
-    kir::Expr* expr);
+// TODO: Remove in favor of ->as<ForLoop>()
+kir::ForLoop* asForLoop(Statement*);
+
+// TODO: Remove in favor of ->as<TensorView>()
+const TensorView* asConstTV(const Val*);
+
+bool isUnrolledFor(const Expr*);
+
+// Represents mapping to bool from BIDx, BIDy, BIDz, TIDx, TIDy and TIDz.
+class ParallelTypeBitmap {
+ public:
+  static constexpr int num_p_type = 6;
+  ParallelTypeBitmap() = default;
+  bool get(ParallelType pt) const;
+  bool set(ParallelType pt, bool);
+  ParallelTypeBitmap operator&=(const ParallelTypeBitmap& other);
+  ParallelTypeBitmap operator|=(const ParallelTypeBitmap& other);
+  ParallelTypeBitmap operator^=(const ParallelTypeBitmap& other);
+  ParallelTypeBitmap operator~() const;
+  bool none() const;
+  bool any() const;
+  bool all() const;
+  bool operator[](size_t pos) const;
+  std::map<ParallelType, bool> getMap() const;
+
+ private:
+  ParallelTypeBitmap(const std::bitset<num_p_type>& bs) : bitset_(bs) {}
+  std::bitset<num_p_type> bitset_;
+  const static std::unordered_map<ParallelType, int, TypeHash> pt_to_offset_;
+  const static std::unordered_map<int, ParallelType> offset_to_pt_;
+};
+
+ParallelTypeBitmap operator&(
+    const ParallelTypeBitmap& lhs,
+    const ParallelTypeBitmap& rhs);
+
+ParallelTypeBitmap operator|(
+    const ParallelTypeBitmap& lhs,
+    const ParallelTypeBitmap& rhs);
+
+ParallelTypeBitmap operator^(
+    const ParallelTypeBitmap& lhs,
+    const ParallelTypeBitmap& rhs);
+
+// Returns a ParallelTypeBitmap representing which domain needs
+// blockBroadcast.
+// Even when a domain is broadcast and parallelized, it does not need
+// blockBroadcast unless it is predicated.
+ParallelTypeBitmap getParallelBroadcastDomains(
+    const Val* bop_out,
+    const ThreadPredicateMap& preds);
 
 } // namespace ir_utils
 
@@ -113,17 +163,22 @@ namespace loop_utils {
 // outside the first loop in loops. Also find out which index in tv the
 // first dimension that needs to be allocated is. Meaning we need to allocate
 // that local axis and above.
-// TODO: Only remaining use of this is in index compute, remove use from there,
-// or refactor and use in lower_allocation
-std::pair<kir::ForLoop*, int64_t> getAllocPoint(
-    const TensorView* tv,
-    const std::vector<kir::ForLoop*>& loops,
-    const std::unordered_map<IterDomain*, IterDomain*>& id_map,
-    bool use_id_map);
-
 std::pair<kir::ForLoop*, int64_t> getAllocPoint(
-    const TensorView* tv,
+    TensorView* tv,
     const std::vector<kir::ForLoop*>& loops);
+
+// Go through exprs mapping root domains from producer to consumer. Provides a
+// ground truth for how root domains map through our expressions. Needed for
+// unrolling.
+std::unordered_map<IterDomain*, IterDomain*> p2cRootMap(
+    const std::vector<Expr*>& exprs);
+
+// Given a root IterationDomain and a p2c_root_map find the root IterationDomain
+// furthest down in the sorted expr list it maps to. Needed for unrolling.
+IterDomain* getTermIDInMap(
+    IterDomain* root_id,
+    std::unordered_map<IterDomain*, IterDomain*> p2c_root_map);
+
 } // namespace loop_utils
 } // namespace cuda
 } // namespace fuser
index 16dedb2..4ffc8c6 100644 (file)
@@ -1,11 +1,9 @@
+#include <c10/util/irange.h>
+
 #include <torch/csrc/jit/codegen/cuda/lower_validation.h>
 
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
 #include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
 #include <torch/csrc/jit/codegen/cuda/lower_utils.h>
 #include <torch/csrc/jit/codegen/cuda/transform_replay.h>
 #include <torch/csrc/jit/codegen/cuda/type.h>
@@ -15,590 +13,60 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-namespace {
-
-//! A parallel type validation pass to make sure all the outputs of
-//!   welford ops are parallelized the same way. Will infer and modify serial
-//!   parallel types if other output/s are parallelized, so that
-//!   user wouldn't have to specify the same parallelization
-//!   3 times. Will throw if conflicts are detected, i.e.
-//!   TIDx vs BIDx etc.
-class ValidateParallelType : public IterVisitor {
- public:
-  static void validate(Fusion* fusion) {
-    ValidateParallelType VPT;
-    VPT.traverse(fusion);
-  }
-
- private:
-  using IterVisitor::handle;
-  // Parallelize id1 and id0 consistently if one is serial and the other isn't
-  void convertIterDomain(IterDomain* id0, IterDomain* id1) {
-    const auto ptype0 = id0->getParallelType();
-    const auto ptype1 = id1->getParallelType();
-
-    if (ptype0 == ParallelType::Vectorize ||
-        ptype1 == ParallelType::Vectorize) {
-      auto other_type = ptype0 == ParallelType::Vectorize ? ptype1 : ptype0;
-      TORCH_INTERNAL_ASSERT(
-          other_type == ParallelType::Vectorize ||
-              (!isParallelTypeThreadDim(other_type) &&
-               !isParallelTypeBlockDim(other_type)),
-          "Vectorize type was parallelized inconsistently in. ",
-          "Detected during promoting parallel types.");
-      return;
-    }
-
-    if (ptype0 != ptype1) {
-      TORCH_CHECK(
-          ptype0 == ParallelType::Serial || ptype1 == ParallelType::Serial,
-          "Error promoting parallel types");
-      if (ptype0 == ParallelType::Serial) {
-        id0->parallelize(ptype1);
-      }
-      if (ptype1 == ParallelType::Serial) {
-        id1->parallelize(ptype0);
-      }
-    }
-  }
-
-  void handle(WelfordOp* wop) override {
-    auto out_avg = wop->outAvg()->as<TensorView>();
-    auto out_var = wop->outVar()->as<TensorView>();
-    auto out_n = wop->outN()->as<TensorView>();
-    TORCH_INTERNAL_ASSERT(out_avg->nDims() == out_var->nDims());
-    TORCH_INTERNAL_ASSERT(out_avg->nDims() == out_n->nDims());
-    for (size_t i = 0; i < out_avg->nDims(); i++) {
-      // TODO: can be cleaner.
-      convertIterDomain(out_avg->axis(i), out_var->axis(i));
-      convertIterDomain(out_avg->axis(i), out_n->axis(i));
-      convertIterDomain(out_n->axis(i), out_var->axis(i));
-    }
-  }
-};
-
-// Make sure all IterDomains are only used for a unique
-// TensorView. Several mappings from IterDomains are
-// created during lowering, which relies on the unique usage of
-// IterDomains.
-void validateIterDomainUsage(Fusion* fusion) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::validateIterDomainUse");
-  FusionGuard fg(fusion);
-
-  auto used_vals = fusion->usedMathVals();
-  std::unordered_map<IterDomain*, TensorView*> domain_use_map;
-
-  for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
-    std::unordered_set<Val*> root_domains;
-    std::copy(
-        tv->getRootDomain().begin(),
-        tv->getRootDomain().end(),
-        std::inserter(root_domains, root_domains.begin()));
-
-    std::vector<Val*> leaf_domains;
-    std::copy(
-        tv->domain()->domain().begin(),
-        tv->domain()->domain().end(),
-        std::back_inserter(leaf_domains));
-
-    auto all_domain_vals =
-        DependencyCheck::getAllValsBetween(root_domains, leaf_domains);
-
-    for (auto id : ir_utils::filterByType<IterDomain>(all_domain_vals)) {
-      auto it = domain_use_map.find(id);
-      TORCH_INTERNAL_ASSERT(
-          it == domain_use_map.end(),
-          "Multiple use of ",
-          id,
-          " detected.",
-          " Used in both TV",
-          tv->name(),
-          " and TV",
-          it->second->name());
-      domain_use_map.insert({id, tv});
-    }
-  }
-}
-
-} // namespace
-
 void validateIr(Fusion* fusion) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::validateIr");
-
-  FusionGuard fg(fusion);
-
-  fusion->validateInputs();
-
-  // Convert all input broadcast iterdomains to strided
-  for (auto tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
-    for (auto id : tv->getMaybeRFactorDomain()) {
-      if (id->isBroadcast()) {
-        id->toStridedBroadcast();
-      }
-    }
-  }
-
-  // Convert all output broadcast iterdomains to strided
-  for (auto tv : ir_utils::filterByType<TensorView>(fusion->outputs())) {
-    for (auto id : tv->getMaybeRFactorDomain()) {
-      if (id->isBroadcast()) {
-        id->toStridedBroadcast();
-      }
-    }
-  }
-
-  // Validate Parallelization
-  ValidateParallelType::validate(fusion);
-
-  validateIterDomainUsage(fusion);
-}
-
-namespace {
-
-// Check contiguity for all root domains associated with Misaligned Vectorize
-// ParallelType
-void checkContiguity(
-    const std::unordered_set<IterDomain*>& domains,
-    TensorView* tv) {
-  TORCH_INTERNAL_ASSERT(tv->getMemoryType() == MemoryType::Global);
-
-  for (size_t idx = 0; idx < tv->getRootDomain().size(); ++idx) {
-    auto root = tv->getRootDomain()[idx];
-    if (domains.find(root) != domains.end()) {
-      TORCH_INTERNAL_ASSERT(
-          !root->isBroadcast(),
-          "Misaligned vectorization prohibits merging broadcast domains.",
-          "Issue found in, ",
-          tv);
-      TORCH_INTERNAL_ASSERT(
-          tv->domain()->contiguity()[idx],
-          "Cannot merge non-contiguous root domains with misaligned vectorization.",
-          "Issue found in, ",
-          tv);
-    }
-  }
-}
-
-// Check all root iter domains in consumer that are present in domain, making
-// sure they're contiguous. Map these domains to producer and make sure they are
-// also contiguous in producer. Producer-consumer relationship is assumed to be
-// through a set operation.
-void checkContiguity(
-    const std::unordered_set<IterDomain*>& domains,
-    TensorView* consumer,
-    TensorView* producer) {
-  // This seems not quite right, shouldn't we be able to reverse this?
-  TORCH_INTERNAL_ASSERT(consumer->getMemoryType() == MemoryType::Local);
-  TORCH_INTERNAL_ASSERT(producer->getMemoryType() == MemoryType::Global);
-
-  auto root_c2p =
-      PairwiseRootDomainMap(producer, consumer)
-          .mapConsumerToProducer(consumer->domain(), producer->domain());
-
-  std::unordered_map<IterDomain*, bool> producer_domain_contiguity;
-  for (size_t idx = 0; idx < producer->getRootDomain().size(); ++idx) {
-    auto root = producer->getRootDomain()[idx];
-    auto contiguity = producer->domain()->contiguity()[idx];
-    producer_domain_contiguity.insert({root, contiguity});
-  }
-
-  for (auto consumer_root : consumer->getRootDomain()) {
-    if (domains.find(consumer_root) != domains.end()) {
-      auto producer_root = root_c2p[consumer_root];
-      TORCH_INTERNAL_ASSERT(
-          producer_domain_contiguity.find(producer_root) !=
-          producer_domain_contiguity.end());
-
-      TORCH_INTERNAL_ASSERT(
-          !consumer_root->isBroadcast() || !producer_root->isBroadcast(),
-          "Misaligned vectorization prohibits merging broadcast domains.",
-          "Issue found in, ",
-          consumer);
-
-      TORCH_INTERNAL_ASSERT(root_c2p.find(consumer_root) != root_c2p.end());
-
-      TORCH_INTERNAL_ASSERT(
-          producer_domain_contiguity[producer_root],
-          "Cannot merge non-contiguous root domains with misaligned vectorization.",
-          "Issue found in, ",
-          consumer);
-    }
-  }
-}
-
-class VectorizeValidator : public OptInDispatch {
- private:
-  // Initially, vectorized_id is the IterDomain with Vectorize ParallelType
-  // After processing all merge and split operations,
-  // vectorized_id is the corresponding root domain
-  VectorizeValidator(IterDomain* vectorized_id)
-      : vectorized_id_(vectorized_id) {}
-
-  using OptInDispatch::handle;
-
-  void handle(Split* s) final {
-    if (s->outer() == vectorized_id_) {
-      is_valid = false;
-    } else if (s->inner() == vectorized_id_) {
-      vectorized_id_ = s->in();
-    }
-    domains_.insert(s->outer());
-    domains_.insert(s->inner());
-  }
+  FUSER_PERF_SCOPE("validateIr");
 
-  void handle(Merge* m) final {
-    if (m->out() == vectorized_id_) {
-      if (m->inner()->isBroadcast() && !m->outer()->isBroadcast()) {
-        vectorized_id_ = m->outer();
-      } else {
-        vectorized_id_ = m->inner();
-      }
-    }
-    domains_.insert(m->outer());
-    domains_.insert(m->inner());
-  }
-
- private:
-  std::unordered_set<IterDomain*> domains_;
-  IterDomain* vectorized_id_ = nullptr;
-  bool is_valid = true;
-
- public:
-  static void validate(TensorView* tv) {
-    // Make sure there's only one vectorized ID
-    IterDomain* v_id = nullptr;
-    bool misaligned_vectorize = false;
-    for (auto id : tv->domain()->domain()) {
-      if (id->getParallelType() == ParallelType::Vectorize ||
-          id->getParallelType() == ParallelType::MisalignedVectorize) {
-        TORCH_INTERNAL_ASSERT(
-            v_id == nullptr,
-            "Found two vectorized domains in ",
-            tv,
-            " only one is allowed.");
-        v_id = id;
-        misaligned_vectorize =
-            id->getParallelType() == ParallelType::MisalignedVectorize;
-      }
-    }
-
-    // If no vectorized id's found simply return;
-    if (v_id == nullptr) {
-      return;
-    }
-
-    auto fusion = FusionGuard::getCurFusion();
-
-    TORCH_CHECK(
-        v_id->extent()->isConstScalar(),
-        "Vectorizing a domain requires a constant size.");
-
-    ExpressionEvaluator const_expr_eval(fusion);
-
-    auto vector_size_optional = const_expr_eval.evaluate(v_id->extent());
-
-    TORCH_CHECK(
-        vector_size_optional.has_value(),
-        "Could not evaluate constant value bound to vectorized dim.");
-
-    auto vector_size = ((int64_t)dataTypeSize(tv->getDataType().value())) *
-        vector_size_optional.value();
-
-    // Allow half2, float2, float4 and same sized vtypes.
-    std::array<int64_t, 4> allowed_vector_sizes = {2, 4, 8, 16}; // NOLINT
-
-    TORCH_CHECK(
-        std::find(
-            allowed_vector_sizes.begin(),
-            allowed_vector_sizes.end(),
-            vector_size) != allowed_vector_sizes.end(),
-        "Tried to vectorize a dim resulting in a word size of ",
-        vector_size,
-        " however, vector sizes only upto and including 16 bytes are supported.");
-
-    auto replay_exprs = ExprSort::getExprs(fusion, {v_id});
-
-    VectorizeValidator validator(v_id);
-
-    for (auto expr_it = replay_exprs.rbegin(); expr_it != replay_exprs.rend();
-         ++expr_it) {
-      auto expr = *expr_it;
-      validator.handle(expr);
-    }
-
-    TORCH_CHECK(
-        validator.is_valid,
-        "Invalid vectorized pattern found, vectorization iter domains must be descendants of inner-most dimension.",
-        "Issue found in, ",
-        tv,
-        "\n");
-
-    if (misaligned_vectorize) {
-      if (tv->getMemoryType() == MemoryType::Global) {
-        checkContiguity(validator.domains_, tv);
-      } else if (
-          tv->definition()->getExprType() == ExprType::UnaryOp &&
-          tv->definition()->as<UnaryOp>()->getUnaryOpType() ==
-              UnaryOpType::Set) {
-        auto input = tv->definition()->input(0);
-        TORCH_INTERNAL_ASSERT(input->isA<TensorView>());
-        auto input_tv = input->as<TensorView>();
-        checkContiguity(validator.domains_, tv, input_tv);
-      }
-    }
-
-    TORCH_INTERNAL_ASSERT(validator.vectorized_id_ != nullptr);
-
-    // TODO: Contiguity is based on root domain not rfactor. Seems this
-    // generally doesn't cause problems, though contiguity should be on rfactor
-    // domain as that's the domain we index on.
-    IterDomain* last_root_dim = nullptr;
-    int last_root_dim_pos = -1;
-    for (size_t i = tv->getRootDomain().size(); i > 0; i--) {
-      auto r_id = tv->getRootDomain()[i - 1];
-      if (r_id->isReduction() || r_id->isBroadcast()) {
-        continue;
-      }
-      last_root_dim = r_id;
-      last_root_dim_pos = (int)i - 1;
-      break;
-    }
-
-    if (last_root_dim == nullptr) {
-      // Should never get here, but that would mean there are no concrete dims,
-      // so we should be fine.
-      return;
-    }
-
-    TORCH_CHECK(
-        last_root_dim == validator.vectorized_id_ &&
-            tv->domain()->contiguity()[last_root_dim_pos],
-        "Vectorized dim has to be from a contiguous inner most position: ",
-        tv,
-        "\n");
-  }
-};
-
-} // namespace
-
-void validateVectorize(Fusion* fusion) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::validateVectorize");
   FusionGuard fg(fusion);
 
-  auto used_vals = fusion->usedMathVals();
+  auto used_vals = DependencyCheck::getAllValsBetween(
+      {fusion->outputs().begin(), fusion->outputs().end()}, fusion->inputs());
 
   std::unordered_set<TensorView*> used_tvs;
 
-  for (auto val : used_vals) {
+  for (const auto& val : used_vals) {
     if (ir_utils::isTV(val)) {
       used_tvs.emplace(val->as<TensorView>());
     }
   }
 
-  for (auto tv : used_tvs) {
-    bool has_vectorize_dim = false;
-    bool has_misaligned_vectorize_dim = false;
-
-    for (size_t i = 0; i < tv->nDims(); i++) {
-      IterDomain* id = tv->axis(i);
-      IterDomain* concrete_id =
-          GpuLower::current()->caParallelMap().getConcreteMappedID(id);
+  fusion->validateInputs();
 
-      auto ptype = concrete_id->getParallelType();
+  for (const auto& tv : used_tvs) {
+    for (const auto i : c10::irange(tv->nDims())) {
+      IterDomain* id = tv->getComputeAtAxis(i).first;
 
-      if (ptype == ParallelType::Vectorize) {
-        // If we want to do this check up front we would have to do 2 things:
-        // (1) Check that the tensor view with vectorize being set on it is
-        // getting set outside the local compute at position
-        // (2) Check any producers of the tensor view with vectorize being set
-        // on it to make sure their compute at position isn't to the right of
-        // the vectorize dim.
-        TORCH_INTERNAL_ASSERT(
-            i >= tv->getComputeAtPosition(),
-            "IterDomains to the left of the compute at point cannot be vectorized: ",
+      if (id->isBlockDim()) {
+        TORCH_CHECK(
+            !id->isBroadcast(),
+            "Parallelization across blocks on broadcast axes is not supported, but found on, ",
             tv,
-            "\n");
-        has_vectorize_dim = true;
-      }
-
-      if (concrete_id->getParallelType() == ParallelType::MisalignedVectorize) {
-        TORCH_INTERNAL_ASSERT(
-            !tv->hasComputeAt() ||
-                tv->getComputeAtPosition() == tv->nDims() - 1,
-            "Only allow misaligned vectorization in the -2 computeAt position.");
-        TORCH_INTERNAL_ASSERT(
-            tv->getMemoryType() == MemoryType::Local ||
-                tv->getMemoryType() == MemoryType::Global,
-            "Only allow misaligned vectorization between global and local memory.");
-        has_misaligned_vectorize_dim = true;
+            ".");
       }
-    }
-    if (has_vectorize_dim) {
-      TORCH_INTERNAL_ASSERT(
-          tv->definition() == nullptr ||
-              (tv->definition()->isA<UnaryOp>() &&
-               tv->definition()->as<UnaryOp>()->getUnaryOpType() ==
-                   UnaryOpType::Set),
-          "Vectorized accesses cannot be inline with computation, they are only supported with a Set operation.",
-          "TensorView: ",
-          tv);
-    }
-    if (has_vectorize_dim || has_misaligned_vectorize_dim) {
-      VectorizeValidator::validate(tv);
-    }
-  }
-}
-
-namespace {
-
-//! Return true if axis is derived from a root axis that is an input
-//! to a CA leaf axis.
-bool derivedFromRootCAAxes(TensorView* tv, IterDomain* axis) {
-  std::vector<IterDomain*> ca_axes(
-      tv->domain()->domain().begin(),
-      tv->domain()->domain().begin() + tv->getComputeAtPosition());
-
-  auto ca_root_vals = IterVisitor::getInputsTo(
-      std::vector<Val*>(ca_axes.begin(), ca_axes.end()));
-
-  auto root_vals = IterVisitor::getInputsTo({axis});
-
-  return std::any_of(
-      root_vals.begin(), root_vals.end(), [&ca_root_vals](auto root) {
-        return std::find(ca_root_vals.begin(), ca_root_vals.end(), root) !=
-            ca_root_vals.end();
-      });
-}
-
-} // namespace
-
-void validateParallelize(Fusion* fusion) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::validateParallelize");
-  FusionGuard fg(fusion);
-
-  const auto& par_map = GpuLower::current()->caParallelMap();
-  const auto& loop_map = GpuLower::current()->caLoopMap();
-  const auto& index_map = GpuLower::current()->caIndexMap();
-  const auto& pred_map = GpuLower::current()->threadPredMap();
-
-  auto exprs = ExprSort::getExprs(fusion);
-
-  for (auto expr : exprs) {
-    if (!ir_utils::isTVOp(expr)) {
-      continue;
-    }
-    for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) {
-      // Parallelization on input tensors have no effect.
-      if (producer->isFusionInput()) {
-        continue;
-      }
-      const auto parallel_bcast_doms =
-          pred_map.getParallelBroadcastDomains(producer);
-      ParallelTypeBitmap pt_map;
-      for (size_t i = 0; i < producer->nDims(); ++i) {
-        // If a producer axis is threaded, either with threadIdx or
-        // blockIdx, there must be a mapped consumer axis with the
-        // same ParallelType. An exception is when the producer is
-        // allocated on shared memory and its parallelized with
-        // threadIdx. In that case, there is no parallelization
-        // constraint on the consumer as syncthreads will be inserted
-        // when necessary.
-        auto producer_axis = producer->axis(i);
-        auto producer_ptype =
-            par_map.getConcreteMappedID(producer_axis)->getParallelType();
-        if (!isParallelTypeThread(producer_ptype)) {
-          continue;
-        }
-        // Each ParallelType can be used only once.
-        TORCH_INTERNAL_ASSERT(
-            !pt_map.get(producer_ptype),
-            "Multiple use of ",
-            producer_ptype,
-            " in tensor t",
-            producer->name(),
-            ": ",
-            producer);
-        pt_map.set(producer_ptype, true);
-        // When the producer axis is a broadcast, it is not really
-        // parallelized unless thread-predicated
-        if (producer_axis->isBroadcast() && parallel_bcast_doms.none()) {
-          continue;
-        }
-        // No constraint on the consumer tensor when the producer
-        // axis is parallelized with threadIdx and allocates on
-        // shared memory
-        if (isParallelTypeThreadDim(producer_ptype) &&
-            producer->getMemoryType() == MemoryType::Shared) {
-          continue;
-        }
-        // There should be also nothing to validate when the producer
-        // axis is reduction.
-        if (producer_axis->isReduction()) {
-          continue;
-        }
-        // There must be a consumer axis that uses the same indexing
-        // with the same parallel type as the producer axis. The index
-        // map is used to to find such an axis. In addition, even when
-        // no mapped axis is found in the index map, but when an
-        // mapped axis exists in the loop map, the producer and
-        // consumer axes may still use the same indexing. That only
-        // happens when the producer is derived from a root axis that
-        // is an input to any leaf CA axes. In such a case, the axis
-        // in the reference tensor that maps to
-        // the producer axis is created based on the consumer, so both
-        // the producer and consumer axes should have the same
-        // indexing. See issue #995 as well as the
-        // FusionValidateParallelize6 test for a concrete example.
-        for (auto consumer :
-             ir_utils::filterByType<TensorView>(expr->outputs())) {
-          auto it = std::find_if(
-              consumer->domain()->domain().begin(),
-              consumer->domain()->domain().end(),
-              [&](IterDomain* consumer_axis) {
-                return index_map.areMapped(producer_axis, consumer_axis) ||
-                    (loop_map.areMapped(producer_axis, consumer_axis) &&
-                     derivedFromRootCAAxes(producer, producer_axis));
-              });
-          TORCH_INTERNAL_ASSERT(
-              it != consumer->domain()->domain().end(),
-              "Inconsistent parallelization found between TV",
-              producer->name(),
-              " (",
-              producer,
-              ") and TV",
-              consumer->name(),
-              "(",
-              consumer,
-              "). ",
-              "TV",
-              consumer->name(),
-              " does not have a matching axis for parallelized producer axis, ",
-              producer_axis,
-              ". CA Map: ",
-              loop_map.toString());
-          auto consumer_axis = *it;
-          auto consumer_ptype =
-              par_map.getConcreteMappedID(consumer_axis)->getParallelType();
-          TORCH_INTERNAL_ASSERT(
-              producer_ptype == consumer_ptype,
-              "Inconsistent parallelization found between TV",
-              producer->name(),
-              " (",
-              producer,
-              ") and TV",
-              consumer->name(),
-              "(",
-              consumer,
-              "). "
-              "Producer axis, ",
-              producer_axis,
-              " is parallelized with ",
-              stringifyThread(producer_ptype),
-              ", but the parallel type of its matching consumer axis, ",
-              consumer_axis,
-              " is ",
-              stringifyThread(consumer_ptype),
-              ".");
+      if (tv->hasBroadcast() && tv->getMemoryType() != MemoryType::Global) {
+        auto td = tv->domain()->domain();
+        auto ca_inputs = ir_utils::iterDomainInputsOf(
+            {td.begin(), td.begin() + tv->getThisComputeAtAxis()});
+        auto non_ca_inputs = ir_utils::iterDomainInputsOf(
+            {td.begin() + tv->getThisComputeAtAxis(), td.end()});
+
+        std::unordered_set<IterDomain*> ca_inputs_set(
+            ca_inputs.begin(), ca_inputs.end());
+        std::unordered_set<IterDomain*> non_ca_inputs_set(
+            non_ca_inputs.begin(), non_ca_inputs.end());
+
+        for (const auto& id : tv->getRootDomain()) {
+          if (id->isBroadcast()) {
+            // If a broadcast dimension is an input to both an axis within the
+            // computeAt point and outside the compute at point we would have to
+            // look at consumers to figure out what that axis will be
+            // broadcasted to, because we would have to generate everything the
+            // consumer could need on that axis. This could be supported but is
+            // not at this point.
+            TORCH_INTERNAL_ASSERT(
+                !(ca_inputs_set.find(id) != ca_inputs_set.end() &&
+                  non_ca_inputs_set.find(id) != non_ca_inputs_set.end()),
+                "Cannot generate a kernel where a root broadcast dimension is input to both IterDomains outside and within the computeAt point.");
+          }
         }
       }
     }
index 445de03..eddee4f 100644 (file)
@@ -11,17 +11,6 @@ namespace cuda {
 
 void validateIr(Fusion* fusion);
 
-void validateVectorize(Fusion* fusion);
-
-//! Validates all tensors are consistently parallelized. Basically,
-//! when a producer axis is threaded, either with threadIdx or
-//! blockIdx, there must be a mapped consumer axis with the
-//! same ParallelType with some exceptions.
-//!
-//! This function assumes Loop and Parallel ComputeAtMaps are already
-//! built as they are used to validate consistency.
-void validateParallelize(Fusion* fusion);
-
 } // namespace cuda
 } // namespace fuser
 } // namespace jit
index 6495978..d46653b 100644 (file)
@@ -5,9 +5,8 @@
 #include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
 #include <torch/csrc/jit/codegen/cuda/manager.h>
 #include <torch/csrc/jit/codegen/cuda/parser.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
+#include <torch/csrc/jit/codegen/cuda/scheduler.h>
 #include <torch/csrc/jit/codegen/cuda/shape_inference.h>
-#include <torch/csrc/jit/codegen/cuda/utils.h>
 #include <torch/csrc/jit/passes/canonicalize.h>
 #include <torch/csrc/jit/passes/shape_analysis.h>
 #include <torch/csrc/jit/runtime/graph_executor.h>
@@ -73,8 +72,8 @@ class CudaFusionManager {
     // We should not call `EraseShapeInformation(graph);`, graph representation
     // does not incorporate static sizes, but just rank of input tensors, which
     // is exactly what we wanted.
-    auto canonical_graph = Canonicalize(graph, false);
-    auto repr = canonical_graph->toString(false);
+    Canonicalize(graph, false);
+    auto repr = graph->toString(false);
 
     // create new graph_cache_ids_ entry if none existed yet;
     if (graph_cache_ids_.count(repr) == 0) {
@@ -87,24 +86,10 @@ class CudaFusionManager {
     return graph_cache_ids_[repr];
   };
 
-  void unregisterCacheId(std::shared_ptr<Graph>& graph) {
-    auto canonical_graph = Canonicalize(graph, false);
-    auto repr = canonical_graph->toString(false);
-
-    // create new graph_cache_ids_ entry if none existed yet;
-    if (graph_cache_ids_.count(repr) > 0) {
-      int32_t kernel_id = graph_cache_ids_[repr];
-      graph_cache_.erase(kernel_id);
-      graph_cache_ids_.erase(repr);
-    }
-  }
-
   std::vector<at::Tensor> runFusionNode(
       int32_t kernel_id,
       const at::ArrayRef<IValue> inputs) {
     std::lock_guard<std::mutex> guard(mutex_);
-    TORCH_INTERNAL_ASSERT(
-        graph_cache_.count(kernel_id) > 0, "graph cache miss at run time");
     return graph_cache_[kernel_id]->runGraphWithInputs(inputs);
   }
 
@@ -220,7 +205,7 @@ class CudaFusionManager {
 } // namespace
 
 void compileCudaFusionGroup(Node* fusion_node) {
-  FUSER_PERF_SCOPE("nvFuser::Manager::compileCudaFusionGroup");
+  FUSER_PERF_SCOPE("compileCudaFusionGroup");
 
   TORCH_CHECK(
       fusion_node->kind() == prim::CudaFusionGroup,
@@ -231,61 +216,37 @@ void compileCudaFusionGroup(Node* fusion_node) {
   // This is not a critical code path, it's OK to do graph copy here;
   auto graph = fusion_node->g(attr::Subgraph)->copy();
 
-  auto compile_fusion = [&]() {
-    // type propagation is needed, as the protocol only requires scalar type on
-    // input tensors.
-    // Note that even for Profiling Executor, scalar type could still be
-    // missing, especially for output tensor from a given node (as profiling
-    // node only insert meta information after itself).
-    TypePropagate(graph);
-
-    int32_t fusion_cache_id =
-        CudaFusionManager::getManager().registerOrGetCacheId(graph);
-    fusion_node->i_(attr::cache_id, fusion_cache_id);
-  };
-
-  if (useFallback()) {
-    try {
-      compile_fusion();
-    } catch (...) {
-      TORCH_WARN(
-          "FALLBACK path has been taken. This is an indication that codegen"
-          "Failed for some reason. To debug try disable codegen fallback path"
-          "via setting the env variable"
-          "`export PYTORCH_NVFUSER_DISABLE_FALLBACK=1`");
-      CudaFusionManager::getManager().unregisterCacheId(graph);
-    }
-  } else {
-    compile_fusion();
-  }
+  // type propagation is needed, as the protocol only requires scalar type on
+  // input tensors.
+  // Note that even for Profiling Executor, scalar type could still be missing,
+  // especially for output tensor from a given node (as profiling node only
+  // insert meta information after itself).
+  TypePropagate(graph);
+
+  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+  int32_t fusion_cache_id =
+      CudaFusionManager::getManager().registerOrGetCacheId(graph);
+  fusion_node->i_(attr::cache_id, fusion_cache_id);
 }
 
 void runCudaFusionGroup(const Node* fusion_node, Stack& stack) {
-  FUSER_PERF_SCOPE("nvFuser::Manager::runCudaFusionGroup");
-
-  // Fallback to use if anything goes wrong
-  auto take_fallback = [&]() {
-    // copying graph here since we are eliminating shape information;
-    auto copied_graph = fusion_node->g(attr::Subgraph)->copy();
-    EraseShapeInformation(copied_graph);
-    InterpreterState{Code(copied_graph, "fallback_cuda_fuser")}.run(stack);
-  };
+  FUSER_PERF_SCOPE("runCudaFusionGroup");
 
-  auto run_fusion = [&]() {
-    TORCH_CHECK(
-        fusion_node->kind() == prim::CudaFusionGroup,
-        "prim::CudaFusionGroup expected");
-    // TODO: should we support runtime compilation with updated dynamic shape;
-    //       shape inference would be needed so we can allocate output;
-    TORCH_CHECK(
-        fusion_node->hasAttribute(attr::cache_id),
-        "node prim::CudaFusionGroup has not been compiled yet");
+  TORCH_CHECK(
+      fusion_node->kind() == prim::CudaFusionGroup,
+      "prim::CudaFusionGroup expected");
+  // TODO: should we support runtime compilation with updated dynamic shape;
+  //       shape inference would be needed so we can allocate output;
+  TORCH_CHECK(
+      fusion_node->hasAttribute(attr::cache_id),
+      "node prim::CudaFusionGroup has not been compiled yet");
+  int32_t kernel_id = fusion_node->i(attr::cache_id);
 
-    int32_t kernel_id = fusion_node->i(attr::cache_id);
-    // Currently we just construct I/O tensors for static graph;
+  // Currently we just construct I/O tensors for static graph;
 
-    const auto nInputs = fusion_node->g(attr::Subgraph)->inputs().size();
+  const auto nInputs = fusion_node->g(attr::Subgraph)->inputs().size();
 
+  auto execute_lambda = [&]() {
     at::ArrayRef<IValue> inputs = last(stack, nInputs);
 
     auto outputs =
@@ -298,19 +259,24 @@ void runCudaFusionGroup(const Node* fusion_node, Stack& stack) {
         std::make_move_iterator(outputs.end()));
   };
 
-  if (useFallback()) {
+  const char* disable_fb_env = getenv("PYTORCH_CUDA_FUSER_DISABLE_FALLBACK");
+  int disable_fb_flag = disable_fb_env ? atoi(disable_fb_env) : 0;
+  if (disable_fb_flag) {
+    execute_lambda();
+  } else {
     try {
-      run_fusion();
+      execute_lambda();
     } catch (...) {
       TORCH_WARN(
-          "FALLBACK path has been taken. This is an indication that codegen"
+          "FALLBACK path is taken. This is an indication that codegen"
           "Failed for some reason. To debug try disable codegen fallback path"
           "via setting the env variable"
-          "`export PYTORCH_NVFUSER_DISABLE_FALLBACK=1`");
-      take_fallback();
+          "`export PYTORCH_CUDA_FUSER_DISABLE_FALLBACK=1`");
+      // copying graph here since we are eliminating shape information;
+      auto copied_graph = fusion_node->g(attr::Subgraph)->copy();
+      EraseShapeInformation(copied_graph);
+      InterpreterState{Code(copied_graph, "fallback_cuda_fuser")}.run(stack);
     }
-  } else {
-    run_fusion();
   }
 }
 
index a717b9f..281da0d 100644 (file)
@@ -10,6 +10,24 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
+void OptOutMutator::mutate(Fusion* fusion) {
+  std::vector<Expr*> orig_exprs = fusion->exprs();
+
+  /*
+   * We go through all the exprs, in topologically sorted order. We call mutate
+   * on them which could insert nodes, removes nodes, or both. These operations
+   * modify the dag and the Fusion will keep track of what has/hasn't been
+   * changed by the origin dependency tracking that it does. If an operation is
+   * added, and its output node is a val which previously was the output of
+   * another expresion, that older expresion will be removed as we can only
+   * assign a Val once due to our SSA restriction. Therefore we don't need to
+   * manually track what expressions stayed constant or were changed.
+   */
+
+  for (Statement* stmt : orig_exprs)
+    mutate(stmt);
+}
+
 // MUTATE FUNCTIONS FOR VALS
 
 Statement* OptOutMutator::mutate(IterDomain* id) {
@@ -46,20 +64,37 @@ Statement* OptOutMutator::mutate(TensorDomain* td) {
 Statement* OptOutMutator::mutate(TensorView* tv) {
   TensorDomain* td = mutateAsVal(tv->domain())->as<TensorDomain>();
 
-  if (!tv->domain()->sameAs(td)) {
+  TensorView* computeAtView = nullptr;
+  if (tv->hasComputeAt())
+    computeAtView = mutateAsVal(tv->getComputeAtView())->as<TensorView>();
+
+  if (!tv->domain()->sameAs(td) ||
+      (tv->hasComputeAt() && !tv->getComputeAtView()->sameAs(computeAtView))) {
     TensorView* mutated_tv = new TensorView(td, tv->getDataType().value());
+    if (tv->hasComputeAt()) {
+      mutated_tv->setComputeAt(
+          computeAtView, (int)(tv->getRelativeComputeAtAxis()));
+    }
     registerMutation(tv, mutated_tv);
     return mutated_tv;
   }
   return tv;
 }
 
+Statement* OptOutMutator::mutate(kir::TensorIndex* ti) {
+  return ti;
+}
+
 Statement* OptOutMutator::mutate(Bool* b) {
   return b;
 }
 
-Statement* OptOutMutator::mutate(Double* d) {
-  return d;
+Statement* OptOutMutator::mutate(Float* f) {
+  return f;
+}
+
+Statement* OptOutMutator::mutate(Half* h) {
+  return h;
 }
 
 Statement* OptOutMutator::mutate(Int* i) {
@@ -72,6 +107,14 @@ Statement* OptOutMutator::mutate(NamedScalar* ns) {
 
 // MUTATE FUNCTIONS FOR EXPRESSIONS.
 
+Statement* OptOutMutator::mutate(kir::Allocate* a) {
+  return a;
+}
+
+Statement* OptOutMutator::mutate(kir::Sync* a) {
+  return a;
+}
+
 Statement* OptOutMutator::mutate(Split* s) {
   IterDomain* ot = mutateAsVal(s->outer())->as<IterDomain>();
   IterDomain* inr = mutateAsVal(s->inner())->as<IterDomain>();
@@ -83,7 +126,7 @@ Statement* OptOutMutator::mutate(Split* s) {
     return s;
   }
   FusionGuard::getCurFusion()->removeExpr(s);
-  return new Split(ot, inr, in, fact, s->innerSplit());
+  return new Split(ot, inr, in, fact);
 }
 
 Statement* OptOutMutator::mutate(Merge* m) {
@@ -141,83 +184,20 @@ Statement* OptOutMutator::mutate(ReductionOp* rop) {
   return new ReductionOp(rop->getReductionOpType(), init, out, in);
 }
 
-namespace {
-__inline__ bool compareOptional(Val* a, Val* b) {
-  if (!a || !b) {
-    return (!a && !b);
-  }
-  return a->sameAs(b);
-}
-
-} // namespace
-
-Statement* OptOutMutator::mutate(WelfordOp* wop) {
-  Val* out_avg = mutateAsVal(wop->outAvg())->asVal();
-  Val* out_var = mutateAsVal(wop->outVar())->asVal();
-  Val* out_N = mutateAsVal(wop->outN())->asVal();
-
-  Val* in_avg = mutateAsVal(wop->inAvg())->asVal();
-  Val* in_var = wop->inVar() ? mutateAsVal(wop->inVar())->asVal() : nullptr;
-  Val* in_N = mutateAsVal(wop->inN())->asVal();
-
-  Val* init_avg =
-      wop->initAvg() ? mutateAsVal(wop->initAvg())->asVal() : nullptr;
-  Val* init_var =
-      wop->initVar() ? mutateAsVal(wop->initVar())->asVal() : nullptr;
-  Val* init_N = mutateAsVal(wop->initN())->asVal();
-
-  const bool out_compare = out_avg->sameAs(wop->outAvg()) &&
-      out_var->sameAs(wop->outVar()) && out_N->sameAs(wop->outN());
-  const bool in_compare = in_avg->sameAs(wop->inAvg()) &&
-      compareOptional(in_var, wop->inVar()) && in_N->sameAs(wop->inN());
-  const bool init_compare = compareOptional(init_avg, wop->initAvg()) &&
-      compareOptional(init_var, wop->initVar()) && init_N->sameAs(wop->initN());
-
-  if (out_compare && init_compare && in_compare) {
-    return wop;
-  } else {
-    return new WelfordOp(
-        out_avg,
-        out_var,
-        out_N,
-        init_avg,
-        init_var,
-        init_N,
-        in_avg,
-        in_var,
-        in_N);
-  }
+Statement* OptOutMutator::mutate(kir::GridReduction* gr) {
+  return gr;
 }
 
 Statement* OptOutMutator::mutate(BroadcastOp* bop) {
   return bop;
 }
 
-Statement* OptOutMutator::mutate(TransposeOp* top) {
-  return top;
+Statement* OptOutMutator::mutate(kir::ForLoop* fl) {
+  return fl;
 }
 
-Statement* OptOutMutator::mutate(ShiftOp* sop) {
-  Val* out = mutateAsVal(sop->out())->asVal();
-  Val* in = mutateAsVal(sop->in())->asVal();
-
-  if (out->sameAs(sop->out()) && in->sameAs(sop->in()))
-    return sop;
-  auto offsets = sop->offsets();
-  FusionGuard::getCurFusion()->removeExpr(sop);
-  return new ShiftOp(out, in, offsets);
-}
-
-Statement* OptOutMutator::mutate(GatherOp* op) {
-  Val* out = mutateAsVal(op->out())->asVal();
-  Val* in = mutateAsVal(op->in())->asVal();
-
-  if (out->sameAs(op->out()) && in->sameAs(op->in()))
-    return op;
-  auto window_shape = op->windowShape();
-  auto pad_width = op->padWidth();
-  FusionGuard::getCurFusion()->removeExpr(op);
-  return new GatherOp(out, in, window_shape, pad_width);
+Statement* OptOutMutator::mutate(kir::IfThenElse* ite) {
+  return ite;
 }
 
 } // namespace cuda
diff --git a/torch/csrc/jit/codegen/cuda/ops/all_ops.h b/torch/csrc/jit/codegen/cuda/ops/all_ops.h
deleted file mode 100644 (file)
index 1ebd2bb..0000000
+++ /dev/null
@@ -1,4 +0,0 @@
-#pragma once
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/ops/composite.h>
-#include <torch/csrc/jit/codegen/cuda/ops/normalization.h>
diff --git a/torch/csrc/jit/codegen/cuda/ops/composite.cpp b/torch/csrc/jit/codegen/cuda/ops/composite.cpp
deleted file mode 100644 (file)
index 9ab96d5..0000000
+++ /dev/null
@@ -1,159 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/ops/composite.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-ForwardDropoutResult dropout(TensorView* x, Val* prob) {
-  auto p1m = sub(new Double(1.), prob);
-  auto zero_check = add(eq(p1m, new Double(0.)), p1m);
-  auto scale = div(new Double(1.), zero_check);
-  return dropout(x, p1m, scale);
-}
-
-ForwardDropoutResult dropout(TensorView* x, Val* prob, Val* scale) {
-  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
-  TORCH_INTERNAL_ASSERT(
-      prob != nullptr && prob->getDataType().has_value() &&
-          prob->getDataType().value() == DataType::Double,
-      "Probability is not a valid Double.");
-  TORCH_INTERNAL_ASSERT(
-      scale != nullptr && scale->getDataType().has_value() &&
-          scale->getDataType().value() == DataType::Double,
-      "Scale is not a valid Double.");
-
-  auto rand_vals = unaryOp(UnaryOpType::RandLike, x);
-  auto mask = lt(rand_vals, prob);
-  auto apply_mask = mul(x, mask);
-  auto y = mul(apply_mask, scale);
-
-  return {y, mask};
-}
-
-TensorView* dropout_backward(TensorView* dy, TensorView* mask, Val* scale) {
-  TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
-  TORCH_INTERNAL_ASSERT(mask != nullptr, "Mask is invalid");
-  TORCH_INTERNAL_ASSERT(
-      scale != nullptr && scale->getDataType().has_value() &&
-          scale->getDataType().value() == DataType::Double,
-      "Scale is not a valid Double.");
-
-  auto grad_mask = mul(dy, mask);
-  auto dx = mul(grad_mask, scale);
-
-  return dx;
-}
-
-Val* softplus(Val* x, Val* beta, Val* threshold) {
-  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
-  TORCH_INTERNAL_ASSERT(beta != nullptr, "Beta is invalid.");
-  TORCH_INTERNAL_ASSERT(
-      threshold != nullptr, "Threshold is not a valid Double.");
-
-  auto op_beta = mul(x, beta);
-  auto maybe_result = div(
-      unaryOp(UnaryOpType::Log1p, unaryOp(UnaryOpType::Exp, op_beta)), beta);
-  auto y = where(gt(op_beta, threshold), x, maybe_result);
-  return y;
-}
-
-LstmResult lstm(
-    TensorView* prev_cell,
-    TensorView* in_x,
-    TensorView* forget_x,
-    TensorView* cell_x,
-    TensorView* out_x) {
-  TORCH_INTERNAL_ASSERT(
-      prev_cell != nullptr, "Previous cell state is invalid.");
-  TORCH_INTERNAL_ASSERT(in_x != nullptr, "In-gate input is invalid");
-  TORCH_INTERNAL_ASSERT(forget_x != nullptr, "Forget-gate input is invalid");
-  TORCH_INTERNAL_ASSERT(cell_x != nullptr, "Cell-gate input is invalid");
-  TORCH_INTERNAL_ASSERT(out_x != nullptr, "Out-gate input is invalid");
-
-  const auto in_gate = unaryOp(UnaryOpType::Sigmoid, in_x);
-  const auto forget_gate = unaryOp(UnaryOpType::Sigmoid, forget_x);
-  const auto cell_gate = unaryOp(UnaryOpType::Tanh, cell_x);
-  const auto out_gate = unaryOp(UnaryOpType::Sigmoid, out_x);
-
-  const auto cell = add(mul(forget_gate, prev_cell), mul(in_gate, cell_gate));
-  const auto hidden = mul(out_gate, unaryOp(UnaryOpType::Tanh, cell));
-
-  return {cell, hidden};
-}
-
-Val* fast_gelu(Val* x) {
-  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid");
-
-  constexpr double kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
-  constexpr double kKappa = 0.044715;
-
-  auto x_cube = mul(x, mul(x, x));
-
-  auto inner_1 = mul(new Double(kKappa), x_cube);
-  auto inner_2 = add(x, inner_1);
-  auto inner_3 = mul(new Double(kBeta), inner_2);
-  auto tanh_inner = unaryOp(UnaryOpType::Tanh, inner_3);
-
-  auto out = mul(x, add(new Double(1.), tanh_inner));
-  auto y = mul(new Double(0.5), out);
-  return y;
-}
-
-Val* fast_gelu_backward(Val* dy, Val* x) {
-  TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
-  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid");
-
-  constexpr double kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
-  constexpr double kKappa = 0.044715;
-
-  auto x_sq = mul(x, x);
-  auto x_cube = mul(x, x_sq);
-
-  auto inner_1 = mul(new Double(kKappa), x_cube);
-  auto inner_2 = add(x, inner_1);
-  auto inner_3 = mul(new Double(kBeta), inner_2);
-  auto tanh_inner = unaryOp(UnaryOpType::Tanh, inner_3);
-
-  auto left = mul(new Double(0.5), x);
-  auto right = add(new Double(1.), tanh_inner);
-
-  auto left_derivative = mul(new Double(0.5), right);
-
-  auto tanh_inner_sq = mul(tanh_inner, tanh_inner);
-  auto tanh_derivative = sub(new Double(1), tanh_inner_sq);
-
-  auto constant_mul_x_sq = mul(new Double(kBeta * 3 * kKappa), x_sq);
-  auto inner_derivative = add(new Double(kBeta), constant_mul_x_sq);
-  auto right_derivative = mul(left, mul(tanh_derivative, inner_derivative));
-
-  auto dx = mul(dy, add(left_derivative, right_derivative));
-  return dx;
-}
-
-Val* gelu_backward(Val* dy, Val* x) {
-  TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
-  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid");
-
-  constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5;
-  const double kHalf = 0.5;
-
-  auto cdf_1 = mul(x, new Double(M_SQRT1_2));
-  auto cdf_2 = unaryOp(UnaryOpType::Erf, cdf_1);
-  auto cdf_3 = add(cdf_2, new Double(1.));
-  auto cdf_4 = mul(cdf_3, new Double(kHalf));
-
-  auto pdf_1 = mul(x, x);
-  auto pdf_2 = mul(pdf_1, new Double(-kHalf));
-  auto pdf_3 = unaryOp(UnaryOpType::Exp, pdf_2);
-
-  auto out = addcmul(cdf_4, x, pdf_3, new Double(kAlpha));
-  auto dx = mul(out, dy);
-  return dx;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/ops/composite.h b/torch/csrc/jit/codegen/cuda/ops/composite.h
deleted file mode 100644 (file)
index f130b27..0000000
+++ /dev/null
@@ -1,55 +0,0 @@
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/type.h>
-
-//
-// The operations defined in this header is intended as user facing functions.
-// The user will provide the necessary input TensorViews and the function will
-// create the correct intermediate nodes and return the output TensorViews.
-//
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-struct ForwardDropoutResult {
-  TensorView* output = nullptr;
-  TensorView* mask = nullptr;
-};
-
-TORCH_CUDA_CU_API ForwardDropoutResult dropout(TensorView* x, Val* prob);
-
-TORCH_CUDA_CU_API ForwardDropoutResult
-dropout(TensorView* x, Val* prob, Val* scale);
-
-TORCH_CUDA_CU_API TensorView* dropout_backward(
-    TensorView* dy,
-    TensorView* mask,
-    Val* scale);
-
-TORCH_CUDA_CU_API Val* softplus(Val* x, Val* beta, Val* threshold);
-
-struct LstmResult {
-  TensorView* cell = nullptr;
-  TensorView* hidden = nullptr;
-};
-
-TORCH_CUDA_CU_API LstmResult lstm(
-    TensorView* prev_cell,
-    TensorView* in_x,
-    TensorView* forget_x,
-    TensorView* cell_x,
-    TensorView* out_x);
-
-TORCH_CUDA_CU_API Val* fast_gelu(Val* x);
-TORCH_CUDA_CU_API Val* fast_gelu_backward(Val* dy, Val* x);
-TORCH_CUDA_CU_API Val* gelu_backward(Val* dy, Val* x);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/ops/normalization.cpp b/torch/csrc/jit/codegen/cuda/ops/normalization.cpp
deleted file mode 100644 (file)
index f3ea0cf..0000000
+++ /dev/null
@@ -1,560 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/ops/normalization.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-TensorView* softmax(TensorView* x, int dim) {
-  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
-
-  const int kNumberOfDims =
-      TensorDomain::noReductions(x->getRootDomain()).size();
-  const int kReductionAxis = (dim < 0) ? dim + kNumberOfDims : dim;
-  TORCH_INTERNAL_ASSERT(kReductionAxis >= 0 && kReductionAxis < kNumberOfDims);
-
-  std::vector<bool> broadcast_mask(kNumberOfDims, false);
-  broadcast_mask[kReductionAxis] = true;
-
-  auto max_val = max(x, {kReductionAxis});
-  auto bcast_max = broadcast(max_val, broadcast_mask);
-  auto x_max_sub = sub(x, bcast_max);
-  auto exp = unaryOp(UnaryOpType::Exp, x_max_sub);
-  auto sum_exp = sum(exp, {kReductionAxis});
-  auto bcast_sum = broadcast(sum_exp, broadcast_mask);
-  auto y = div(exp, bcast_sum);
-
-  return y;
-}
-
-TensorView* softmax_backward(
-    TensorView* dy,
-    TensorView* y,
-    int dim,
-    TensorView* x) {
-  TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
-  TORCH_INTERNAL_ASSERT(y != nullptr, "Output is invalid.");
-  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
-
-  const int kNumberOfDims =
-      TensorDomain::noReductions(x->getRootDomain()).size();
-  const int kReductionAxis = (dim < 0) ? dim + kNumberOfDims : dim;
-  TORCH_INTERNAL_ASSERT(kReductionAxis >= 0 && kReductionAxis < kNumberOfDims);
-
-  std::vector<bool> broadcast_mask(kNumberOfDims, false);
-  broadcast_mask[kReductionAxis] = true;
-
-  auto new_grad = mul(dy, y);
-  auto sum_new_grad = sum(new_grad, {kReductionAxis});
-  auto bcast_sum = broadcast(sum_new_grad, broadcast_mask);
-  auto output_sum_mul = mul(y, bcast_sum);
-  auto dx = sub(new_grad, output_sum_mul);
-
-  return dx;
-}
-
-ForwardNormResult layer_norm(
-    TensorView* x,
-    const std::vector<int64_t>& norm_shape,
-    TensorView* weight,
-    TensorView* bias,
-    Val* eps) {
-  return layer_norm(x, norm_shape.size(), weight, bias, eps);
-}
-
-ForwardNormResult layer_norm(
-    TensorView* x,
-    const size_t kNormShapeNumDims,
-    TensorView* weight,
-    TensorView* bias,
-    Val* eps) {
-  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
-  TORCH_INTERNAL_ASSERT(
-      eps != nullptr && eps->getDataType().has_value() &&
-          eps->getDataType().value() == DataType::Double,
-      "Epsilon (eps) is not a valid Double.");
-
-  // (B, C, H, W, D) tensor
-  // norm_shape = [H, W, D]
-  // M = outer = product of remaining dimensions = B * C
-  // N = reduction = product of norm_shape = H * W * D
-  // weight = bias = norm_shape tensor
-  const size_t kNumberOfDims =
-      TensorDomain::noReductions(x->getRootDomain()).size();
-  const size_t kOuterNumDims = kNumberOfDims - kNormShapeNumDims;
-
-  std::vector<int> outer_reduction_axes(kOuterNumDims);
-  std::vector<bool> outer_broadcast_mask(kNumberOfDims, false);
-  for (size_t idx = 0; idx < kOuterNumDims; ++idx) {
-    outer_reduction_axes[idx] = idx;
-    outer_broadcast_mask[idx] = true;
-  }
-
-  std::vector<int> inner_reduction_axes(kNormShapeNumDims);
-  std::vector<bool> inner_broadcast_mask(kNumberOfDims, false);
-  Val* num_features = new Double(1);
-  for (size_t idx = 0; idx < kNormShapeNumDims; ++idx) {
-    const size_t axis = kNumberOfDims - 1 - idx;
-    inner_reduction_axes[idx] = axis;
-    inner_broadcast_mask[axis] = true;
-    num_features = mul(num_features, x->domain()->domain()[axis]->extent());
-  }
-
-  // Main algorithm
-  auto welford_out = Welford(x, inner_reduction_axes);
-  auto mean_bcast = broadcast(welford_out.avg, inner_broadcast_mask);
-  auto x_sub_mean = sub(x, mean_bcast);
-
-  auto var_sum_bcast = broadcast(welford_out.var_sum, inner_broadcast_mask);
-  auto var = div(var_sum_bcast, num_features);
-  auto var_eps = add(var, eps);
-  auto invstd = unaryOp(UnaryOpType::Rsqrt, var_eps);
-
-  auto y = mul(x_sub_mean, invstd);
-
-  // Optional: norm * weight
-  if (weight != nullptr) {
-    auto weight_bcast = broadcast(weight, outer_broadcast_mask);
-    y = mul(y, weight_bcast);
-  }
-
-  // Optional: norm * weight + bias
-  if (bias != nullptr) {
-    auto bias_bcast = broadcast(bias, outer_broadcast_mask);
-    y = add(y, bias_bcast);
-  }
-
-  return {y, mean_bcast, invstd};
-}
-
-BackwardNormResult layer_norm_backward(
-    TensorView* dy,
-    TensorView* x,
-    const std::vector<int64_t>& norm_shape,
-    TensorView* mean,
-    TensorView* invstd,
-    TensorView* weight,
-    TensorView* bias,
-    const std::vector<bool>& output_mask) {
-  TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
-  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
-  TORCH_INTERNAL_ASSERT(mean != nullptr, "Mean is invalid.");
-  TORCH_INTERNAL_ASSERT(invstd != nullptr, "Inv std is invalid.");
-
-  // (B, C, H, W, D) tensor
-  // norm_shape = [H, W, D]
-  // M = outer = product of remaining dimensions = B * C
-  // N = reduction = product of norm_shape = H * W * D
-  // weight = bias = norm_shape tensor
-  const size_t kNumberOfDims =
-      TensorDomain::noReductions(x->getRootDomain()).size();
-  const size_t kNormShapeNumDims = norm_shape.size();
-  const size_t kOuterNumDims = kNumberOfDims - kNormShapeNumDims;
-
-  std::vector<int> outer_reduction_axes(kOuterNumDims);
-  std::vector<bool> outer_broadcast_mask(kNumberOfDims, false);
-  for (size_t idx = 0; idx < kOuterNumDims; ++idx) {
-    outer_reduction_axes[idx] = idx;
-    outer_broadcast_mask[idx] = true;
-  }
-
-  std::vector<int> inner_reduction_axes(kNormShapeNumDims);
-  std::vector<bool> inner_broadcast_mask(kNumberOfDims, false);
-  Val* num_features = new Double(1);
-  for (size_t idx = 0; idx < kNormShapeNumDims; ++idx) {
-    const size_t axis = kNumberOfDims - 1 - idx;
-    inner_reduction_axes[idx] = axis;
-    inner_broadcast_mask[axis] = true;
-    num_features = mul(num_features, x->domain()->domain()[axis]->extent());
-  }
-
-  auto x_hat = mul(sub(x, mean), invstd);
-
-  TensorView* grad_x_hat = nullptr;
-  if (weight != nullptr) {
-    auto* bcast_weight = broadcast(weight, outer_broadcast_mask);
-    grad_x_hat = mul(dy, bcast_weight);
-  } else {
-    grad_x_hat = dy;
-  }
-
-  auto a = mul(num_features, grad_x_hat);
-
-  auto b = sum(grad_x_hat, inner_reduction_axes);
-  auto bcast_b = broadcast(b, inner_broadcast_mask);
-
-  auto c1 = mul(grad_x_hat, x_hat);
-  auto c2 = sum(c1, inner_reduction_axes);
-  auto bcast_c2 = broadcast(c2, inner_broadcast_mask);
-  auto c3 = mul(x_hat, bcast_c2);
-
-  auto inner = sub(sub(a, bcast_b), c3);
-  auto reciprocal_size = unaryOp(UnaryOpType::Reciprocal, num_features);
-
-  TensorView* dx = nullptr;
-  if (output_mask[0]) {
-    dx = mul(mul(reciprocal_size, invstd), inner);
-  }
-
-  TensorView* dw = nullptr;
-  if (output_mask[1] && weight != nullptr) {
-    dw = sum(mul(dy, x_hat), outer_reduction_axes);
-  }
-
-  TensorView* db = nullptr;
-  if (output_mask[2] && bias != nullptr) {
-    db = sum(dy, outer_reduction_axes);
-  }
-  return {dx, dw, db};
-}
-
-ForwardNormResult batch_norm(
-    TensorView* x,
-    TensorView* weight,
-    TensorView* bias,
-    TensorView* running_mean,
-    TensorView* running_var,
-    const bool kTraining,
-    Val* momentum,
-    Val* eps) {
-  auto fusion = FusionGuard::getCurFusion();
-
-  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
-
-  TORCH_INTERNAL_ASSERT(
-      !((running_var == nullptr) ^ (running_mean == nullptr)),
-      "running stats should comes in pairs");
-
-  TORCH_INTERNAL_ASSERT(
-      momentum != nullptr && momentum->getDataType().has_value() &&
-          momentum->getDataType().value() == DataType::Double,
-      "Momentum is not a valid Double.");
-
-  TORCH_INTERNAL_ASSERT(
-      eps != nullptr && eps->getDataType().has_value() &&
-          eps->getDataType().value() == DataType::Double,
-      "Epsilon (eps) is not a valid Double.");
-
-  // (B, C, H, W, D) tensor
-  // M = outer = channels
-  // N = reduction = B * H * W * D
-  // weight = bias = (C) tensor
-  // const size_t kChannelsDim = 1;
-  const size_t kNumberOfDims =
-      TensorDomain::noReductions(x->getRootDomain()).size();
-
-  std::vector<int> reduction_axes;
-  std::vector<bool> broadcast_mask(kNumberOfDims, false);
-  Val* num_features = new Double(1);
-  for (size_t axis = 0; axis < kNumberOfDims; ++axis) {
-    if (axis != 1) {
-      reduction_axes.push_back(axis);
-      broadcast_mask[axis] = true;
-      num_features = mul(num_features, x->domain()->domain()[axis]->extent());
-    }
-  }
-
-  TensorView* y = nullptr;
-  TensorView* mean = nullptr;
-  TensorView* invstd = nullptr;
-  if (kTraining || running_mean == nullptr) {
-    // Algorithm
-    auto welford_out = Welford(x, reduction_axes);
-
-    // updating running mean and running var
-    if (running_mean != nullptr && running_var != nullptr) {
-      auto rev_momentum = sub(new Double(1.0), momentum);
-      auto current_mean_hat = mul(welford_out.avg, momentum);
-      auto mean_hat = mul(running_mean, rev_momentum);
-      auto new_mean_hat = add(mean_hat, current_mean_hat);
-      fusion->addOutput(new_mean_hat);
-      fusion->aliasOutputToInput(new_mean_hat, running_mean);
-
-      auto num_feature_decrement = sub(num_features, new Int(1));
-      auto unbiased_var = div(welford_out.var_sum, num_feature_decrement);
-      auto current_var_hat = mul(unbiased_var, momentum);
-      auto var_hat = mul(running_var, rev_momentum);
-      auto new_var_hat = add(var_hat, current_var_hat);
-      fusion->addOutput(new_var_hat);
-      fusion->aliasOutputToInput(new_var_hat, running_var);
-    }
-
-    mean = welford_out.avg;
-    auto mean_bcast = broadcast(mean, broadcast_mask);
-    auto x_sub_mean = sub(x, mean_bcast);
-
-    auto var = div(welford_out.var_sum, num_features);
-    auto var_eps = add(var, eps);
-    invstd = unaryOp(UnaryOpType::Rsqrt, var_eps);
-    auto invstd_bcast = broadcast(invstd, broadcast_mask);
-
-    y = mul(x_sub_mean, invstd_bcast);
-  } else {
-    // This is inference mode with running stats
-    auto r_mean_bcasted = broadcast(running_mean, broadcast_mask);
-    auto x_sub_mean = sub(x, r_mean_bcasted);
-
-    auto var_eps = add(running_var, eps);
-    auto unbiased_invstd = unaryOp(UnaryOpType::Rsqrt, var_eps);
-    auto invstd_bcast = broadcast(unbiased_invstd, broadcast_mask);
-
-    // During inference, mean/invstd output are empty tensors
-    mean = TensorViewBuilder().shape({0}).build();
-    invstd = TensorViewBuilder().shape({0}).build();
-    y = mul(x_sub_mean, invstd_bcast);
-  }
-
-  // Optional: norm * weight
-  if (weight) {
-    auto weight_bcast = broadcast(weight, broadcast_mask);
-    y = mul(y, weight_bcast);
-  }
-
-  // Optional: norm * weight + bias
-  if (bias) {
-    auto bias_bcast = broadcast(bias, broadcast_mask);
-    y = add(y, bias_bcast);
-  }
-  return {y, mean, invstd};
-}
-
-BackwardNormResult batch_norm_backward(
-    TensorView* x,
-    TensorView* dy,
-    TensorView* weight,
-    TensorView* running_mean,
-    TensorView* running_var,
-    TensorView* save_mean,
-    TensorView* save_invstd,
-    const bool kTraining,
-    Val* eps,
-    const std::vector<bool>& output_mask) {
-  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
-  TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
-  TORCH_INTERNAL_ASSERT(
-      eps != nullptr && eps->getDataType().has_value() &&
-          eps->getDataType().value() == DataType::Double,
-      "Epsilon (eps) is not a valid Double.");
-
-  // (B, C, H, W, D) tensor
-  // M = outer = channels
-  // N = reduction = B * H * W * D
-  // weight = bias = (C) tensor
-  const size_t kChannelsDim = 1;
-  const size_t kNumberOfDims =
-      TensorDomain::noReductions(x->getRootDomain()).size();
-
-  std::vector<int> reduction_axes;
-  std::vector<bool> broadcast_mask(kNumberOfDims, false);
-  Val* num_features = new Double(1);
-  for (size_t axis = 0; axis < kNumberOfDims; ++axis) {
-    if (axis != kChannelsDim) {
-      reduction_axes.push_back(axis);
-      broadcast_mask[axis] = true;
-      num_features = mul(num_features, x->domain()->domain()[axis]->extent());
-    }
-  }
-
-  Val* bcast_weight = nullptr;
-  if (weight != nullptr) {
-    bcast_weight = broadcast(weight, broadcast_mask);
-  } else {
-    bcast_weight = new Double(1);
-  }
-
-  TensorView* dx = nullptr;
-  TensorView* dw = nullptr;
-  TensorView* db = nullptr;
-  if (kTraining) {
-    TORCH_INTERNAL_ASSERT(
-        save_mean != nullptr && save_invstd != nullptr,
-        "When training=True, save_mean and save_invstd are required.");
-
-    auto bcast_rstd = broadcast(save_invstd, broadcast_mask);
-    auto bcast_mean = broadcast(save_mean, broadcast_mask);
-    auto x_hat = mul(sub(x, bcast_mean), bcast_rstd);
-    auto grad_x_hat = mul(dy, bcast_weight);
-
-    auto a = mul(num_features, grad_x_hat);
-
-    auto b = sum(grad_x_hat, reduction_axes);
-    auto bcast_b = broadcast(b, broadcast_mask);
-
-    auto c1 = mul(grad_x_hat, x_hat);
-    auto c2 = sum(c1, reduction_axes);
-    auto bcast_c2 = broadcast(c2, broadcast_mask);
-    auto c3 = mul(x_hat, bcast_c2);
-
-    auto inner = sub(sub(a, bcast_b), c3);
-
-    auto reciprocal_size = unaryOp(UnaryOpType::Reciprocal, num_features);
-
-    if (output_mask[0]) {
-      dx = mul(mul(reciprocal_size, bcast_rstd), inner);
-    }
-
-    if (output_mask[1]) {
-      dw = sum(mul(dy, x_hat), reduction_axes);
-    }
-  } else {
-    // TODO: this is not a legit assumption? Can't we run with
-    // track_running_stats == false && training == false
-    // which should just run through the case above.
-    TORCH_INTERNAL_ASSERT(
-        running_mean != nullptr && running_var != nullptr,
-        "When training=False, running_mean and running_invstd are required.");
-
-    auto bcast_var = broadcast(running_var, broadcast_mask);
-    auto var_eps = add(bcast_var, eps);
-    auto bcast_rstd = unaryOp(UnaryOpType::Rsqrt, var_eps);
-    auto bcast_mean = broadcast(running_mean, broadcast_mask);
-
-    if (output_mask[0]) {
-      dx = mul(mul(dy, bcast_rstd), bcast_weight);
-    }
-
-    if (output_mask[1]) {
-      auto x_hat = mul(sub(x, bcast_mean), bcast_rstd);
-      dw = sum(mul(dy, x_hat), reduction_axes);
-    }
-  }
-
-  if (output_mask[2]) {
-    db = sum(dy, reduction_axes);
-  }
-
-  return {dx, dw, db};
-}
-
-ForwardNormResult instance_norm(
-    TensorView* x,
-    TensorView* weight,
-    TensorView* bias,
-    TensorView* running_mean,
-    TensorView* running_var,
-    const bool kUseInputStats,
-    Val* momentum,
-    Val* eps) {
-  auto fusion = FusionGuard::getCurFusion();
-
-  TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
-
-  TORCH_INTERNAL_ASSERT(
-      !((running_var == nullptr) ^ (running_mean == nullptr)),
-      "running stats should comes in pairs");
-
-  TORCH_INTERNAL_ASSERT(
-      momentum != nullptr && momentum->getDataType().has_value() &&
-          momentum->getDataType().value() == DataType::Double,
-      "Momentum is not a valid Double.");
-
-  TORCH_INTERNAL_ASSERT(
-      eps != nullptr && eps->getDataType().has_value() &&
-          eps->getDataType().value() == DataType::Double,
-      "Epsilon (eps) is not a valid Double.");
-
-  // (B, C, H, W, D) tensor
-  // M = outer = B * C
-  // N = reduction = H * W * D
-  // weight = bias = C tensor
-  const size_t kBatchDim = 0;
-  const size_t kChannelsDim = 1;
-  const size_t kNumberOfDims =
-      TensorDomain::noReductions(x->getRootDomain()).size();
-
-  std::vector<int> x_reduction_axes;
-  std::vector<bool> x_broadcast_mask(kNumberOfDims, false);
-  Val* N = new Double(1);
-  for (size_t axis = 0; axis < kNumberOfDims; ++axis) {
-    if (axis != kBatchDim && axis != kChannelsDim) {
-      x_reduction_axes.push_back(axis);
-      x_broadcast_mask[axis] = true;
-      N = mul(N, x->domain()->domain()[axis]->extent());
-    }
-  }
-  Val* B = new Double(1);
-  B = mul(B, x->domain()->domain()[kBatchDim]->extent());
-
-  std::vector<bool> channels_only_broadcast_mask(kNumberOfDims, false);
-  for (size_t axis = 0; axis < kNumberOfDims; ++axis) {
-    if (axis != kChannelsDim) {
-      channels_only_broadcast_mask[axis] = true;
-    }
-  }
-
-  TensorView* y = nullptr;
-  TensorView* mean = nullptr;
-  TensorView* invstd = nullptr;
-  if (kUseInputStats || running_mean == nullptr) {
-    // Algorithm
-    auto welford_out = Welford(x, x_reduction_axes);
-
-    // updating running mean and running var
-    if (running_mean != nullptr && running_var != nullptr) {
-      auto rev_momentum = sub(new Double(1.0), momentum);
-      auto current_mean_hat = mul(welford_out.avg, momentum);
-      auto mean_hat = mul(running_mean, rev_momentum);
-      auto new_mean_hat = add(mean_hat, current_mean_hat);
-
-      auto new_mean_sum = sum(new_mean_hat, {kBatchDim});
-      auto new_mean_channels_only = div(new_mean_sum, B);
-      fusion->addOutput(new_mean_channels_only);
-      fusion->aliasOutputToInput(new_mean_channels_only, running_mean);
-
-      auto num_feature_decrement = sub(N, new Int(1));
-      auto unbiased_var = div(welford_out.var_sum, num_feature_decrement);
-      auto current_var_hat = mul(unbiased_var, momentum);
-      auto var_hat = mul(running_var, rev_momentum);
-      auto new_var_hat = add(var_hat, current_var_hat);
-
-      auto new_var_sum = sum(new_var_hat, {kBatchDim});
-      auto new_var_channels_only = div(new_var_sum, B);
-      fusion->addOutput(new_var_channels_only);
-      fusion->aliasOutputToInput(new_var_channels_only, running_var);
-    }
-
-    mean = welford_out.avg;
-    auto mean_bcast = broadcast(mean, x_broadcast_mask);
-    auto x_sub_mean = sub(x, mean_bcast);
-
-    auto var = div(welford_out.var_sum, N);
-    auto var_eps = add(var, eps);
-    invstd = unaryOp(UnaryOpType::Rsqrt, var_eps);
-    auto invstd_bcast = broadcast(invstd, x_broadcast_mask);
-
-    y = mul(x_sub_mean, invstd_bcast);
-  } else {
-    // This is inference mode with running stats
-    auto r_mean_bcasted = broadcast(running_mean, channels_only_broadcast_mask);
-    auto x_sub_mean = sub(x, r_mean_bcasted);
-
-    auto var_eps = add(running_var, eps);
-    auto unbiased_invstd = unaryOp(UnaryOpType::Rsqrt, var_eps);
-    auto invstd_bcast =
-        broadcast(unbiased_invstd, channels_only_broadcast_mask);
-
-    // During inference, mean/invstd output are empty tensors
-    mean = TensorViewBuilder().shape({0}).build();
-    invstd = TensorViewBuilder().shape({0}).build();
-    y = mul(x_sub_mean, invstd_bcast);
-  }
-
-  // Optional: norm * weight
-  if (weight) {
-    auto weight_bcast = broadcast(weight, channels_only_broadcast_mask);
-    y = mul(y, weight_bcast);
-  }
-
-  // Optional: norm * weight + bias
-  if (bias) {
-    auto bias_bcast = broadcast(bias, channels_only_broadcast_mask);
-    y = add(y, bias_bcast);
-  }
-  return {y, mean, invstd};
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/ops/normalization.h b/torch/csrc/jit/codegen/cuda/ops/normalization.h
deleted file mode 100644 (file)
index a951b12..0000000
+++ /dev/null
@@ -1,98 +0,0 @@
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/type.h>
-
-//
-// The operations defined in this header is intended as user facing functions.
-// The user will provide the necessary input TensorViews and the function will
-// create the correct intermediate nodes and return the output TensorViews.
-//
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-struct ForwardNormResult {
-  TensorView* output = nullptr;
-  TensorView* mean = nullptr;
-  TensorView* invstd = nullptr;
-};
-
-struct BackwardNormResult {
-  TensorView* grad_input = nullptr;
-  TensorView* grad_weight = nullptr;
-  TensorView* grad_bias = nullptr;
-};
-
-TORCH_CUDA_CU_API TensorView* softmax(TensorView* x, int dim);
-
-TORCH_CUDA_CU_API TensorView* softmax_backward(
-    TensorView* dy,
-    TensorView* y,
-    const int dim,
-    TensorView* x);
-
-TORCH_CUDA_CU_API ForwardNormResult layer_norm(
-    TensorView* x,
-    const std::vector<int64_t>& norm_shape,
-    TensorView* weight,
-    TensorView* bias,
-    Val* eps);
-
-TORCH_CUDA_CU_API ForwardNormResult layer_norm(
-    TensorView* x,
-    const size_t kNormShapeNumDims,
-    TensorView* weight,
-    TensorView* bias,
-    Val* eps);
-
-TORCH_CUDA_CU_API BackwardNormResult layer_norm_backward(
-    TensorView* dy,
-    TensorView* x,
-    const std::vector<int64_t>& norm_shape,
-    TensorView* mean,
-    TensorView* rstd,
-    TensorView* weight,
-    TensorView* bias,
-    const std::vector<bool>& output_mask);
-
-TORCH_CUDA_CU_API ForwardNormResult batch_norm(
-    TensorView* x,
-    TensorView* weight,
-    TensorView* bias,
-    TensorView* running_mean,
-    TensorView* running_var,
-    const bool kTraining,
-    Val* momentum,
-    Val* eps);
-
-TORCH_CUDA_CU_API BackwardNormResult batch_norm_backward(
-    TensorView* x,
-    TensorView* dy,
-    TensorView* weight,
-    TensorView* running_mean,
-    TensorView* running_var,
-    TensorView* save_mean,
-    TensorView* save_invstd,
-    const bool kTraining,
-    Val* eps,
-    const std::vector<bool>& output_mask);
-
-TORCH_CUDA_CU_API ForwardNormResult instance_norm(
-    TensorView* x,
-    TensorView* weight,
-    TensorView* bias,
-    TensorView* running_mean,
-    TensorView* running_var,
-    const bool kUseInputStats,
-    Val* momentum,
-    Val* eps);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp b/torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp
deleted file mode 100644 (file)
index a27c0be..0000000
+++ /dev/null
@@ -1,296 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/parallel_dimension_map.h>
-
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-
-#include <sstream>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-void ParallelDimensionMap::build(Fusion* fusion) {
-  // Scan all TVs to build ParallelType maps
-  auto all_vals = fusion->usedMathVals();
-  for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
-    for (auto id : tv->domain()->domain()) {
-      registerConstantExtent(id);
-      if (!isParallelTypeThread(id->getParallelType())) {
-        continue;
-      }
-      handleParallelDomain(id);
-    }
-  }
-
-  // Populate the dimension map for each parallel type
-  for (const auto& kv : concrete_dom_map_) {
-    auto pt = kv.first;
-    const auto& concrete_dom_set = kv.second;
-    TORCH_INTERNAL_ASSERT(!concrete_dom_set.empty());
-    if (concrete_dom_set.size() == 1) {
-      populateDimensionMapWithSingleCASet(pt, concrete_dom_set);
-    } else {
-      populateDimensionMapWithMultipleCASet(pt, concrete_dom_set);
-    }
-  }
-}
-
-void ParallelDimensionMap::registerConstantExtent(IterDomain* id) {
-  ExpressionEvaluator ee(id->fusion());
-  auto extent_int = ee.evaluate(id->extent());
-  if (!extent_int.has_value()) {
-    // Nothing to do if not constant
-    return;
-  }
-
-  auto const_extent = extent_int.value();
-
-  // Ignore if this is derived from a size-1 domain as it is likely a
-  // size-1 broadcast domain and that does not represent the actual
-  // dimension even if it's constant. Being size-1 may not always mean
-  // it's a broadcast domain, but it'd be safe to assume it is mostly
-  // the case. If it is not a broadcast, ignoring this domain does not
-  // impact the correctness.
-  auto extent_inputs = InputsOf::output(id->fusion(), id->extent());
-  if (std::any_of(extent_inputs.begin(), extent_inputs.end(), [](Val* input) {
-        return input->isOneInt();
-      })) {
-    return;
-  }
-
-  auto concrete_id = getCAMappedConcreteDomain(id);
-
-  auto existing_it = constant_extent_map_.find(id);
-
-  // Adds the constant extent to the set for the concrete domain. If
-  // multiple constants are found, this concrete domain has multiple
-  // distinctive extents, which can happen with broadcast.
-  if (existing_it == constant_extent_map_.end()) {
-    constant_extent_map_.insert({concrete_id, {const_extent}});
-  } else {
-    existing_it->second.insert(const_extent);
-  }
-}
-
-// Adds the conrecte domain of id to the mappsed set for its
-// parallel type
-void ParallelDimensionMap::handleParallelDomain(IterDomain* id) {
-  auto pt = id->getParallelType();
-  TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt));
-  auto concrete_id = getCAMappedConcreteDomain(id);
-
-  auto it = concrete_dom_map_.find(pt);
-  if (it == concrete_dom_map_.end()) {
-    concrete_dom_map_.insert({pt, {concrete_id}});
-  } else {
-    it->second.insert(concrete_id);
-  }
-}
-
-void ParallelDimensionMap::populateDimensionMapWithSingleCASet(
-    ParallelType pt,
-    const std::unordered_set<IterDomain*>& dom_set) {
-  TORCH_INTERNAL_ASSERT(dom_set.size() == 1);
-
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  // pt is used by only one concrete domain
-  auto id = *dom_set.begin();
-  auto it = constant_extent_map_.find(id);
-
-  if (it != constant_extent_map_.end()) {
-    if (it->second.size() == 1) {
-      dim_map_.insert({pt, ir_builder.create<kir::Int>(*(it->second.begin()))});
-      exact_types_.insert(pt);
-    } else {
-      // Multiple constant dimensions found; Use the corresponding
-      // symbolic parallel dim
-      dim_map_.insert({pt, kir::NamedScalar::getParallelDim(pt)});
-    }
-  } else {
-    // Prefer to use blockDim/gridDim if not constant
-    dim_map_.insert({pt, kir::NamedScalar::getParallelDim(pt)});
-    exact_types_.insert(pt);
-  }
-}
-
-void ParallelDimensionMap::populateDimensionMapWithMultipleCASet(
-    ParallelType pt,
-    const std::unordered_set<IterDomain*>& dom_set) {
-  TORCH_INTERNAL_ASSERT(dom_set.size() > 1);
-
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  bool all_equal = true;
-  kir::Val* known_dimension =
-      gpu_lower->lowerValue((*dom_set.begin())->extent());
-  // Set it -1 to signal it's not initialied yet
-  int64_t known_const = -1;
-
-  // Check all of concrete domains to see if they match all together.
-  for (auto concrete_id : dom_set) {
-    // If this concrete domain has a constant extent, check if it
-    // matches with the known constant extent.
-    auto it = constant_extent_map_.find(concrete_id);
-    if (it != constant_extent_map_.end()) {
-      const auto& const_extent_set = it->second;
-      // If multiple constants are detected, it's not exact.
-      if (const_extent_set.size() > 1) {
-        all_equal = false;
-        break;
-      }
-      auto this_const = *(const_extent_set.begin());
-      // known_const is initialized to -1
-      if (known_const == -1) {
-        known_const = this_const;
-      } else if (known_const == this_const) {
-        // Matched with previously known const. The extent of this
-        // domain must be equal to that's previously known.
-        continue;
-      } else {
-        // Unmatched. This dom_set extents may not be unique.
-        all_equal = false;
-        break;
-      }
-    }
-
-    // At this point, it still remains undetermined whether this id
-    // matches with those previously looked at. Constant check failed,
-    // but symbolic matching may succeed.
-    if (!equalDim(
-            known_dimension, gpu_lower->lowerValue(concrete_id->extent()))) {
-      all_equal = false;
-      break;
-    }
-  }
-
-  // If all_equal is still true, the dimension of this paralel type
-  // must be exact.
-  if (all_equal) {
-    exact_types_.insert(pt);
-  }
-  // Use the const value, if found, as its dimension
-  if (all_equal && known_const != -1) {
-    dim_map_.insert({pt, ir_builder.create<kir::Int>(known_const)});
-  } else {
-    dim_map_.insert({pt, kir::NamedScalar::getParallelDim(pt)});
-  }
-}
-
-kir::Val* ParallelDimensionMap::get(ParallelType pt) const {
-  TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt), "Invalid ParallelType: ", pt);
-  auto it = dim_map_.find(pt);
-  if (it == dim_map_.end()) {
-    return nullptr;
-  } else {
-    return it->second;
-  }
-}
-
-bool ParallelDimensionMap::isExact(ParallelType pt) const {
-  return exact_types_.find(pt) != exact_types_.end();
-}
-
-IterDomain* ParallelDimensionMap::getCAMappedConcreteDomain(IterDomain* id) {
-  const auto gpu_lower = GpuLower::current();
-  const auto& ca_map = gpu_lower->caIndexMap();
-  return ca_map.getConcreteMappedID(id);
-}
-
-// Symbolically compares equality of two KIR vals. Comparison is done
-// conservatively, so returning false does not guarantee non-equality.
-bool ParallelDimensionMap::equalDim(kir::Val* dim1, kir::Val* dim2) {
-  TORCH_INTERNAL_ASSERT(dim1 != nullptr && dim2 != nullptr);
-
-  if (dim1 == dim2) {
-    return true;
-  }
-
-  // When Both are Int, they are same if both have the same constant
-  auto dim1_int = dynamic_cast<kir::Int*>(dim1);
-  auto dim2_int = dynamic_cast<kir::Int*>(dim2);
-  if (dim1_int && dim2_int) {
-    if (dim1_int->isConst() && dim2_int->isConst()) {
-      return dim1_int->value() == dim2_int->value();
-    }
-  }
-
-  // When both are NamedScalar, they are same if Both have the same
-  // name
-  auto dim1_ns = dynamic_cast<kir::NamedScalar*>(dim1);
-  auto dim2_ns = dynamic_cast<kir::NamedScalar*>(dim2);
-  if (dim1_ns && dim2_ns) {
-    return dim1_ns->name() == dim2_ns->name();
-  }
-
-  // Check recursively their definitions
-
-  auto dim1_def = dim1->definition();
-  auto dim2_def = dim2->definition();
-
-  if (dim1_def == nullptr || dim2_def == nullptr) {
-    return false;
-  }
-
-  // If both are BinaryOp or UnaryOp, check their inputs. Since these
-  // Vals are IterDomain extents, UnaryOp should not occur, but
-  // checking shouldn't be harmful.
-  if ((dim1_def->isA<kir::BinaryOp>() && dim2_def->isA<kir::BinaryOp>() &&
-       (dim1_def->as<kir::BinaryOp>()->operation() ==
-        dim2_def->as<kir::BinaryOp>()->operation())) ||
-      (dim1_def->isA<kir::UnaryOp>() && dim2_def->isA<kir::UnaryOp>() &&
-       (dim1_def->as<kir::UnaryOp>()->operation() ==
-        dim2_def->as<kir::UnaryOp>()->operation()))) {
-    for (size_t i = 0; i < dim1_def->inputs().size(); ++i) {
-      if (!equalDim(dim1_def->inputs()[0], dim2_def->inputs()[0])) {
-        return false;
-      }
-    }
-    return true;
-  }
-
-  return false;
-}
-
-std::string ParallelDimensionMap::toString() const {
-  std::stringstream ss;
-
-  const std::array<ParallelType, 6> ptypes{
-      ParallelType::BIDx,
-      ParallelType::BIDy,
-      ParallelType::BIDz,
-      ParallelType::TIDx,
-      ParallelType::TIDy,
-      ParallelType::TIDz};
-
-  for (auto pt : ptypes) {
-    ss << pt << ": ";
-    auto dim = get(pt);
-    if (dim != nullptr) {
-      ss << kir::toString(dim);
-      if (isExact(pt)) {
-        ss << ", exact";
-      } else {
-        ss << ", non-exact";
-      }
-    } else {
-      ss << "unused";
-    }
-    ss << "\n";
-  }
-
-  return ss.str();
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/parallel_dimension_map.h b/torch/csrc/jit/codegen/cuda/parallel_dimension_map.h
deleted file mode 100644 (file)
index e1054fb..0000000
+++ /dev/null
@@ -1,74 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <deque>
-#include <unordered_map>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Maps TID/BID to its dimension. It is by default blockDim/gridDim,
-//! but if use of a ParallelType is mapped to a unique constant
-//! extent, the constant value is used instead since presumably it's
-//! more efficient.
-class TORCH_CUDA_CU_API ParallelDimensionMap {
- public:
-  void build(Fusion* fusion);
-
-  //! Returns the dimension of a ParallelType. nullptr is returned if
-  //! a ParallelType is unused.
-  kir::Val* get(ParallelType pt) const;
-
-  //! True if the dimension of a ParallelType is known to be exact
-  bool isExact(ParallelType pt) const;
-
-  std::string toString() const;
-
-  //! Symbolically analyze if two extent vals are equal
-  static bool equalDim(kir::Val* dim1, kir::Val* dim2);
-
- private:
-  //! Register the extent of an IterDomain if its constant
-  void registerConstantExtent(IterDomain* id);
-
-  void handleParallelDomain(IterDomain* id);
-
-  void populateDimensionMapWithSingleCASet(
-      ParallelType pt,
-      const std::unordered_set<IterDomain*>& dom_set);
-
-  void populateDimensionMapWithMultipleCASet(
-      ParallelType pt,
-      const std::unordered_set<IterDomain*>& dom_set);
-
-  static IterDomain* getCAMappedConcreteDomain(IterDomain* id);
-
- private:
-  //! Maps from parallel types to dimensions, which are constant if
-  //! a unique value is found.
-  std::unordered_map<ParallelType, kir::Val*, TypeHash> dim_map_;
-  //! Set of parallel types whose dimensions are identified to be
-  //! exactly the same as extents of mapped domains.
-  std::unordered_set<ParallelType, TypeHash> exact_types_;
-
-  // Below are temporary maps to build the ParallelType-to-dimension
-  // map. Only used during build().
-
-  //! Map from a parallel type to a set of concrete domains where the
-  //! parallel type is used.
-  std::unordered_map<ParallelType, std::unordered_set<IterDomain*>, TypeHash>
-      concrete_dom_map_;
-  //! Keep track of constant extents found for a CA domain set
-  //! represented by the concrete domain.
-  std::unordered_map<IterDomain*, std::unordered_set<int64_t>>
-      constant_extent_map_;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.cpp b/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.cpp
deleted file mode 100644 (file)
index 7efd569..0000000
+++ /dev/null
@@ -1,134 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-const std::unordered_map<ParallelType, int, TypeHash>
-    ParallelTypeBitmap::pt_to_offset_{
-        {ParallelType::BIDx, 0},
-        {ParallelType::BIDy, 1},
-        {ParallelType::BIDz, 2},
-        {ParallelType::TIDx, 3},
-        {ParallelType::TIDy, 4},
-        {ParallelType::TIDz, 5}};
-
-const std::unordered_map<int, ParallelType> ParallelTypeBitmap::offset_to_pt_ =
-    {{0, ParallelType::BIDx},
-     {1, ParallelType::BIDy},
-     {2, ParallelType::BIDz},
-     {3, ParallelType::TIDx},
-     {4, ParallelType::TIDy},
-     {5, ParallelType::TIDz}};
-
-bool ParallelTypeBitmap::get(ParallelType pt) const {
-  if (pt_to_offset_.find(pt) == pt_to_offset_.end()) {
-    TORCH_INTERNAL_ASSERT(false, "Could not recognize parallel type.");
-  }
-  return bitset_[pt_to_offset_.at(pt)];
-}
-
-bool ParallelTypeBitmap::set(ParallelType pt, bool new_val) {
-  if (pt_to_offset_.find(pt) == pt_to_offset_.end()) {
-    TORCH_INTERNAL_ASSERT(false, "Could not recognize parallel type.");
-  }
-  bool old_val = bitset_[pt_to_offset_.at(pt)];
-  bitset_[pt_to_offset_.at(pt)] = new_val;
-  return old_val;
-}
-
-ParallelTypeBitmap ParallelTypeBitmap::operator&=(
-    const ParallelTypeBitmap& other) {
-  bitset_ &= other.bitset_;
-  return *this;
-}
-
-ParallelTypeBitmap ParallelTypeBitmap::operator|=(
-    const ParallelTypeBitmap& other) {
-  bitset_ |= other.bitset_;
-  return *this;
-}
-
-ParallelTypeBitmap ParallelTypeBitmap::operator^=(
-    const ParallelTypeBitmap& other) {
-  bitset_ ^= other.bitset_;
-  return *this;
-}
-
-ParallelTypeBitmap ParallelTypeBitmap::operator~() const {
-  return ParallelTypeBitmap(~bitset_);
-}
-
-bool ParallelTypeBitmap::none() const {
-  return bitset_.none();
-}
-
-bool ParallelTypeBitmap::any() const {
-  return bitset_.any();
-}
-
-bool ParallelTypeBitmap::all() const {
-  return bitset_.all();
-}
-
-bool ParallelTypeBitmap::operator[](size_t pos) const {
-  TORCH_INTERNAL_ASSERT(
-      pos < num_p_type, "Invalid index to ParallelTypeBitset: ", pos);
-  return bitset_[pos];
-}
-
-bool ParallelTypeBitmap::hasTID() const {
-  for (auto pt : {ParallelType::TIDx, ParallelType::TIDy, ParallelType::TIDz}) {
-    if (get(pt)) {
-      return true;
-    }
-  }
-  return false;
-}
-
-bool ParallelTypeBitmap::hasBID() const {
-  for (auto pt : {ParallelType::BIDx, ParallelType::BIDy, ParallelType::BIDz}) {
-    if (get(pt)) {
-      return true;
-    }
-  }
-  return false;
-}
-
-std::map<ParallelType, bool> ParallelTypeBitmap::getMap() const {
-  std::map<ParallelType, bool> map;
-  for (const auto& pt_offset : pt_to_offset_) {
-    map.emplace(pt_offset.first, bitset_[pt_offset.second]);
-  }
-  return map;
-}
-
-ParallelTypeBitmap operator&(
-    const ParallelTypeBitmap& lhs,
-    const ParallelTypeBitmap& rhs) {
-  auto x = lhs;
-  x &= rhs;
-  return x;
-}
-
-ParallelTypeBitmap operator|(
-    const ParallelTypeBitmap& lhs,
-    const ParallelTypeBitmap& rhs) {
-  auto x = lhs;
-  x |= rhs;
-  return x;
-}
-
-ParallelTypeBitmap operator^(
-    const ParallelTypeBitmap& lhs,
-    const ParallelTypeBitmap& rhs) {
-  auto x = lhs;
-  x ^= rhs;
-  return x;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h b/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h
deleted file mode 100644 (file)
index 2260e20..0000000
+++ /dev/null
@@ -1,76 +0,0 @@
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/codegen/cuda/type.h>
-
-#include <bitset>
-#include <map>
-#include <unordered_map>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Represents mapping to bool from BIDx, BIDy, BIDz, TIDx, TIDy and TIDz.
-class ParallelTypeBitmap {
- public:
-  static constexpr int num_p_type = 6;
-
-  ParallelTypeBitmap() = default;
-
-  //! Return true if pt is included
-  bool get(ParallelType pt) const;
-  //! Set the mapping of pt
-  bool set(ParallelType pt, bool);
-  //! Assign logical AND with other
-  ParallelTypeBitmap operator&=(const ParallelTypeBitmap& other);
-  //! Assign logical OR with other
-  ParallelTypeBitmap operator|=(const ParallelTypeBitmap& other);
-  //! Assign logical NOR with other
-  ParallelTypeBitmap operator^=(const ParallelTypeBitmap& other);
-  //! Return logical compliment
-  ParallelTypeBitmap operator~() const;
-  //! Return true if none of the mapppings is true
-  bool none() const;
-  //! Return true if any of the mapppings is true
-  bool any() const;
-  //! Return true if all of the mapppings is true
-  bool all() const;
-  //! Return true if the parallel type corresponding to a position
-  //! defined in offset_to_pt_ is true
-  bool operator[](size_t pos) const;
-  //! Return an equivalent std::map
-  std::map<ParallelType, bool> getMap() const;
-  //! Return true if TIDx/y/z is included
-  bool hasTID() const;
-  //! Return true if BIDx/y/z is included
-  bool hasBID() const;
-
- private:
-  ParallelTypeBitmap(const std::bitset<num_p_type>& bs) : bitset_(bs) {}
-
- private:
-  std::bitset<num_p_type> bitset_;
-  //! Map of ParallelType to bit positions
-  const static std::unordered_map<ParallelType, int, TypeHash> pt_to_offset_;
-  //! Map of bit positions to ParallelType
-  const static std::unordered_map<int, ParallelType> offset_to_pt_;
-};
-
-ParallelTypeBitmap operator&(
-    const ParallelTypeBitmap& lhs,
-    const ParallelTypeBitmap& rhs);
-
-ParallelTypeBitmap operator|(
-    const ParallelTypeBitmap& lhs,
-    const ParallelTypeBitmap& rhs);
-
-ParallelTypeBitmap operator^(
-    const ParallelTypeBitmap& lhs,
-    const ParallelTypeBitmap& rhs);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
index e4d9432..423f882 100644 (file)
@@ -4,7 +4,6 @@
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
 
 #include <torch/csrc/jit/frontend/function_schema_parser.h>
 #include <torch/csrc/jit/ir/constants.h>
@@ -21,31 +20,13 @@ typedef Node JitOp;
 namespace fuser {
 namespace cuda {
 
-constexpr auto kNumUnaryOps = 32;
-constexpr auto kNumBinaryOps = 29;
+constexpr auto kNumUnaryOps = 31;
+constexpr auto kNumBinaryOps = 24;
 constexpr auto kNumBinaryOpsWithAlpha = 4;
 constexpr auto kNumLerpOps = 2;
-constexpr auto kNumLayernormFwd = 2;
-constexpr auto kNumBatchnormFwd = 3;
-constexpr auto kNumInstancenormFwd = 1;
-constexpr auto kNumSumToSize = 2;
-// constexpr auto kNumAutocastOps = 2;
 
 namespace {
 
-#define REGISTER_PARSE_RULE(op, func_body, ...)                             \
-  registerParseRule(                                                        \
-      op,                                                                   \
-      [](const Node* node,                                                  \
-         std::unordered_map<size_t, CgValue>& value_map) -> void func_body, \
-      __VA_ARGS__)
-
-const auto& sizeAttr = Symbol::attr("profiled_size");
-const auto& intListAttr = Symbol::attr("profiled_int_list");
-const auto& intAttr = Symbol::attr("profiled_int");
-const auto& boolListAttr = Symbol::attr("profiled_bool_list");
-const auto& boolAttr = Symbol::attr("profiled_bool");
-
 typedef Val* CgValue;
 typedef Expr* CgOp;
 
@@ -54,52 +35,37 @@ typedef bool (*MergeQueryFuncPtr)(const Node*);
 
 // TODO: add a mutex to make it thread safe.
 class IrParser {
-  enum class OperatorType {
-    ElementWise,
-    Reduction,
-    ReductionToSize,
-    Normalization
-  };
-  typedef OperatorType (*OperatorTypeFuncPtr)(const Node*);
-
   class RegistrationEntry {
    public:
-    RegistrationEntry(
-        ParseFuncPtr parse_f,
-        MergeQueryFuncPtr merge_f = nullptr,
-        OperatorTypeFuncPtr type_f = nullptr)
-        : parse_f_(parse_f), merge_f_(merge_f), type_f_(type_f) {}
+    RegistrationEntry(ParseFuncPtr parse_f, MergeQueryFuncPtr merge_f = nullptr)
+        : parse_f_(parse_f), merge_f_(merge_f) {}
 
-    void parse(const Node* node, std::unordered_map<size_t, CgValue>& values)
-        const {
+    void parse(const Node* node, std::unordered_map<size_t, CgValue>& values) {
       parse_f_(node, values);
     }
 
-    bool isCompatible(const Node* node) const {
+    bool is_compatible(const Node* node) {
       if (merge_f_ == nullptr) {
         return true;
       }
       return merge_f_(node);
     }
 
-    bool isType(const Node* node, OperatorType type) const {
-      auto n_type =
-          type_f_ == nullptr ? OperatorType::ElementWise : type_f_(node);
-      return n_type == type;
-    }
-
    private:
     ParseFuncPtr parse_f_;
     MergeQueryFuncPtr merge_f_;
-    OperatorTypeFuncPtr type_f_;
   };
 
  public:
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
   IrParser(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
-    initRegistry();
+    if (init_registry_) {
+      registerJitOperator();
+      init_registry_ = false;
+    }
   }
 
+  // Fuses pointwise ops with loop unrolling (factor = 4).
   std::unique_ptr<Fusion> parse() {
     auto fusion = std::make_unique<Fusion>();
     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@@ -110,10 +76,7 @@ class IrParser {
     for (auto val : block->inputs()) {
       TORCH_INTERNAL_ASSERT(
           registerValue(val),
-          "Failure when register value: ",
-          *(val->node()),
-          " with type: ",
-          val->type());
+          "Error trying to register value with code generation.");
       fusion->addInput(value_map_[val->unique()]);
 
       auto opt_dtype = value_map_[val->unique()]->getDataType();
@@ -125,11 +88,22 @@ class IrParser {
       }
     }
 
+    // TODO: disable unroll to ensure rand_like generates identical output as
+    // with eager mode
+    bool disable_unroll = false;
+    bool has_reduction = false;
     // compose nodes in topo order;
     for (const JitOp* node : block->nodes()) {
       processJitNode(node);
+      if (node->kind() == aten::rand_like) {
+        // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
+        disable_unroll = true;
+      }
+      if (node->kind() == aten::sum) {
+        // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
+        has_reduction = true;
+      }
     }
-    auto alias_indices = fusion->getInputAliasIndices();
 
     // mark output;
     for (auto jit_output : block->outputs()) {
@@ -144,86 +118,37 @@ class IrParser {
       }
       fusion->addOutput(out);
     }
-
     return fusion;
   }
 
-  // return nullptr if entry does not exist
-  static const RegistrationEntry* lookupInRegistry(const Node* node) {
-    // we need to use maybeSchema for nodes like prim::Constant, which doesn't
-    // have a schema
-    auto schema_ptr = node->maybeSchema();
-    if (schema_ptr != nullptr) {
-      // search cached entry first
-      auto cache_it = cached_registry_lookup_.find(schema_ptr);
-      if (cache_it != cached_registry_lookup_.end()) {
-        return cache_it->second;
-      } else {
-        // match signature
-        auto schema_str = canonicalSchemaString(*schema_ptr);
-
-        auto iter = jit_operator_registry_.find(schema_str);
-        if (iter != jit_operator_registry_.end()) {
-          // update cache entry
-          cached_registry_lookup_.insert(cache_it, {schema_ptr, &iter->second});
-          return &iter->second;
-        }
-      }
-    }
-    return nullptr;
-  }
-
-  static void initRegistry() {
+  static bool canParseNode(const Node* node) {
     if (init_registry_) {
       // TODO: mutex this guy;
       registerJitOperator();
       init_registry_ = false;
     }
-  }
-
-  static bool canParseNode(const Node* node) {
-    initRegistry();
 
     // match signature.
-    auto schema_ptr = node->maybeSchema();
-    if (schema_ptr == nullptr) {
+    auto iter = jit_operator_registry_.find(node->kind());
+    if (iter == jit_operator_registry_.end()) {
       return false;
     }
-    auto reg_entry = lookupInRegistry(node);
-    return reg_entry != nullptr && reg_entry->isCompatible(node);
-  }
-
-  static bool isReductionToSizeNode(const Node* node) {
-    initRegistry();
-
-    auto reg_entry = lookupInRegistry(node);
-    return reg_entry != nullptr &&
-        reg_entry->isType(node, OperatorType::ReductionToSize);
+    for (auto& pair_op_func : iter->second) {
+      if (node->matches(pair_op_func.first->schema())) {
+        return pair_op_func.second.is_compatible(node);
+      }
+    }
+    return false;
   }
 
   static bool isReductionNode(const Node* node) {
-    initRegistry();
-
-    auto reg_entry = lookupInRegistry(node);
-    return reg_entry != nullptr &&
-        (reg_entry->isType(node, OperatorType::Reduction) ||
-         reg_entry->isType(node, OperatorType::ReductionToSize));
-  }
-
-  static bool isNormalizationNode(const Node* node) {
-    initRegistry();
-
-    auto reg_entry = lookupInRegistry(node);
-    return reg_entry != nullptr &&
-        reg_entry->isType(node, OperatorType::Normalization);
-  }
-
-  static bool isElementWiseNode(const Node* node) {
-    initRegistry();
+    if (init_registry_) {
+      // TODO: mutex this guy;
+      registerJitOperator();
+      init_registry_ = false;
+    }
 
-    auto reg_entry = lookupInRegistry(node);
-    return reg_entry != nullptr &&
-        reg_entry->isType(node, OperatorType::ElementWise);
+    return jit_reduction_op_registry_.count(node->kind());
   }
 
   // TODO: is_reduction is too hacky here. we should categorize operation types
@@ -233,11 +158,16 @@ class IrParser {
       std::shared_ptr<Operator>& op,
       ParseFuncPtr parse_fn,
       MergeQueryFuncPtr merge_query_fn = nullptr,
-      OperatorTypeFuncPtr type_fn = nullptr) {
-    jit_operator_registry_.emplace(
-        std::piecewise_construct,
-        std::forward_as_tuple(canonicalSchemaString(op->schema())),
-        std::forward_as_tuple(parse_fn, merge_query_fn, type_fn));
+      bool is_reduction = false) {
+    jit_operator_registry_[Symbol::fromQualString(op->schema().name())]
+        .emplace_back(
+            std::piecewise_construct,
+            std::forward_as_tuple(op),
+            std::forward_as_tuple(parse_fn, merge_query_fn));
+    if (is_reduction) {
+      jit_reduction_op_registry_.emplace(
+          Symbol::fromQualString(op->schema().name()));
+    }
   }
 
  private:
@@ -253,9 +183,10 @@ class IrParser {
         "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor"};
     for (auto signature : BinaryOpWithAlpha) {
       auto ptr_op = getOperatorForLiteral(signature);
-      REGISTER_PARSE_RULE(
+      registerParseRule(
           ptr_op,
-          {
+          [](const Node* node,
+             std::unordered_map<size_t, CgValue>& value_map) -> void {
             using BinaryOpWithAlphaType = Val* (*)(Val*, Val*, Val*);
             static std::unordered_map<
                 Symbol,
@@ -281,9 +212,7 @@ class IrParser {
               auto out = op_mapping[node->kind()].second(lhs, rhs, alpha);
               value_map.emplace(node->output()->unique(), out);
             }
-          },
-          nullptr,
-          nullptr);
+          });
     }
 
     std::array<const char*, kNumBinaryOps> BinaryOp = {
@@ -299,11 +228,6 @@ class IrParser {
         "aten::pow(Scalar self, Tensor exponent) -> Tensor",
         "aten::remainder(Tensor self, Tensor other) -> Tensor",
         "aten::fmod(Tensor self, Tensor other) -> Tensor",
-        "aten::__and__(Tensor self, Tensor other) -> Tensor",
-        "aten::__or__(Tensor self, Tensor other) -> Tensor",
-        "aten::__xor__(Tensor self, Tensor other) -> Tensor",
-        "aten::__lshift__(Tensor self, Tensor other) -> Tensor",
-        "aten::__rshift__(Tensor self, Tensor other) -> Tensor",
         "aten::eq(Tensor self, Tensor other) -> Tensor",
         "aten::eq(Tensor self, Scalar other) -> Tensor",
         "aten::ne(Tensor self, Tensor other) -> Tensor",
@@ -318,9 +242,10 @@ class IrParser {
         "aten::lt(Tensor self, Scalar other) -> Tensor"};
     for (auto signature : BinaryOp) {
       auto ptr_op = getOperatorForLiteral(signature);
-      REGISTER_PARSE_RULE(
+      registerParseRule(
           ptr_op,
-          {
+          [](const Node* node,
+             std::unordered_map<size_t, CgValue>& value_map) -> void {
             static std::unordered_map<Symbol, BinaryOpType> op_mapping(
                 {{aten::div, BinaryOpType::Div},
                  {aten::mul, BinaryOpType::Mul},
@@ -337,20 +262,13 @@ class IrParser {
                  {aten::gt, BinaryOpType::GT},
                  {aten::ge, BinaryOpType::GE},
                  {aten::ne, BinaryOpType::NE},
-                 {aten::eq, BinaryOpType::Eq},
-                 {aten::__and__, BinaryOpType::And},
-                 {aten::__or__, BinaryOpType::Or},
-                 {aten::__xor__, BinaryOpType::Xor},
-                 {aten::__lshift__, BinaryOpType::Lshift},
-                 {aten::__rshift__, BinaryOpType::Rshift}});
+                 {aten::eq, BinaryOpType::Eq}});
             auto lhs = value_map[node->inputs()[0]->unique()];
             auto rhs = value_map[node->inputs()[1]->unique()];
 
             auto out = binaryOp(op_mapping[node->kind()], lhs, rhs);
             value_map.emplace(node->output()->unique(), out);
-          },
-          nullptr,
-          nullptr);
+          });
     }
 
     // TODO: cast operations should be merged in.
@@ -381,18 +299,18 @@ class IrParser {
         "aten::floor(Tensor self) -> Tensor",
         "aten::round(Tensor self) -> Tensor",
         "aten::trunc(Tensor self) -> Tensor",
-        "aten::bitwise_not(Tensor self) -> Tensor",
         "aten::frac(Tensor self) -> Tensor",
         "aten::reciprocal(Tensor self) -> Tensor",
         "aten::relu(Tensor self) -> Tensor",
         "aten::sigmoid(Tensor self) -> Tensor",
-        "aten::silu(Tensor self) -> Tensor",
+        "aten::gelu(Tensor self) -> Tensor",
     };
     for (auto signature : UnaryOp) {
       auto ptr_op = getOperatorForLiteral(signature);
-      REGISTER_PARSE_RULE(
+      registerParseRule(
           ptr_op,
-          {
+          [](const Node* node,
+             std::unordered_map<size_t, CgValue>& value_map) -> void {
             static std::unordered_map<Symbol, UnaryOpType> op_mapping({
                 {aten::neg, UnaryOpType::Neg},
                 {aten::abs, UnaryOpType::Abs},
@@ -420,107 +338,84 @@ class IrParser {
                 {aten::floor, UnaryOpType::Floor},
                 {aten::round, UnaryOpType::Round},
                 {aten::trunc, UnaryOpType::Trunc},
-                {aten::bitwise_not, UnaryOpType::Not},
                 {aten::frac, UnaryOpType::Frac},
                 {aten::reciprocal, UnaryOpType::Reciprocal},
                 {aten::relu, UnaryOpType::Relu},
                 {aten::sigmoid, UnaryOpType::Sigmoid},
-                {aten::silu, UnaryOpType::Silu},
+                {aten::gelu, UnaryOpType::Gelu},
             });
             auto operand = value_map[node->input()->unique()];
 
             auto out = unaryOp(op_mapping[node->kind()], operand);
             value_map.emplace(node->output()->unique(), out);
-          },
-          nullptr,
-          nullptr);
+          });
     }
 
     {
       auto ptr_op = getOperatorForLiteral(
           "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor");
-      REGISTER_PARSE_RULE(
+      registerParseRule(
           ptr_op,
-          {
+          [](const Node* node,
+             std::unordered_map<size_t, CgValue>& value_map) -> void {
             auto operand = value_map[node->inputs()[0]->unique()];
 
             auto out = unaryOp(UnaryOpType::RandLike, operand);
             value_map.emplace(node->output()->unique(), out);
-          },
-          nullptr,
-          nullptr);
-    }
-
-    {
-      auto ptr_op = getOperatorForLiteral(
-          "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor");
-      REGISTER_PARSE_RULE(
-          ptr_op,
-          {
-            auto operand = value_map[node->inputs()[0]->unique()];
-            auto beta = value_map[node->inputs()[1]->unique()];
-            auto threshold = value_map[node->inputs()[2]->unique()];
-            auto out = softplus(operand, beta, threshold);
-            value_map.emplace(node->output()->unique(), out);
-          },
-          nullptr,
-          nullptr);
+          });
     }
 
     {
       auto ptr_op = getOperatorForLiteral(
           "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor");
-      REGISTER_PARSE_RULE(
+      registerParseRule(
           ptr_op,
-          {
+          [](const Node* node,
+             std::unordered_map<size_t, CgValue>& value_map) -> void {
             auto operand = value_map[node->inputs()[0]->unique()];
             auto th = value_map[node->inputs()[1]->unique()];
             auto value = value_map[node->inputs()[2]->unique()];
 
             auto out = threshold(operand, th, value);
             value_map.emplace(node->output()->unique(), out);
-          },
-          nullptr,
-          nullptr);
+          });
     }
 
     {
       auto ptr_op = getOperatorForLiteral(
           "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor");
-      REGISTER_PARSE_RULE(
+      registerParseRule(
           ptr_op,
-          {
+          [](const Node* node,
+             std::unordered_map<size_t, CgValue>& value_map) -> void {
             auto operand = value_map[node->inputs()[0]->unique()];
             // TODO: we need to get a proper lower bound per dtype in operand.
             auto low = value_map.count(node->inputs()[1]->unique()) != 0
                 ? value_map[node->inputs()[1]->unique()]
-                : new Double(std::numeric_limits<float>::min());
+                : new Float(std::numeric_limits<float>::min());
             auto high = value_map.count(node->inputs()[2]->unique()) != 0
                 ? value_map[node->inputs()[2]->unique()]
-                : new Double(std::numeric_limits<float>::max());
+                : new Float(std::numeric_limits<float>::max());
 
             auto out = clamp(operand, low, high);
             value_map.emplace(node->output()->unique(), out);
-          },
-          nullptr,
-          nullptr);
+          });
     }
 
     {
       auto ptr_op = getOperatorForLiteral(
           "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor");
-      REGISTER_PARSE_RULE(
+      registerParseRule(
           ptr_op,
-          {
+          [](const Node* node,
+             std::unordered_map<size_t, CgValue>& value_map) -> void {
             auto condition = value_map[node->inputs()[0]->unique()];
             auto x = value_map[node->inputs()[1]->unique()];
             auto y = value_map[node->inputs()[2]->unique()];
 
             auto out = where(condition, x, y);
             value_map.emplace(node->output()->unique(), out);
-          },
-          nullptr,
-          nullptr);
+          });
     }
 
     {
@@ -529,27 +424,27 @@ class IrParser {
           "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor"};
       for (auto signature : LerpOp) {
         auto ptr_op = getOperatorForLiteral(signature);
-        REGISTER_PARSE_RULE(
+        registerParseRule(
             ptr_op,
-            {
+            [](const Node* node,
+               std::unordered_map<size_t, CgValue>& value_map) -> void {
               auto self = value_map[node->inputs()[0]->unique()];
               auto end = value_map[node->inputs()[1]->unique()];
               auto weight = value_map[node->inputs()[2]->unique()];
 
               auto out = lerp(self, end, weight);
               value_map.emplace(node->output()->unique(), out);
-            },
-            nullptr,
-            nullptr);
+            });
       }
     }
 
     {
       auto ptr_op = getOperatorForLiteral(
           "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor");
-      REGISTER_PARSE_RULE(
+      registerParseRule(
           ptr_op,
-          {
+          [](const Node* node,
+             std::unordered_map<size_t, CgValue>& value_map) -> void {
             auto self = value_map[node->inputs()[0]->unique()];
             auto tensor1 = value_map[node->inputs()[1]->unique()];
             auto tensor2 = value_map[node->inputs()[2]->unique()];
@@ -557,677 +452,39 @@ class IrParser {
 
             auto out = addcmul(self, tensor1, tensor2, value);
             value_map.emplace(node->output()->unique(), out);
-          },
-          nullptr,
-          nullptr);
-    }
-
-    {
-      auto ptr_op = getOperatorForLiteral(
-          "aten::dropout(Tensor input, float p, bool train) -> Tensor");
-      REGISTER_PARSE_RULE(
-          ptr_op,
-          {
-            auto input = value_map[node->input(0)->unique()]->as<TensorView>();
-            auto train = constant_as<bool>(node->input(2));
-            TORCH_INTERNAL_ASSERT(
-                train.has_value(), "dropout needs constant `train` flag");
-
-            if (train.value()) {
-              auto prob = value_map[node->input(1)->unique()];
-              auto result = dropout(input, prob);
-
-              value_map.emplace(node->output()->unique(), result.output);
-            } else {
-              value_map.emplace(node->output()->unique(), input);
-            }
-          },
-          nullptr,
-          nullptr);
-    }
-
-    {
-      std::array<const char*, kNumInstancenormFwd> InstanceNormFwd = {
-          "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor"};
-      for (auto signature : InstanceNormFwd) {
-        auto ptr_op = getOperatorForLiteral(signature);
-        REGISTER_PARSE_RULE(
-            ptr_op,
-            {
-              auto fusion = FusionGuard::getCurFusion();
-
-              auto input =
-                  value_map[node->input(0)->unique()]->as<TensorView>();
-
-              TensorView* weight = nullptr;
-              if (!node->input(1)->type()->isSubtypeOf(
-                      static_cast<c10::TypePtr>(NoneType::get()))) {
-                weight = value_map[node->input(1)->unique()]->as<TensorView>();
-              }
-
-              TensorView* bias = nullptr;
-              if (!node->input(2)->type()->isSubtypeOf(
-                      static_cast<c10::TypePtr>(NoneType::get()))) {
-                bias = value_map[node->input(2)->unique()]->as<TensorView>();
-              }
-
-              TensorView* running_mean = nullptr;
-              if (!node->input(3)->type()->isSubtypeOf(
-                      static_cast<c10::TypePtr>(NoneType::get()))) {
-                running_mean =
-                    value_map[node->input(3)->unique()]->as<TensorView>();
-                TORCH_INTERNAL_ASSERT(
-                    fusion->hasInput(running_mean),
-                    "IO_tensor `batch_norm::running_mean` can only be input tensor to fusion");
-              }
-
-              TensorView* running_var = nullptr;
-              if (!node->input(4)->type()->isSubtypeOf(
-                      static_cast<c10::TypePtr>(NoneType::get()))) {
-                running_var =
-                    value_map[node->input(4)->unique()]->as<TensorView>();
-                TORCH_INTERNAL_ASSERT(
-                    fusion->hasInput(running_var),
-                    "IO_tensor `batch_norm::running_var` can only be input tensor to fusion");
-              }
-
-              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-              auto use_input_stats = constant_as<bool>(node->input(5));
-              TORCH_INTERNAL_ASSERT(
-                  use_input_stats.has_value(),
-                  "The use_input_stats (bool) parameter is required.");
-              const bool kUseInputStats = use_input_stats.value();
-
-              Val* momentum_ptr = nullptr;
-              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-              if (auto momentum = constant_as<float>(node->input(6))) {
-                momentum_ptr = new Double(momentum.value());
-              } else {
-                // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-                momentum_ptr = value_map[node->input(6)->unique()];
-              }
-
-              Val* eps_ptr = nullptr;
-              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-              if (auto eps = constant_as<float>(node->input(7))) {
-                eps_ptr = new Double(eps.value());
-              } else {
-                // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-                eps_ptr = value_map[node->input(7)->unique()];
-              }
-
-              auto result = instance_norm(
-                  input,
-                  weight,
-                  bias,
-                  running_mean,
-                  running_var,
-                  kUseInputStats,
-                  momentum_ptr,
-                  eps_ptr);
-
-              if (node->kind() ==
-                  c10::Symbol::fromQualString("aten::instance_norm")) {
-                value_map.emplace(node->output()->unique(), result.output);
-              }
-            },
-            [](const Node* node) -> bool { return true; },
-            [](const Node* node) -> OperatorType {
-              return OperatorType::Normalization;
-            });
-      }
-    }
-
-    {
-      std::array<const char*, kNumBatchnormFwd> BatchNormFwd = {
-          "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)",
-          "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
-          "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor"};
-      for (auto signature : BatchNormFwd) {
-        auto ptr_op = getOperatorForLiteral(signature);
-        REGISTER_PARSE_RULE(
-            ptr_op,
-            {
-              auto fusion = FusionGuard::getCurFusion();
-
-              auto input =
-                  value_map[node->input(0)->unique()]->as<TensorView>();
-
-              TensorView* weight = nullptr;
-              if (!node->input(1)->type()->isSubtypeOf(
-                      static_cast<c10::TypePtr>(NoneType::get()))) {
-                weight = value_map[node->input(1)->unique()]->as<TensorView>();
-              }
-
-              TensorView* bias = nullptr;
-              if (!node->input(2)->type()->isSubtypeOf(
-                      static_cast<c10::TypePtr>(NoneType::get()))) {
-                bias = value_map[node->input(2)->unique()]->as<TensorView>();
-              }
-
-              TensorView* running_mean = nullptr;
-              if (!node->input(3)->type()->isSubtypeOf(
-                      static_cast<c10::TypePtr>(NoneType::get()))) {
-                running_mean =
-                    value_map[node->input(3)->unique()]->as<TensorView>();
-                TORCH_INTERNAL_ASSERT(
-                    fusion->hasInput(running_mean),
-                    "IO_tensor `batch_norm::running_mean` can only be input tensor to fusion");
-              }
-
-              TensorView* running_var = nullptr;
-              if (!node->input(4)->type()->isSubtypeOf(
-                      static_cast<c10::TypePtr>(NoneType::get()))) {
-                running_var =
-                    value_map[node->input(4)->unique()]->as<TensorView>();
-                TORCH_INTERNAL_ASSERT(
-                    fusion->hasInput(running_var),
-                    "IO_tensor `batch_norm::running_var` can only be input tensor to fusion");
-              }
-
-              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-              auto training = constant_as<bool>(node->input(5));
-              TORCH_INTERNAL_ASSERT(
-                  training.has_value(),
-                  "The training (bool) parameter is required.");
-              const bool kTraining = training.value();
-
-              Val* momentum_ptr = nullptr;
-              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-              if (auto momentum = constant_as<float>(node->input(6))) {
-                momentum_ptr = new Double(momentum.value());
-              } else {
-                // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-                momentum_ptr = value_map[node->input(6)->unique()];
-              }
-
-              Val* eps_ptr = nullptr;
-              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-              if (auto eps = constant_as<float>(node->input(7))) {
-                eps_ptr = new Double(eps.value());
-              } else {
-                // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-                eps_ptr = value_map[node->input(7)->unique()];
-              }
-
-              auto result = batch_norm(
-                  input,
-                  weight,
-                  bias,
-                  running_mean,
-                  running_var,
-                  kTraining,
-                  momentum_ptr,
-                  eps_ptr);
-
-              if (node->kind() ==
-                  c10::Symbol::fromQualString("aten::native_batch_norm")) {
-                value_map.emplace(node->output(0)->unique(), result.output);
-
-                value_map.emplace(node->output(1)->unique(), result.mean);
-
-                value_map.emplace(node->output(2)->unique(), result.invstd);
-              } else if (
-                  node->kind() ==
-                  c10::Symbol::fromQualString("aten::batch_norm")) {
-                value_map.emplace(node->output()->unique(), result.output);
-              } else if (
-                  node->kind() ==
-                  c10::Symbol::fromQualString("aten::_batch_norm_impl_index")) {
-                value_map.emplace(node->output(0)->unique(), result.output);
-
-                value_map.emplace(node->output(1)->unique(), result.mean);
-
-                value_map.emplace(node->output(2)->unique(), result.invstd);
-
-                // TODO: output 3 & 4 are not created
-                //       we are not creating these outputs because codegen
-                //       currently lacks the support.
-              }
-            },
-            [](const Node* node) -> bool { return true; },
-            [](const Node* node) -> OperatorType {
-              return OperatorType::Normalization;
-            });
-      }
-    }
-
-    {
-      auto ptr_op = getOperatorForLiteral(
-          "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)");
-      REGISTER_PARSE_RULE(
-          ptr_op,
-          {
-            // discard impl_index and reservedSpace since we don't use them
-
-            auto input = value_map[node->input(1)->unique()]->as<TensorView>();
-
-            auto grad_out =
-                value_map[node->input(2)->unique()]->as<TensorView>();
-
-            TensorView* weight = nullptr;
-            if (!node->input(3)->type()->isSubtypeOf(
-                    static_cast<c10::TypePtr>(NoneType::get()))) {
-              weight = value_map[node->input(3)->unique()]->as<TensorView>();
-            }
-
-            TensorView* running_mean = nullptr;
-            if (!node->input(4)->type()->isSubtypeOf(
-                    static_cast<c10::TypePtr>(NoneType::get()))) {
-              running_mean =
-                  value_map[node->input(4)->unique()]->as<TensorView>();
-            }
-
-            TensorView* running_var = nullptr;
-            if (!node->input(5)->type()->isSubtypeOf(
-                    static_cast<c10::TypePtr>(NoneType::get()))) {
-              running_var =
-                  value_map[node->input(5)->unique()]->as<TensorView>();
-            }
-
-            TensorView* save_mean = nullptr;
-            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-            if (!node->input(6)->type()->isSubtypeOf(
-                    static_cast<c10::TypePtr>(NoneType::get()))) {
-              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-              save_mean = value_map[node->input(6)->unique()]->as<TensorView>();
-            }
-
-            TensorView* save_invstd = nullptr;
-            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-            if (!node->input(7)->type()->isSubtypeOf(
-                    static_cast<c10::TypePtr>(NoneType::get()))) {
-              save_invstd =
-                  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-                  value_map[node->input(7)->unique()]->as<TensorView>();
-            }
-
-            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-            auto training = constant_as<bool>(node->input(8));
-            TORCH_INTERNAL_ASSERT(
-                training.has_value(),
-                "The training (bool) parameter is required.");
-            const bool kTraining = training.value();
-
-            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-            Val* eps_ptr = nullptr;
-            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-            if (auto eps = constant_as<float>(node->input(9))) {
-              eps_ptr = new Double(eps.value());
-            } else {
-              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-              eps_ptr = value_map[node->input(7)->unique()];
-            }
-
-            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-            auto out_mask_list = constant_as<c10::List<bool>>(node->input(10));
-            TORCH_INTERNAL_ASSERT(
-                out_mask_list.has_value(),
-                "output mask for batch_norm_backward");
-            std::vector<bool> output_mask;
-            for (const auto value : out_mask_list->vec()) {
-              output_mask.emplace_back(static_cast<bool>(value));
-            }
-
-            // TODO: merge this loop below.
-            if (kTraining) {
-              TORCH_INTERNAL_ASSERT(
-                  save_mean != nullptr && save_invstd != nullptr,
-                  "When training=True, save_mean and save_invstd are required.");
-            } else {
-              // TODO: this is not a legit assumption? Can't we run with
-              // track_running_stats == false && training == false
-              // which should just run through the case above.
-              TORCH_INTERNAL_ASSERT(
-                  running_mean != nullptr && running_var != nullptr,
-                  "When training=False, running_mean and running_invstd are required.");
-            }
-
-            auto grads = batch_norm_backward(
-                input,
-                grad_out,
-                weight,
-                running_mean,
-                running_var,
-                save_mean,
-                save_invstd,
-                kTraining,
-                eps_ptr,
-                output_mask);
-
-            if (output_mask[0]) {
-              TORCH_INTERNAL_ASSERT(grads.grad_input != nullptr);
-              value_map.emplace(node->output(0)->unique(), grads.grad_input);
-            } else {
-              TORCH_INTERNAL_ASSERT(grads.grad_input == nullptr);
-              value_map.emplace(
-                  node->output(1)->unique(), TensorViewBuilder().build());
-            }
-
-            if (output_mask[1]) {
-              TORCH_INTERNAL_ASSERT(grads.grad_weight != nullptr);
-              value_map.emplace(node->output(1)->unique(), grads.grad_weight);
-            } else {
-              TORCH_INTERNAL_ASSERT(grads.grad_weight == nullptr);
-              value_map.emplace(
-                  node->output(1)->unique(), TensorViewBuilder().build());
-            }
-
-            if (output_mask[2]) {
-              TORCH_INTERNAL_ASSERT(grads.grad_bias != nullptr);
-              value_map.emplace(node->output(2)->unique(), grads.grad_bias);
-            } else {
-              TORCH_INTERNAL_ASSERT(grads.grad_bias == nullptr);
-              value_map.emplace(
-                  node->output(2)->unique(), TensorViewBuilder().build());
-            }
-          },
-          [](const Node* node) -> bool { return true; },
-          [](const Node* node) -> OperatorType {
-            return OperatorType::Normalization;
-          });
-    }
-
-    {
-      std::array<const char*, kNumLayernormFwd> LayerNormFwd = {
-          "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)",
-          "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"};
-      for (auto signature : LayerNormFwd) {
-        auto ptr_op = getOperatorForLiteral(signature);
-        REGISTER_PARSE_RULE(
-            ptr_op,
-            {
-              auto input =
-                  value_map[node->input(0)->unique()]->as<TensorView>();
-
-              auto norm_shape_optional =
-                  constant_as<c10::List<int64_t>>(node->input(1));
-              TORCH_INTERNAL_ASSERT(
-                  norm_shape_optional.has_value(),
-                  "The Normalized_Shape list is required.");
-              auto norm_shape = norm_shape_optional->vec();
-
-              TensorView* weight = nullptr;
-              if (!node->input(2)->type()->isSubtypeOf(
-                      static_cast<c10::TypePtr>(NoneType::get()))) {
-                weight = value_map[node->input(2)->unique()]->as<TensorView>();
-              }
-
-              TensorView* bias = nullptr;
-              if (!node->input(3)->type()->isSubtypeOf(
-                      static_cast<c10::TypePtr>(NoneType::get()))) {
-                bias = value_map[node->input(3)->unique()]->as<TensorView>();
-              }
-
-              Val* eps_ptr = nullptr;
-              if (auto eps = constant_as<float>(node->input(4))) {
-                eps_ptr = new Double(eps.value());
-              } else {
-                eps_ptr = value_map[node->input(4)->unique()];
-              }
-
-              auto result =
-                  layer_norm(input, norm_shape, weight, bias, eps_ptr);
-
-              if (node->kind() ==
-                  c10::Symbol::fromQualString("aten::native_layer_norm")) {
-                value_map.emplace(node->output(0)->unique(), result.output);
-                value_map.emplace(node->output(1)->unique(), result.mean);
-                value_map.emplace(node->output(2)->unique(), result.invstd);
-              } else if (
-                  node->kind() ==
-                  c10::Symbol::fromQualString("aten::layer_norm")) {
-                value_map.emplace(node->output()->unique(), result.output);
-              }
-            },
-            // TODO: #ProfileIValue List should update this
-            [](const Node* node) -> bool { return true; },
-            [](const Node* node) -> OperatorType {
-              return OperatorType::Normalization;
-            });
-      }
-    }
-
-    {
-      auto ptr_op = getOperatorForLiteral(
-          "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)");
-      REGISTER_PARSE_RULE(
-          ptr_op,
-          {
-            auto grad_out =
-                value_map[node->input(0)->unique()]->as<TensorView>();
-
-            auto input = value_map[node->input(1)->unique()]->as<TensorView>();
-
-            auto norm_shape_optional =
-                constant_as<c10::List<int64_t>>(node->input(2));
-            TORCH_INTERNAL_ASSERT(
-                norm_shape_optional.has_value(),
-                "The Normalized_Shape list is required.");
-            auto norm_shape = norm_shape_optional->vec();
-
-            auto mean = value_map[node->input(3)->unique()]->as<TensorView>();
-            auto rstd = value_map[node->input(4)->unique()]->as<TensorView>();
-
-            TensorView* weight = nullptr;
-            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-            if (!node->input(5)->type()->isSubtypeOf(
-                    static_cast<c10::TypePtr>(NoneType::get()))) {
-              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-              weight = value_map[node->input(5)->unique()]->as<TensorView>();
-            }
-
-            TensorView* bias = nullptr;
-            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-            if (!node->input(6)->type()->isSubtypeOf(
-                    static_cast<c10::TypePtr>(NoneType::get()))) {
-              // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-              bias = value_map[node->input(6)->unique()]->as<TensorView>();
-            }
-
-            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-            auto output_mask_optional =
-                constant_as<c10::List<bool>>(node->input(7));
-            TORCH_INTERNAL_ASSERT(
-                output_mask_optional.has_value(),
-                "output mask for layer_norm_backward");
-            std::vector<bool> output_mask = output_mask_optional->vec();
-
-            auto grad = layer_norm_backward(
-                grad_out,
-                input,
-                norm_shape,
-                mean,
-                rstd,
-                weight,
-                bias,
-                output_mask);
-
-            if (output_mask[0]) {
-              TORCH_INTERNAL_ASSERT(grad.grad_input != nullptr);
-              value_map.emplace(node->output(0)->unique(), grad.grad_input);
-            } else {
-              TORCH_INTERNAL_ASSERT(grad.grad_input == nullptr);
-              value_map.emplace(
-                  node->output(0)->unique(), TensorViewBuilder().build());
-            }
-
-            if (output_mask[1] && weight != nullptr) {
-              TORCH_INTERNAL_ASSERT(grad.grad_weight != nullptr);
-              value_map.emplace(node->output(1)->unique(), grad.grad_weight);
-            } else {
-              TORCH_INTERNAL_ASSERT(grad.grad_weight == nullptr);
-              value_map.emplace(
-                  node->output(1)->unique(), TensorViewBuilder().build());
-            }
-
-            if (output_mask[2] && bias != nullptr) {
-              TORCH_INTERNAL_ASSERT(grad.grad_bias != nullptr);
-              value_map.emplace(node->output(2)->unique(), grad.grad_bias);
-            } else {
-              TORCH_INTERNAL_ASSERT(grad.grad_bias == nullptr);
-              value_map.emplace(
-                  node->output(2)->unique(), TensorViewBuilder().build());
-            }
-          },
-          // TODO: #ProfileIValue List should update this
-          [](const Node* node) -> bool { return true; },
-          [](const Node* node) -> OperatorType {
-            return OperatorType::Normalization;
-          });
-    }
-
-    {
-      auto ptr_op = getOperatorForLiteral(
-          "aten::softmax.int(Tensor self, int dim, int? dtype) -> Tensor");
-      REGISTER_PARSE_RULE(
-          ptr_op,
-          {
-            auto input = value_map[node->input(0)->unique()]->as<TensorView>();
-
-            auto dim_value = constant_as<int>(node->input(1));
-            TORCH_INTERNAL_ASSERT(
-                dim_value.has_value(), "dim in softmax is not valid");
-
-            auto output = softmax(input, dim_value.value());
-            value_map.emplace(node->output()->unique(), output);
-          },
-          [](const Node* node) -> bool {
-            if (node->inputs()[1]->node()->kind() != prim::Constant) {
-              return false;
-            }
-            if (!node->inputs()[2]->type()->isSubtypeOf(
-                    static_cast<c10::TypePtr>(NoneType::get()))) {
-              return false;
-            }
-            return true;
-          },
-          [](const Node* node) -> OperatorType {
-            return OperatorType::Normalization;
-          });
-    }
-
-    {
-      auto ptr_op = getOperatorForLiteral(
-          "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor");
-      REGISTER_PARSE_RULE(
-          ptr_op,
-          {
-            auto grad_output =
-                value_map[node->input(0)->unique()]->as<TensorView>();
-
-            auto output = value_map[node->input(1)->unique()]->as<TensorView>();
-
-            auto dim_value = constant_as<int>(node->input(2));
-            TORCH_INTERNAL_ASSERT(
-                dim_value.has_value(), "dim in softmax is not valid");
-
-            auto input = value_map[node->input(3)->unique()]->as<TensorView>();
-
-            auto grad_input =
-                softmax_backward(grad_output, output, dim_value.value(), input);
-            value_map.emplace(node->output()->unique(), grad_input);
-          },
-          [](const Node* node) -> bool {
-            if (node->inputs()[2]->node()->kind() != prim::Constant) {
-              return false;
-            }
-            return true;
-          },
-          [](const Node* node) -> OperatorType {
-            return OperatorType::Normalization;
           });
     }
 
     {
       auto ptr_op = getOperatorForLiteral(
           "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
-      REGISTER_PARSE_RULE(
+      registerParseRule(
           ptr_op,
-          {
+          [](const Node* node,
+             std::unordered_map<size_t, CgValue>& value_map) -> void {
             auto self = value_map[node->input(0)->unique()];
             auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
             TORCH_INTERNAL_ASSERT(
-                dims_list.has_value(),
-                "aten::sum cannot be fused with dynamic axes");
-            std::vector<int> dims;
-            for (const auto dim : dims_list->vec()) {
-              dims.emplace_back(static_cast<int>(dim));
-            }
+                dims_list.has_value(), "requires static reduce axes");
             auto keepdim = constant_as<bool>(node->input(2));
-            TORCH_INTERNAL_ASSERT(
-                keepdim.has_value(),
-                "aten::sum cannot be fused with dynamic keepdim");
-            auto out = sum(self->as<TensorView>(), dims, keepdim.value());
-            value_map.emplace(node->output()->unique(), out);
-          },
-          [](const Node* node) -> bool {
-            // TODO: support cast of output types
-            if (!node->inputs()[3]->type()->isSubtypeOf(
-                    static_cast<c10::TypePtr>(NoneType::get()))) {
-              // We can only handle output as half, float, and double;
-              if (const auto opt_ivalue = toIValue(node->input(3))) {
-                const auto scalar_type = opt_ivalue->toScalarType();
-                if (scalar_type == at::ScalarType::Double ||
-                    scalar_type == at::ScalarType::Float ||
-                    scalar_type == at::ScalarType::Half) {
-                  return true;
-                }
-              }
-              return false;
-            }
-            // we don't support dynamic reduction axes;
-            if (node->inputs()[1]->node()->kind() != prim::Constant) {
-              return false;
-            }
-            // we don't support dynamic keepdim yet;
-            if (node->inputs()[2]->node()->kind() != prim::Constant) {
-              return false;
-            }
-            return true;
-          },
-          [](const Node* node) -> OperatorType {
-            return OperatorType::Reduction;
-          });
-    }
-
-    {
-      auto ptr_op = getOperatorForLiteral(
-          "aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor");
-      REGISTER_PARSE_RULE(
-          ptr_op,
-          {
-            auto self = value_map[node->input(0)->unique()]->as<TensorView>();
-            auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
-            TORCH_INTERNAL_ASSERT(
-                dims_list.has_value(),
-                "aten::mean cannot be fused with dynamic axes");
             std::vector<int> dims;
             for (const auto dim : dims_list->vec()) {
               dims.emplace_back(static_cast<int>(dim));
             }
-            auto keepdim = constant_as<bool>(node->input(2));
             TORCH_INTERNAL_ASSERT(
-                keepdim.has_value(),
-                "aten::mean cannot be fused with dynamic keepdim");
-            auto o_sum = sum(self, dims, keepdim.value());
-            Val* num_features = new Double(1);
-            for (const auto axis : dims) {
-              num_features =
-                  mul(num_features, self->domain()->domain()[axis]->extent());
-            }
-            auto out = div(o_sum, num_features);
+                keepdim.has_value() && !keepdim.value(),
+                "Keep dim in reduction is not a const false");
+            auto out = sum(self->as<TensorView>(), dims);
             value_map.emplace(node->output()->unique(), out);
           },
           [](const Node* node) -> bool {
-            // TODO: support cast of output types
+            // TODO: support cast of output types yet;
             if (!node->inputs()[3]->type()->isSubtypeOf(
                     static_cast<c10::TypePtr>(NoneType::get()))) {
-              // We can only handle output as half, float, and double;
+              // We can only handle output as half and float;
               if (const auto opt_ivalue = toIValue(node->input(3))) {
                 const auto scalar_type = opt_ivalue->toScalarType();
-                if (scalar_type == at::ScalarType::Double ||
-                    scalar_type == at::ScalarType::Float ||
+                if (scalar_type == at::ScalarType::Float ||
                     scalar_type == at::ScalarType::Half) {
                   return true;
                 }
@@ -1238,185 +495,14 @@ class IrParser {
             if (node->inputs()[1]->node()->kind() != prim::Constant) {
               return false;
             }
-            // we don't support dynamic keepdim yet;
-            if (node->inputs()[2]->node()->kind() != prim::Constant) {
+            // we don't support keepdim yet;
+            if (node->inputs()[2]->node()->kind() != prim::Constant ||
+                *constant_as<bool>(node->input(2))) {
               return false;
             }
             return true;
           },
-          [](const Node* node) -> OperatorType {
-            return OperatorType::Reduction;
-          });
-    }
-    {
-      std::array<const char*, kNumSumToSize> SumToSize = {
-          "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)",
-          "aten::sum_to_size(Tensor self, int[] size) -> Tensor"};
-      for (auto signature : SumToSize) {
-        auto ptr_op = getOperatorForLiteral(signature);
-        REGISTER_PARSE_RULE(
-            ptr_op,
-            {
-              auto self = value_map[node->input(0)->unique()];
-              auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
-              TORCH_INTERNAL_ASSERT(
-                  size_to.has_value(),
-                  "aten::sum cannot be fused with dynamic axes");
-              if (!size_to->empty()) {
-                auto out = sum_to(self->as<TensorView>(), size_to->vec());
-                value_map.emplace(node->output()->unique(), out);
-              } else {
-                // We are introducing alias here!
-                value_map.emplace(node->output()->unique(), self);
-              }
-            },
-            [](const Node* node) -> bool {
-              // we don't support dynamic reduction axes;
-              if (node->inputs()[1]->node()->kind() != prim::Constant) {
-                return false;
-              }
-              return true;
-              // auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
-              // return size_to.has_value() && !size_to->empty();
-            },
-            [](const Node* node) -> OperatorType {
-              auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
-              // technically size_to->empty() should never occur, as specialized
-              // _grad_sum_to_size should have been removed by optimization pass
-              if (size_to->empty()) {
-                return OperatorType::ElementWise;
-              } else {
-                return OperatorType::ReductionToSize;
-              }
-            });
-      }
-    }
-
-    // Limiting aten::to implementation to only change the dtype of a tensor
-    {
-      auto ptr_op = getOperatorForLiteral(
-          "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor");
-      REGISTER_PARSE_RULE(
-          ptr_op,
-          {
-            const auto self = value_map[node->input(0)->unique()];
-
-            // we need static type for cast
-            TORCH_INTERNAL_ASSERT(
-                node->input(1)->node()->kind() == prim::Constant);
-            auto dtype = toIValue(node->input(1))->toScalarType();
-
-            // We want to keep our internal fusion math in FP32
-            // Shape Inference will continue to propagate the right
-            // type to outputs unchanged.
-            if (dtype == at::ScalarType::Half) {
-              dtype = at::ScalarType::Float;
-            }
-
-            auto out = castOp(aten_to_data_type(dtype), self);
-            value_map.emplace(node->output()->unique(), out);
-          },
-          nullptr,
-          nullptr);
-    }
-
-    {
-      auto ptr_op = getOperatorForLiteral(
-          "aten::type_as(Tensor self, Tensor other) -> Tensor");
-      REGISTER_PARSE_RULE(
-          ptr_op,
-          {
-            auto self = value_map[node->inputs()[0]->unique()];
-
-            // TODO: switch to PyTorch dtype as it's closer to truth.
-            // For now, reality is that PyTorch IR profiling information could
-            // be missing even with profiling executor, due to upstream
-            // transformations between profiling runs to fusion pass.
-            auto opt_dtype =
-                value_map[node->inputs()[1]->unique()]->getDataType();
-            TORCH_INTERNAL_ASSERT(opt_dtype.has_value());
-
-            auto out = castOp(opt_dtype.value(), self);
-            value_map.emplace(node->output()->unique(), out);
-          },
-          nullptr,
-          nullptr);
-    }
-
-    {
-      // We are not fusing `linear` yet, because we can't codegen efficient gemm
-      // However, we still need this here, so PE would insert profile node for
-      // this node.
-      // During fusion pass, We decompose linear into gemm + elementwise.
-      auto ptr_op = getOperatorForLiteral(
-          "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor");
-      REGISTER_PARSE_RULE(
-          ptr_op,
-          {
-            // this entry is created so we do profile input tensors;
-            TORCH_INTERNAL_ASSERT(false, "not implemented yet");
-          },
-          [](const Node* node) -> bool {
-            // We only profile `linear` layer with bias.
-            if (node->input(2)->type()->isSubtypeOf(
-                    static_cast<c10::TypePtr>(NoneType::get()))) {
-              return false;
-            }
-            return true;
-          });
-    }
-
-    {
-      auto ptr_op = getOperatorForLiteral(
-          "prim::add_optional(Tensor(a) input, Tensor? bias) -> Tensor(a)");
-      REGISTER_PARSE_RULE(
-          ptr_op,
-          {
-            // this entry is created so we do profile input tensors;
-            if (node->input(1)->type()->isSubtypeOf(
-                    static_cast<c10::TypePtr>(NoneType::get()))) {
-              // forwarding the value;
-              value_map.emplace(
-                  node->output()->unique(),
-                  value_map[node->inputs()[0]->unique()]);
-            } else {
-              auto lhs = value_map[node->inputs()[0]->unique()];
-              auto rhs = value_map[node->inputs()[1]->unique()];
-
-              auto out = binaryOp(BinaryOpType::Add, lhs, rhs);
-              value_map.emplace(node->output()->unique(), out);
-            }
-          },
-          nullptr,
-          nullptr);
-    }
-
-    {
-      auto ptr_op = getOperatorForLiteral("aten::gelu(Tensor self) -> Tensor");
-      REGISTER_PARSE_RULE(
-          ptr_op,
-          {
-            auto self = value_map[node->inputs()[0]->unique()];
-            auto output = unaryOp(UnaryOpType::Gelu, self);
-            value_map.emplace(node->output()->unique(), output);
-          },
-          nullptr,
-          nullptr);
-    }
-
-    {
-      auto ptr_op = getOperatorForLiteral(
-          "aten::gelu_backward(Tensor grad, Tensor self) -> Tensor");
-      REGISTER_PARSE_RULE(
-          ptr_op,
-          {
-            auto grad = value_map[node->inputs()[0]->unique()];
-            auto self = value_map[node->inputs()[1]->unique()];
-            auto grad_in = gelu_backward(grad, self);
-            value_map.emplace(node->output()->unique(), grad_in);
-          },
-          nullptr,
-          nullptr);
+          true);
     }
   }
 
@@ -1433,12 +519,22 @@ class IrParser {
             *node);
       }
     } else {
-      auto reg_entry = lookupInRegistry(node);
+      auto iter = IrParser::jit_operator_registry_.find(node->kind());
+      // make sure we have a parser for the op;
+      TORCH_INTERNAL_ASSERT(
+          iter != IrParser::jit_operator_registry_.end(),
+          "CudaFusionGroup Parser doesn't handle operator kind(): ",
+          node->kind().toDisplayString());
+      for (auto& pair_op_func : iter->second) {
+        if (node->matches(pair_op_func.first->schema())) {
+          pair_op_func.second.parse(node, value_map_);
+          return;
+        }
+      }
       TORCH_INTERNAL_ASSERT(
-          reg_entry != nullptr,
-          "CudaFusionGroup Parser doesn't handle node: ",
+          false,
+          "CudaFusionGroup Parser doesn't recognize operator overload:",
           canonicalSchemaString(node->schema()));
-      reg_entry->parse(node, value_map_);
     }
   }
 
@@ -1450,10 +546,11 @@ class IrParser {
     if (val->type()->isSubtypeOf(static_cast<c10::TypePtr>(FloatType::get()))) {
       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
       CgValue cg_val;
-      if (auto ival = constant_as<double>(val)) {
-        cg_val = new Double(ival.value());
+      // NOLINTNEXTLINE(bugprone-branch-clone)
+      if (auto ival = constant_as<float>(val)) {
+        cg_val = new Float(ival.value());
       } else {
-        cg_val = new Double();
+        cg_val = new Float();
       }
       value_map_.emplace(val->unique(), cg_val);
       return true;
@@ -1461,7 +558,8 @@ class IrParser {
                    static_cast<c10::TypePtr>(IntType::get()))) {
       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
       CgValue cg_val;
-      if (auto ival = constant_as<int64_t>(val)) {
+      // NOLINTNEXTLINE(bugprone-branch-clone)
+      if (auto ival = constant_as<int>(val)) {
         cg_val = new Int(ival.value());
       } else {
         cg_val = new Int();
@@ -1472,6 +570,7 @@ class IrParser {
                    static_cast<c10::TypePtr>(BoolType::get()))) {
       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
       CgValue cg_val;
+      // NOLINTNEXTLINE(bugprone-branch-clone)
       if (auto ival = constant_as<bool>(val)) {
         cg_val = new Bool(ival.value());
       } else {
@@ -1495,17 +594,7 @@ class IrParser {
   bool registerTensor(const JitValue* val) {
     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
     CgValue cg_val;
-    // Don't register if we don't support the type
-    if (auto tensor_type = val->type()->cast<c10::TensorType>()) {
-      if (!tensor_type->scalarType().has_value()) {
-        return false;
-      }
-
-      if (aten_to_data_type(tensor_type->scalarType().value()) ==
-          DataType::Null) {
-        return false;
-      }
-
+    if (auto tensor_type = val->type()->cast<TensorType>()) {
       // TODO: make this a static function in Tensor class;
       // create tensor;
       cg_val = new TensorView(tensor_type);
@@ -1520,207 +609,31 @@ class IrParser {
   // maps from JitValue::unique() to fusion Val;
   std::unordered_map<size_t, CgValue> value_map_;
   // parsing rule registry.
-  static std::unordered_map<std::string, RegistrationEntry>
-      jit_operator_registry_; // NOLINT
-
-  // pointing cached entry stored in `jit_operator_registry_`
-  static std::unordered_map<const FunctionSchema*, const RegistrationEntry*>
-      cached_registry_lookup_; // NOLINT
-
-  // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+  static std::unordered_map<
+      Symbol,
+      std::vector<std::pair<std::shared_ptr<Operator>, RegistrationEntry>>>
+      jit_operator_registry_;
+  static std::unordered_set<Symbol> jit_reduction_op_registry_;
   static bool init_registry_;
 };
 
-std::unordered_map<std::string, IrParser::RegistrationEntry>
-    IrParser::jit_operator_registry_; // NOLINT
-std::unordered_map<const FunctionSchema*, const IrParser::RegistrationEntry*>
-    IrParser::cached_registry_lookup_; // NOLINT
-
-// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+std::unordered_map<
+    Symbol,
+    std::vector<
+        std::pair<std::shared_ptr<Operator>, IrParser::RegistrationEntry>>>
+    IrParser::jit_operator_registry_;
+std::unordered_set<Symbol> IrParser::jit_reduction_op_registry_;
 bool IrParser::init_registry_ = true;
 
-ProfileIValueOp* insertProfileIValueOp(
-    Node* node,
-    size_t offset,
-    ProfilingRecord* pr) {
-  auto in_val = node->input(offset);
-  auto pn = pr->createProfileIValueNode(in_val);
-  pn->insertBefore(node);
-  node->replaceInput(offset, pn->output());
-  return pn;
-}
-
-void profileSize(ProfilingRecord* pr, Node* node, size_t offset) {
-  auto pn = insertProfileIValueOp(node, offset, pr);
-
-  const auto ivalue_profiler = [pr, pn](Stack& stack) {
-    std::lock_guard<std::mutex> lock(pr->mutex_);
-
-    // TODO: we don't care about merging multiple profiling runs as we don't
-    // support it at all;
-    int64_t frame_id = 0;
-    pop(stack, frame_id);
-    IValue value;
-    pop(stack, value);
-
-    std::vector<int64_t> size_vec;
-    if (value.isIntList()) {
-      size_vec = value.toIntVector();
-    } else if (value.isNone()) {
-      size_vec.clear();
-    } else {
-      TORCH_INTERNAL_ASSERT(
-          false, "profileSize does not support data type: ", value.tagKind());
-    }
-    if (!pn->hasAttribute(sizeAttr)) {
-      pn->is_(sizeAttr, size_vec);
-    } else {
-      auto profiled_ints = pn->is(sizeAttr);
-      TORCH_INTERNAL_ASSERT(
-          profiled_ints.size() == size_vec.size() &&
-              std::equal(
-                  profiled_ints.begin(), profiled_ints.end(), size_vec.begin()),
-          "profiling ivalue doesn't support merge");
-    }
-    push(stack, value);
-  };
-  pn->setCallback(ivalue_profiler);
-}
-
-void profileIntList(ProfilingRecord* pr, Node* node, size_t offset) {
-  auto pn = insertProfileIValueOp(node, offset, pr);
-
-  const auto ivalue_profiler = [pr, pn](Stack& stack) {
-    std::lock_guard<std::mutex> lock(pr->mutex_);
-
-    // TODO: we don't care about merging multiple profiling runs as we don't
-    // support it at all;
-    int64_t frame_id = 0;
-    pop(stack, frame_id);
-    IValue value;
-    pop(stack, value);
-    TORCH_INTERNAL_ASSERT(
-        value.isIntList(), "profiling seeing the wrong data type");
-    if (!pn->hasAttribute(intListAttr)) {
-      pn->is_(intListAttr, value.toIntVector());
-    } else {
-      auto profiled_ints = pn->is(intListAttr);
-      auto input_ints = value.toIntList();
-      TORCH_INTERNAL_ASSERT(
-          profiled_ints.size() == input_ints.size() &&
-              std::equal(
-                  profiled_ints.begin(),
-                  profiled_ints.end(),
-                  input_ints.begin()),
-          "profiling ivalue doesn't support merge");
-    }
-    push(stack, value);
-  };
-
-  pn->setCallback(ivalue_profiler);
-}
-
-void profileBool(ProfilingRecord* pr, Node* node, size_t offset) {
-  auto pn = insertProfileIValueOp(node, offset, pr);
-
-  const auto ivalue_profiler = [pr, pn](Stack& stack) {
-    std::lock_guard<std::mutex> lock(pr->mutex_);
-
-    // TODO: we don't care about merging multiple profiling runs as we don't
-    // support it at all;
-    int64_t frame_id = 0;
-    pop(stack, frame_id);
-    IValue value;
-    pop(stack, value);
-    TORCH_INTERNAL_ASSERT(
-        value.isBool(), "profiling seeing the wrong data type");
-    if (!pn->hasAttribute(boolAttr)) {
-      pn->i_(boolAttr, value.toBool());
-    } else {
-      auto profiled_bool = pn->i(boolAttr);
-      auto input_bool = value.toBool();
-      TORCH_INTERNAL_ASSERT(
-          input_bool == profiled_bool,
-          "profiling ivalue doesn't support merge");
-    }
-    push(stack, value);
-  };
-
-  pn->setCallback(ivalue_profiler);
-}
-
-void profileInt(ProfilingRecord* pr, Node* node, size_t offset) {
-  auto pn = insertProfileIValueOp(node, offset, pr);
-
-  const auto ivalue_profiler = [pr, pn](Stack& stack) {
-    std::lock_guard<std::mutex> lock(pr->mutex_);
-
-    // TODO: we don't care about merging multiple profiling runs as we don't
-    // support it at all;
-    int64_t frame_id = 0;
-    pop(stack, frame_id);
-    IValue value;
-    pop(stack, value);
-    TORCH_INTERNAL_ASSERT(
-        value.isInt(), "profiling seeing the wrong data type");
-    if (!pn->hasAttribute(intAttr)) {
-      pn->i_(intAttr, value.toInt());
-    } else {
-      auto profiled_int = pn->i(intAttr);
-      auto input_int = value.toInt();
-      TORCH_INTERNAL_ASSERT(
-          input_int == profiled_int, "profiling ivalue doesn't support merge");
-    }
-    push(stack, value);
-  };
-
-  pn->setCallback(ivalue_profiler);
-}
-
-void profileBoolList(ProfilingRecord* pr, Node* node, size_t offset) {
-  auto pn = insertProfileIValueOp(node, offset, pr);
-
-  const auto ivalue_profiler = [pr, pn](Stack& stack) {
-    std::lock_guard<std::mutex> lock(pr->mutex_);
-
-    // TODO: we don't care about merging multiple profiling runs as we don't
-    // support it at all;
-    int64_t frame_id = 0;
-    pop(stack, frame_id);
-    IValue value;
-    pop(stack, value);
-    TORCH_INTERNAL_ASSERT(
-        value.isBoolList(), "profiling seeing the wrong data type");
-    if (!pn->hasAttribute(boolListAttr)) {
-      auto list = value.toBoolList();
-      std::vector<int64_t> val(list.begin(), list.end());
-      pn->is_(boolListAttr, val);
-    } else {
-      auto profiled_ints = pn->is(boolListAttr);
-      auto input_bools = value.toBoolList();
-      TORCH_INTERNAL_ASSERT(
-          profiled_ints.size() == input_bools.size() &&
-              std::equal(
-                  input_bools.begin(),
-                  input_bools.end(),
-                  profiled_ints.begin()),
-          "profiling ivalue doesn't support merge");
-    }
-    push(stack, value);
-  };
-
-  pn->setCallback(ivalue_profiler);
-}
+} // namespace
 
-bool anyInBlock(
-    const Block* block,
-    const std::function<bool(const Node*)>& fn) {
+bool hasReductionNode(const Block* block) {
   for (auto node : block->nodes()) {
-    if (fn(node)) {
+    if (isReductionNode(node)) {
       return true;
     }
     for (auto block : node->blocks()) {
-      if (anyInBlock(block, fn)) {
+      if (hasReductionNode(block)) {
         return true;
       }
     }
@@ -1728,243 +641,14 @@ bool anyInBlock(
   return false;
 }
 
-} // namespace
-
-bool hasReductionNode(const Block* block) {
-  return anyInBlock(block, isReductionNode);
-}
-
 bool isReductionNode(const Node* node) {
   return IrParser::isReductionNode(node);
 }
 
-bool isReductionToSizeNode(const Node* node) {
-  return IrParser::isReductionToSizeNode(node);
-}
-
-bool hasNormalizationNode(const Block* block) {
-  return anyInBlock(block, isNormalizationNode);
-}
-
-bool isNormalizationNode(const Node* node) {
-  return IrParser::isNormalizationNode(node);
-}
-
-bool isElementWiseNode(const Node* node) {
-  return IrParser::isElementWiseNode(node);
-}
-
 bool isNodeParsible(const Node* node) {
   return IrParser::canParseNode(node);
 }
 
-bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
-  // is skip constant necessary?
-  if (node->input(offset)->node()->kind() == prim::Constant) {
-    return false;
-  }
-
-  static auto dropout_schema =
-      getOperatorForLiteral(
-          "aten::dropout(Tensor input, float p, bool train) -> Tensor")
-          ->schema();
-  if (node->matches(dropout_schema)) {
-    switch (offset) {
-      // argument 2: Is training?
-      case 2:
-        profileBool(pr, node, offset);
-        break;
-      default:
-        return false;
-    }
-    return true;
-  }
-
-  static auto reduction_operator_schema =
-      getOperatorForLiteral(
-          "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)")
-          ->schema();
-  if (node->matches(reduction_operator_schema)) {
-    switch (offset) {
-      // argument 1: reduction axes;
-      case 1:
-        profileIntList(pr, node, offset);
-        break;
-      // argument 2: keepdim;
-      case 2:
-        profileBool(pr, node, offset);
-        break;
-      default:
-        return false;
-    }
-    return true;
-  }
-
-  static auto sum_to_size_schema =
-      getOperatorForLiteral(
-          "aten::sum_to_size(Tensor self, int[] size) -> Tensor")
-          ->schema();
-  static auto grad_sum_to_size_schema =
-      getOperatorForLiteral(
-          "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)")
-          ->schema();
-  if (node->matches(sum_to_size_schema) ||
-      node->matches(grad_sum_to_size_schema)) {
-    switch (offset) {
-      // argument 1: reduction sizes;
-      case 1:
-        // TODO(profile_size): double check optional[size]?
-        profileSize(pr, node, offset);
-        break;
-      default:
-        return false;
-    }
-    return true;
-  }
-
-  static auto batch_norm_impl_index_schema =
-      getOperatorForLiteral(
-          "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)")
-          ->schema();
-  static auto native_batch_norm_schema =
-      getOperatorForLiteral(
-          "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")
-          ->schema();
-  static auto batch_norm_schema =
-      getOperatorForLiteral(
-          "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")
-          ->schema();
-  static auto instance_norm_schema =
-      getOperatorForLiteral(
-          "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor")
-          ->schema();
-  if (node->matches(native_batch_norm_schema) ||
-      node->matches(batch_norm_impl_index_schema) ||
-      node->matches(batch_norm_schema) || node->matches(instance_norm_schema)) {
-    switch (offset) {
-      // argument 5: training;
-      case 5:
-        profileBool(pr, node, offset);
-        break;
-      default:
-        return false;
-    }
-    return true;
-  }
-
-  static auto native_layer_norm_schema =
-      getOperatorForLiteral(
-          "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)")
-          ->schema();
-  static auto layer_norm_schema =
-      getOperatorForLiteral(
-          "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor")
-          ->schema();
-  if (node->matches(native_layer_norm_schema) ||
-      node->matches(layer_norm_schema)) {
-    switch (offset) {
-      case 1:
-        profileIntList(pr, node, offset);
-        break;
-      default:
-        return false;
-    }
-    return true;
-  }
-
-  static auto native_batch_norm_backward_schema =
-      getOperatorForLiteral(
-          "aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)")
-          ->schema();
-  if (node->matches(native_batch_norm_backward_schema)) {
-    switch (offset) {
-      // argument 7: training;
-      case 7:
-        profileBool(pr, node, offset);
-        break;
-      // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-      case 9:
-        profileBoolList(pr, node, offset);
-        break;
-      default:
-        return false;
-    }
-    return true;
-  }
-
-  static auto batch_norm_impl_index_backward_schema =
-      getOperatorForLiteral(
-          "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)")
-          ->schema();
-  if (node->matches(batch_norm_impl_index_backward_schema)) {
-    switch (offset) {
-      // TODO: guard impl_index, but I think that's not needed;
-      // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-      case 8: // argument 8: training;
-        profileBool(pr, node, offset);
-        break;
-      // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-      case 10:
-        profileBoolList(pr, node, offset);
-        break;
-      default:
-        return false;
-    }
-    return true;
-  }
-
-  static auto native_layer_norm_backward_schema =
-      getOperatorForLiteral(
-          "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)")
-          ->schema();
-  if (node->matches(native_layer_norm_backward_schema)) {
-    switch (offset) {
-      case 2:
-        profileIntList(pr, node, offset);
-        break;
-      // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-      case 7:
-        profileBoolList(pr, node, offset);
-        break;
-      default:
-        return false;
-    }
-    return true;
-  }
-
-  static auto to_dtype_schema =
-      getOperatorForLiteral(
-          "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor")
-          ->schema();
-  if (node->matches(to_dtype_schema)) {
-    switch (offset) {
-      case 1:
-        profileInt(pr, node, offset);
-        return true;
-      default:
-        return false;
-    }
-  }
-
-  return false;
-}
-
-void insertProfileNodesForCUDAFuser_(Block* block, ProfilingRecord* pr) {
-  for (const auto& n : block->nodes()) {
-    for (size_t offset = 0; offset < n->inputs().size(); offset++) {
-      insertProfileIValue(pr, n, offset);
-    }
-
-    for (auto ib : n->blocks()) {
-      insertProfileNodesForCUDAFuser_(ib, pr);
-    }
-  }
-}
-
-void InsertProfileNodes(ProfilingRecord* pr) {
-  insertProfileNodesForCUDAFuser_(pr->profiled_graph_->block(), pr);
-}
-
 std::unique_ptr<Fusion> parseJitIR(const std::shared_ptr<Graph>& graph) {
   FUSER_PERF_SCOPE("parseJitIR");
 
index 56d935d..592263e 100644 (file)
@@ -2,7 +2,6 @@
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/ir/ir.h>
-#include <torch/csrc/jit/runtime/profiling_record.h>
 
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
 
@@ -32,19 +31,12 @@ constexpr int kNonFcdReductionThreadX = 32;
 constexpr int kNonFcdReductionThreadY = 32;
 
 TORCH_CUDA_CU_API bool hasReductionNode(const Block* block);
-TORCH_CUDA_CU_API bool isReductionToSizeNode(const Node* node);
-TORCH_CUDA_CU_API bool isReductionNode(const Node* node);
-
-TORCH_CUDA_CU_API bool hasNormalizationNode(const Block* block);
-TORCH_CUDA_CU_API bool isNormalizationNode(const Node* node);
 
-TORCH_CUDA_CU_API bool isElementWiseNode(const Node* node);
+TORCH_CUDA_CU_API bool isReductionNode(const Node* node);
 
 // returns whether or not a parsing function exists for the given node type.
 TORCH_CUDA_CU_API bool isNodeParsible(const Node* node);
 
-void InsertProfileNodes(ProfilingRecord* pr);
-
 // lowers PyTorch jit graph to `Fusion`.
 TORCH_CUDA_CU_API std::unique_ptr<Fusion> parseJitIR(
     const std::shared_ptr<Graph>& graph);
index 3167c27..90e3f4f 100644 (file)
@@ -1,7 +1,6 @@
 #include <torch/csrc/jit/codegen/cuda/partition.h>
 
 #include <ATen/core/jit_type.h>
-#include <ATen/cuda/CUDAContext.h>
 #include <c10/util/irange.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/parser.h>
@@ -13,22 +12,6 @@ namespace cuda {
 
 namespace {
 
-bool hasNonElementWiseOperation(const Node* node) {
-  if (node->kind() == prim::CudaFusionGroup) {
-    for (auto n : node->g(attr::Subgraph)->nodes()) {
-      if (hasNonElementWiseOperation(n)) {
-        return true;
-      }
-    }
-  } else {
-    // prim::Constant is not parsible, but it is also not nonElementWise
-    if (node->kind() != prim::Constant && !isElementWiseNode(node)) {
-      return true;
-    }
-  }
-  return false;
-}
-
 // Check all outputs are:
 //   1. TensorType
 //   2. on the same device;
@@ -52,7 +35,7 @@ static c10::optional<c10::Device> getDevice(const Node* node) {
   return c10::nullopt;
 }
 
-static bool isFusibleDevice(const Node* node, const c10::Device device) {
+static bool isFusableDevice(const Node* node, const c10::Device device) {
   for (auto value : node->outputs()) {
     auto output_device = getDevice(value);
     if (output_device.has_value() && output_device.value() != device) {
@@ -63,94 +46,28 @@ static bool isFusibleDevice(const Node* node, const c10::Device device) {
 }
 
 // TODO: we need to check input type when we handle `to()`
-static bool isFusibleDevice(const Node* node) {
+static bool isFusableDevice(const Node* node) {
   auto device = getDevice(node);
   if (!device.has_value()) {
     return true;
   }
-  return device->is_cuda() &&
-      (at::cuda::getDeviceProperties(device->index())->major >= 7 ||
-       !hasNonElementWiseOperation(node));
-}
-
-bool compatibleType(const torch::jit::Value* val) {
-  if (auto tensor_type = val->type()->cast<c10::TensorType>()) {
-    if (tensor_type->scalarType().has_value()) {
-      if (aten_to_data_type(tensor_type->scalarType().value()) ==
-          DataType::Null) {
-        return false;
-      }
-    }
-  }
-  return true;
+  return device->is_cuda();
 }
 
-bool checkInputTensorTypes(const Node* node) {
-  for (size_t i = 0; i < node->inputs().size(); i++) {
-    const auto& val = node->inputs()[i];
-    if (!compatibleType(val)) {
-      // special case on aten::_batch_norm_impl_index_backward, the 11th output
-      // is going to be discarded, so no need to check data type there.
-      if (node->kind() ==
-              c10::Symbol::fromQualString(
-                  "aten::_batch_norm_impl_index_backward") &&
-          i == 11) {
-        continue;
-      }
-      return false;
-    }
-  }
-  return true;
-}
-
-bool checkOutputTensorTypes(const Node* node) {
-  for (size_t i = 0; i < node->outputs().size(); i++) {
-    const auto& val = node->outputs()[i];
-    if (!compatibleType(val)) {
-      // special case on aten::_batch_norm_impl_index, the 4th output
-      // is going to be discarded, so no need to check data type there.
-      if (node->kind() ==
-              c10::Symbol::fromQualString("aten::_batch_norm_impl_index") &&
-          i == 3) {
-        continue;
-      }
-      return false;
-    }
-  }
-  return true;
+inline bool isFusableNode(const Node* node) {
+  // checks if node is compatible with parser:
+  // 1. if we have a parsing rule; or 2. if the node is already a fusion group.
+  return (isNodeParsible(node) || node->kind() == prim::CudaFusionGroup);
 }
 
-inline bool isFusibleNode(const Node* node) {
-  if (node->kind() == prim::CudaFusionGroup)
+bool hasReductionOperation(const Node* node) {
+  if (isReductionNode(node)) {
     return true;
-  // Check we have a parsing rule
-  bool isFusible = isNodeParsible(node);
-  // Check if we have a tensor type it's one we support
-  isFusible = isFusible && checkInputTensorTypes(node);
-  isFusible = isFusible && checkOutputTensorTypes(node);
-  // Check if already part of a fusion group
-  return isFusible;
-}
-
-bool maybeBroadcast(
-    const TensorTypePtr& type,
-    const std::vector<c10::optional<int64_t>>& shape) {
-  if (type->dim()) {
-    if (type->dim().value() < shape.size()) {
-      // no broadcast for reduction operation;
-      return false;
-    } else if (type->dim().value() > shape.size()) {
-      // increased rank means there is reduction;
-      return true;
-    } else {
-      // same rank, we need to iterate through sizes and check if size-1
-      // exists in input `shape`
-      for (const auto& opt_size : shape) {
-        // TODO: not sure if we need to check for output size != 1, since we
-        // are currently marking all size-1 dimension as broadcast in codegen.
-        if (opt_size.has_value() && opt_size.value() == 1) {
-          return true;
-        }
+  }
+  if (node->kind() == prim::CudaFusionGroup) {
+    for (auto n : node->g(attr::Subgraph)->nodes()) {
+      if (hasReductionOperation(n)) {
+        return true;
       }
     }
   }
@@ -168,44 +85,33 @@ bool maybeBroadcast(
 bool maybeBroadcastOnShape(
     const Node* n,
     const std::vector<c10::optional<int64_t>>& shape) {
-  // TODO: we are only checking output 0. This means that our current check for
-  // normalization is not complete.
+  TORCH_INTERNAL_ASSERT(
+      n->outputs().size() == 1,
+      "not expecting multiple outputs from a node, graph partitioning logic needs to be updated");
   // assumes that if output is not a tensor type, it's not broadcasting
   if (auto out_type = n->output(0)->type()->cast<TensorType>()) {
-    return maybeBroadcast(out_type, shape);
-  }
-  return false;
-};
-
-// return true if node is pointwise operation and input tensors all have
-// identical shape.
-bool isNonBroadcastElementWise(const Node* n) {
-  if (hasNonElementWiseOperation(n)) {
-    return false;
-  }
-
-  for (const auto output : n->outputs()) {
-    const auto& n_output_type = output->type()->cast<TensorType>();
-
-    // TODO: we need to stay on safer side instead of "default to return true
-    // when shape information is not available.", Change that when we enable
-    // profiling on autodiff FW execution.
-    if (n_output_type != nullptr && n_output_type->sizes().sizes()) {
-      const std::vector<c10::optional<int64_t>>& n_output_shape =
-          n_output_type->sizes().sizes().value();
-
-      for (auto input : n->inputs()) {
-        if (auto t_type = input->type()->cast<TensorType>()) {
-          if (maybeBroadcast(t_type, n_output_shape)) {
-            return false;
+    if (out_type->dim()) {
+      if (out_type->dim().value() < shape.size()) {
+        // no broadcast for reduction operation;
+        return false;
+      } else if (out_type->dim().value() > shape.size()) {
+        // increased rank means there is reduction;
+        return true;
+      } else {
+        // same rank, we need to iterate through sizes and check if size-1
+        // exists in input `shape`
+        for (const auto& opt_size : shape) {
+          // TODO: not sure if we need to check for output size != 1, since we
+          // are currently marking all size-1 dimension as broadcast in codegen.
+          if (opt_size.has_value() && opt_size.value() == 1) {
+            return true;
           }
         }
       }
     }
   }
-
-  return true;
-}
+  return false;
+};
 
 //! [ Note - tricky broadcasting ]
 //!
@@ -385,31 +291,30 @@ bool createTrickyBroadcast(const Node* consumer, const Node* producer) {
 
 } // namespace
 
-bool isFusibleCudaFusionGroup(const Node* node) {
-  FUSER_PERF_SCOPE("isFusibleCudaFusionGroup");
+bool isFusableCudaFusionGroup(const Node* node) {
+  FUSER_PERF_SCOPE("isFusableCudaFusionGroup");
 
-  if (isFusibleNode(node)) {
-    auto ret = isFusibleDevice(node);
-    return ret;
+  if (isFusableNode(node)) {
+    return isFusableDevice(node);
   }
   return false;
 }
 
-bool isFusibleCudaFusionGroup(const Node* fusion, const Node* node) {
-  FUSER_PERF_SCOPE("isFusibleCudaFusionGroup");
-  bool fused = false;
+bool isFusableCudaFusionGroup(const Node* fusion, const Node* node) {
+  FUSER_PERF_SCOPE("isFusableCudaFusionGroup");
+
   // TODO: lift the restriction of not fusing producer containing reduction when
   //       we have proper scheduling.
-  if (isFusibleCudaFusionGroup(node)) {
+  if (isFusableCudaFusionGroup(node) && !hasReductionOperation(node) &&
+      !createTrickyBroadcast(fusion, node)) {
     // ensure if the node has a designated device, it's on the same device with
     // fusion.
     // TODO: is there a danger of us fusing operations that's supposed to be on
     //       separate GPUs? And is that necessarily bad?
     auto device = getDevice(fusion);
-    fused = (!device.has_value() || isFusibleDevice(node, device.value()));
+    return (!device.has_value() || isFusableDevice(node, device.value()));
   }
-
-  return fused;
+  return false;
 }
 
 } // namespace cuda
index 4ebac40..61d0df4 100644 (file)
@@ -19,10 +19,10 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-TORCH_CUDA_CU_API bool isFusibleCudaFusionGroup(const Node* node);
+TORCH_CUDA_CU_API bool isFusableCudaFusionGroup(const Node* node);
 
 // consider if `node` could be fused into `fusion`
-TORCH_CUDA_CU_API bool isFusibleCudaFusionGroup(
+TORCH_CUDA_CU_API bool isFusableCudaFusionGroup(
     const Node* fusion,
     const Node* node);
 
index e2a1e46..714e7ef 100644 (file)
@@ -1,14 +1,13 @@
 #include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
 
 #include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
 #include <torch/csrc/jit/codegen/cuda/index_compute.h>
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_utils.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
 #include <torch/csrc/jit/codegen/cuda/transform_iter.h>
 
 #include <c10/util/irange.h>
@@ -18,255 +17,278 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-namespace {
-
-// find the first (and only) TensorView output
-//
-// TODO(kir): same question as ir_utils::getTvOutput():
-//    why do we assume a single TV output?
-//
-kir::TensorView* firstTensorViewOutput(const kir::Expr* expr) {
-  TORCH_INTERNAL_ASSERT(expr != nullptr);
-  for (auto out : expr->outputs()) {
-    if (out->isA<kir::TensorView>()) {
-      return out->as<kir::TensorView>();
-    } else if (out->isA<kir::TensorIndex>()) {
-      return out->as<kir::TensorIndex>()->view();
+std::vector<kir::Bool*> PredicateCompute::computePredicates(
+    const TensorView* tv,
+    const std::vector<Val*>& indices,
+    bool use_rfactor) {
+  FUSER_PERF_SCOPE("computePredicates");
+
+  const std::vector<IterDomain*>& root =
+      use_rfactor ? tv->getMaybeRFactorDomain() : tv->getRootDomain();
+
+  TORCH_INTERNAL_ASSERT(root.size() == indices.size());
+
+  bool no_pred_needed = true;
+  for (auto id : tv->domain()->domain()) {
+    if (id->getOrigin() != nullptr) {
+      no_pred_needed = false;
     }
   }
-  TORCH_INTERNAL_ASSERT(false, "Missing kir::TensorView output");
-}
 
-bool isTensorIndexOp(kir::Expr* expr) {
-  const auto& outputs = expr->outputs();
-  return outputs.size() >= 1 && outputs[0]->isA<kir::TensorIndex>();
-}
+  if (no_pred_needed) {
+    return {};
+  }
 
-bool isOutputLocal(const kir::Expr* expr) {
-  return std::all_of(
-      expr->outputs().begin(),
-      expr->outputs().end(),
-      [](const kir::Val* output) {
-        return !output->isA<kir::TensorView>() ||
-            output->as<kir::TensorView>()->memoryType() == MemoryType::Local;
-      });
-}
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+
+  auto true_bool = ir_builder.create<kir::Bool>(true);
+  std::vector<kir::Bool*> preds(root.size(), true_bool);
+  Val* extent = nullptr;
 
-} // namespace
+  for (const auto i : c10::irange(indices.size())) {
+    const bool zero_ind = indices[i]->isZeroInt();
+    const bool simple_ind = indices[i]->getOrigin() == nullptr;
+
+    if (root[i]->isBroadcast()) {
+      continue;
+    } else if (simple_ind && !zero_ind) {
+      extent = nullptr;
+      continue;
+    } else if (zero_ind) {
+      if (root[i]->extent()->isOneInt()) {
+        continue;
+      }
+      const auto lowered_extent = GpuLower::lowerValue(root[i]->extent());
+      if (extent == nullptr) {
+        extent = lowered_extent;
+      } else {
+        extent = ir_builder.mulExpr(extent, lowered_extent);
+      }
+    } else {
+      auto local_extent = GpuLower::lowerValue(root[i]->extent());
+      if (extent != nullptr) {
+        local_extent = ir_builder.mulExpr(extent, local_extent);
+      }
+      auto pred = ir_builder.ltExpr(indices[i], local_extent);
+      extent = nullptr;
+      TORCH_INTERNAL_ASSERT(
+          pred->getValType().value() == ValType::KirScalar &&
+          pred->getDataType().value() == DataType::Bool);
+      preds[i] = pred->as<kir::Bool>();
+    }
+  }
+  return preds;
+}
 
 kir::Bool* PredicateCompute::getInlinePredicate(
-    const kir::Expr* expr,
+    Expr* expr,
     const std::vector<kir::ForLoop*>& loops,
     kir::Bool* thread_pred,
-    PredicateType pred_type) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::getInlinePredicate");
+    bool ignore_block_grid_reductions) {
+  FUSER_PERF_SCOPE("getInlinePredicate");
 
-  const auto gpu_lower = GpuLower::current();
-  kir::IrBuilder ir_builder(gpu_lower->kernel());
-
-  // If outputs are registers, no need to predicate for threads
-  if (isOutputLocal(expr)) {
-    thread_pred = ir_builder.trueVal();
-  }
+  kir::IrBuilder ir_builder(GpuLower::current()->kernel());
 
   if (loops.empty()) {
-    TORCH_INTERNAL_ASSERT(thread_pred != nullptr);
-    return thread_pred;
+    return ir_builder.create<kir::Bool>(true);
   }
 
-  auto out_tv = firstTensorViewOutput(expr);
-
-  if (gpu_lower->predicateElimination().canOmitPredicate(expr)) {
-    return thread_pred;
+  // Handle these elsewhere
+  if (ignore_block_grid_reductions &&
+      expr->getExprType() == ExprType::ReductionOp &&
+      (expr->as<ReductionOp>()->out()->as<TensorView>()->hasBlockReduction() ||
+       expr->as<ReductionOp>()->out()->as<TensorView>()->hasGridReduction())) {
+    return ir_builder.create<kir::Bool>(true);
   }
 
-  auto all_preds = Index::getReferenceRootPredicates(out_tv, loops);
+  TORCH_INTERNAL_ASSERT(
+      ir_utils::isTVOp(expr),
+      "Cannot generate predicate based on operation without a TensorView.");
 
-  std::vector<kir::Bool*> preds;
+  auto out_tv = ir_utils::getTVOutput(expr);
 
-  auto is_true = [](const kir::Bool* p) {
-    return p->isConst() && p->value().value();
-  };
-
-  // When pred_type is ReductionWrite, filter out predicates for
-  // reduction axes. For blockReduce, this is necessary when reduction
-  // axes start at non-zero offsets and parallelized with TID since
-  // blockReduce returns a valid output only at offset-zero
-  // threads. Similarly, for gridReduce, the last block to store the
-  // output may be predicated out with the read predicate, so the
-  // write predicate needs to ignore the reduction axes.
-  bool non_zero_start_found = false;
-  for (size_t i = 0; i < all_preds.first.size(); ++i) {
-    auto pred = all_preds.first[i];
-    if (pred_type == PredicateType::ReductionWrite) {
-      const auto& concrete_root_ids = all_preds.second[i];
-      bool pred_for_reduction_axis = false;
-      for (auto pred_root_id : concrete_root_ids) {
-        auto kir_pred_root_id =
-            gpu_lower->lowerValue(pred_root_id)->as<kir::IterDomain>();
-        auto it = std::find_if(
-            out_tv->domain()->rootDomain().begin(),
-            out_tv->domain()->rootDomain().end(),
-            [&](const auto& out_root_id) {
-              return gpu_lower->caIndexMap().areMapped(
-                  kir_pred_root_id, out_root_id);
-            });
-        TORCH_INTERNAL_ASSERT(
-            it != out_tv->domain()->rootDomain().end(),
-            "No corresponding root ID found for ",
-            pred_root_id);
-        auto out_root_id = *it;
-        if (out_root_id->isReduction()) {
-          if (!out_root_id->start()->isZeroInt()) {
-            non_zero_start_found = true;
-          }
-          pred_for_reduction_axis = true;
-          break;
-        }
-      }
-      // Don't add the predicate if it corresponds to a reduction axis
-      if (pred_for_reduction_axis) {
-        continue;
-      }
+  auto pred_contiguity = out_tv->domain()->contiguity();
+
+  for (auto inp : expr->inputs()) {
+    if (!ir_utils::isTV(inp)) {
+      continue;
     }
-    if (!is_true(pred)) {
-      preds.push_back(pred);
+    auto inp_tv = inp->as<TensorView>();
+    // NOLINTNEXTLINE(bugprone-branch-clone)
+    if (inp_tv->domain()->hasRFactor()) {
+      continue;
+    } else if (
+        inp_tv->getMemoryType() == MemoryType::Shared ||
+        inp_tv->getMemoryType() == MemoryType::Local) {
+      continue;
+    } else {
+      pred_contiguity = IndexCompute::contiguityAnd(
+          pred_contiguity,
+          IndexCompute::contiguityPasC(inp_tv->domain(), out_tv->domain()));
     }
   }
 
-  // When generating a predicate for blockReduce writes and not for
-  // gridReduce, if all reduction axes start with zero, we can just
-  // use the same predicate for reads. nullptr is returned then.
-  if (pred_type == PredicateType::ReductionWrite && !non_zero_start_found &&
-      !out_tv->fuserTv()->domain()->hasGridReduction()) {
-    return nullptr;
+  auto pred_inds =
+      Index::getConsumerRootPredIndices(out_tv, loops, pred_contiguity);
+  auto root_indices = pred_inds.first;
+  bool use_maybe_rfactor = pred_inds.second;
+
+  if (out_tv->getMemoryType() == MemoryType::Local && out_tv->hasReduction() &&
+      !use_maybe_rfactor) {
+    auto tv_filter_inp_view =
+        ir_utils::filterByType<TensorView>(expr->inputs());
+    auto has_tv_inputs = tv_filter_inp_view.begin() != tv_filter_inp_view.end();
+    // If predicates doesn't need maybe_rfactor, but it has reduction axes, and
+    // expr has no inputs, we're pretty confident we're intializing a reduction
+    // buffer. If we're initing a reduction buffer don't generate an inline
+    // predicate.
+    if (!has_tv_inputs) {
+      return ir_builder.create<kir::Bool>(true);
+    }
   }
 
-  if (thread_pred != nullptr && !is_true(thread_pred)) {
-    preds.push_back(thread_pred);
+  auto all_preds = PredicateCompute::computePredicates(
+      out_tv, root_indices, use_maybe_rfactor);
+
+  // If we have thread predicates, add those
+  if (thread_pred != nullptr) {
+    all_preds.push_back(thread_pred);
   }
 
+  std::vector<kir::Bool*> preds;
+
+  for (auto pred : all_preds)
+    if (!(pred->isConst()) || !(pred->isConst() && pred->value().value()))
+      preds.push_back(pred);
+
   if (preds.empty()) {
-    return ir_builder.trueVal();
+    return ir_builder.create<kir::Bool>(true);
   }
 
-  kir::Val* cond = preds[0];
-  for (size_t i = 1; i < preds.size(); i++) {
+  Val* cond = preds[0];
+
+  for (decltype(preds.size()) i{1}; i < preds.size(); i++) {
     cond = ir_builder.andExpr(cond, preds[i]);
   }
 
+  TORCH_INTERNAL_ASSERT(
+      cond->getValType().value() == ValType::KirScalar &&
+          cond->getDataType().value() == DataType::Bool,
+      "Error computing predicate, should be returning a Bool, but returning ",
+      cond->getDataType().value());
+
   return cond->as<kir::Bool>();
 }
 
-kir::Bool* UnswitchPredicate::get(
+kir::Bool* UnrollPredicate::get(
     const std::vector<kir::ForLoop*>& outer_loops,
-    kir::ForLoop* unrolled_loop) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::get");
+    kir::ForLoop* unrolled_loop,
+    const std::unordered_map<IterDomain*, IterDomain*>& p2c_root_map) {
+  FUSER_PERF_SCOPE("UnrollPredicate::get");
 
   kir::IrBuilder ir_builder(GpuLower::current()->kernel());
 
-  UnswitchPredicate up(outer_loops, unrolled_loop);
+  UnrollPredicate up(outer_loops, unrolled_loop, p2c_root_map);
 
-  kir::Val* unroll_pred = nullptr;
-  for (auto pred : up.predicates_) {
-    if (pred->isConst() && pred->value().value()) {
-      continue;
-    } else if (unroll_pred == nullptr) {
+  std::unordered_set<kir::Bool*> pred_set;
+  for (auto entry : up.predicates_) {
+    pred_set.emplace(entry.second);
+  }
+
+  if (up.predicates_.empty()) {
+    return ir_builder.create<kir::Bool>(true);
+  }
+
+  Val* unroll_pred = nullptr;
+  for (auto pred : pred_set) {
+    if (unroll_pred == nullptr) {
       unroll_pred = pred;
     } else {
       unroll_pred = ir_builder.andExpr(unroll_pred, pred);
     }
   }
-
-  return unroll_pred == nullptr ? ir_builder.trueVal()
-                                : unroll_pred->as<kir::Bool>();
+  TORCH_INTERNAL_ASSERT(
+      // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
+      unroll_pred->getValType().value() == ValType::KirScalar &&
+      unroll_pred->getDataType().value() == DataType::Bool);
+  return unroll_pred->as<kir::Bool>();
 }
 
-void UnswitchPredicate::predicateOn(kir::Expr* tv_expr) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::predicateOn");
+void UnrollPredicate::predicateOn(Expr* tv_expr) {
+  FUSER_PERF_SCOPE("UnrollPredicate::predicateOn");
 
   if (for_loops_.empty()) {
     return;
   }
 
-  const auto gpu_lower = GpuLower::current();
-
-  if (gpu_lower->predicateElimination().canOmitPredicate(tv_expr)) {
-    return;
-  }
-
-  auto out_tv = firstTensorViewOutput(tv_expr);
+  auto out_tv = ir_utils::getTVOutput(tv_expr);
 
-  auto pred_info = Index::getReferenceRootPredicates(out_tv, for_loops_, true);
+  auto pred_contiguity = out_tv->domain()->contiguity();
 
-  for (auto i : c10::irange(pred_info.first.size())) {
-    auto pred = pred_info.first[i];
-    if (pred->isConst() && pred->value()) {
+  for (auto inp : tv_expr->inputs()) {
+    if (!ir_utils::isTV(inp)) {
+      continue;
+    }
+    auto inp_tv = inp->as<TensorView>();
+    // NOLINTNEXTLINE(bugprone-branch-clone)
+    if (inp_tv->domain()->hasRFactor()) {
       continue;
+    } else if (
+        inp_tv->getMemoryType() == MemoryType::Shared ||
+        inp_tv->getMemoryType() == MemoryType::Local) {
+      continue;
+    } else {
+      pred_contiguity = IndexCompute::contiguityAnd(
+          pred_contiguity,
+          IndexCompute::contiguityPasC(inp_tv->domain(), out_tv->domain()));
     }
+  }
 
-    const auto& root_ids = pred_info.second[i];
+  auto pred_inds = Index::getConsumerRootPredIndices(
+      out_tv, for_loops_, pred_contiguity, true);
+  auto root_indices = pred_inds.first;
+  auto use_rfactor = pred_inds.second;
 
-    bool add_pred = false;
+  auto all_preds =
+      PredicateCompute::computePredicates(out_tv, root_indices, use_rfactor);
 
-    for (auto root_id : root_ids) {
-      auto kir_root_id = gpu_lower->lowerValue(root_id)->as<kir::IterDomain>();
+  auto root_dom =
+      use_rfactor ? out_tv->getMaybeRFactorDomain() : out_tv->getRootDomain();
 
-      if (kir_root_id->isBroadcast()) {
-        continue;
-      }
+  TORCH_INTERNAL_ASSERT(
+      all_preds.size() == root_dom.size(),
+      "Predicates should be produced for every dimension, even if it's simply set as true.");
 
-      if (std::find(
-              predicated_iter_dom_.begin(),
-              predicated_iter_dom_.end(),
-              kir_root_id) == predicated_iter_dom_.end()) {
-        add_pred = true;
-        predicated_iter_dom_.push_back(kir_root_id);
-      }
-    }
-    if (add_pred) {
-      predicates_.push_back(pred);
+  for (const auto i : c10::irange(all_preds.size())) {
+    if (all_preds[i]->isConst() && all_preds[i]->value().value()) {
+      continue;
     }
+    auto term_id = loop_utils::getTermIDInMap(root_dom[i], p2c_root_map_);
+    predicates_[term_id] = all_preds[i];
   }
 }
 
-void UnswitchPredicate::openLoop(kir::ForLoop* fl) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::openLoop");
+void UnrollPredicate::openLoop(kir::ForLoop* fl) {
+  FUSER_PERF_SCOPE("UnrollPredicate::openLoop");
 
   for_loops_.push_back(fl);
 
   for (auto expr : fl->body().exprs()) {
-    if (ir_utils::isTVOp(expr) || isTensorIndexOp(expr)) {
+    if (ir_utils::isTVOp(expr)) {
       predicateOn(expr);
-    } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
-      openIte(ite);
-    } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
-      openLoop(for_loop);
+    } else if (expr->getExprType().value() == ExprType::ForLoop) {
+      openLoop(expr->as<kir::ForLoop>());
     }
   }
 
   for_loops_.pop_back();
 }
 
-void UnswitchPredicate::openIte(kir::IfThenElse* ite) {
-  FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::openIte");
-
-  // only expand the ite thenBody
-  for (auto expr : ite->thenBody().exprs()) {
-    if (ir_utils::isTVOp(expr) || isTensorIndexOp(expr)) {
-      predicateOn(expr);
-    } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
-      openIte(ite);
-    } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
-      openLoop(for_loop);
-    }
-  }
-}
-
-UnswitchPredicate::UnswitchPredicate(
+UnrollPredicate::UnrollPredicate(
     std::vector<kir::ForLoop*> outer_loops,
-    kir::ForLoop* unrolled_loop)
-    : for_loops_(std::move(outer_loops)) {
+    kir::ForLoop* unrolled_loop,
+    const std::unordered_map<IterDomain*, IterDomain*>& _p2c_root_map)
+    : for_loops_(std::move(outer_loops)), p2c_root_map_(_p2c_root_map) {
   openLoop(unrolled_loop);
 }
 
index 62e925e..2aa681c 100644 (file)
@@ -1,9 +1,33 @@
 #pragma once
 
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
+#include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+
+/*
+ * Predicate compute takes a TensorView and set of indices. The number of
+ * indices and the root of the TensorView are required to have the same number
+ * of dimensions. Predicate compute should be run after index compute, and the
+ * result of index compute should be used for the indices entry.
+ *
+ * A vector of Int values are returned which are the output of the operation
+ * index[i] < get_root(TV)->domain()->axis(i)->size()
+ *
+ * It is assumed that no predicate is required if index[i] is an index directly
+ * from a for loop. This will not catch all cases if we actually have static
+ * size information for example:
+ *
+ * TV[I].split(4)
+ * would produce the code:
+ * for(i : I/4)
+ *   for(j : 4)
+ *     if( i * 4 + j < TV.size(0))
+ *       TV[i * 4 + j]...
+ *
+ * However if we had TV.size[0] = 16 at "compile time" then we wouldn't need the
+ * predicate. However we will still generate: for(i : 4) for(j : 4) if( i * 4 +
+ * j < TV.size(0)) TV[i * 4 + j]...
+ *
+ */
 
 namespace torch {
 namespace jit {
@@ -12,42 +36,42 @@ namespace cuda {
 
 class PredicateCompute {
  public:
-  // ignore_internal_syncthread_ops will prevent creation of predicates on
-  // block/grid broadcast/reduce as these have syncthread calls within them
-  // so all threads need to execute the function.
+  // Return the series of predicates, if an axis doesn't have a predicate
+  // reutrns 1
+  static std::vector<kir::Bool*> computePredicates(
+      const TensorView* tv,
+      const std::vector<Val*>& indices,
+      bool use_rfactor);
+
   static kir::Bool* getInlinePredicate(
-      const kir::Expr* expr,
+      Expr* expr,
       const std::vector<kir::ForLoop*>& loops,
       kir::Bool* thread_pred,
-      PredicateType pred_type);
+      bool ignore_block_grid_reductions = true);
 };
 
-class TORCH_CUDA_CU_API UnswitchPredicate {
+class TORCH_CUDA_CU_API UnrollPredicate {
  public:
   static kir::Bool* get(
       const std::vector<kir::ForLoop*>& outer_loops,
-      kir::ForLoop* unrolled_loop);
+      kir::ForLoop* unrolled_loop,
+      const std::unordered_map<IterDomain*, IterDomain*>& p2c_root_map);
 
  private:
-  UnswitchPredicate(
+  UnrollPredicate(
       std::vector<kir::ForLoop*> outer_loops,
-      kir::ForLoop* unrolled_loop);
+      kir::ForLoop* unrolled_loop,
+      const std::unordered_map<IterDomain*, IterDomain*>& _p2c_root_map);
 
-  void predicateOn(kir::Expr*);
+  void predicateOn(Expr*);
 
   void openLoop(kir::ForLoop*);
 
-  void openIte(kir::IfThenElse*);
-
  private:
-  // Track which iter domains have been predicated, uses concrete_id from
-  // caLoopMap.
-  std::vector<kir::IterDomain*> predicated_iter_dom_;
-
-  // The predicates that have been generated.
-  std::vector<kir::Bool*> predicates_;
-
+  std::unordered_map<IterDomain*, kir::Bool*> predicates_;
   std::vector<kir::ForLoop*> for_loops_;
+
+  const std::unordered_map<IterDomain*, IterDomain*>& p2c_root_map_;
 };
 
 } // namespace cuda
index ce4504d..284ee05 100644 (file)
@@ -19,11 +19,12 @@ class RegisterInterface {
  public:
   RegisterInterface() {
     auto ptr = getFuserInterface();
-    ptr->fn_compile_n = &compileCudaFusionGroup;
-    ptr->fn_run_n_s = &runCudaFusionGroup;
-    ptr->fn_fuse_graph = &CudaFuseGraph;
-    ptr->fn_can_fuse_n = &isFusibleCudaFusionGroup;
-    ptr->fn_insert_profile_inodes = &InsertProfileNodes;
+    ptr->fn_compile_n_ = &compileCudaFusionGroup;
+    ptr->fn_run_n_s_ = &runCudaFusionGroup;
+    ptr->fn_fuse_graph_ = &CudaFuseGraph;
+    ptr->fn_can_fuse_n_ = &isFusableCudaFusionGroup;
+
+    RegisterProfilingNode(canFuseNode);
   }
 };
 
diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp
deleted file mode 100644 (file)
index 3b6a772..0000000
+++ /dev/null
@@ -1,938 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-
-#include <sstream>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-std::unordered_map<IterDomain*, IterDomain*> RootDomainMap::
-    mapProducerToConsumer(
-        const TensorDomain* producer,
-        const TensorDomain* consumer,
-        const std::unordered_set<IterDomain*>& root_dims_to_map) const {
-  return map(producer, consumer, root_dims_to_map, true);
-}
-
-std::unordered_map<IterDomain*, IterDomain*> RootDomainMap::
-    mapProducerToConsumer(
-        const TensorDomain* producer,
-        const TensorDomain* consumer) const {
-  std::unordered_set<IterDomain*> root_dims_to_map(
-      producer->getMaybeRFactorDomain().begin(),
-      producer->getMaybeRFactorDomain().end());
-  return mapProducerToConsumer(producer, consumer, root_dims_to_map);
-}
-
-std::unordered_map<IterDomain*, IterDomain*> RootDomainMap::
-    mapConsumerToProducer(
-        const TensorDomain* consumer,
-        const TensorDomain* producer,
-        const std::unordered_set<IterDomain*>& root_dims_to_map) const {
-  return map(producer, consumer, root_dims_to_map, false);
-}
-
-std::unordered_map<IterDomain*, IterDomain*> RootDomainMap::
-    mapConsumerToProducer(
-        const TensorDomain* consumer,
-        const TensorDomain* producer) const {
-  std::unordered_set<IterDomain*> root_dims_to_map(
-      consumer->getRootDomain().begin(), consumer->getRootDomain().end());
-  return mapConsumerToProducer(consumer, producer, root_dims_to_map);
-}
-
-PairwiseRootDomainMap::PairwiseRootDomainMap(
-    const TensorView* producer,
-    const TensorView* consumer)
-    : producer_tv_(producer), consumer_tv_(consumer) {
-  TORCH_INTERNAL_ASSERT(producer != nullptr);
-  TORCH_INTERNAL_ASSERT(consumer != nullptr);
-  TORCH_INTERNAL_ASSERT(producer->fusion() == consumer->fusion());
-  // Make sure they are really a producer and its consumer
-  TORCH_INTERNAL_ASSERT(
-      producer->isConsumerOf(consumer),
-      "Not a producer-consumer pair: ",
-      producer,
-      ", ",
-      consumer);
-}
-
-std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
-    const TensorDomain* producer,
-    const TensorDomain* consumer,
-    const std::unordered_set<IterDomain*>& root_dims_to_map,
-    bool producer_to_consumer) const {
-  // Sanity check that the given producer and consumer domains are
-  // really the TensorDomains of the producer and consumer TensorViews
-  // given to the constructor.
-  TORCH_INTERNAL_ASSERT(producer_tv_->domain() == producer);
-  TORCH_INTERNAL_ASSERT(consumer_tv_->domain() == consumer);
-
-  if (consumer_tv_->definition()->isA<TransposeOp>()) {
-    return mapTranspose(
-        producer, consumer, root_dims_to_map, producer_to_consumer);
-  }
-
-  std::vector<bool> broadcast_flags;
-  if (BroadcastOp* bop =
-          dynamic_cast<BroadcastOp*>(consumer_tv_->definition())) {
-    broadcast_flags = bop->getBroadcastDimFlags();
-  }
-
-  std::unordered_map<IterDomain*, IterDomain*> dom_map;
-  const auto producer_root =
-      TensorDomain::noReductions(producer->getMaybeRFactorDomain());
-  const auto& consumer_root = consumer->getRootDomain();
-  size_t itc = 0, itp = 0;
-  while (itc < consumer_root.size() && itp < producer_root.size()) {
-    IterDomain* producer_id = producer_root[itp];
-    IterDomain* consumer_id = consumer_root[itc];
-
-    // When the consumer ID is a new broadcast domain, there is no
-    // mapping for it.
-    if (!broadcast_flags.empty() && broadcast_flags.at(itc)) {
-      TORCH_INTERNAL_ASSERT(consumer_id->isBroadcast());
-      itc++;
-      continue;
-    }
-
-    IterDomain* map_key_id = producer_id;
-    IterDomain* map_value_id = consumer_id;
-    if (!producer_to_consumer) {
-      std::swap(map_key_id, map_value_id);
-    }
-
-    if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) {
-      dom_map.insert(std::make_pair(map_key_id, map_value_id));
-    }
-    itc++;
-    itp++;
-  }
-  return dom_map;
-}
-
-std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::
-    mapTranspose(
-        const TensorDomain* producer,
-        const TensorDomain* consumer,
-        const std::unordered_set<IterDomain*>& root_dims_to_map,
-        bool producer_to_consumer) const {
-  const auto producer_root =
-      TensorDomain::noReductions(producer->getMaybeRFactorDomain());
-  const auto& consumer_root = consumer->getRootDomain();
-
-  std::unordered_map<IterDomain*, IterDomain*> dom_map;
-
-  TransposeOp* top = dynamic_cast<TransposeOp*>(consumer_tv_->definition());
-  TORCH_INTERNAL_ASSERT(top != nullptr);
-
-  const auto& new2old = top->new2old();
-  for (size_t i = 0; i < consumer_root.size(); ++i) {
-    IterDomain* map_key_id = producer_root[new2old[i]];
-    IterDomain* map_value_id = consumer_root[i];
-    if (!producer_to_consumer) {
-      std::swap(map_key_id, map_value_id);
-    }
-    if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) {
-      dom_map.insert(std::make_pair(map_key_id, map_value_id));
-    }
-  }
-  return dom_map;
-}
-
-std::string toString(const PairwiseRootDomainMap& root_map) {
-  std::stringstream ss;
-  ss << "{producer: " << root_map.producer()
-     << ", consumer: " << root_map.consumer() << "}";
-  return ss.str();
-}
-
-namespace {
-
-template <typename T>
-auto ensureMapping(
-    T& m,
-    const typename T::key_type& key,
-    const typename T::mapped_type& init_value) {
-  auto it = m.find(key);
-  if (it == m.end()) {
-    it = m.insert({key, init_value}).first;
-  }
-  return it;
-}
-
-} // namespace
-
-std::string toString(const DomainKey& key) {
-  std::stringstream ss;
-  ss << "{";
-  if (key.td()) {
-    ss << key.td() << " (root: " << key.td()->getRootDomain()
-       << ", maybe rfactor: " << key.td()->getMaybeRFactorDomain() << ")";
-  } else {
-    ss << "null";
-  }
-  ss << ", ";
-  if (key.id()) {
-    ss << key.id();
-  } else {
-    ss << "null";
-  }
-  if (key.concreteId()) {
-    ss << " (" << key.concreteId() << ")";
-  }
-  ss << "}";
-  return ss.str();
-}
-
-UnmappableReductionDomains::UnmappableReductionDomains() {
-  Fusion* fusion = FusionGuard::getCurFusion();
-  traverse(fusion);
-}
-
-namespace {
-
-//! Find all domains that a given domain is depeendent on
-class FindInputDomains : BackwardVisitor {
- private:
-  FindInputDomains(TensorView* tv, const IterDomain* id)
-      : BackwardVisitor(false), tv_(tv) {
-    input_keys.insert(DomainKey(tv_->domain(), id));
-  }
-
-  DomainKeySet find() {
-    traverseFrom(tv_->fusion(), {tv_});
-    return input_keys;
-  }
-
-  void handle(Expr* expr) override {
-    for (auto output : expr->outputs()) {
-      if (!output->isA<TensorView>()) {
-        continue;
-      }
-      for (auto input : expr->inputs()) {
-        if (!input->isA<TensorView>()) {
-          continue;
-        }
-        propagate(input->as<TensorView>(), output->as<TensorView>());
-      }
-    }
-  }
-
-  void propagate(TensorView* in_tv, TensorView* out_tv) {
-    auto c2p = PairwiseRootDomainMap(in_tv, out_tv)
-                   .mapConsumerToProducer(out_tv->domain(), in_tv->domain());
-    for (auto root_dom : out_tv->getRootDomain()) {
-      DomainKey out_key({out_tv->domain(), root_dom});
-      if (input_keys.find(out_key) == input_keys.end()) {
-        continue;
-      }
-      auto input_id_it = c2p.find(root_dom);
-      if (input_id_it == c2p.end()) {
-        continue;
-      }
-      DomainKey input_key(in_tv->domain(), input_id_it->second);
-      input_keys.insert(input_key);
-    }
-  }
-
- private:
-  TensorView* tv_ = nullptr;
-  DomainKeySet input_keys;
-
- public:
-  static DomainKeySet find(TensorView* tv, const IterDomain* id) {
-    return FindInputDomains(tv, id).find();
-  }
-};
-
-} // namespace
-
-void UnmappableReductionDomains::handleReductionOutput(TensorView* out_tv) {
-  std::vector<DomainKey> reduction_keys;
-  for (const auto id : out_tv->getRootDomain()) {
-    if (id->isReduction()) {
-      DomainKey key(out_tv->domain(), id);
-      reduction_keys.push_back(key);
-      reduction_domains_.insert({key, {}});
-    }
-  }
-  auto use_chains = DependencyCheck::getAllUseChains(out_tv);
-  for (const auto& chain : use_chains) {
-    for (const auto& tv : ir_utils::filterByType<TensorView>(chain)) {
-      const auto& root_domain = tv->getRootDomain();
-      for (const auto& id : root_domain) {
-        DomainKey consumer_key(tv->domain(), id);
-        for (const auto& reduction_key : reduction_keys) {
-          reduction_domains_.at(reduction_key).insert(consumer_key);
-        }
-      }
-    }
-  }
-  for (const auto& reduction_key : reduction_keys) {
-    reduction_domain_inputs_.insert(
-        {reduction_key, FindInputDomains::find(out_tv, reduction_key.id())});
-  }
-}
-
-void UnmappableReductionDomains::handle(ReductionOp* op) {
-  // Builds a map from reduction domains to consumer domains.
-  TensorView* out_tv = op->out()->as<TensorView>();
-  handleReductionOutput(out_tv);
-}
-
-void UnmappableReductionDomains::handle(WelfordOp* op) {
-  // Builds a map from reduction domains to consumer domains.
-  handleReductionOutput(op->outAvg()->as<TensorView>());
-  handleReductionOutput(op->outVar()->as<TensorView>());
-  handleReductionOutput(op->outN()->as<TensorView>());
-}
-
-bool UnmappableReductionDomains::isReductionOutputMapped(
-    const std::vector<DomainKey>& consumer_domains,
-    const ComputeAtRootDomainMap& root_map) const {
-  for (const auto& kv : reduction_domains_) {
-    const DomainKey& reduction_domain = kv.first;
-    const DomainKeySet& incompatible_domains = kv.second;
-    DomainKey consumer_domain_with_reduction;
-    bool reduction_found = false;
-    const auto& input_keys = reduction_domain_inputs_.at(reduction_domain);
-    for (const DomainKey& consumer_domain : consumer_domains) {
-      for (const auto& input_key : input_keys) {
-        if (input_key == consumer_domain) {
-          consumer_domain_with_reduction = consumer_domain;
-          reduction_found = true;
-          break;
-        }
-      }
-    }
-    if (!reduction_found) {
-      continue;
-    }
-    // Make sure no incompatible domains will be merged with the reduction
-    // domain.
-    for (const auto& consumer_domain : consumer_domains) {
-      if (consumer_domain == consumer_domain_with_reduction) {
-        continue;
-      }
-      if (std::any_of(
-              incompatible_domains.begin(),
-              incompatible_domains.end(),
-              [&](const DomainKey& incompatible_domain) {
-                return root_map.canMap(
-                    consumer_domain.td(),
-                    consumer_domain.id(),
-                    incompatible_domain.td(),
-                    incompatible_domain.id());
-              })) {
-        return true;
-      }
-    }
-  }
-  return false;
-}
-
-void ComputeAtRootDomainMap::build(bool map_through_reduction) {
-  // Make sure we start from scratch. Throw away previous results.
-  eq_set_.clear();
-  bcast_map_.clear();
-  new_broadcast_domains_.clear();
-  ComputeAtRootDomainMapBuilder builder(*this, map_through_reduction);
-}
-
-bool ComputeAtRootDomainMap::canMap(
-    const TensorDomain* td_a,
-    const IterDomain* id_a,
-    const TensorDomain* td_b,
-    const IterDomain* id_b) const {
-  TORCH_INTERNAL_ASSERT(
-      id_a->definition() == nullptr || id_a->isRFactorProduct(),
-      "Non-root domain is not supproted: ",
-      id_a);
-  TORCH_INTERNAL_ASSERT(
-      id_b->definition() == nullptr || id_b->isRFactorProduct(),
-      "Non-root domain is not supproted: ",
-      id_b);
-
-  // Forward to overloaded functions
-  if (!id_a->isBroadcast() && !id_b->isBroadcast()) {
-    return canMap(DomainKey(td_a, id_a), DomainKey(td_b, id_b));
-  } else if (!id_a->isBroadcast()) {
-    return canMap(DomainKey(td_a, id_a), td_b, id_b);
-  } else if (!id_b->isBroadcast()) {
-    return canMap(DomainKey(td_b, id_b), td_a, id_a);
-  }
-
-  // At this point, both are broadcast. Every pair of concrete IDs of
-  // both id_a and id_b needs to be looked at. Whether they are
-  // mappable depends on whether the concrete IDs are broadcast or
-  // not. Note that a broadcast axis is used a concrete ID when it is
-  // part of an output tensor domain, i.e., when it never gets
-  // concretized with any non-broadcast axis.
-
-  // If there exists a pair of non-broadcast concrete IDs is not
-  // mappable, id_a and id_b can't be mapped together. Otherwise, they
-  // can be mapped when there is any mappable pair is found.
-  bool mappable_pair_found = false;
-  for (const auto& key_a : getConcretizedKeys(td_a, id_a)) {
-    for (const auto& key_b : getConcretizedKeys(td_b, id_b)) {
-      const bool mappable = canMap(key_a, key_b);
-      mappable_pair_found = mappable_pair_found || mappable;
-      // If both concrete IDs are not broadcast, they must be
-      // mappable. Also, if either of the concrete IDs is a reduction,
-      // that means a trivial reduction (i.e., broadcast immediately
-      // followed by reduction), which does not prevent any mapping.
-      if (!key_a.concreteId()->isBroadcast() &&
-          !key_b.concreteId()->isBroadcast() &&
-          !key_a.concreteId()->isReduction() &&
-          !key_b.concreteId()->isReduction() && !mappable) {
-        return false;
-      }
-    }
-  }
-
-  return mappable_pair_found;
-}
-
-bool ComputeAtRootDomainMap::canMap(
-    const DomainKey& key_a,
-    const TensorDomain* td_b,
-    const IterDomain* id_b) const {
-  TORCH_INTERNAL_ASSERT(
-      id_b->definition() == nullptr || id_b->isRFactorProduct(),
-      "Non-root domain is not supproted: ",
-      id_b);
-
-  if (!id_b->isBroadcast()) {
-    return canMap(key_a, DomainKey(td_b, id_b));
-  }
-
-  // If id_b is broadcast, look at all the concrete IDs that id_b may
-  // be concretized to. Whether it is mappable with key_a depends on
-  // whether key_a's concrete ID is also broadcast.
-  // 1) key_a's concrete ID is also broadcast: They are mappable when
-  // there is any mappable concrete ID exists in the concrete ID set
-  // of id_b.
-  // 2) key_a's concrete ID is not broadcast: Since key_a is indeed
-  // concrete, it must be mappable with any of concrete ID of id_b,
-  // except when a id_b concrete is broadcast.
-  const bool key_a_bcast =
-      key_a.concreteId() && key_a.concreteId()->isBroadcast();
-  const bool key_a_reduction =
-      (key_a.concreteId() && key_a.concreteId()->isReduction()) ||
-      key_a.id()->isReduction();
-  bool mappable_pair_found = false;
-  for (const auto& key_b : getConcretizedKeys(td_b, id_b)) {
-    const bool mappable = canMap(key_a, key_b);
-    mappable_pair_found = mappable_pair_found || mappable;
-    // If both concrete IDs are not broadcast, they must be mappable.
-    // However, if key_b's concrete ID is a reduction, the concrete ID
-    // is a result of a trivial reduction, so it should not prevent
-    // any other mapping. Similarly, if key_a is a reduction, it just
-    // needs to find any concrete ID of key_b that can be mapped.
-    if (!key_a_bcast && !key_b.concreteId()->isBroadcast() &&
-        !key_b.concreteId()->isReduction() && !key_a_reduction && !mappable) {
-      return false;
-    }
-  }
-
-  return mappable_pair_found;
-}
-
-bool ComputeAtRootDomainMap::canMap(
-    const DomainKey& key_a,
-    const DomainKey& key_b) const {
-  return key_a == key_b || eq_set_.areEquivalent(key_a, key_b);
-}
-
-void ComputeAtRootDomainMap::setAlias(
-    const TensorDomain* td,
-    const TensorDomain* td_alias) {
-  auto tmp_bcast_map = bcast_map_;
-  for (const auto& kv : bcast_map_) {
-    const auto& bcast_map_key = kv.first;
-    const auto& bcast_concrete_id_set = kv.second;
-    if (bcast_map_key.td() == td) {
-      DomainKey alias_key(td_alias, bcast_map_key.id());
-      tmp_bcast_map.insert({alias_key, bcast_concrete_id_set});
-    }
-  }
-  bcast_map_ = tmp_bcast_map;
-
-  for (const auto& key : eq_set_.getAllElements()) {
-    if (key.td() == td) {
-      DomainKey alias_key(td_alias, key.id(), key.concreteId());
-      eq_set_.join(key, alias_key);
-    }
-  }
-
-  auto tmp_new_broadcast_domains = new_broadcast_domains_;
-  for (const auto& key : new_broadcast_domains_) {
-    if (key.td() == td) {
-      DomainKey alias_key(td_alias, key.id());
-      tmp_new_broadcast_domains.insert(alias_key);
-    }
-  }
-  new_broadcast_domains_ = tmp_new_broadcast_domains;
-}
-
-std::vector<DomainKey> ComputeAtRootDomainMap::getConcretizedKeys(
-    const TensorDomain* td,
-    const IterDomain* id) const {
-  DomainKey key(td, id);
-  auto it = bcast_map_.find(key);
-  TORCH_INTERNAL_ASSERT(it != bcast_map_.end(), "Not found: ", toString(key));
-  std::vector<DomainKey> domains;
-  std::transform(
-      it->second.begin(),
-      it->second.end(),
-      std::back_inserter(domains),
-      [&](const IterDomain* concrete_id) {
-        return DomainKey(td, id, concrete_id);
-      });
-  return domains;
-}
-
-std::unordered_set<const IterDomain*>& ComputeAtRootDomainMap::
-    getConcretizedDomains(const TensorDomain* td, const IterDomain* id) {
-  DomainKey key(td, id);
-  auto it = bcast_map_.find(key);
-  TORCH_INTERNAL_ASSERT(it != bcast_map_.end(), "Not found: ", toString(key));
-  return it->second;
-}
-
-std::unordered_map<IterDomain*, IterDomain*> ComputeAtRootDomainMap::
-    mapBestEffort(
-        const TensorDomain* from_td,
-        const std::vector<IterDomain*>& from_root,
-        const TensorDomain* to_td,
-        const std::vector<IterDomain*>& to_root) const {
-  std::unordered_map<IterDomain*, IterDomain*> id_map;
-  for (auto& from_id : from_root) {
-    for (const auto& to_id : to_root) {
-      if (canMap(from_td, from_id, to_td, to_id)) {
-        TORCH_INTERNAL_ASSERT(
-            id_map.insert({from_id, to_id}).second,
-            "Multiple matching ID detected for ",
-            from_id);
-      }
-    }
-  }
-  return id_map;
-}
-
-std::unordered_map<IterDomain*, IterDomain*> ComputeAtRootDomainMap::map(
-    const TensorDomain* producer,
-    const TensorDomain* consumer,
-    const std::unordered_set<IterDomain*>& root_dims_to_map,
-    bool producer_to_consumer) const {
-  const auto& producer_root = producer->getMaybeRFactorDomain();
-  const auto& consumer_root = consumer->getRootDomain();
-  const TensorDomain* from_td = producer_to_consumer ? producer : consumer;
-  const TensorDomain* to_td = producer_to_consumer ? consumer : producer;
-  const auto& from_ids = producer_to_consumer ? producer_root : consumer_root;
-  const auto& to_ids = producer_to_consumer ? consumer_root : producer_root;
-  std::unordered_map<IterDomain*, IterDomain*> id_map =
-      mapBestEffort(from_td, from_ids, to_td, to_ids);
-  for (auto& from_id : from_ids) {
-    if (root_dims_to_map.find(from_id) == root_dims_to_map.end()) {
-      // Remove mapping if exists
-      id_map.erase(from_id);
-      continue;
-    }
-    if (id_map.find(from_id) != id_map.end()) {
-      continue;
-    }
-    // Matching ID not found. It's an error unless: from_id is
-    // reduction of a producer domain; from_id is a new broadcast of a
-    // consumer domain; or from_id is a window axis of a consumer
-    // domain.
-    if ((producer_to_consumer && from_id->isReduction()) ||
-        (!producer_to_consumer &&
-         (new_broadcast_domains_.find(DomainKey(from_td, from_id)) !=
-              new_broadcast_domains_.end() ||
-          (window_axes_.count(from_id) > 0)))) {
-      continue;
-    }
-    TORCH_INTERNAL_ASSERT(
-        false,
-        "Mapping IterDomain ",
-        from_id,
-        " of ",
-        from_td,
-        " not possible as it would require recomputing the source tensor.",
-        " Producer root: ",
-        producer_root,
-        ". Consumer root: ",
-        consumer_root,
-        ". Mapping: ",
-        toString(*this));
-  }
-  return id_map;
-}
-
-std::unordered_set<IterDomain*> ComputeAtRootDomainMap::getMappableDims(
-    const TensorDomain* producer,
-    const TensorDomain* consumer) const {
-  const auto& producer_root = producer->getMaybeRFactorDomain();
-  const auto& consumer_root = consumer->getRootDomain();
-
-  std::unordered_map<IterDomain*, IterDomain*> id_map =
-      mapBestEffort(producer, producer_root, consumer, consumer_root);
-
-  std::unordered_set<IterDomain*> mappable_ids;
-
-  for (auto& from_id : producer_root) {
-    if (id_map.find(from_id) != id_map.end()) {
-      mappable_ids.emplace(from_id);
-      mappable_ids.emplace(id_map.at(from_id));
-    }
-  }
-  return mappable_ids;
-}
-
-std::string toString(const ComputeAtRootDomainMap& root_map) {
-  std::stringstream ss;
-  root_map.eq_set_.print(ss);
-  return ss.str();
-}
-
-ComputeAtRootDomainMapBuilder::ComputeAtRootDomainMapBuilder(
-    ComputeAtRootDomainMap& root_map,
-    bool map_through_reduction)
-    : BackwardVisitor(false),
-      root_map_(root_map),
-      map_through_reduction_(map_through_reduction) {
-  Fusion* fusion = FusionGuard::getCurFusion();
-  TORCH_INTERNAL_ASSERT(fusion != nullptr);
-  traverseFrom(fusion, fusion->outputs(), false);
-  if (!pending_map_.empty()) {
-    std::stringstream ss;
-    ss << "pending map:\n";
-    for (auto& kv : pending_map_) {
-      ss << "\t" << toString(kv.first) << "\n";
-      for (auto& dk : kv.second) {
-        ss << "\t\t" << toString(dk) << "\n";
-      }
-    }
-    std::cerr << ss.str();
-  }
-  TORCH_INTERNAL_ASSERT(pending_map_.empty());
-}
-
-// Set concrete domains for broadcast domains that never get joined
-// with a concrete domain. Just set its own domain as a concrete
-// domain, which is not concrete but is sufficient for this analysis.
-void ComputeAtRootDomainMapBuilder::initializeBcastMap(
-    const TensorView* tv,
-    const IterDomain* id) {
-  TORCH_INTERNAL_ASSERT(id->isBroadcast(), "Not a broadcast axis");
-  auto key = DomainKey(tv->domain(), id);
-  auto it = root_map_.bcast_map_.find(key);
-  if (it != root_map_.bcast_map_.end()) {
-    // already initialized.
-    return;
-  }
-
-  // This initialization should be only used for fusion output tensors and
-  // outputs of multi-consumer expressions that are not fusion outputs.
-  TORCH_INTERNAL_ASSERT(
-      tv->isFusionOutput() || tv->definition()->outputs().size() > 1,
-      "Invalid tensor to initialize bcast map: t",
-      tv->name());
-  root_map_.bcast_map_.insert({key, {id}});
-}
-
-void ComputeAtRootDomainMapBuilder::addToPendingList(
-    const DomainKey& producer,
-    const DomainKey& consumer) {
-  auto it = ensureMapping(pending_map_, producer, {});
-  auto& consumer_set = it->second;
-  consumer_set.insert(consumer);
-}
-
-void ComputeAtRootDomainMapBuilder::setMapped(
-    const DomainKey& producer,
-    const DomainKey& consumer) {
-  root_map_.eq_set_.join(producer, consumer);
-}
-
-void ComputeAtRootDomainMapBuilder::setMaybeMapped(
-    const TensorDomain* producer_td,
-    const IterDomain* producer_id,
-    const TensorDomain* consumer_td,
-    const IterDomain* consumer_id) {
-  const DomainKey producer_key(producer_td, producer_id);
-  const DomainKey consumer_key(consumer_td, consumer_id);
-
-  if (producer_id->isBroadcast()) {
-    ensureMapping(root_map_.bcast_map_, producer_key, {});
-  }
-
-  if (consumer_id->isBroadcast()) {
-    TORCH_INTERNAL_ASSERT(producer_id->isBroadcast());
-    // Get bcast_map_ entry for consumer_id
-    const auto consumer_bcast_domains =
-        root_map_.getConcretizedKeys(consumer_td, consumer_id);
-    auto& producer_domains =
-        root_map_.getConcretizedDomains(producer_td, producer_id);
-
-    // If consumer id is broadcasted, make sure to propagate its concrete_id(s)
-    // to producer
-    for (const auto& consumer_bcast_key : consumer_bcast_domains) {
-      const auto concrete_id = consumer_bcast_key.concreteId();
-      const DomainKey producer_bcast_key(producer_td, producer_id, concrete_id);
-      producer_domains.insert(concrete_id);
-      addToPendingList(producer_bcast_key, consumer_bcast_key);
-    }
-  } else {
-    TORCH_INTERNAL_ASSERT(
-        !consumer_id->isBroadcast(),
-        "No concrete domain found for a broadcast domain: ",
-        toString(consumer_key));
-    auto producer_concrete_key = producer_key;
-    if (producer_id->isBroadcast()) {
-      const auto concrete_id = consumer_id;
-      auto& producer_domains =
-          root_map_.getConcretizedDomains(producer_td, producer_id);
-      producer_concrete_key = DomainKey(producer_td, producer_id, concrete_id);
-      producer_domains.insert(concrete_id);
-    }
-    addToPendingList(producer_concrete_key, consumer_key);
-  }
-}
-
-void ComputeAtRootDomainMapBuilder::handle(Expr* e) {
-  // Avoid visiting expressions multiple times
-  if (visited_.find(e) != visited_.end()) {
-    return;
-  }
-  BackwardVisitor::handle(e);
-  visited_.insert(e);
-}
-
-void ComputeAtRootDomainMapBuilder::mapPointwiseOrReductionOp(Expr* e) {
-  if (e->output(0)->getValType() != ValType::TensorView) {
-    return;
-  }
-
-  // Broadcast is handled separately, so e should never be BroadcastOp.
-  TORCH_INTERNAL_ASSERT(e->getExprType() != ExprType::BroadcastOp);
-
-  TORCH_INTERNAL_ASSERT(e->outputs().size() >= 1);
-  const TensorView* out_tv = e->output(0)->as<TensorView>();
-  const TensorDomain* out_td = out_tv->domain();
-  const auto& out_root = out_td->getRootDomain();
-
-  // Record equalities from output to all the inputs
-  // ignores un-concretizable broadcasts
-  for (auto* i : ir_utils::filterByType<TensorView>(e->inputs())) {
-    const TensorDomain* in_td = i->domain();
-    std::vector<IterDomain*> in_root =
-        TensorDomain::noReductions(i->getMaybeRFactorDomain());
-    TORCH_INTERNAL_ASSERT(
-        in_root.size() == out_root.size(),
-        "\nExpression: ",
-        e,
-        "\nInput root domain: ",
-        in_root,
-        "\nOutput root domain: ",
-        out_root);
-    for (size_t it = 0; it < in_root.size(); it++) {
-      if (e->outputs().size() > 1) {
-        TORCH_INTERNAL_ASSERT(
-            e->isA<WelfordOp>(), "Only supported multioutput op is welford");
-        for (auto o : e->outputs()) {
-          auto o_tv = o->as<TensorView>();
-          auto o_td = o_tv->domain();
-          auto o_root = o_td->getRootDomain();
-          setMaybeMapped(in_td, in_root[it], o_td, o_root[it]);
-        }
-      } else {
-        setMaybeMapped(in_td, in_root[it], out_td, out_root[it]);
-      }
-    }
-  }
-}
-
-void ComputeAtRootDomainMapBuilder::handle(BroadcastOp* op) {
-  const TensorDomain* in_td = op->in()->as<TensorView>()->domain();
-  const TensorDomain* out_td = op->out()->as<TensorView>()->domain();
-  const auto in_root = TensorDomain::noReductions(in_td->getRootDomain());
-  const auto& out_root = out_td->getRootDomain();
-  const auto& bcast_dim_flags = op->getBroadcastDimFlags();
-  TORCH_INTERNAL_ASSERT(
-      out_root.size() == bcast_dim_flags.size(),
-      "dim flags: ",
-      bcast_dim_flags,
-      ", out root: ",
-      out_root);
-  auto in_it = in_root.begin();
-  auto out_it = out_root.begin();
-  while (in_it != in_root.end() && out_it != out_root.end()) {
-    if (bcast_dim_flags.at(std::distance(out_root.begin(), out_it))) {
-      // new broadcast dim. No matching dimension in the input
-      // tensor.
-      root_map_.new_broadcast_domains_.insert(DomainKey(out_td, *out_it));
-      ++out_it;
-      continue;
-    }
-    setMaybeMapped(in_td, *in_it, out_td, *out_it);
-    ++in_it;
-    ++out_it;
-  }
-  // At this point, the input domain should have been scanned
-  // entirely.
-  TORCH_INTERNAL_ASSERT(
-      in_it == in_root.end(),
-      "Unmatched domain detected: ",
-      *in_it,
-      " of ",
-      in_td);
-  // On the other hand, the output may still have some domains left,
-  // and they must be new broadcast domains.
-  for (; out_it != out_root.end(); ++out_it) {
-    TORCH_INTERNAL_ASSERT(
-        bcast_dim_flags.at(std::distance(out_root.begin(), out_it)),
-        "Unmatched domain detected: ",
-        *out_it,
-        " of ",
-        out_td);
-    root_map_.new_broadcast_domains_.insert(DomainKey(out_td, *out_it));
-  }
-}
-
-void ComputeAtRootDomainMapBuilder::handle(TransposeOp* op) {
-  const TensorDomain* in_td = op->in()->as<TensorView>()->domain();
-  std::vector<IterDomain*> in_root =
-      TensorDomain::noReductions(in_td->getRootDomain());
-
-  const TensorDomain* out_td = op->out()->as<TensorView>()->domain();
-  const auto& out_root = out_td->getRootDomain();
-
-  TORCH_INTERNAL_ASSERT(in_root.size() == out_root.size());
-
-  const auto& new2old = op->new2old();
-
-  for (size_t it = 0; it < out_root.size(); it++) {
-    setMaybeMapped(in_td, in_root[new2old[it]], out_td, out_root[it]);
-  }
-}
-
-void ComputeAtRootDomainMapBuilder::handle(GatherOp* op) {
-  const TensorDomain* in_td = op->in()->as<TensorView>()->domain();
-  const TensorDomain* out_td = op->out()->as<TensorView>()->domain();
-  const auto in_root = TensorDomain::noReductions(in_td->getRootDomain());
-  const auto& out_root = out_td->getRootDomain();
-
-  // Only maps the input root axes. Do not map the new window axes.
-  for (size_t it = 0; it < in_root.size(); it++) {
-    setMaybeMapped(in_td, in_root[it], out_td, out_root[it]);
-  }
-
-  // Keep track of window axes so that they can be skipped when
-  // mapping root domains
-  for (size_t it = in_root.size(); it < out_root.size(); it++) {
-    root_map_.window_axes_.insert(out_root[it]);
-  }
-}
-
-bool ComputeAtRootDomainMapBuilder::mapAllConsumers(
-    const DomainKey& producer_key) {
-  auto it = pending_map_.find(producer_key);
-  if (it == pending_map_.end()) {
-    return false;
-  }
-  const auto& consumer_set = it->second;
-  // All entries in key_set must be equivalent with each other.
-  TORCH_INTERNAL_ASSERT(consumer_set.size() > 0);
-  bool consistent = safeToMap(consumer_set);
-  if (consistent) {
-    for (const auto pending_consumer : consumer_set) {
-      setMapped(producer_key, pending_consumer);
-    }
-  }
-  // This entry should never be used again, so remove it.
-  pending_map_.erase(it);
-  return consistent;
-}
-
-void ComputeAtRootDomainMapBuilder::handle(TensorView* tv) {
-  const TensorDomain* td = tv->domain();
-  const auto root = TensorDomain::noReductions(td->getMaybeRFactorDomain());
-  for (auto id : root) {
-    if (id->isBroadcast()) {
-      initializeBcastMap(tv, id);
-      for (const auto& key : root_map_.getConcretizedKeys(td, id)) {
-        mapAllConsumers(key);
-      }
-    } else {
-      mapAllConsumers(DomainKey(td, id));
-    }
-  }
-}
-
-// Checks whether all consumers of a producer can be joined without
-// introducing unsupported mappings. Specifically, if a domain of a
-// consumer has a mapped iteration domain in another consumer that
-// does not correspond to the same producer iteration domain, mapping
-// the consumer domains would result in the producer iteration domain
-// mapped to two different consumer iteration domains, requiring
-// recomputations.
-bool ComputeAtRootDomainMapBuilder::hasMatchingDomains(
-    const std::vector<DomainKey>& unique_domains) {
-  for (const auto& key : unique_domains) {
-    for (const auto& other_key : unique_domains) {
-      if (key == other_key) {
-        continue;
-      }
-      const auto& other_root = other_key.td()->getRootDomain();
-      if (std::any_of(
-              other_root.begin(), other_root.end(), [&](const IterDomain* id) {
-                return root_map_.canMap(key, other_key.td(), id);
-              })) {
-        return true;
-      }
-    }
-  }
-  return false;
-}
-
-// Checks whether all consumers of a producer can be joined without
-// introducing unsupported mappings, i.e., requiring recomputations.
-bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) {
-  if (domains.size() <= 1) {
-    return true;
-  }
-  // Filter out equivalent domains
-  std::vector<DomainKey> unique_domains;
-  for (const auto& domain : domains) {
-    if (std::none_of(
-            unique_domains.begin(),
-            unique_domains.end(),
-            [&](const auto& unique_dom) {
-              return root_map_.canMap(domain, unique_dom);
-            })) {
-      unique_domains.push_back(domain);
-    }
-  }
-  if (hasMatchingDomains(unique_domains)) {
-    return false;
-  }
-  // Can't map if reduction output domains would be mapped
-  if (incompatible_domains_.isReductionOutputMapped(
-          unique_domains, root_map_) &&
-      !map_through_reduction_) {
-    return false;
-  }
-  return true;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.h b/torch/csrc/jit/codegen/cuda/root_domain_map.h
deleted file mode 100644 (file)
index dbc16c3..0000000
+++ /dev/null
@@ -1,427 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/utils.h>
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Generic interface for mapping root domains of a producer-consumer pair.
-class TORCH_CUDA_CU_API RootDomainMap : public PolymorphicBase {
- public:
-  //! Return a map from a producer TensorDomain to a consumer
-  //! TensorDomain
-  //!
-  //! \param producer A producer TensorDomain
-  //! \param consumer A consumer TensorDomain
-  //! \param root_dims_to_map Maps only producer root domains in this set
-  std::unordered_map<IterDomain*, IterDomain*> mapProducerToConsumer(
-      const TensorDomain* producer,
-      const TensorDomain* consumer,
-      const std::unordered_set<IterDomain*>& root_dims_to_map) const;
-
-  //! Return a map from a producer TensorDomain to a consumer
-  //! TensorDomain
-  //!
-  //! \param producer A producer TensorDomain
-  //! \param consumer A consumer TensorDomain
-  std::unordered_map<IterDomain*, IterDomain*> mapProducerToConsumer(
-      const TensorDomain* producer,
-      const TensorDomain* consumer) const;
-
-  //! Return a map from a consumer TensorDomain to a producer
-  //! TensorDomain
-  //!
-  //! \param consumer A consumer TensorDomain
-  //! \param producer A producer TensorDomain
-  //! \param root_dims_to_map Maps only consumer root domains in this set
-  std::unordered_map<IterDomain*, IterDomain*> mapConsumerToProducer(
-      const TensorDomain* consumer,
-      const TensorDomain* producer,
-      const std::unordered_set<IterDomain*>& root_dims_to_map) const;
-
-  //! Return a map from a consumer TensorDomain to a producer
-  //! TensorDomain
-  //!
-  //! \param consumer A consumer TensorDomain
-  //! \param producer A producer TensorDomain
-  std::unordered_map<IterDomain*, IterDomain*> mapConsumerToProducer(
-      const TensorDomain* consumer,
-      const TensorDomain* producer) const;
-
- protected:
-  //! Return a map between root IterDomains of a producer-consumer
-  //! pair.
-  //!
-  //! \param producer A producer TensorDomain
-  //! \param consumer A consumer TensorDomain
-  //! \param root_dims_to_map Maps only from IterDomains in this set
-  //! \param producer_to_consumer Maps from producer to consumer if true
-  virtual std::unordered_map<IterDomain*, IterDomain*> map(
-      const TensorDomain* producer,
-      const TensorDomain* consumer,
-      const std::unordered_set<IterDomain*>& root_dims_to_map,
-      bool producer_to_consumer) const = 0;
-};
-
-//! Maps root domains of a producer-consumer pair. This class only
-//! looks at the given pair of TensorViews and does not take into
-//! consideration the constraints of the computeAt transformation,
-//! i.e., unable to compute the same tensors multiple times. This
-//! should not be used for transformations implementing computeAt, but
-//! should be valid otherwise.
-class TORCH_CUDA_CU_API PairwiseRootDomainMap : public RootDomainMap {
- public:
-  //! \param producer The producer tensor of a producer-consumer pair.
-  //! \param consumer The consumer tensor of a producer-consumer pair.
-  explicit PairwiseRootDomainMap(
-      const TensorView* producer,
-      const TensorView* consumer);
-
-  const TensorView* producer() const {
-    return producer_tv_;
-  }
-
-  const TensorView* consumer() const {
-    return consumer_tv_;
-  }
-
- protected:
-  std::unordered_map<IterDomain*, IterDomain*> map(
-      const TensorDomain* producer,
-      const TensorDomain* consumer,
-      const std::unordered_set<IterDomain*>& root_dims_to_map,
-      bool producer_to_consumer) const override;
-
-  std::unordered_map<IterDomain*, IterDomain*> mapTranspose(
-      const TensorDomain* producer,
-      const TensorDomain* consumer,
-      const std::unordered_set<IterDomain*>& root_dims_to_map,
-      bool producer_to_consumer) const;
-
- private:
-  const TensorView* producer_tv_ = nullptr;
-  const TensorView* consumer_tv_ = nullptr;
-};
-
-std::string toString(const PairwiseRootDomainMap& root_map);
-
-//! Represents an iteration domain of a TensorDomain. Only used for
-//! root domain mapping.
-//!
-//! Note that an IterDomain object may be reused
-//! across multiple TensorDomains, but an IterDomain in a
-//! TensorDomain may not be necessarily mappable to the same
-//! IterDomain used in a different TensorDomain. Thus, for the purpose
-//! of root domain mapping, an iteration domain needs to be identified
-//! with an IterDomain and its TensorDomain.
-class DomainKey {
- public:
-  DomainKey() = default;
-  DomainKey(
-      const TensorDomain* td,
-      const IterDomain* id,
-      const IterDomain* concrete_id = nullptr)
-      : td_(td), id_(id), concrete_id_(concrete_id) {}
-  const TensorDomain* td() const {
-    return td_;
-  }
-  const IterDomain* id() const {
-    return id_;
-  }
-  const IterDomain* concreteId() const {
-    return concrete_id_;
-  }
-  bool operator==(const DomainKey& other) const {
-    return td() == other.td() && id() == other.id() &&
-        concreteId() == other.concreteId();
-  }
-
- private:
-  const TensorDomain* td_ = nullptr;
-  const IterDomain* id_ = nullptr;
-  const IterDomain* concrete_id_ = nullptr;
-};
-
-std::string toString(const DomainKey& key);
-
-struct DomainKeyHash {
-  std::size_t operator()(const DomainKey& key) const {
-    return std::hash<const TensorDomain*>{}(key.td()) ^
-        std::hash<const IterDomain*>{}(key.id());
-  }
-};
-
-using DomainKeySet = std::unordered_set<DomainKey, DomainKeyHash>;
-
-template <typename Mapped>
-using DomainKeyMap = std::unordered_map<DomainKey, Mapped, DomainKeyHash>;
-
-class ComputeAtRootDomainMap;
-
-//! A helper class to find all DomainKeys that are consumers of
-//! reduction outputs. Such consumer IterDomains may not be mapped to
-//! the producer reduction domain since the corresponding reduction
-//! loop must be closed before any of the consumers can appear.
-class TORCH_CUDA_CU_API UnmappableReductionDomains : private IterVisitor {
- public:
-  UnmappableReductionDomains();
-  virtual ~UnmappableReductionDomains() = default;
-
-  //! Returns true when mapping consumer domains would cause a
-  //! reduction output domain to be mapped with a consumer domain of
-  //! the redution. It needs to be avoided as computing consumers of
-  //! reduction outputs within the corresponding reduction loop is not
-  //! possible. This routine is used to build root domain mappings.
-  bool isReductionOutputMapped(
-      const std::vector<DomainKey>& consumer_domains,
-      const ComputeAtRootDomainMap& root_map) const;
-
- private:
-  using IterVisitor::handle;
-  void handle(ReductionOp* op) override;
-  void handle(WelfordOp* op) override;
-
-  void handleReductionOutput(TensorView* out_tv);
-
- private:
-  //! Map from Reduction output DomainKeys to consumer DomainKeys
-  DomainKeyMap<DomainKeySet> reduction_domains_;
-  //! Map from Reduction output DomainKeys to producer DomainKeys
-  DomainKeyMap<DomainKeySet> reduction_domain_inputs_;
-};
-
-//! Models root-domain mappings for computeAt
-//!
-//! Two iteration domains are mapped when computeAt of one iteration
-//! domain is possible at another iteration domain. Consider a simple
-//! example:
-//!    T2 [i0,i1] = T1[i2,i3] + T0[i4,i5]
-//! This will create mappings between i0, i2 and i4.
-class TORCH_CUDA_CU_API ComputeAtRootDomainMap : public RootDomainMap {
-  friend class ComputeAtRootDomainMapBuilder;
-  friend std::string toString(const ComputeAtRootDomainMap&);
-
- public:
-  //! Builds a mapping table by analyzing the current
-  //! fusion. Overwrite a previous table if any.
-  //!
-  //! \param map_through_reduction If set
-  //!   true, will disable UnmappableReductionDomains check.
-  //!   This is only for re-using logic in detecting
-  //!   normalization fusions, which deviates slightly from
-  //!   intended use of this class. Should always be true
-  //!   in compute_at use cases.
-  void build(bool map_through_reduction = false);
-
-  //! Returns if key(td_a, id_a) and key(td_b, id_b) are mapped to eachother
-  //! (equivalent), or are the same key.
-  //!
-  //! \param td_a A TensorDomain
-  //! \param id_a An IterDomain in td_a
-  //! \param td_b Another TensorDomain
-  //! \param id_b An IterDomain in td_b
-  //! \returns Boolean representing if they are mapped
-  bool canMap(
-      const TensorDomain* td_a,
-      const IterDomain* id_a,
-      const TensorDomain* td_b,
-      const IterDomain* id_b) const;
-
-  //! Make a TensorDomain an alias of another TensorDomain
-  //!
-  //! This is for the computeAt transformation, where TensorViews are
-  //! updated with new TensorDomains. Since they keep using the same
-  //! root doamins, the root mapping remains valid but needs to
-  //! reflect the use of new TensorDomains as aliases of the existing
-  //! ones.
-  //!
-  //! \param td An existing TensorDomain
-  //! \param td_alias An alias of td
-  void setAlias(const TensorDomain* td, const TensorDomain* td_alias);
-
-  //! Return a map between TensorDomains
-  //!
-  //! Unlike the other map functions, two TensorDomains do not need to
-  //! be a producer-consumer pair. Since they may not be a
-  //! producer-consumer pair, this function requires proper root
-  //! domains, which may be root or rfactor domains. Also, no error
-  //! check is done as we do not assume producer-consumer relationship.
-  //!
-  //! \param from_td A TensorDomain from which a map is created
-  //! \param from_root A root domain of from_td
-  //! \param to_td A TensorDomain to which a map is created
-  //! \param to_root A root domain of to_td
-  std::unordered_map<IterDomain*, IterDomain*> mapBestEffort(
-      const TensorDomain* from_td,
-      const std::vector<IterDomain*>& from_root,
-      const TensorDomain* to_td,
-      const std::vector<IterDomain*>& to_root) const;
-
-  // Returns an unordered set of all iter domains in producer and consumer that
-  // can map to eachother
-  std::unordered_set<IterDomain*> getMappableDims(
-      const TensorDomain* producer,
-      const TensorDomain* consumer) const;
-
- private:
-  //! Returns if key_a and key(td_b, id_b) are mapped to eachother (equivalent),
-  //! or are the same key.
-  //!
-  //! \param key_a A DomainKey
-  //! \param td_b Another TensorDomain
-  //! \param id_b An IterDomain in td_b
-  //! \returns Boolean representing if they are mapped
-  bool canMap(
-      const DomainKey& key_a,
-      const TensorDomain* td_b,
-      const IterDomain* id_b) const;
-
-  //! Returns if key_a and key_b are mapped to eachother (equivalent), or are
-  //! the same key.
-  bool canMap(const DomainKey& key_a, const DomainKey& key_b) const;
-
-  //! Returns the set of (non-broadcast) DomainKeys that id in td is
-  //! broadcasted to. Can result in more than one "concrete" DomainKey.
-  std::vector<DomainKey> getConcretizedKeys(
-      const TensorDomain* td,
-      const IterDomain* id) const;
-
-  //! Returns the set of (non-broadcast) iter domains that id in td is
-  //! broadcasted to. Can result in more than one "concrete" iter domain.
-  std::unordered_set<const IterDomain*>& getConcretizedDomains(
-      const TensorDomain* td,
-      const IterDomain* id);
-
-  //! Return a map between root IterDomains of a producer-consumer
-  //! pair.
-  //!
-  //! \param producer A producer TensorDomain
-  //! \param consumer A consumer TensorDomain
-  //! \param root_dims_to_map Maps only from IterDomains in this set
-  //! \param producer_to_consumer Maps from producer to consumer if true
-  std::unordered_map<IterDomain*, IterDomain*> map(
-      const TensorDomain* producer,
-      const TensorDomain* consumer,
-      const std::unordered_set<IterDomain*>& root_dims_to_map,
-      bool producer_to_consumer) const override;
-
- private:
-  //! Disjoint set of all mapped <TD, ID> keys to determine axes equivalency
-  DisjointSet<DomainKey, DomainKeyHash> eq_set_;
-
-  //! All IterDomains in the mapping that are a broadcast ID
-  DomainKeyMap<std::unordered_set<const IterDomain*>> bcast_map_;
-
-  //! Broadcast iter domain that does not match dimensions in its produer,
-  //! meaning it is a brand new domain in its TensorDomain.
-  DomainKeySet new_broadcast_domains_;
-
-  //! Keep track of window axes so that the map function can ignore them.
-  std::unordered_set<IterDomain*> window_axes_;
-};
-
-std::string toString(const ComputeAtRootDomainMap& root_map);
-
-//! Create a DisjointSet of root IterDomains by traversing the
-//! current fusion entirely. IterDomains that can be mapped each
-//! other with computeAt are grouped into the same subset in the
-//! DisjointSet.
-class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder
-    : private BackwardVisitor {
- public:
-  explicit ComputeAtRootDomainMapBuilder(
-      ComputeAtRootDomainMap& root_map,
-      bool map_through_reduction = false);
-
- private:
-  //! Initialize the bcast map for fusion outputs
-  void initializeBcastMap(const TensorView* tv, const IterDomain* id);
-
-  //! Set a pair of producer-consumer domain keys as mappable
-  void setMapped(const DomainKey& producer, const DomainKey& consumer);
-
-  //! Track a pair of producer-consumer domains as potentially mappable. Inserts
-  //! entries into pending_map_, but does not add anything into the root_map_
-  //! (added when handle is called on a TensorView). Maybe mapped will, however,
-  //! immediately propagate broadcast iter domains.
-  void setMaybeMapped(
-      const TensorDomain* producer_td,
-      const IterDomain* producer_id,
-      const TensorDomain* consumer_td,
-      const IterDomain* consumer_id);
-
-  void addToPendingList(const DomainKey& producer, const DomainKey& consumer);
-
-  //! Map pointwise IterDomains from inputs of expressions to outputs.
-  //! Do not map reduction IterDomains in inputs.
-  void mapPointwiseOrReductionOp(Expr* e);
-
-  using BackwardVisitor::handle;
-
-  void handle(Expr* e) override;
-
-  void handle(UnaryOp* uop) override {
-    mapPointwiseOrReductionOp(uop);
-  }
-
-  void handle(BinaryOp* bop) override {
-    mapPointwiseOrReductionOp(bop);
-  }
-
-  void handle(TernaryOp* top) override {
-    mapPointwiseOrReductionOp(top);
-  }
-
-  void handle(ReductionOp* op) override {
-    mapPointwiseOrReductionOp(op);
-  }
-
-  void handle(WelfordOp* wop) override {
-    mapPointwiseOrReductionOp(wop);
-  }
-
-  void handle(ShiftOp* op) override {
-    mapPointwiseOrReductionOp(op);
-  }
-
-  void handle(BroadcastOp* op) override;
-
-  void handle(TransposeOp* op) override;
-
-  void handle(GatherOp* op) override;
-
-  void handle(TensorView* tv) override;
-
-  //! Maps all consumers with a producer.
-  //! This is called for each of TensorViews in a backward traversal,
-  //! recursively building mappings from the output tensors to the
-  //! input tensors.
-  bool mapAllConsumers(const DomainKey& producer_key);
-
-  bool hasMatchingDomains(const std::vector<DomainKey>& unique_domains);
-
-  bool safeToMap(const DomainKeySet& domains);
-
- private:
-  ComputeAtRootDomainMap& root_map_;
-  //! Keep track of what we want to try and map. Set in attemptToProveId.
-  DomainKeyMap<DomainKeySet> pending_map_;
-  std::unordered_set<Expr*> visited_;
-  UnmappableReductionDomains incompatible_domains_;
-
-  //! Disable UnmappableReductions check, should
-  //!  always be false for compute_at use cases
-  bool map_through_reduction_ = false;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
index 899f75e..480a99e 100644 (file)
@@ -23,13 +23,12 @@ template <
     typename _dim3bd>
 __device__ void blockReduce(
     T& out,
-    const T& inp_val,
+    const T inp_val,
     Func reduction_op,
     const _dim3ti& thread_idx,
     const _dim3bd& block_dim,
     T* shared_mem,
-    bool read_pred,
-    bool write_pred,
+    bool read_write_pred,
     T init_val) {
   unsigned int reduction_size = (X_REDUCE ? block_dim.x : 1) *
       (Y_REDUCE ? block_dim.y : 1) * (Z_REDUCE ? block_dim.z : 1);
@@ -73,12 +72,12 @@ __device__ void blockReduce(
 
   assert(reduction_stride != 0);
 
-  if (read_pred) {
+  if (read_write_pred) {
     shared_mem[linear_tid] = inp_val;
   } else {
     shared_mem[linear_tid] = init_val;
   }
-  block_sync::sync();
+  __syncthreads();
   // Reduce down to nearest power of 2:
   int np2 = 1 << (31 - __clz(reduction_size));
 
@@ -89,54 +88,17 @@ __device__ void blockReduce(
           shared_mem[linear_tid + np2 * reduction_stride]);
     }
   }
-  block_sync::sync();
-  // loop peel the final iteration to save one syncthread for the end
-  for (int factor = np2 / 2; factor > 1; factor >>= 1) {
+  __syncthreads();
+  // for (int factor = np2/2; factor > contig_threads / 2; factor>>=1) {
+  for (int factor = np2 / 2; factor > 0; factor >>= 1) {
     if (reduction_tid < factor) {
       reduction_op(
           shared_mem[linear_tid],
           shared_mem[linear_tid + factor * reduction_stride]);
     }
-    block_sync::sync();
-  }
-
-  if (should_write && write_pred) {
-    T result = out;
-    reduction_op(result, shared_mem[linear_tid]);
-    if (reduction_size > 1) {
-      reduction_op(result, shared_mem[linear_tid + 1 * reduction_stride]);
-    }
-    out = result;
+    __syncthreads();
   }
-  block_sync::sync();
-}
 
-// Use the same pred for both reads and writes
-template <
-    bool X_REDUCE,
-    bool Y_REDUCE,
-    bool Z_REDUCE,
-    typename T,
-    typename Func,
-    typename _dim3ti,
-    typename _dim3bd>
-__device__ void blockReduce(
-    T& out,
-    const T& inp_val,
-    Func reduction_op,
-    const _dim3ti& thread_idx,
-    const _dim3bd& block_dim,
-    T* shared_mem,
-    bool read_write_pred,
-    T init_val) {
-  blockReduce<X_REDUCE, Y_REDUCE, Z_REDUCE, T, Func, _dim3ti, _dim3bd>(
-      out,
-      inp_val,
-      reduction_op,
-      thread_idx,
-      block_dim,
-      shared_mem,
-      read_write_pred,
-      read_write_pred,
-      init_val);
+  if (should_write && read_write_pred)
+    out = shared_mem[linear_tid];
 }
diff --git a/torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu b/torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu
deleted file mode 100644 (file)
index 637a64d..0000000
+++ /dev/null
@@ -1,51 +0,0 @@
-
-// Counter-based block synchronization. Only meant to be used for
-// debugging and validating synchronization. This should be replaced
-// with cuda::barrier::arrive_and_wait as that should be more robust.
-
-namespace block_sync {
-
-using CounterType = unsigned int;
-static constexpr CounterType COUNTER_TYPE_MAX = ~(CounterType)0;
-__shared__ CounterType sync_counter;
-
-__device__ void init() {
-  const unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x +
-      threadIdx.z * blockDim.x * blockDim.y;
-  if (tid == 0) {
-    sync_counter = 0;
-  }
-  __syncthreads();
-}
-
-// Emulate __syncthreads() with a synchronization counter
-__device__ void sync() {
-  unsigned int backoff = 8;
-  const unsigned int backoff_max = 256;
-  const unsigned int num_threads = blockDim.x * blockDim.y * blockDim.z;
-
-  __threadfence_block();
-
-  // Use counter range only up to a limit so that the next val won't
-  // overflow.
-
-  const auto counter_max = (COUNTER_TYPE_MAX / num_threads) * num_threads;
-  const auto old = atomicInc(&sync_counter, counter_max - 1);
-
-  const auto next = (old / num_threads) * num_threads + num_threads;
-
-  auto local_sync_counter = *(volatile CounterType*)(&sync_counter);
-
-  // sync_counter may wrap around, which means local_sync_counter
-  // becomes smaller than old. In that case, it's guaranteed that all
-  // threads have incremented the counter.
-  while (local_sync_counter < next && old < local_sync_counter) {
-    __nanosleep(backoff);
-    if (backoff < backoff_max) {
-      backoff *= 2;
-    }
-    local_sync_counter = *(volatile CounterType*)(&sync_counter);
-  }
-}
-
-} // namespace block_sync
diff --git a/torch/csrc/jit/codegen/cuda/runtime/block_sync_default.cu b/torch/csrc/jit/codegen/cuda/runtime/block_sync_default.cu
deleted file mode 100644 (file)
index ea371a5..0000000
+++ /dev/null
@@ -1,12 +0,0 @@
-
-// Default block synchronization. Just use __barrier_sync
-namespace block_sync {
-
-__forceinline__ __device__ void init() {}
-
-// Thread-block synchronization
-__forceinline__ __device__ void sync() {
-  __barrier_sync(0);
-}
-
-} // namespace block_sync
index 15962fb..9a13b02 100644 (file)
@@ -1,4 +1,3 @@
-
 namespace broadcast {
 
 template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD>
@@ -24,28 +23,19 @@ __host__ __device__ unsigned offset_of_source(
 // out: Per-thread output location
 //
 template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename T>
-__device__ void blockBroadcast(
-    T& out,
-    const T& inp_val,
-    T* shared_mem,
-    bool read_write_pred) {
+__device__ void blockBroadcast(T& out, T inp_val, T* shared_mem) {
   const bool has_valid_data = (!X_THREAD || threadIdx.x == 0) &&
       (!Y_THREAD || threadIdx.y == 0) && (!Z_THREAD || threadIdx.z == 0);
 
   const auto shared_offset =
       offset_of_source<X_THREAD, Y_THREAD, Z_THREAD>(blockDim, threadIdx);
 
-  if (has_valid_data && read_write_pred) {
+  if (has_valid_data)
     shared_mem[shared_offset] = inp_val;
-  }
-
-  block_sync::sync();
 
-  if (read_write_pred) {
-    out = shared_mem[shared_offset];
-  }
+  __syncthreads();
 
-  block_sync::sync();
+  out = shared_mem[shared_offset];
 }
 
 } // namespace broadcast
index 4bd402e..ba23678 100644 (file)
@@ -1,17 +1,8 @@
-
-#define __NVFUSER_HALF_TO_US(var) *(reinterpret_cast<unsigned short*>(&(var)))
-#define __NVFUSER_HALF_TO_CUS(var) \
-  *(reinterpret_cast<const unsigned short*>(&(var)))
-
-struct __half;
-__device__ __half __float2half(const float);
+#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short*>(&(var)))
+#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short*>(&(var)))
 
 struct __align__(2) __half {
-  __half() = default;
-
-  __device__ __half(const float f) {
-    __x = __float2half(f).__x;
-  }
+  __host__ __device__ __half() {}
 
  protected:
   unsigned short __x;
@@ -19,25 +10,12 @@ struct __align__(2) __half {
 
 __device__ __half __float2half(const float f) {
   __half val;
-  asm("{  cvt.rn.f16.f32 %0, %1;}\n"
-      : "=h"(__NVFUSER_HALF_TO_US(val))
-      : "f"(f));
+  asm("{  cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__HALF_TO_US(val)) : "f"(f));
   return val;
 }
 
 __device__ float __half2float(const __half h) {
   float val;
-  asm("{  cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__NVFUSER_HALF_TO_CUS(h)));
+  asm("{  cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(h)));
   return val;
 }
-
-// aligned vector generates vectorized load/store on CUDA
-template <typename scalar_t, int vec_size>
-struct alignas(sizeof(scalar_t) * vec_size) Array {
-  scalar_t val[vec_size];
-  __device__ void set(scalar_t v) {
-    for (int i = 0; i < vec_size; ++i) {
-      val[i] = v;
-    }
-  }
-};
index 3d2067e..4915fa5 100644 (file)
@@ -45,22 +45,20 @@ namespace reduction {
 
 // Utility functions
 template <typename _dim3>
-__device__ __forceinline__ nvfuser_index_t size(const _dim3& d) {
-  return (nvfuser_index_t)d.x * (nvfuser_index_t)d.y * (nvfuser_index_t)d.z;
+__device__ __forceinline__ size_t size(const _dim3& d) {
+  return (size_t)d.x * (size_t)d.y * (size_t)d.z;
 }
 
-#define isize(d) ((d).x * (d).y * (d).z)
+#define isize(d) d.x* d.y* d.z
 
 template <typename _dim3pos, typename _dim3dim>
-__device__ __forceinline__ nvfuser_index_t
+__device__ __forceinline__ size_t
 offset(const _dim3pos& pos, const _dim3dim& dim) {
-  return (nvfuser_index_t)pos.x +
-      (nvfuser_index_t)pos.y * (nvfuser_index_t)dim.x +
-      (nvfuser_index_t)pos.z * (nvfuser_index_t)dim.x * (nvfuser_index_t)dim.y;
+  return (size_t)pos.x + (size_t)pos.y * (size_t)dim.x +
+      (size_t)pos.z * (size_t)dim.x * (size_t)dim.y;
 }
 
-#define ioffset(pos, dim) \
-  ((pos).x + (pos).y * (dim).x + (pos).z * (dim).x * (dim).y)
+#define ioffset(pos, dim) pos.x + pos.y* dim.x + pos.z* dim.x* dim.y
 
 // Returns dim3 of each reduction segment.
 template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
@@ -73,14 +71,14 @@ __device__ dim3 dimension_of_reduction_segment(const _dim3& grid_dim) {
 
 // Returns the number of blocks in each reduction segment.
 template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
-__device__ nvfuser_index_t size_of_reduction_segment(const _dim3& grid_dim) {
+__device__ size_t size_of_reduction_segment(const _dim3& grid_dim) {
   return size(
       dimension_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(grid_dim));
 }
 
 // Returns the total number of reduction segments.
 template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
-__device__ nvfuser_index_t number_of_reduction_segments(const _dim3& grid_dim) {
+__device__ size_t number_of_reduction_segments(const _dim3& grid_dim) {
   return (X_BLOCK ? 1 : grid_dim.x) * (Y_BLOCK ? 1 : grid_dim.y) *
       (Z_BLOCK ? 1 : grid_dim.z);
 }
@@ -92,9 +90,9 @@ template <
     bool Z_BLOCK,
     typename _dim3bi,
     typename _dim3gd>
-__device__ nvfuser_index_t
+__device__ size_t
 index_of_reduction_segment(const _dim3bi& block_idx, const _dim3gd& grid_dim) {
-  nvfuser_index_t seg_idx = 0;
+  size_t seg_idx = 0;
   if (!Z_BLOCK)
     seg_idx += block_idx.z;
   if (!Y_BLOCK)
@@ -111,9 +109,9 @@ template <
     bool Z_BLOCK,
     typename _dim3bi,
     typename _dim3gd>
-__device__ nvfuser_index_t
+__device__ size_t
 offset_in_reduction_segment(const _dim3bi& block_idx, const _dim3gd& grid_dim) {
-  nvfuser_index_t offset = 0;
+  size_t offset = 0;
   if (Z_BLOCK)
     offset = offset * grid_dim.z + block_idx.z;
   if (Y_BLOCK)
@@ -197,10 +195,10 @@ template <
 __device__ void gridReduceLastBlock(
     T& out,
     const T* in,
-    const nvfuser_index_t in_size,
+    const size_t in_size,
     Func reduction_op,
     T* shared_buf,
-    bool write_pred,
+    bool read_write_pred,
     T init_val) {
   const int tid = ioffset(threadIdx, blockDim);
   const int block_size = isize(blockDim);
@@ -211,7 +209,7 @@ __device__ void gridReduceLastBlock(
   if (tid < in_size) {
     inp = in[tid];
   }
-  for (nvfuser_index_t i = tid + block_size; i < in_size; i += block_size) {
+  for (size_t i = tid + block_size; i < in_size; i += block_size) {
     reduction_op(inp, in[i]);
   }
 
@@ -223,9 +221,8 @@ __device__ void gridReduceLastBlock(
   if (rem_size > 1) {
     const int rblock_offset = tid % rblock_size;
     const int rblock_idx = tid / rblock_size;
-    T inp_tmp = init_val;
     blockReduce<false, true, false>(
-        inp_tmp,
+        inp,
         inp,
         reduction_op,
         dim3{(unsigned)rblock_offset, (unsigned)rblock_idx, 0},
@@ -233,20 +230,19 @@ __device__ void gridReduceLastBlock(
         shared_buf,
         true,
         init_val);
-    block_sync::sync();
-    inp = inp_tmp;
+    __syncthreads();
     if (tid < rblock_size) {
       shared_buf[tid] = inp;
     }
-    block_sync::sync();
+    __syncthreads();
     if (should_write) {
       inp = shared_buf[offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(
           threadIdx, blockDim)];
     }
   }
 
-  if (should_write && write_pred) {
-    reduction_op(out, inp);
+  if (should_write && read_write_pred) {
+    out = inp;
   }
 }
 
@@ -309,13 +305,12 @@ template <
     typename Func>
 __device__ bool gridReduce(
     T& out,
-    const T& inp_val,
+    T inp_val,
     Func reduction_op,
     volatile T* work_buf,
     Tensor<int64_t, 1> sync_flags,
     T* shared_buf,
-    bool read_pred,
-    bool write_pred,
+    bool read_write_pred,
     T init_val) {
   // Number of values to reduce in the grid dimensions
   const auto seg_size =
@@ -342,13 +337,13 @@ __device__ bool gridReduce(
         offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(
             threadIdx, blockDim);
     auto work_buf_offset = rblock_size * rblock_offset + thread_offset;
-    if (read_pred) {
+    if (read_write_pred) {
       work_buf[work_buf_offset] = inp_val;
     } else {
       work_buf[work_buf_offset] = init_val;
     }
   }
-  block_sync::sync();
+  __syncthreads();
 
   __shared__ bool last_block;
   if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
@@ -358,7 +353,7 @@ __device__ bool gridReduce(
     last_block = old + 1 == seg_size;
     // printf("Last_block = %d + 1 == %d\n", (int)old, (int)seg_size);
   }
-  block_sync::sync();
+  __syncthreads();
 
   if (last_block) {
     // printf("Last block %d %d %d %d\n", blockIdx.x, blockIdx.y, blockIdx.z);
@@ -369,7 +364,7 @@ __device__ bool gridReduce(
         seg_size * rblock_size,
         reduction_op,
         shared_buf,
-        write_pred,
+        read_write_pred,
         init_val);
     return true;
   } else {
@@ -379,6 +374,3 @@ __device__ bool gridReduce(
 }
 
 } // namespace reduction
-
-#undef isize
-#undef ioffset
index 15ae469..15b33b2 100644 (file)
-#define NVFUSER_DEFINE_MAGIC_ZERO          \
-  __shared__ int nvfuser_zero_s;           \
-  if (threadIdx.x == 0)                    \
-    nvfuser_zero_s = 0;                    \
-  __syncthreads();                         \
-  atomicMin(&nvfuser_zero_s, threadIdx.x); \
-  int nvfuser_zero = nvfuser_zero_s;
-
-#define NVFUSER_UPDATE_MAGIC_ZERO \
-  do {                            \
-    nvfuser_zero <<= 1;           \
-  } while (0);
-
 __device__ constexpr int ceilDiv(int a, int b) {
   return (a + b - 1) / b;
 }
 
-__device__ constexpr int64_t ceilDiv(int64_t a, int64_t b) {
-  return (a + b - 1) / b;
-}
-
-__device__ constexpr int64_t ceilDiv(int64_t a, int b) {
-  return ceilDiv(a, (int64_t)b);
-}
-
-__device__ constexpr int64_t ceilDiv(int a, int64_t b) {
-  return ceilDiv((int64_t)a, b);
-}
-
 __device__ constexpr int alignBufferSize(int buffer, int size) {
   return (buffer + (size - 1)) & ~(size - 1);
 }
 
-__device__ double clamp(double x, double minv, double maxv) {
-  return x < minv ? minv : (x > maxv ? maxv : x);
-}
-
-__device__ float clamp(float x, double minv, double maxv) {
+__device__ float clamp(float x, float minv, float maxv) {
   return x < minv ? minv : (x > maxv ? maxv : x);
 }
 
-__device__ double frac(double x) {
-  return x - trunc(x);
-}
-
 __device__ float frac(float x) {
-  return x - trunc(x);
-}
-
-__device__ double gelu(double x) {
-  return x * normcdf(x);
+  return x - truncf(x);
 }
 
 __device__ float gelu(float x) {
   return x * normcdf(x);
 }
 
-__device__ double reciprocal(double x) {
-  return 1 / x;
-}
-
 __device__ float reciprocal(float x) {
-  return 1 / x;
-}
-
-__device__ double relu(double x) {
-  return x <= 0 ? 0 : x;
+  return 1.f / x;
 }
 
 __device__ float relu(float x) {
-  return x <= 0 ? 0 : x;
-}
-
-__device__ double remainder(double a, double b) {
-  auto mod = ::fmod(a, b);
-  if ((mod != 0) && ((b < 0) != (mod < 0)))
-    mod += b;
-  return mod;
+  return x <= 0.f ? 0.f : x;
 }
 
 __device__ float remainder(float a, float b) {
-  auto mod = ::fmod(a, b);
-  if ((mod != 0) && ((b < 0) != (mod < 0)))
-    mod += b;
-  return mod;
-}
-
-__device__ double sigmoid(double x) {
-  return 1 / (1 + exp(-x));
+  return a - b * floorf(a / b);
 }
 
 __device__ float sigmoid(float x) {
-  return 1 / (1 + exp(-x));
+  return 1.f / (1.f + expf(-x));
 }
 
-__device__ double silu(double x) {
-  return x * sigmoid(x);
-}
-
-__device__ float silu(float x) {
-  return x * sigmoid(x);
-}
-
-__device__ double threshold(double x, double t, double v) {
+__device__ float threshold(float x, float t, float v) {
   return x <= t ? v : x;
 }
 
-__device__ float threshold(float x, double t, double v) {
-  return x <= t ? v : x;
-}
-
-__device__ double where(bool c, double a, double b) {
-  return c ? a : b;
-}
-
 __device__ float where(bool c, float a, float b) {
   return c ? a : b;
 }
 
-__device__ int64_t where(bool c, int64_t a, int64_t b) {
-  return c ? a : b;
-}
-
-__device__ double randLike(Philox rnd) {
-  return uniform(rnd(), rnd());
-}
-
-__device__ float randLikef(Philox rnd) {
-  return uniformf(rnd());
-}
-
-__device__ constexpr int64_t remainder(int64_t a, int64_t b) {
-  auto mod = a % b;
-  if ((mod != 0) && ((b < 0) != (mod < 0)))
-    mod += b;
-  return mod;
-}
-
-__device__ constexpr int remainder(int a, int b) {
-  auto mod = a % b;
-  if ((mod != 0) && ((b < 0) != (mod < 0)))
-    mod += b;
-  return mod;
+__device__ float randLike(Philox rnd) {
+  return uniform(rnd());
 }
index bbea265..d690145 100644 (file)
@@ -98,14 +98,7 @@ class Philox {
   unsigned int STATE = 0;
 };
 
-__device__ float uniformf(unsigned int x) {
+__device__ float uniform(unsigned int x) {
   constexpr float kRanInvM32 = 2.3283064e-10f; // Inverse of 2^32.
   return x * kRanInvM32;
 }
-
-__device__ double uniform(unsigned int x, unsigned int y) {
-  constexpr double kRan2Pow53Inv = 1.1102230246251565e-16;
-  const unsigned long long z =
-      (unsigned long long)x ^ ((unsigned long long)y << (53 - 32));
-  return z * kRan2Pow53Inv + (kRan2Pow53Inv / 2.0);
-}
index aab51a8..76731c8 100644 (file)
@@ -1,19 +1,24 @@
+typedef unsigned char uint8_t;
+typedef signed char int8_t;
+typedef short int int16_t;
+typedef long long int int64_t;
+
 template <typename T, int N>
 struct Tensor {
-  __device__ T& operator[](nvfuser_index_t ind) {
+  __device__ T& operator[](int64_t ind) {
     return data[ind];
   };
 
   T* data;
-  nvfuser_index_t size[N];
-  nvfuser_index_t stride[N];
+  int64_t size[N];
+  int64_t stride[N];
 };
 
 // Specialization for 0-dim case as it does not need size and stride arrays.
 // They will be an error as well since zero-length arrays are not allowed.
 template <typename T>
 struct Tensor<T, 0> {
-  __device__ T& operator[](nvfuser_index_t) {
+  __device__ T& operator[](int64_t) {
     return *data;
   };
 
diff --git a/torch/csrc/jit/codegen/cuda/runtime/welford.cu b/torch/csrc/jit/codegen/cuda/runtime/welford.cu
deleted file mode 100644 (file)
index e0cbab6..0000000
+++ /dev/null
@@ -1,482 +0,0 @@
-// -----------------------------------------------------------------------------------------------
-//  Block Welford Primitives
-// -----------------------------------------------------------------------------------------------
-// Basic utility for welford update. Can be used to scan one value, or two merge
-// two welford results
-template <typename T, typename TN>
-__inline__ __device__ void welfordCombine(
-    T& a_avg,
-    T& a_M2,
-    TN& a_N,
-    const T& b_avg,
-    const T& b_M2,
-    TN b_N) {
-  if (b_N == 0) {
-    return;
-  }
-  TN ab_N = a_N + b_N;
-  T delta = b_avg - a_avg;
-  a_avg += delta * b_N / ab_N;
-  a_M2 += b_M2 + delta * delta * a_N * b_N / ab_N;
-  a_N = ab_N;
-}
-
-// [Z,Y,X]_THREADS is the number of participating threads in the z, y, x
-// dimension of the block.
-template <
-    bool X_REDUCE,
-    bool Y_REDUCE,
-    bool Z_REDUCE,
-    typename T,
-    typename TN,
-    typename _dim3ti,
-    typename _dim3bd>
-__inline__ __device__ void blockWelford(
-    T& out_avg,
-    T& out_M2,
-    TN& out_N,
-    const T& in_avg,
-    const T& in_M2,
-    const TN& in_N,
-    const _dim3ti& thread_idx,
-    const _dim3bd& block_dim,
-    T* shared_mem_avg,
-    T* shared_mem_M2,
-    TN* shared_mem_N,
-    bool read_pred,
-    bool write_pred,
-    T init_val) {
-  unsigned int reduction_size = (X_REDUCE ? block_dim.x : 1) *
-      (Y_REDUCE ? block_dim.y : 1) * (Z_REDUCE ? block_dim.z : 1);
-  // If this thread will output a final result
-  bool should_write = true;
-  if (X_REDUCE)
-    should_write = should_write && thread_idx.x == 0;
-  if (Y_REDUCE)
-    should_write = should_write && thread_idx.y == 0;
-  if (Z_REDUCE)
-    should_write = should_write && thread_idx.z == 0;
-  unsigned int reduction_stride;
-  unsigned int reduction_tid;
-  unsigned int linear_tid;
-  if (X_REDUCE && !Y_REDUCE && Z_REDUCE) {
-    // Transpose Z and Y in the shared memory so Z and X dims are contiguous in
-    // smem
-    reduction_stride = 1;
-    linear_tid = threadIdx.y * blockDim.z * blockDim.x +
-        threadIdx.z * blockDim.x + threadIdx.x;
-    reduction_tid = threadIdx.z * blockDim.x + threadIdx.x;
-  } else {
-    // Normal reduction in order
-    reduction_stride =
-        (X_REDUCE ? 1
-                  : (Y_REDUCE ? block_dim.x
-                              : (Z_REDUCE ? block_dim.x * block_dim.y : 0)));
-    linear_tid = thread_idx.z * block_dim.y * block_dim.x +
-        thread_idx.y * block_dim.x + thread_idx.x;
-    reduction_tid = (Z_REDUCE ? thread_idx.z : 0) *
-            (Y_REDUCE ? block_dim.y : 1) * (X_REDUCE ? block_dim.x : 1) +
-        (Y_REDUCE ? thread_idx.y : 0) * (X_REDUCE ? block_dim.x : 1) +
-        (X_REDUCE ? thread_idx.x : 0);
-  }
-  assert(reduction_stride != 0);
-  if (read_pred) {
-    shared_mem_avg[linear_tid] = in_avg;
-    shared_mem_M2[linear_tid] = in_M2;
-    shared_mem_N[linear_tid] = in_N;
-  } else {
-    shared_mem_avg[linear_tid] = init_val;
-    shared_mem_M2[linear_tid] = init_val;
-    shared_mem_N[linear_tid] = 0;
-  }
-  block_sync::sync();
-  // Reduce down to nearest power of 2:
-  int np2 = 1 << (31 - __clz(reduction_size));
-  if (reduction_tid < np2) {
-    if (reduction_tid + np2 < reduction_size) {
-      welfordCombine(
-          shared_mem_avg[linear_tid],
-          shared_mem_M2[linear_tid],
-          shared_mem_N[linear_tid],
-          shared_mem_avg[linear_tid + np2 * reduction_stride],
-          shared_mem_M2[linear_tid + np2 * reduction_stride],
-          shared_mem_N[linear_tid + np2 * reduction_stride]);
-    }
-  }
-  block_sync::sync();
-
-  // loop peel the final iteration to save one syncthread for the end
-  for (int factor = np2 / 2; factor > 1; factor >>= 1) {
-    if (reduction_tid < factor) {
-      welfordCombine(
-          shared_mem_avg[linear_tid],
-          shared_mem_M2[linear_tid],
-          shared_mem_N[linear_tid],
-          shared_mem_avg[linear_tid + factor * reduction_stride],
-          shared_mem_M2[linear_tid + factor * reduction_stride],
-          shared_mem_N[linear_tid + factor * reduction_stride]);
-    }
-    block_sync::sync();
-  }
-  if (should_write && write_pred) {
-    T res_avg = out_avg;
-    T res_M2 = out_M2;
-    TN res_N = out_N;
-    welfordCombine(
-        res_avg,
-        res_M2,
-        res_N,
-        shared_mem_avg[linear_tid],
-        shared_mem_M2[linear_tid],
-        shared_mem_N[linear_tid]);
-    if (reduction_size > 1) {
-      welfordCombine(
-          res_avg,
-          res_M2,
-          res_N,
-          shared_mem_avg[linear_tid + reduction_stride],
-          shared_mem_M2[linear_tid + reduction_stride],
-          shared_mem_N[linear_tid + reduction_stride]);
-    }
-    out_avg = res_avg;
-    out_M2 = res_M2;
-    out_N = res_N;
-  }
-  block_sync::sync();
-}
-
-// Use the same pred for both reads and writes
-template <
-    bool X_REDUCE,
-    bool Y_REDUCE,
-    bool Z_REDUCE,
-    typename T,
-    typename TN,
-    typename _dim3ti,
-    typename _dim3bd>
-__inline__ __device__ void blockWelford(
-    T& out_avg,
-    T& out_M2,
-    TN& out_N,
-    const T& in_avg,
-    const T& in_M2,
-    const TN& in_N,
-    const _dim3ti& thread_idx,
-    const _dim3bd& block_dim,
-    T* shared_mem_avg,
-    T* shared_mem_M2,
-    TN* shared_mem_N,
-    bool read_write_pred,
-    T init_val) {
-  blockWelford<X_REDUCE, Y_REDUCE, Z_REDUCE, T, TN, _dim3ti, _dim3bd>(
-      out_avg,
-      out_M2,
-      out_N,
-      in_avg,
-      in_M2,
-      in_N,
-      thread_idx,
-      block_dim,
-      shared_mem_avg,
-      shared_mem_M2,
-      shared_mem_N,
-      read_write_pred,
-      read_write_pred,
-      init_val);
-}
-// -----------------------------------------------------------------------------------------------
-//  Grid Welford Prototype
-// -----------------------------------------------------------------------------------------------
-namespace welford {
-// Utility functions
-template <typename _dim3>
-__host__ __device__ __forceinline__ nvfuser_index_t size(const _dim3& d) {
-  return (nvfuser_index_t)d.x * (nvfuser_index_t)d.y * (nvfuser_index_t)d.z;
-}
-
-#define isize(d) ((d).x * (d).y * (d).z)
-
-template <typename _dim3pos, typename _dim3dim>
-__host__ __device__ __forceinline__ nvfuser_index_t
-offset(const _dim3pos& pos, const _dim3dim& dim) {
-  return (nvfuser_index_t)pos.x +
-      (nvfuser_index_t)pos.y * (nvfuser_index_t)dim.x +
-      (nvfuser_index_t)pos.z * (nvfuser_index_t)dim.x * (nvfuser_index_t)dim.y;
-}
-
-#define ioffset(pos, dim) \
-  ((pos).x + (pos).y * (dim).x + (pos).z * (dim).x * (dim).y)
-
-// Returns dim3 of each reduction segment.
-template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
-__host__ __device__ dim3 dimension_of_reduction_segment(const _dim3& grid_dim) {
-  return dim3{
-      X_BLOCK ? grid_dim.x : 1,
-      Y_BLOCK ? grid_dim.y : 1,
-      Z_BLOCK ? grid_dim.z : 1};
-}
-
-// Returns the number of blocks in each reduction segment.
-template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
-__host__ __device__ nvfuser_index_t
-size_of_reduction_segment(const _dim3& grid_dim) {
-  return size(
-      dimension_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(grid_dim));
-}
-
-// Returns the total number of reduction segments.
-template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
-__host__ __device__ nvfuser_index_t
-number_of_reduction_segments(const _dim3& grid_dim) {
-  return (X_BLOCK ? 1 : grid_dim.x) * (Y_BLOCK ? 1 : grid_dim.y) *
-      (Z_BLOCK ? 1 : grid_dim.z);
-}
-
-// Returns the 1-D index of the segment of thread block of block_idx.
-template <
-    bool X_BLOCK,
-    bool Y_BLOCK,
-    bool Z_BLOCK,
-    typename _dim3bi,
-    typename _dim3gd>
-__host__ __device__ nvfuser_index_t
-index_of_reduction_segment(const _dim3bi& block_idx, const _dim3gd& grid_dim) {
-  nvfuser_index_t seg_idx = 0;
-  if (!Z_BLOCK)
-    seg_idx += block_idx.z;
-  if (!Y_BLOCK)
-    seg_idx = seg_idx * grid_dim.y + block_idx.y;
-  if (!X_BLOCK)
-    seg_idx = seg_idx * grid_dim.x + block_idx.x;
-  return seg_idx;
-}
-
-// Returns the offset of thread block in its reduction segment.
-template <
-    bool X_BLOCK,
-    bool Y_BLOCK,
-    bool Z_BLOCK,
-    typename _dim3bi,
-    typename _dim3gd>
-__host__ __device__ nvfuser_index_t
-offset_in_reduction_segment(const _dim3bi& block_idx, const _dim3gd& grid_dim) {
-  nvfuser_index_t offset = 0;
-  if (Z_BLOCK)
-    offset = offset * grid_dim.z + block_idx.z;
-  if (Y_BLOCK)
-    offset = offset * grid_dim.y + block_idx.y;
-  if (X_BLOCK)
-    offset = offset * grid_dim.x + block_idx.x;
-  return offset;
-}
-
-// Returns dim3 of each reduction block.
-template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename _dim3>
-__host__ __device__ dim3 dimension_of_reduction_block(const _dim3& block_dim) {
-  return dim3{
-      X_THREAD ? block_dim.x : 1,
-      Y_THREAD ? block_dim.y : 1,
-      Z_THREAD ? block_dim.z : 1};
-}
-
-// Returns the number of threads of each reduction block.
-template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename _dim3>
-__host__ __device__ int size_of_reduction_block(const _dim3& block_dim) {
-  auto tmp_dim =
-      dimension_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(block_dim);
-  return isize(tmp_dim);
-}
-
-// Returns the linear offset of a thread in a reduction block.
-template <
-    bool X_THREAD,
-    bool Y_THREAD,
-    bool Z_THREAD,
-    typename _dim3ti,
-    typename _dim3bd>
-__host__ __device__ int offset_in_reduction_block(
-    const _dim3ti& thread_idx,
-    const _dim3bd& block_dim) {
-  int offset = 0;
-  if (Z_THREAD)
-    offset += thread_idx.z;
-  if (Y_THREAD)
-    offset = offset * block_dim.y + thread_idx.y;
-  if (X_THREAD)
-    offset = offset * block_dim.x + thread_idx.x;
-  return offset;
-}
-
-template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename T, typename TN>
-__device__ void gridWelfordLastBlock(
-    T& out_avg,
-    T& out_M2,
-    TN& out_N,
-    const T* in_avg,
-    const T* in_M2,
-    const TN* in_N,
-    const nvfuser_index_t in_size,
-    T* shared_buf_avg,
-    T* shared_buf_M2,
-    TN* shared_buf_N,
-    bool write_pred,
-    T init_val) {
-  const int tid = ioffset(threadIdx, blockDim);
-  const int block_size = isize(blockDim);
-  const int rblock_size =
-      size_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(blockDim);
-
-  T inp_avg = init_val;
-  T inp_M2 = init_val;
-  TN inp_N = 0;
-  if (tid < in_size) {
-    inp_avg = in_avg[tid];
-    inp_M2 = in_M2[tid];
-    inp_N = in_N[tid];
-  }
-  for (nvfuser_index_t i = tid + block_size; i < in_size; i += block_size) {
-    welfordCombine(inp_avg, inp_M2, inp_N, in_avg[i], in_M2[i], in_N[i]);
-  }
-  const auto should_write = (X_THREAD || threadIdx.x == 0) &&
-      (Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0);
-
-  auto rem_size = block_size / rblock_size;
-
-  if (rem_size > 1) {
-    const int rblock_offset = tid % rblock_size;
-    const int rblock_idx = tid / rblock_size;
-    T inp_avg_tmp = init_val;
-    T inp_M2_tmp = init_val;
-    TN inp_N_tmp = 0;
-    blockWelford<false, true, false>(
-        inp_avg_tmp,
-        inp_M2_tmp,
-        inp_N_tmp,
-        inp_avg,
-        inp_M2,
-        inp_N,
-        dim3{(unsigned)rblock_offset, (unsigned)rblock_idx, 0},
-        dim3{(unsigned)rblock_size, (unsigned)rem_size},
-        shared_buf_avg,
-        shared_buf_M2,
-        shared_buf_N,
-        true,
-        init_val);
-    block_sync::sync();
-    if (tid < rblock_size) {
-      shared_buf_avg[tid] = inp_avg_tmp;
-      shared_buf_M2[tid] = inp_M2_tmp;
-      shared_buf_N[tid] = inp_N_tmp;
-    }
-    block_sync::sync();
-    if (should_write) {
-      nvfuser_index_t offset_write =
-          offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(
-              threadIdx, blockDim);
-      inp_avg = shared_buf_avg[offset_write];
-      inp_M2 = shared_buf_M2[offset_write];
-      inp_N = shared_buf_N[offset_write];
-    }
-  }
-
-  if (should_write && write_pred) {
-    welfordCombine(out_avg, out_M2, out_N, inp_avg, inp_M2, inp_N);
-  }
-}
-
-//    Grid welford combine
-template <
-    bool X_BLOCK,
-    bool Y_BLOCK,
-    bool Z_BLOCK,
-    bool X_THREAD,
-    bool Y_THREAD,
-    bool Z_THREAD,
-    typename T,
-    typename TN>
-__device__ bool gridWelford(
-    T& out_avg,
-    T& out_M2,
-    TN& out_N,
-    const T& inp_avg,
-    const T& inp_M2,
-    const TN& inp_N,
-    volatile T* work_buf_avg,
-    volatile T* work_buf_M2,
-    volatile TN* work_buf_N,
-    Tensor<int64_t, 1> sync_flags,
-    T* shared_buf_avg,
-    T* shared_buf_M2,
-    TN* shared_buf_N,
-    bool read_pred,
-    bool write_pred,
-    T init_val) {
-  // Number of values to reduce in the grid dimensions
-  const auto seg_size =
-      size_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim);
-
-  // Index of the reduction we're performing out of the seg_size
-  const auto seg_idx =
-      index_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
-
-  // Number of threads we can use in final reduction, Seems to assume all
-  // threads in the block participate
-  const auto rblock_size =
-      size_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(blockDim);
-
-  work_buf_avg += seg_idx * seg_size * rblock_size;
-  work_buf_M2 += seg_idx * seg_size * rblock_size;
-  work_buf_N += seg_idx * seg_size * rblock_size;
-
-  if ((X_THREAD || threadIdx.x == 0) && (Y_THREAD || threadIdx.y == 0) &&
-      (Z_THREAD || threadIdx.z == 0)) {
-    auto rblock_offset = offset_in_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(
-        blockIdx, gridDim);
-    auto thread_offset =
-        offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(
-            threadIdx, blockDim);
-    auto work_buf_offset = rblock_size * rblock_offset + thread_offset;
-    if (read_pred) {
-      work_buf_avg[work_buf_offset] = inp_avg;
-      work_buf_M2[work_buf_offset] = inp_M2;
-      work_buf_N[work_buf_offset] = inp_N;
-    } else {
-      work_buf_avg[work_buf_offset] = init_val;
-      work_buf_M2[work_buf_offset] = init_val;
-      work_buf_N[work_buf_offset] = 0;
-    }
-  }
-  block_sync::sync();
-
-  __shared__ bool last_block;
-  if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
-    __threadfence();
-    auto old = (int64_t)atomicAdd((unsigned long long*)&sync_flags[seg_idx], 1);
-    last_block = old + 1 == seg_size;
-  }
-  block_sync::sync();
-
-  if (last_block) {
-    // final reduction
-    gridWelfordLastBlock<X_THREAD, Y_THREAD, Z_THREAD>(
-        out_avg,
-        out_M2,
-        out_N,
-        (T*)work_buf_avg,
-        (T*)work_buf_M2,
-        (TN*)work_buf_N,
-        seg_size * rblock_size,
-        shared_buf_avg,
-        shared_buf_M2,
-        shared_buf_N,
-        write_pred,
-        init_val);
-    return true;
-  } else {
-    return false;
-  }
-}
-} // namespace welford
-
-#undef isize
-#undef ioffset
diff --git a/torch/csrc/jit/codegen/cuda/scheduler.cpp b/torch/csrc/jit/codegen/cuda/scheduler.cpp
new file mode 100644 (file)
index 0000000..199e564
--- /dev/null
@@ -0,0 +1,689 @@
+#include <torch/csrc/jit/codegen/cuda/scheduler.h>
+
+#include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
+#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
+#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
+#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
+#include <torch/csrc/jit/codegen/cuda/parser.h>
+
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/util/irange.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
+
+constexpr int kUnrollFactor = 1;
+
+namespace {
+
+std::vector<int> reductionAxes(TensorView* tv) {
+  size_t n_dims = tv->nDims();
+  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+  std::vector<int> reduction_axes;
+  for (const auto i : c10::irange(n_dims)) {
+    if (tv->axis(i)->isReduction()) {
+      reduction_axes.emplace_back(i);
+    }
+  }
+  return reduction_axes;
+}
+
+// Merge all reduction to the right side and returns total number of
+// reduction axes
+size_t mergeReduction(TensorView* tv) {
+  int prev_i = -1;
+  size_t num_merged = 0;
+  for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
+    if (!tv->axis(i)->isReduction()) {
+      continue;
+    }
+    if (prev_i == -1) {
+      prev_i = i;
+    } else {
+      tv->merge(i, prev_i);
+      prev_i = i;
+      num_merged++;
+    }
+  }
+  if (prev_i == 0) {
+    tv->reorder({{prev_i, -1}});
+  }
+
+  return prev_i == -1 ? 0 : num_merged + 1;
+}
+
+// merge all non-reduction axes to the left side and returns total number of
+// iteration axes
+size_t mergeNonReduction(TensorView* tv) {
+  int prev_i = -1;
+  size_t num_merged = 0;
+  for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
+    if (tv->axis(i)->isReduction()) {
+      continue;
+    }
+    if (prev_i == -1) {
+      prev_i = i;
+    } else {
+      tv->merge(i, prev_i);
+      prev_i = i;
+      num_merged++;
+    }
+  }
+  if (prev_i != 0) {
+    tv->reorder({{prev_i, 0}});
+  }
+
+  return prev_i == -1 ? 0 : num_merged + 1;
+}
+
+} // namespace
+
+// This one is a total mess and it should go.
+bool scheduleFusion(Fusion* fusion, const at::ArrayRef<c10::IValue> inputs) {
+  FUSER_PERF_SCOPE("scheduleFusion");
+
+  FusionGuard fg(fusion);
+  // maybe has_reduction for scheudling should be done on a per output tensor
+  // basis.
+  TORCH_INTERNAL_ASSERT(
+      !fusion->hasReduction(), "This scheduler only handles pointwise ops.");
+  const bool disable_unroll = fusion->isStochastic();
+
+  for (auto out_val : fusion->outputs()) {
+    auto out = out_val->as<TensorView>();
+
+    // Merge all dimensions because we're only supporting pointwise
+    while (out->nDims() > 1) {
+      out->merge(-2, -1);
+    }
+  }
+
+  // Run through outputs, grab all inputs of outputs
+  // squeeze with computeAt to set overall structure.
+  for (auto output : fusion->outputs()) {
+    if (output->getValType() != ValType::TensorView)
+      continue;
+    TensorView* out_tv = output->as<TensorView>();
+
+    // Split into 128 which will be bockDim.x
+    out_tv->split(0, kPwThreadX);
+    // Split by another 4 which will be our unroll factor
+    auto ur_factor = disable_unroll ? 1 : kUnrollFactor;
+    out_tv->split(0, ur_factor);
+  }
+
+  for (auto output : fusion->outputs()) {
+    if (output->getValType() != ValType::TensorView)
+      continue;
+    TensorView* out_tv = output->as<TensorView>();
+    for (Val* inp : fusion->inputsOf(output)) {
+      if (inp->getValType().value() == ValType::TensorView)
+        inp->as<TensorView>()->computeAt(out_tv, -1);
+    }
+    out_tv->axis(0)->parallelize(ParallelType::BIDx);
+    out_tv->axis(1)->parallelize(ParallelType::Unroll);
+    out_tv->axis(2)->parallelize(ParallelType::TIDx);
+  }
+
+  return true;
+}
+
+namespace {
+// Largest Power of 2 less-than n
+constexpr int lastPow2(int n) {
+  n |= (n >> 1);
+  n |= (n >> 2);
+  n |= (n >> 4);
+  n |= (n >> 8); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
+  n |= (n >> 16); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
+  return std::max(1, n - (n >> 1));
+}
+
+ReductionParams reductionHeuristic(
+    int red_elems,
+    int red_outputs,
+    bool red_on_fastest_dim) {
+  ReductionParams rparams;
+  rparams.fastest_dim = red_on_fastest_dim;
+
+  int gdimx = LaunchParams::UNINITIALIZED_VAL;
+  int gdimy = LaunchParams::UNINITIALIZED_VAL;
+  int bdimx = LaunchParams::UNINITIALIZED_VAL;
+  int bdimy = LaunchParams::UNINITIALIZED_VAL;
+
+  // 1. Initial Assumptions
+
+  // Evaluate Dimensions of Reduction TensorView
+  TORCH_INTERNAL_ASSERT(red_elems > 0 && red_outputs > 0);
+
+  // 2. Initial Definition of Block Dimensions
+
+  // Is fastest dimension a reduction dimension?
+  if (rparams.fastest_dim) {
+    if (red_elems < rparams.loop_unroll) {
+      rparams.loop_unroll = 1;
+    }
+    bdimx = ceilDiv(red_elems, rparams.loop_unroll);
+    bdimy = red_outputs;
+  } else {
+    bdimx = red_outputs;
+    bdimy = red_elems;
+  }
+
+  // 3. Applying Power of 2 Blocking based on the Maximum Number of threads
+
+  constexpr int kMaxNumThreads = 512;
+  int num_threads = kMaxNumThreads;
+  int device_warp_size = at::cuda::warp_size();
+
+  if (bdimx < num_threads) {
+    bdimx = lastPow2(bdimx);
+  } else {
+    bdimx = num_threads;
+  }
+
+  if (bdimy < num_threads) {
+    bdimy = lastPow2(bdimy);
+  } else {
+    bdimy = num_threads;
+  }
+
+  int bdimx_prev = bdimx;
+  bdimx = std::min(bdimx, device_warp_size);
+  bdimy = std::min(bdimy, num_threads / bdimx);
+  bdimx = std::min(bdimx_prev, num_threads / bdimy);
+
+  // 4. Distributing work across a block
+
+  // Magic numbers of calculations allowed per thread.
+  constexpr int kMinValuesPerThread = 16;
+  constexpr int kMaxValuesPerThread = 256;
+
+  int inputs_consumed_per_block_iter = 1;
+  int red_elems_per_thread = red_elems;
+
+  int outputs_produced_per_block_iter = 1;
+
+  // Reduction is performed across warp threads (cross-thread reduction)
+  if (rparams.fastest_dim) {
+    inputs_consumed_per_block_iter *= bdimx;
+    red_elems_per_thread =
+        ceilDiv(red_elems_per_thread, inputs_consumed_per_block_iter);
+    // Warp threads are applied across the output
+  } else {
+    outputs_produced_per_block_iter *= bdimx;
+  }
+
+  // Decision to do a cross-warp reduction per block
+  if (red_elems_per_thread >= (bdimy * kMinValuesPerThread) ||
+      red_elems_per_thread >= kMaxValuesPerThread || !rparams.fastest_dim) {
+    // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
+    inputs_consumed_per_block_iter *= bdimy;
+    red_elems_per_thread = ceilDiv(red_elems_per_thread, bdimy);
+    rparams.cross_block = true;
+    rparams.mul_reds_per_blk = false;
+    // Do multiple reductions per block
+  } else {
+    rparams.cross_block = false;
+    rparams.mul_reds_per_blk = true;
+    outputs_produced_per_block_iter *= bdimy;
+  }
+
+  // 5. Distributing work across blocks
+
+  // WARNING: Current device for codegen may not be the target device
+  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+  int device_max_threads_per_multiprocessor =
+      at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor;
+  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+  int device_multiprocessor_count =
+      at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
+
+  int blocks_per_sm = device_max_threads_per_multiprocessor / (bdimx * bdimy);
+  int target_grid_size = device_multiprocessor_count * blocks_per_sm;
+
+  // Setting the number of blocks based on the number of outputs
+  gdimx = ceilDiv(red_outputs, outputs_produced_per_block_iter);
+
+  // Cross-block reductions (if necessary)
+  if (rparams.cross_block && red_elems_per_thread >= kMaxValuesPerThread &&
+      gdimx <= target_grid_size) {
+    int blks_per_out_1 = ceilDiv(target_grid_size, gdimx);
+    int blks_per_out_2 = ceilDiv(red_elems_per_thread, kMinValuesPerThread);
+    int blks_per_out_3 = ceilDiv(red_elems_per_thread, kMaxValuesPerThread);
+    int blks_per_output =
+        std::max(std::min(blks_per_out_1, blks_per_out_2), blks_per_out_3);
+
+    gdimy = std::max(1, blks_per_output);
+    // If a cross-block reduction was generated
+    if (blks_per_output > 1) {
+      rparams.cross_grid = true;
+    }
+  }
+
+  const char* debug_env = getenv("PYTORCH_CUDA_FUSER_RED_SCHED_DEBUG");
+  if (debug_env && atoi(debug_env)) {
+    std::cout << "\n===== Reduction Parameters ========" << std::endl
+              << "Inputs:" << std::endl
+              << "\tRed Elems: " << red_elems << " Red Outputs: " << red_outputs
+              << " Red On Fastest Dim? " << red_on_fastest_dim << std::endl
+              << "Reduction Characteristics:" << std::endl
+              << "\tMultiple Reds Per Block? " << rparams.mul_reds_per_blk
+              << " Cross Block? " << rparams.cross_block << " Cross Grid? "
+              << rparams.cross_grid << std::endl
+              << "Recommended Blocking:" << std::endl
+              << "\tGridX: " << gdimx << " GridY: " << gdimy
+              << " BlckX: " << bdimx << " BlckY: " << bdimy << std::endl
+              << "====================================" << std::endl;
+  }
+
+  rparams.lparams = LaunchParams(
+      LaunchParams::UNINITIALIZED_VAL,
+      gdimy,
+      LaunchParams::UNINITIALIZED_VAL,
+      bdimx,
+      bdimy,
+      LaunchParams::UNINITIALIZED_VAL);
+  return rparams;
+}
+} // anonymous namespace
+
+TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
+    Fusion* fusion,
+    const at::ArrayRef<c10::IValue>& fusion_inputs,
+    TensorView* red_tv) {
+  FUSER_PERF_SCOPE("scheduleReduction");
+
+  FusionGuard fg(fusion);
+
+  if (!fusion->hasReduction()) {
+    return c10::nullopt;
+  }
+
+  auto red_root_dom = red_tv->getRootDomain();
+  const bool red_on_fastest_dim =
+      red_root_dom[red_root_dom.size() - 1]->isReduction();
+
+  TORCH_INTERNAL_ASSERT(
+      red_tv != nullptr, "Reduction TensorView wasn't found.");
+
+  if (!fusion->hasReduction()) {
+    return c10::nullopt;
+  }
+
+  TORCH_INTERNAL_ASSERT(
+      red_tv->hasReduction(), "TensorView doesn't have a reduction.");
+  const auto red_expr = fusion->origin(red_tv);
+
+  TORCH_INTERNAL_ASSERT(
+      red_expr->getExprType() != c10::nullopt &&
+          red_expr->getExprType().value() == ExprType::ReductionOp,
+      "TensorView doesn't have a reduction.");
+
+  StatefulExpressionEvaluator evaluator(
+      executor_utils::statefulBindInputs(fusion_inputs, fusion));
+
+  int64_t red_outputs = 1;
+  int64_t red_elements = 1;
+
+  for (auto id : red_tv->getRootDomain()) {
+    auto inferred_val = evaluator.inferValue(id->rawExtent());
+    TORCH_INTERNAL_ASSERT(
+        inferred_val.has_value(), "Error inferring reduction size.");
+    if (id->isReduction()) {
+      red_elements *= inferred_val.value();
+    } else {
+      red_outputs *= inferred_val.value();
+    }
+  }
+
+  return reductionHeuristic(red_elements, red_outputs, red_on_fastest_dim);
+}
+
+// fusion is the input IR that will be modified by this function
+void scheduleReduction(
+    Fusion* fusion,
+    const ReductionParams& rparams,
+    TensorView* red_tv,
+    std::vector<TensorView*> outs_of_red) {
+  FusionGuard fg(fusion);
+
+  // We coalesc all reduction axes to the right;
+  mergeReduction(red_tv);
+
+  // Merge all iteration dimensions
+  mergeNonReduction(red_tv);
+  for (auto iter_tv : outs_of_red) {
+    mergeNonReduction(iter_tv);
+  }
+
+  // Evaluate Dimensions of Reduction TensorView
+  auto red_ids = red_tv->domain()->domain();
+
+  TORCH_INTERNAL_ASSERT(
+      red_ids.size() == 2, "We coalesced all dimensions into 2 previously.");
+
+  constexpr int kLoopUnrollSplit = 4;
+
+  // Scheduling the Reduction
+  if (rparams.fastest_dim) {
+    // Do multiple reductions per block
+    if (rparams.mul_reds_per_blk) {
+      // Reduction Splits
+      //      [outputs, |rF-Leftover, X-Warp, rf-Unroll|]
+      // Idx:     0     |   1(-1)      2(-2)     3(-1) |
+      //                --------------------------------
+      //                Reduction Dimensions
+      red_tv->split(1, rparams.loop_unroll);
+      red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
+
+      // Output Splits
+      //      [|Out-Leftover, Out-PerBlock|, <Reduction Dims>]
+      // Idx:  |     0             1      |   2(-2) -- 3(-1)
+      //       ----------------------------
+      //       Output Dimensions
+      red_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDy));
+      for (auto iter_tv : outs_of_red) {
+        iter_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDy));
+      }
+
+      auto red_tv_rf = red_tv->rFactor({-3, -1});
+
+      // WARNING: computeAt will coalesce the rFactored dimensions
+      // rFactored Reduction Tensor after computeAt():
+      //      [<output dims>, | rF-Leftover, X-Warp, rF-Unroll|]
+      // Idx:      0 -- 1     |    2(-3)      3(-2)     4(-1)  |
+      //                      ---------------------------------
+      //                      Reduction Dimensions
+      red_tv_rf->computeAt(red_tv, -1);
+
+      // After the Reduction Tensor has rFactoring applied
+      // Reduction Output Tensor:
+      //      [Out-Leftover, Out-PerBlock, X-Warp]
+      // Idx:       0              1       2(-1)
+      if (!outs_of_red.empty()) {
+        red_tv->computeAt(outs_of_red[0], -1);
+      }
+
+      red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
+
+      red_tv->axis(0)->parallelize(ParallelType::BIDx);
+      for (auto iter_tv : outs_of_red) {
+        iter_tv->axis(0)->parallelize(ParallelType::BIDx);
+      }
+      red_tv->axis(1)->parallelize(ParallelType::TIDy);
+      for (auto iter_tv : outs_of_red) {
+        iter_tv->axis(1)->parallelize(ParallelType::TIDy);
+      }
+      red_tv->axis(-1)->parallelize(ParallelType::TIDx);
+
+      // Bind Inputs to Reduction
+      for (auto input : fusion->inputsOf(red_tv_rf)) {
+        if (input->getValType().value() == ValType::TensorView) {
+          input->as<TensorView>()->computeAt(red_tv_rf, -1);
+        }
+      }
+      // Do a cross-warp reduction per block
+    } else {
+      if (rparams.cross_grid) {
+        // Reduction Splits
+        //      [outputs, |rF-Leftover, X-Grid, X-Block, X-Warp, rf-Unroll|]
+        // Idx:     0     |   1(-5)      2(-4)    3(-3)   4(-2)     5(-1) |
+        //                -------------------------------------------------
+        //                Reduction Dimensions
+        red_tv->split(1, rparams.loop_unroll);
+        red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
+        red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
+        red_tv->split(1, NamedScalar::getParallelDim(ParallelType::BIDy));
+
+        auto red_tv_rf = red_tv->rFactor(
+            {-5, -1}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
+
+        // WARNING: computeAt will coalesce the rFactored dimensions
+        // rFactored Reduction Tensor after computeAt():
+        //      [Outputs, |X-Grid, X-Block, X-Warp, rF-Leftover, rF-Unroll|]
+        // Idx:     0     | 1(-5)    2(-4)   3(-3)      4(-2)      5(-1)  |
+        //                -------------------------------------------------
+        //                Reduction Dimensions
+        red_tv_rf->computeAt(red_tv, -1);
+
+        // After the Reduction Tensor has rFactoring applied
+        // Reduction Output Tensor:
+        //      [Outputs, X-Grid, X-Block, X-Warp]
+        // Idx:     0      1(-3)   2(-2)    3(-1)
+
+        if (!outs_of_red.empty()) {
+          red_tv->computeAt(outs_of_red[0], -1);
+        }
+
+        red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
+
+        red_tv->axis(0)->parallelize(ParallelType::BIDx);
+        for (auto iter_tv : outs_of_red) {
+          iter_tv->axis(0)->parallelize(ParallelType::BIDx);
+        }
+        red_tv->axis(-1)->parallelize(ParallelType::TIDx);
+        red_tv->axis(-2)->parallelize(ParallelType::TIDy);
+        red_tv->axis(-3)->parallelize(ParallelType::BIDy);
+
+        // Bind Inputs to Reduction
+        for (auto input : fusion->inputsOf(red_tv_rf)) {
+          if (input->getValType().value() == ValType::TensorView) {
+            input->as<TensorView>()->computeAt(red_tv_rf, -1);
+          }
+        }
+      } else {
+        // Reduction Splits
+        //      [outputs, |rF-Leftover, X-Block, X-Warp, rf-Unroll|]
+        // Idx:     0     |   1(-4)       2(-3)   3(-2)     4(-1) |
+        //                -----------------------------------------
+        //                Reduction Dimensions
+        red_tv->split(1, rparams.loop_unroll);
+        red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
+        red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
+
+        auto red_tv_rf = red_tv->rFactor({-4, -1});
+
+        // WARNING: computeAt will coalesce the rFactored dimensions
+        // rFactored Reduction Tensor after computeAt():
+        //      [Outputs, |X-Block, X-Warp, rF-Leftover, rF-Unroll|]
+        // Idx:     0     | 1(-4)   2(-3)      3(-2)       4(-1)  |
+        //                -----------------------------------------
+        //                Reduction Dimensions
+        red_tv_rf->computeAt(red_tv, -1);
+
+        // After the Reduction Tensor has rFactoring applied
+        // Reduction Output Tensor:
+        //      [Outputs, X-Block, X-Warp]
+        // Idx:     0      1(-2)    2(-1)
+
+        if (!outs_of_red.empty()) {
+          red_tv->computeAt(outs_of_red[0], -1);
+        }
+
+        red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
+
+        red_tv->axis(0)->parallelize(ParallelType::BIDx);
+        for (auto iter_tv : outs_of_red) {
+          iter_tv->axis(0)->parallelize(ParallelType::BIDx);
+        }
+        red_tv->axis(-1)->parallelize(ParallelType::TIDx);
+        red_tv->axis(-2)->parallelize(ParallelType::TIDy);
+
+        // Bind Inputs to Reduction
+        for (auto input : fusion->inputsOf(red_tv_rf)) {
+          if (input->getValType().value() == ValType::TensorView) {
+            input->as<TensorView>()->computeAt(red_tv_rf, -1);
+          }
+        }
+      }
+    }
+  } else {
+    if (rparams.cross_block) {
+      if (rparams.cross_grid) {
+        // Reduction Splits
+        //      [outputs, |rF-Leftover, rf-Unroll, X-Grid, X-Block|]
+        // Idx:     0     |   1(-4)       2(-3)     3(-2)   4(-1) |
+        //                -----------------------------------------
+        //                Reduction Dimensions
+        red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
+        red_tv->split(1, NamedScalar::getParallelDim(ParallelType::BIDy));
+        red_tv->split(1, kLoopUnrollSplit);
+
+        // Reordering the Unroll dimension eases applying computeAt()
+        // for preceeding operations and the rFactored Tensor.
+        //                                 |--- Reordered ----|
+        //                                 V                  V
+        //      [outputs, |rF-Leftover, X-Block, X-Grid, rF-Unroll|]
+        // Idx:     0     |   1(-4)      2(-3)   3(-2)     4(-1)  |
+        //                -----------------------------------------
+        //                Reduction Dimensions
+        red_tv->reorder({{-1, -3}, {-3, -1}});
+
+        // Output Splits
+        //      [|Out-Leftover, Out-PerBlock|, <Reduction Dims>]
+        // Idx:  |     0             1      |   2(-4) -- 5(-1)
+        //       ----------------------------
+        //       Output Dimensions
+        red_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
+        for (auto iter_tv : outs_of_red) {
+          iter_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
+        }
+
+        auto red_tv_rf = red_tv->rFactor({-4, -1});
+
+        // WARNING: computeAt will coalesce the rFactored dimensions
+        // rFactored Reduction Tensor after computeAt():
+        //      [<output dims>, |X-Block, X-Grid, rF-Leftover, rF-Unroll|]
+        // Idx:      0 -- 1     | 2(-4)   3(-3)      4(-2)       5(-1)  |
+        //                      -----------------------------------------
+        //                      Reduction Dimensions
+        red_tv_rf->computeAt(red_tv, -1);
+
+        // After the Reduction Tensor has rFactoring applied
+        // Reduction Output Tensor:
+        //      [Out-Leftover, Out-PerBlock, X-Block, X-Grid]
+        // Idx:       0              1        2(-2)   3(-1)
+
+        if (!outs_of_red.empty()) {
+          red_tv->computeAt(outs_of_red[0], -1);
+        }
+
+        red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
+
+        red_tv->axis(0)->parallelize(ParallelType::BIDx);
+        for (auto iter_tv : outs_of_red) {
+          iter_tv->axis(0)->parallelize(ParallelType::BIDx);
+          iter_tv->axis(1)->parallelize(ParallelType::TIDx);
+        }
+
+        red_tv->axis(-3)->parallelize(ParallelType::TIDx);
+        red_tv->axis(-2)->parallelize(ParallelType::TIDy);
+        red_tv->axis(-1)->parallelize(ParallelType::BIDy);
+
+        // Bind Inputs to Reduction
+        for (auto input : fusion->inputsOf(red_tv_rf)) {
+          if (input->getValType().value() == ValType::TensorView) {
+            input->as<TensorView>()->computeAt(red_tv_rf, -1);
+          }
+        }
+      } else {
+        // Reduction Splits
+        //      [outputs, |rF-Leftover, rf-Unroll, X-Block|]
+        // Idx:     0     |   1(-3)       2(-2)     3(-1) |
+        //                ---------------------------------
+        //                Reduction Dimensions
+        red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
+        red_tv->split(1, kLoopUnrollSplit);
+
+        // Reordering the Unroll dimension eases applying computeAt()
+        // for preceeding operations and the rFactored Tensor.
+        //                               |- Reordered -|
+        //                               V             V
+        //      [outputs, |rF-Leftover, X-Block, rF-Unroll|]
+        // Idx:     0     |   1(-3)      2(-2)     3(-1)  |
+        //                ---------------------------------
+        //                Reduction Dimensions
+        red_tv->reorder({{-1, -2}, {-2, -1}});
+
+        // Output Splits
+        //      [|Out-Leftover, Out-PerBlock|, <Reduction Dims>]
+        // Idx:  |     0             1      |   2(-3) -- 4(-1)
+        //       ----------------------------
+        //       Output Dimensions
+        red_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
+        for (auto iter_tv : outs_of_red) {
+          iter_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
+        }
+
+        auto red_tv_rf = red_tv->rFactor({-3, -1});
+
+        // WARNING: computeAt will coalesce the rFactored dimensions
+        // rFactored Reduction Tensor after computeAt():
+        //      [<output dims>, |X-Block, rF-Leftover, rF-Unroll|]
+        // Idx:      0 -- 1     | 2(-3)      3(-2)       4(-1)  |
+        //                      ---------------------------------
+        //                      Reduction Dimensions
+        red_tv_rf->computeAt(red_tv, -1);
+
+        // After the Reduction Tensor has rFactoring applied
+        // Reduction Output Tensor:
+        //      [Out-Leftover, Out-PerBlock, X-Block]
+        // Idx:       0              1        2(-1)
+
+        if (!outs_of_red.empty()) {
+          red_tv->computeAt(outs_of_red[0], -1);
+        }
+
+        red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
+
+        red_tv->axis(0)->parallelize(ParallelType::BIDx);
+        for (auto iter_tv : outs_of_red) {
+          iter_tv->axis(0)->parallelize(ParallelType::BIDx);
+          iter_tv->axis(1)->parallelize(ParallelType::TIDx);
+        }
+        red_tv->axis(-2)->parallelize(ParallelType::TIDx);
+        red_tv->axis(-1)->parallelize(ParallelType::TIDy);
+
+        // Bind Inputs to Reduction
+        for (auto input : fusion->inputsOf(red_tv_rf)) {
+          if (input->getValType().value() == ValType::TensorView) {
+            input->as<TensorView>()->computeAt(red_tv_rf, -1);
+          }
+        }
+      }
+    } else {
+      red_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
+      for (auto iter_tv : outs_of_red) {
+        iter_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
+      }
+
+      if (!outs_of_red.empty()) {
+        red_tv->computeAt(outs_of_red[0], -1);
+      }
+
+      red_tv->axis(0)->parallelize(ParallelType::BIDx);
+      red_tv->axis(1)->parallelize(ParallelType::TIDx);
+      for (auto iter_tv : outs_of_red) {
+        iter_tv->axis(0)->parallelize(ParallelType::BIDx);
+        iter_tv->axis(1)->parallelize(ParallelType::TIDx);
+      }
+
+      for (auto input : fusion->inputsOf(red_tv)) {
+        if (input->getValType().value() == ValType::TensorView) {
+          input->as<TensorView>()->computeAt(red_tv, -1);
+        }
+      }
+    }
+  }
+}
+
+} // namespace cuda
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler.h b/torch/csrc/jit/codegen/cuda/scheduler.h
new file mode 100644 (file)
index 0000000..5efb239
--- /dev/null
@@ -0,0 +1,73 @@
+#pragma once
+
+#include <ATen/core/ivalue.h>
+#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
+#include <torch/csrc/jit/codegen/cuda/fusion.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
+
+// return true or false on whether given fusion could be scheduled;
+TORCH_CUDA_CU_API bool scheduleFusion(
+    Fusion* fusion,
+    const at::ArrayRef<c10::IValue> inputs);
+
+// Parameters the Reduction Heuristic Generates to describe the optimial
+// schedule. Warning: equal operator is intended for use in caching the kernel
+// associated with these reduction parameteres. It does not check if the launch
+// parameters are equivelent!
+struct ReductionParams {
+  // Reducing inner most dimension?
+  bool fastest_dim = true;
+  // Reduce across the block?
+  bool cross_block = false;
+  // Reduce across the grid?
+  bool cross_grid = false;
+  // Perform multiple reductions per block?
+  bool mul_reds_per_blk = false;
+  // Unrolling factor
+  int loop_unroll = 4;
+
+  LaunchParams lparams;
+
+  // Warning: Does not check launch parameters!
+  bool operator==(const ReductionParams& other) const {
+    bool attr_equal = other.fastest_dim == fastest_dim &&
+        other.cross_block == cross_block && other.cross_grid == cross_grid &&
+        other.mul_reds_per_blk == mul_reds_per_blk &&
+        other.loop_unroll == loop_unroll;
+    return attr_equal;
+  }
+};
+
+// Warning: Hash is not based on launch parameters!
+class ReductionParamsHash {
+ public:
+  size_t operator()(const ReductionParams& rp) const {
+    constexpr size_t bits = sizeof(std::size_t) * 8;
+    size_t attr_hash = static_cast<size_t>(rp.fastest_dim) << (bits - 1) |
+        static_cast<size_t>(rp.cross_block) << (bits - 2) |
+        static_cast<size_t>(rp.cross_grid) << (bits - 3) |
+        static_cast<size_t>(rp.mul_reds_per_blk) << (bits - 4);
+    return attr_hash;
+  }
+};
+
+TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
+    Fusion* fusion,
+    const at::ArrayRef<c10::IValue>& fusion_inputs,
+    TensorView* red_tv);
+
+TORCH_CUDA_CU_API void scheduleReduction(
+    Fusion* fusion,
+    const ReductionParams& rparams,
+    TensorView* red_tv,
+    std::vector<TensorView*> outs_of_red);
+
+} // namespace cuda
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h b/torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h
deleted file mode 100644 (file)
index c7482c0..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-#pragma once
-#include <torch/csrc/jit/codegen/cuda/scheduler/normalization.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/reduction.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-enum class TORCH_CUDA_CU_API ScheduleHeuristic {
-  PointWise,
-  Reduction,
-  Normalization
-};
-
-}
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp
deleted file mode 100644 (file)
index 72f6b3e..0000000
+++ /dev/null
@@ -1,868 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/scheduler/reduction.h>
-
-#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-
-#include <ATen/cuda/CUDAContext.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-// Copied from reduction scheduler, should generalize. Simply needed to take out
-// grid reductions.
-ReductionParams innerNormalizationHeuristic(
-    const int64_t num_elems_in_reduction,
-    const int64_t num_outputs_for_reduction,
-    const int64_t n_tensor_inputs,
-    const int64_t max_input_dtype_size,
-    bool persistence_required,
-    const int64_t max_persistent_buffer_size,
-    size_t vectorize_factor) {
-  // Set some targets for parallelization
-  const int64_t n_elems = num_elems_in_reduction * num_outputs_for_reduction;
-
-  // WARNING: Current device for codegen may not be the target device
-  const int64_t device_max_threads_per_multiprocessor =
-      (int64_t)at::cuda::getCurrentDeviceProperties()
-          ->maxThreadsPerMultiProcessor;
-
-  const int64_t device_multiprocessor_count =
-      (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
-
-  auto const max_unroll = ceilDiv(
-      // Available unrolling based on size of data type
-      (int64_t)16 / (int64_t)max_input_dtype_size,
-      // Reduce unrolling if we have many inputs, start reduction at 4 inputs
-      std::max(
-          (scheduler_utils::lastPow2((int64_t)n_tensor_inputs - 1) >> 1),
-          (int64_t)1));
-
-  // Conservative value, could be set to larger based on arch if necessary.
-  constexpr int64_t l1_cache = 32 * 1024;
-  // Could change per generation, but for l1 we want to consider active threads,
-  // not resident
-  constexpr int64_t active_threads = 1024;
-
-  // if data fits in l2 and we need more parallelization in the reduction dim,
-  // we can use a smaller warp size. While thread local data fits in l1, and
-  // reduction dim is really small, we can use <32 threads per warp.
-  const bool fits_in_l2 = n_elems * max_input_dtype_size * n_tensor_inputs <
-      at::cuda::getCurrentDeviceProperties()->l2CacheSize;
-
-  // If it fits in l2, we just want to make sure each warp uses 32Bytes. Set
-  // minimum warp as 16 threads instead of 32 as if we have a small reduction
-  // dim going a bit smaller than 32 usually helps.
-  const int64_t warp_size_based_on_l2 =
-      fits_in_l2 ? (int64_t)32 / max_input_dtype_size : 16;
-
-  // Check how many elements it would take per thread to start thrashing l1
-  // set that to minimum number we want to reduce per thread.
-  const int64_t warp_size_based_on_l1 = std::min(
-      ceilDiv(
-          num_elems_in_reduction,
-          std::max(
-              l1_cache /
-                  (n_tensor_inputs * max_input_dtype_size * active_threads),
-              (int64_t)1)),
-      (int64_t)16);
-
-  const int64_t warp_size =
-      std::min(warp_size_based_on_l1, warp_size_based_on_l2);
-
-  // Initialization
-  int64_t target_blocks = 1;
-  int64_t target_unroll = 1;
-  int64_t target_iterations = 1;
-
-  // Try to set a minmum amount of work for each thread, as cross thread
-  // communication is slow so it shouldn't be done for every element in the
-  // reduction.
-  int64_t min_target_iterations =
-      std::max((int64_t)32 / (int64_t)max_input_dtype_size, (int64_t)1);
-
-  // Start trying to break parallelization up across threads,
-  // unrolling/iterations, and blocks.
-
-  // max_threads_in_block is the cap on a thread block, the minimum is based on
-  // warp_size
-  int64_t max_threads_in_block = std::max(
-      warp_size, ceilDiv(num_elems_in_reduction, min_target_iterations));
-
-  // If we have one warp per block, check if that's enough to saturate the SMs
-  target_blocks = ceilDiv(n_elems, warp_size);
-
-  // If we have more than a wave of blocks, put parallelism into unrolling and
-  // target iterations
-  if (target_blocks > device_multiprocessor_count) {
-    auto available_unroll = std::max(
-        n_elems / (warp_size * device_multiprocessor_count), (int64_t)1);
-
-    // Spread across unrolling and iterations, want a balance of the two so flip
-    // back and forth to alternate adding to them.
-    bool flip = true;
-
-    while (available_unroll > 1 &&
-           (target_unroll < max_unroll ||
-            // Prefer unrolling
-            target_iterations < ceilDiv(min_target_iterations, max_unroll))) {
-      if (target_unroll * 2 <= max_unroll && flip) {
-        target_unroll *= 2;
-      }
-
-      if (target_iterations * 2 <= ceilDiv(min_target_iterations, max_unroll) &&
-          !flip) {
-        target_iterations *= 2;
-      }
-
-      available_unroll = std::max(
-          n_elems /
-              (warp_size * device_multiprocessor_count * target_unroll *
-               target_iterations),
-          (int64_t)1);
-
-      flip = !flip;
-    }
-
-    // Recompute target blocks
-    target_blocks =
-        ceilDiv(n_elems, warp_size * target_unroll * target_iterations);
-  }
-
-  // Cap target blocks to 4 waves
-  target_blocks = std::min(target_blocks, device_multiprocessor_count * 4);
-
-  if (target_blocks * target_unroll * target_iterations < n_elems) {
-    // targetting 4 waves, so try to use a quarter of available threads
-    max_threads_in_block = std::min(
-        ceilDiv(n_elems, target_blocks * target_unroll),
-        ceilDiv(device_max_threads_per_multiprocessor, (int64_t)4));
-  }
-
-  // Compute maximum number of reductions we could do in the same kernel based
-  // on persistent buffer size
-  const int64_t max_multi_reduction_factor = std::max(
-      (persistence_required ? (scheduler_utils::register_file_size * 3) /
-               (max_persistent_buffer_size * 4)
-                            : std::numeric_limits<int64_t>::max()),
-      (int64_t)1);
-
-  // To get to target threads:
-  // Prioritize
-  // (1) x dim in reduction
-  // (2) unrolling in reduction
-  // (3) y in output
-  // To get target blocks:
-  // Prioritize
-  // (1) x dim in multiple outputs
-  // (2) y dim in multiple reductions
-
-  // Blocks for outputs
-  int64_t godim = 1;
-
-  // Threads for outputs
-  int64_t bdimy = 1;
-  // Threads for reduction
-  int64_t bdimx = 1;
-
-  // Should we unroll from reduction axis, or outs axis
-  bool unroll_reduction = true;
-
-  // Unroll amount
-  int64_t unroll_factor = 1;
-
-  // Grab what we can out of reduction domain, but don't go over a warp size yet
-  bdimx = std::min(num_elems_in_reduction, (int64_t)warp_size);
-
-  // Put everything else in bdimy for now
-  bdimy = std::min(
-      std::max(max_threads_in_block / bdimx, (int64_t)1),
-      max_multi_reduction_factor);
-
-  int64_t remainder_in_reduction = ceilDiv(num_elems_in_reduction, bdimx);
-  int64_t remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimy);
-
-  // Adjust blocking and setup unrolling
-  // Disable unrolling on iteration domain for persistent kernels for now.
-  // TODO: Re-enable.
-  if (remainder_in_reduction == 1 && !persistence_required) {
-    // Small number of reduction elements, try unrolling output dimension
-    unroll_factor = std::min(target_unroll, remainder_in_output);
-
-    if (unroll_factor > 1) {
-      unroll_reduction = false;
-      remainder_in_output =
-          ceilDiv(num_outputs_for_reduction, unroll_factor * bdimy);
-    }
-  } else {
-    // If there are reduction elements left after unrolling a warp, re-adjust
-    // the block dims to put more threads into the reduction
-    bdimx = std::min(
-        std::max(
-            ceilDiv(num_elems_in_reduction, target_iterations * target_unroll),
-            warp_size),
-        max_threads_in_block);
-
-    // Don't exceed target threads in a block.
-    bdimy = std::min(
-        std::max(max_threads_in_block / bdimx, (int64_t)1),
-        max_multi_reduction_factor);
-    remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimy);
-
-    remainder_in_reduction = ceilDiv(num_elems_in_reduction, bdimx);
-    unroll_factor = std::min(remainder_in_reduction, target_unroll);
-
-    // If there's no longer any space for unrolling the reduction dimension, try
-    // unrolling the iteration (output) dimension.
-    // Disable unrolling on iteration domain for persistent kernels for now.
-    // TODO: Re-enable.
-    if (unroll_factor == 1 && !persistence_required) {
-      // If we can't unroll reduction dim, unroll output dim
-      unroll_factor = std::min(remainder_in_output, target_unroll);
-      if (unroll_factor > 1) {
-        unroll_reduction = false;
-      }
-      remainder_in_output =
-          ceilDiv(num_outputs_for_reduction, bdimy * unroll_factor);
-      // Clang-tidy
-      //   remainder_in_reduction =
-      //       ceilDiv(num_elems_in_reduction, bdimx *
-      //       target_iterations);
-    }
-    // else {
-    //   remainder_in_reduction = ceilDiv(
-    //       num_elems_in_reduction,
-    //       bdimx * std::max(unroll_factor, target_iterations));
-    // }
-  }
-
-  godim = remainder_in_output;
-
-  bool vectorize = false;
-
-  // Move unrolling factor into vectorization upto vectorization limit.
-  if (vectorize_factor > 1 && unroll_factor > 1 && unroll_reduction) {
-    vectorize = true;
-    unroll_factor = std::min(
-        scheduler_utils::lastPow2(unroll_factor), (int64_t)vectorize_factor);
-  }
-
-  // Set size of persistent per thread buffer
-  int64_t batches_per_block = ceilDiv(
-      num_elems_in_reduction,
-      bdimx * (unroll_reduction ? unroll_factor : (int64_t)1));
-  // round up to multiple of 8 or pow2 whichever smaller
-  auto round_up_pow2 = scheduler_utils::lastPow2(batches_per_block);
-  if (round_up_pow2 < batches_per_block) {
-    round_up_pow2 *= 2;
-  }
-
-  constexpr int64_t kEight = 8; // clang tidy
-
-  auto round_up_8 = batches_per_block % kEight == 0
-      ? batches_per_block
-      : batches_per_block + (kEight - batches_per_block % kEight);
-
-  batches_per_block = std::min(round_up_8, round_up_pow2);
-
-  // Prefer putting iterations into unrolling over having a very large
-  // persistent buffer. Likely this should be more carefully adjusted to not
-  // blow out registers, but can revisit if we see any kernels with local memory
-  // use.
-  while (persistence_required && !vectorize && unroll_factor < max_unroll &&
-         batches_per_block % 2 == 0) {
-    batches_per_block /= 2;
-    unroll_factor *= 2;
-  }
-
-  ReductionParams rparams;
-  rparams.fastest_dim = true;
-  rparams.cross_block = true;
-  rparams.cross_grid = false;
-  rparams.multiple_reds_per_blk =
-      bdimy > 1 || (!unroll_reduction && unroll_factor);
-  rparams.loop_unroll = unroll_factor;
-  rparams.vectorize = vectorize;
-  rparams.reduction_unroll = unroll_reduction;
-  rparams.batches_per_block = batches_per_block;
-  rparams.persistent_kernel = persistence_required;
-
-  // Check if we need to split grid-x binding
-  rparams.split_grid_dim = godim > scheduler_utils::x_grid_limit;
-
-  rparams.lparams = LaunchParams(
-      LaunchParams::UNINITIALIZED_VAL,
-      LaunchParams::UNINITIALIZED_VAL,
-      LaunchParams::UNINITIALIZED_VAL,
-      persistence_required ? LaunchParams::UNINITIALIZED_VAL : bdimx,
-      bdimy,
-      LaunchParams::UNINITIALIZED_VAL);
-
-  rparams.tag = persistence_required ? "Inner normalization heuristic.\n"
-                                     : "Multi inner reduction (norm heuristic)";
-
-  if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
-    std::cerr << "\n===== Reduction Stats ========\n"
-              << "num_elems_in_reduction: " << num_elems_in_reduction << "\n"
-              << "num_outputs_for_reduction: " << num_outputs_for_reduction
-              << "\n"
-              << "n_tensor_inputs: " << n_tensor_inputs << "\n"
-              << "max_input_dtype_size: " << max_input_dtype_size << "\n"
-              << "persistence_required: " << persistence_required << "\n"
-              << "max_persistent_buffer_size: " << max_persistent_buffer_size
-              << std::endl;
-    std::cerr << rparams.toString() << std::endl;
-  }
-
-  return rparams;
-}
-
-// Copied from reduction scheduler, should generalize. Simply needed to take out
-// grid reductions.
-ReductionParams OuterNormalizationHeuristic(
-    const int64_t num_elems_in_reduction,
-    const int64_t num_outputs_for_reduction,
-    const int64_t n_tensor_inputs,
-    const int64_t max_input_dtype_size,
-    bool persistence_required,
-    const int64_t max_persistent_buffer_size,
-    size_t vectorize_factor) {
-  // Set some targets for parallelization
-  const int64_t n_elems = num_elems_in_reduction * num_outputs_for_reduction;
-
-  // WARNING: Current device for codegen may not be the target device
-  const int64_t device_max_threads_per_multiprocessor =
-      (int64_t)at::cuda::getCurrentDeviceProperties()
-          ->maxThreadsPerMultiProcessor;
-
-  const int64_t device_multiprocessor_count =
-      (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
-
-  auto const max_unroll = ceilDiv(
-      // Available unrolling based on size of data type
-      (int64_t)16 / (int64_t)max_input_dtype_size,
-      // Reduce unrolling if we have many inputs, start reduction at 4 inputs
-      std::max(
-          (scheduler_utils::lastPow2((int64_t)n_tensor_inputs - 1) >> 1),
-          (int64_t)1));
-
-  // If it fits in l2, we just want to make sure each warp uses 32Bytes. Set
-  // minimum warp as 16 threads instead of 32 as if we have a small reduction
-  // dim going a bit smaller than 32 usually helps.
-  const int64_t warp_size = n_elems * max_input_dtype_size * n_tensor_inputs <
-          at::cuda::getCurrentDeviceProperties()->l2CacheSize
-      ? (int64_t)32 / max_input_dtype_size
-      : 16;
-
-  // Initialization
-  int64_t target_blocks = 1;
-  int64_t target_unroll = 1;
-  int64_t max_threads_in_block = warp_size;
-
-  // If we have one warp per block, check if that's enough to saturate the SMs
-  target_blocks = ceilDiv(n_elems, (int64_t)warp_size);
-
-  // If we have more than a wave of blocks, put parallelism into unrolling
-  if (target_blocks > device_multiprocessor_count) {
-    target_unroll = std::min(
-        max_unroll, ceilDiv(target_blocks, device_multiprocessor_count));
-    target_blocks = ceilDiv(target_blocks, target_unroll);
-  }
-
-  // Cap target blocks to 4 waves
-  target_blocks = std::min(target_blocks, device_multiprocessor_count * 4);
-
-  if (target_blocks * target_unroll * max_threads_in_block < n_elems) {
-    // targetting 4 waves, so try to use a quarter of available threads
-    max_threads_in_block = std::min(
-        ceilDiv(n_elems, target_blocks * target_unroll),
-        ceilDiv(device_max_threads_per_multiprocessor, (int64_t)4));
-  }
-
-  // Compute maximum number of reductions we could do in the same kernel based
-  // on persistent buffer size
-
-  const int64_t max_multi_reduction_factor = std::max(
-      (persistence_required ? (scheduler_utils::register_file_size * 3) /
-               (max_persistent_buffer_size * 4)
-                            : std::numeric_limits<int64_t>::max()),
-      (int64_t)1);
-
-  // To get to target threads:
-  // Prioritize
-  // (1) x dim in iter domain
-  // (2) unrolling in iter domain
-  // (3) y in reduction domain
-  // To get target blocks:
-  // Prioritize
-  // (1) x dim in multiple outputs
-  // (2) y dim in multiple reductions - need to flip unrolling to reduction
-  // domain for this
-
-  // Blocks for outputs
-  // int64_t gdimx = 1; // unused at this time, comment for clang tidy
-
-  // Threads for reduction
-  int64_t bdimy = 1;
-  // Threads for output
-  int64_t bdimx = 1;
-
-  // Should we unroll from reduction axis, or outs axis
-  bool unroll_reduction = false;
-
-  // Unroll amount
-  int64_t unroll_factor = 1;
-
-  int64_t remainder_in_reduction = num_elems_in_reduction;
-  int64_t remainder_in_output = num_outputs_for_reduction;
-
-  if (ceilDiv(num_outputs_for_reduction, warp_size) <
-      device_multiprocessor_count) {
-    // If we can't hit a full wave, leave bdimx as warp_size, and prioritize
-    // bdimy.
-    bdimx = std::min(
-        std::min(num_outputs_for_reduction, warp_size),
-        max_multi_reduction_factor);
-  } else {
-    bdimx = std::min(
-        max_threads_in_block,
-        ceilDiv(num_outputs_for_reduction, target_blocks));
-    bdimx = std::min(std::max(bdimx, warp_size), max_multi_reduction_factor);
-  }
-
-  // Fill bdimy with left over threads
-  bdimy = std::min(
-      std::max(max_threads_in_block / bdimx, (int64_t)1),
-      num_elems_in_reduction);
-
-  // Clang tidy
-  // remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimx);
-  remainder_in_reduction = ceilDiv(remainder_in_reduction, bdimy);
-
-  if (num_outputs_for_reduction >=
-      device_multiprocessor_count * max_threads_in_block) {
-    // If we easily saturate the GPU, don't use block dim y and unroll output
-    // dimension TODO: this could be a more gentle transition starting earlier
-    bdimx = std::min(max_threads_in_block, max_multi_reduction_factor);
-    remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimx);
-
-    // TODO: This should probably still be based on max threads in a block
-    // especially if we're limited by max_multi_reduction_factor
-    bdimy = 1;
-    remainder_in_reduction = num_elems_in_reduction;
-
-    // Assume unroll in output, switch to remainder if cross grid
-    // Don't unroll if we don't have 2 full waves
-    //
-    // Disable unrolling on iteration domain for persistent kernels for now.
-    // TODO: Re-enable.
-    unroll_factor = persistence_required
-        ? 1
-        : std::min(
-              ceilDiv(remainder_in_output, device_multiprocessor_count * 2),
-              target_unroll);
-    if (unroll_factor == 1 && remainder_in_reduction > 1) {
-      // Try unrolling in reduction dimension
-      unroll_factor = std::min(remainder_in_reduction, unroll_factor);
-      // Clang tidy
-      // remainder_in_reduction = ceilDiv(remainder_in_reduction,
-      // unroll_factor);
-      if (unroll_factor > 1) {
-        unroll_reduction = true;
-      }
-    }
-    //  else {
-    // remainder_in_output =
-    //     ceilDiv(num_outputs_for_reduction, bdimx * unroll_factor);
-    // unused, comment for clang tidy
-    // }
-  } else {
-    // Not many output elements, try unrolling reduction dimension, would
-    // typically go cross grid, but can't for multi-reduction and normalization
-    // kernels.
-    // TODO: Enable cross reduction for multi-reduction cases
-    unroll_factor = std::min(max_unroll, remainder_in_reduction);
-    if (unroll_factor > 1) {
-      unroll_reduction = true;
-    }
-  }
-
-  if (unroll_factor == 1) {
-    unroll_reduction = true;
-  }
-
-  // Persistence size from buffers
-  int64_t batches_per_block = 1;
-  if (persistence_required) {
-    batches_per_block = ceilDiv(
-        num_elems_in_reduction,
-        bdimy * (unroll_reduction ? unroll_factor : (int64_t)1));
-    // round up to multiple of 8 or pow2 whichever smaller
-  }
-
-  auto round_up_pow2 = scheduler_utils::lastPow2(batches_per_block);
-  if (round_up_pow2 < batches_per_block) {
-    round_up_pow2 *= 2;
-  }
-
-  constexpr int64_t kEight = 8; // clang tidy
-
-  auto round_up_8 = batches_per_block % kEight == 0
-      ? batches_per_block
-      : batches_per_block + (kEight - batches_per_block % kEight);
-
-  batches_per_block = std::min(round_up_8, round_up_pow2);
-
-  bool vectorize = false;
-
-  if (vectorize_factor > 1 && unroll_factor > 1 && !unroll_reduction) {
-    vectorize = true;
-    unroll_factor = std::min(
-        scheduler_utils::lastPow2(unroll_factor), (int64_t)vectorize_factor);
-  }
-
-  ReductionParams rparams;
-  rparams.fastest_dim = false;
-  rparams.cross_block = bdimy > 1;
-  rparams.cross_grid = false;
-  rparams.multiple_reds_per_blk =
-      bdimx > 1 || (!unroll_reduction && unroll_factor);
-  rparams.loop_unroll = unroll_factor;
-  rparams.vectorize = vectorize;
-  rparams.reduction_unroll = unroll_reduction;
-  rparams.batches_per_block = batches_per_block;
-  rparams.persistent_kernel = persistence_required;
-
-  rparams.lparams = LaunchParams(
-      LaunchParams::UNINITIALIZED_VAL,
-      LaunchParams::UNINITIALIZED_VAL,
-      LaunchParams::UNINITIALIZED_VAL,
-      bdimx,
-      persistence_required ? LaunchParams::UNINITIALIZED_VAL : bdimy,
-      LaunchParams::UNINITIALIZED_VAL);
-
-  rparams.tag = persistence_required ? "Outer normalization heuristic.\n"
-                                     : "Multi outer reduction (norm heuristic)";
-
-  if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
-    std::cerr << "\n===== Reduction Stats ========\n"
-              << "num_elems_in_reduction: " << num_elems_in_reduction << "\n"
-              << "num_outputs_for_reduction: " << num_outputs_for_reduction
-              << "\n"
-              << "n_tensor_inputs: " << n_tensor_inputs << "\n"
-              << "max_input_dtype_size: " << max_input_dtype_size << "\n"
-              << "persistence_required: " << persistence_required << "\n"
-              << "max_persistent_buffer_size: " << max_persistent_buffer_size
-              << std::endl;
-    std::cerr << rparams.toString() << std::endl;
-  }
-
-  return rparams;
-}
-
-} // namespace
-
-ReductionParams NormalizationHeuristic(
-    int64_t num_elems_in_reduction,
-    int64_t num_outputs_for_reduction,
-    bool fastest_dim_reduction,
-    size_t n_tensor_inputs,
-    size_t max_input_dtype_size,
-    bool persistence_required,
-    const int64_t max_persistent_buffer_size,
-    size_t vectorize_factor) {
-  if (fastest_dim_reduction) {
-    return innerNormalizationHeuristic(
-        num_elems_in_reduction,
-        num_outputs_for_reduction,
-        n_tensor_inputs,
-        max_input_dtype_size,
-        persistence_required,
-        max_persistent_buffer_size,
-        vectorize_factor);
-  } else {
-    return OuterNormalizationHeuristic(
-        num_elems_in_reduction,
-        num_outputs_for_reduction,
-        n_tensor_inputs,
-        max_input_dtype_size,
-        persistence_required,
-        max_persistent_buffer_size,
-        vectorize_factor);
-  }
-}
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getNormalizationHeuristics(
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    HeuristicSummary* data_cache) {
-  FUSER_PERF_SCOPE("getNormalizationHeuristics");
-
-  FusionGuard fg(fusion);
-
-  HeuristicCacheAccessor<std::vector<TensorView*>> reduction_tv_data;
-  // TODO: move all these boilerplate code into the accessor class
-  // (follow up)
-  if (data_cache && !data_cache->isRecording()) {
-    reduction_tv_data.writeTemporary(data_cache->getReductionTVs());
-  } else {
-    reduction_tv_data.writeNew(scheduler_utils::getReductionTvs(fusion));
-    if (data_cache && data_cache->isRecording()) {
-      data_cache->setReductionTVs(reduction_tv_data.read());
-    }
-  }
-
-  auto& reduction_tvs = reduction_tv_data.read();
-
-  TORCH_INTERNAL_ASSERT(
-      !reduction_tvs.empty(), "Need reduction tensor views to schedule.");
-
-  auto first_red_tv = reduction_tvs[0];
-
-  TORCH_INTERNAL_ASSERT(
-      first_red_tv != nullptr, "Reduction TensorView wasn't found.");
-
-  TORCH_INTERNAL_ASSERT(
-      first_red_tv->hasReduction(), "TensorView doesn't have a reduction.");
-  const auto red_expr = first_red_tv->definition();
-
-  TORCH_INTERNAL_ASSERT(
-      red_expr->getExprType() != c10::nullopt &&
-          (red_expr->getExprType().value() == ExprType::ReductionOp ||
-           red_expr->getExprType().value() == ExprType::WelfordOp),
-      "TensorView doesn't have a reduction.");
-
-  size_t max_dtype_size = 1;
-  size_t n_tensor_inputs = 0;
-  for (auto inp : fusion->inputs()) {
-    if (inp->isA<TensorView>()) {
-      max_dtype_size =
-          std::max(max_dtype_size, dataTypeSize(inp->getDataType().value()));
-      n_tensor_inputs++;
-    }
-  }
-
-  TORCH_INTERNAL_ASSERT(
-      n_tensor_inputs > 0,
-      "Tried to schedule a fusion with no tensor inputs, currently not supported.");
-
-  HeuristicCacheAccessor<scheduler_utils::PersistentBufferInfo>
-      persistent_buffer_data;
-
-  // TODO: move all these boilerplate code into the accessor class
-  // (follow up)
-  if (data_cache && !data_cache->isRecording()) {
-    persistent_buffer_data.writeTemporary(
-        data_cache->getPersistentBufferInfo());
-  } else {
-    persistent_buffer_data.writeNew(scheduler_utils::persistentBuffers(fusion));
-    if (data_cache && data_cache->isRecording()) {
-      data_cache->setPersistentBufferInfo(persistent_buffer_data.read());
-    }
-  }
-
-  auto& persistent_buffers = persistent_buffer_data.read();
-  bool requires_persistence = !persistent_buffers.buffers.empty();
-
-  auto properties =
-      scheduler_utils::getProperties(fusion, runtime_info, first_red_tv);
-
-  auto max_persistent_size = scheduler_utils::persistentBufferSize(
-      fusion, runtime_info, persistent_buffers, data_cache);
-
-  HeuristicCacheAccessor<std::vector<TensorView*>>
-      vectorizable_inputs_outputs_data;
-
-  // TODO: move all these boilerplate code into the accessor class
-  // (follow up)
-  if (data_cache && !data_cache->isRecording()) {
-    vectorizable_inputs_outputs_data.writeTemporary(
-        data_cache->getVectorizableInputsOutputs());
-  } else {
-    vectorizable_inputs_outputs_data.writeNew(
-        scheduler_utils::getVectorizableInputsOutputs(first_red_tv));
-    if (data_cache && data_cache->isRecording()) {
-      data_cache->setVectorizableInputsOutputs(
-          vectorizable_inputs_outputs_data.read());
-    }
-  }
-
-  auto& vectorizable_inputs_outputs = vectorizable_inputs_outputs_data.read();
-
-  // Vectorize as much as we can
-  size_t vectorize_factor = std::numeric_limits<size_t>::max();
-
-  for (auto tv : vectorizable_inputs_outputs) {
-    const auto tv_vectorize_factor = runtime_info.getVectorizableWidth(tv);
-    vectorize_factor = std::min(vectorize_factor, tv_vectorize_factor);
-  }
-
-  if (vectorize_factor == std::numeric_limits<size_t>::max()) {
-    vectorize_factor = 1;
-  }
-
-  return NormalizationHeuristic(
-      properties.reduction_numel,
-      properties.iteration_numel,
-      properties.fastest_dim_reduction,
-      n_tensor_inputs,
-      max_dtype_size,
-      requires_persistence,
-      max_persistent_size,
-      vectorize_factor);
-}
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getNormalizationHeuristics(
-    Fusion* fusion,
-    const at::ArrayRef<c10::IValue>& runtime_inputs,
-    HeuristicSummary* data_cache) {
-  FUSER_PERF_SCOPE("getNormalizationHeuristicsFromIValue");
-  SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true);
-  return getNormalizationHeuristics(fusion, runtime_info, data_cache);
-}
-
-namespace {
-
-void schedulePersistentNormalization(
-    Fusion* fusion,
-    const ReductionParams& rparams) {
-  FUSER_PERF_SCOPE("schedulePersistentNormalization");
-  FusionGuard fg(fusion);
-  // Cache tensors before grabbing any references to reductions as cache_before
-  // can invalidate the references since when applied to a reduction tensor view
-  // the new tensor view contains the reduction and original doesn't.
-
-  // Cache inputs if unrolled
-  auto cached_inputs =
-      scheduler_utils::cacheInputs(fusion, rparams.loop_unroll > 1);
-
-  // Cache and fork  outputs
-  std::vector<std::pair<TensorView*, TensorView*>> cached_outputs =
-      scheduler_utils::cacheAndForkOutputs(fusion, rparams.loop_unroll > 1);
-
-  // Make sure we don't have global memory set on intermediate tensors from
-  // fusion segmentation
-  scheduler_utils::clearMemorySpace(fusion);
-
-  auto reduction_tvs = scheduler_utils::getReductionTvs(fusion);
-
-  TORCH_INTERNAL_ASSERT(reduction_tvs.size());
-  auto reduction_tv = reduction_tvs[0];
-
-  auto dim_analysis =
-      scheduler_utils::canonicalDimReduction(fusion, reduction_tv);
-  bool has_iter_axis = dim_analysis.first;
-  bool has_red_axis = dim_analysis.second;
-
-  TORCH_INTERNAL_ASSERT(
-      has_red_axis,
-      "Could not find reduction axis in tensor used for reduction scheduler.");
-
-  if (!has_iter_axis) {
-    TORCH_INTERNAL_ASSERT(
-        rparams.fastest_dim,
-        "If all dims are reduction, should be sending it to fastest dim scheduler.");
-  }
-
-  TensorView* reference_tv = scheduler_utils::scheduleReductionTV(
-      rparams, reduction_tv, has_iter_axis);
-
-  // Reduction tensor views and rfactor tensor views are setup. Let's finish off
-  // the scheduling, particularly inlining and unrolling.
-  TORCH_INTERNAL_ASSERT(
-      reference_tv != nullptr && reduction_tv != nullptr,
-      "Need these two tensor views to finish the scheduling.");
-
-  scheduler_utils::multiReductionInliner(
-      fusion,
-      rparams,
-      reduction_tv,
-      reference_tv,
-      reduction_tvs,
-      cached_inputs,
-      cached_outputs);
-}
-
-void scheduleMultiReduction(Fusion* fusion, const ReductionParams& rparams) {
-  FUSER_PERF_SCOPE("scheduleMultiReduction");
-  FusionGuard fg(fusion);
-  // Cache tensors before grabbing any references to reductions as cache_before
-  // can invalidate the references since when applied to a reduction tensor view
-  // the new tensor view contains the reduction and original doesn't.
-
-  // Cache inputs if unrolled
-  auto cached_inputs =
-      scheduler_utils::cacheInputs(fusion, rparams.loop_unroll > 1);
-
-  // Cache and fork  outputs
-  std::vector<std::pair<TensorView*, TensorView*>> cached_outputs =
-      scheduler_utils::cacheAndForkOutputs(fusion, rparams.loop_unroll > 1);
-
-  // Make sure we don't have global memory set on intermediate tensors from
-  // fusion segmentation
-  scheduler_utils::clearMemorySpace(fusion);
-
-  auto reduction_tvs = scheduler_utils::getReductionTvs(fusion);
-
-  TORCH_INTERNAL_ASSERT(reduction_tvs.size());
-  auto reduction_tv = reduction_tvs[0];
-
-  auto dim_analysis =
-      scheduler_utils::canonicalDimReduction(fusion, reduction_tv);
-  bool has_iter_axis = dim_analysis.first;
-  bool has_red_axis = dim_analysis.second;
-
-  TORCH_INTERNAL_ASSERT(
-      has_red_axis,
-      "Could not find reduction axis in tensor used for reduction scheduler.");
-
-  if (!has_iter_axis) {
-    TORCH_INTERNAL_ASSERT(
-        rparams.fastest_dim,
-        "If all dims are reduction, should be sending it to fastest dim scheduler.");
-  }
-
-  TensorView* reference_tv = scheduler_utils::scheduleReductionTV(
-      rparams, reduction_tv, has_iter_axis);
-
-  // Reduction tensor views and rfactor tensor views are setup. Let's finish off
-  // the scheduling, particularly inlining and unrolling.
-  TORCH_INTERNAL_ASSERT(
-      reference_tv != nullptr && reduction_tv != nullptr,
-      "Need these two tensor views to finish the scheduling.");
-
-  scheduler_utils::multiReductionInliner(
-      fusion,
-      rparams,
-      reduction_tv,
-      reference_tv,
-      reduction_tvs,
-      cached_inputs,
-      cached_outputs);
-}
-} // namespace
-
-// fusion is the input IR that will be modified by this function
-TORCH_CUDA_CU_API void scheduleNormalization(
-    Fusion* fusion,
-    const ReductionParams& rparams) {
-  if (rparams.persistent_kernel) {
-    schedulePersistentNormalization(fusion, rparams);
-  } else {
-    scheduleMultiReduction(fusion, rparams);
-  }
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler/normalization.h b/torch/csrc/jit/codegen/cuda/scheduler/normalization.h
deleted file mode 100644 (file)
index 290cb1b..0000000
+++ /dev/null
@@ -1,38 +0,0 @@
-#pragma once
-
-#include <ATen/core/ivalue.h>
-
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h>
-
-// TODO: If caching inputs would require persistence we are sending it to the
-// persistent kerenl scheduler. This isn't necessary if the only persistent
-// buffers are inputs as we could re-read them from global memory. Need to
-// consider if this is worth implementing.
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class SchedulerRuntimeInfo;
-class HeuristicSummary;
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getNormalizationHeuristics(
-    Fusion* fusion,
-    const at::ArrayRef<c10::IValue>& runtime_inputs,
-    HeuristicSummary* data_cache = nullptr);
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getNormalizationHeuristics(
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    HeuristicSummary* data_cache = nullptr);
-
-TORCH_CUDA_CU_API void scheduleNormalization(
-    Fusion* fusion,
-    const ReductionParams& rparams);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
deleted file mode 100644 (file)
index 054833f..0000000
+++ /dev/null
@@ -1,730 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise.h>
-
-#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-#include <torch/csrc/jit/codegen/cuda/utils.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-
-#include <ATen/cuda/CUDAContext.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-// constexpr int64_t x_grid_limit = ((int64_t)1 << (int64_t)31) - (int64_t)1;
-// Unused at the moment, commenting for clang tidy
-constexpr int64_t kThreadX = 128;
-} // namespace
-
-c10::optional<PointwiseParams> getPointwiseHeuristics(
-    Fusion* fusion,
-    const at::ArrayRef<c10::IValue>& runtime_inputs,
-    HeuristicSummary* data_cache) {
-  SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true);
-  return getPointwiseHeuristics(fusion, runtime_info, data_cache);
-}
-
-c10::optional<PointwiseParams> getPointwiseHeuristics(
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    HeuristicSummary* data_cache) {
-  FUSER_PERF_SCOPE("getPointwiseHeuristics");
-
-  FusionGuard fg(fusion);
-  TensorView* largest_out = nullptr;
-  int max_dims = -1;
-
-  auto in_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
-  auto out_tvs_it = ir_utils::filterByType<TensorView>(fusion->outputs());
-  // Will want to access this with direct indexing later, convert now.
-  std::vector<TensorView*> out_tvs(out_tvs_it.begin(), out_tvs_it.end());
-
-  for (auto out_tv : out_tvs) {
-    int n_dims = 0;
-    for (auto id : out_tv->getMaybeRFactorDomain()) {
-      if (id->isReduction() || id->isBroadcast()) {
-        continue;
-      }
-      n_dims++;
-    }
-    if (n_dims > max_dims) {
-      largest_out = out_tv;
-      max_dims = n_dims;
-    }
-  }
-
-  TORCH_INTERNAL_ASSERT(largest_out != nullptr);
-
-  // If zero dimensional, return default parameters
-  if (TensorDomain::noReductions(
-          TensorDomain::noBroadcasts(largest_out->domain()->domain()))
-          .size() == 0) {
-    if (data_cache && data_cache->isRecording()) {
-      data_cache->setVectorizableInputsOutputs(std::vector<TensorView*>());
-      data_cache->setMappedInputOutputDims(std::vector<int64_t>());
-    }
-    return PointwiseParams();
-  }
-
-  auto ref_root = largest_out->getMaybeRFactorDomain();
-
-  std::vector<int64_t> elem_counts(ref_root.size(), 1);
-  int64_t n_elems = 1;
-  for (size_t ref_i = 0; ref_i < ref_root.size(); ref_i++) {
-    auto inferred_val =
-        runtime_info.expressionEvaluator().evaluate(ref_root[ref_i]->extent());
-    TORCH_INTERNAL_ASSERT(
-        inferred_val.has_value(),
-        "Error inferring size for pointwise scheduler.");
-    elem_counts[ref_i] = inferred_val.value();
-    n_elems *= inferred_val.value();
-  }
-
-  const int64_t device_multiprocessor_count =
-      (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
-
-  // TODO: Set to 1?
-  int64_t max_input_dtype_size = 2;
-  size_t n_tensors = 0;
-
-  for (auto inp : in_tvs) {
-    max_input_dtype_size = std::max(
-        max_input_dtype_size,
-        (int64_t)dataTypeSize(inp->getDataType().value()));
-    n_tensors++;
-  }
-  n_tensors += std::distance(out_tvs.begin(), out_tvs.end());
-
-  constexpr int64_t kSixteen = 16; // clang tidy
-
-  auto max_unroll_factor = ceilDiv(
-      // Available unrolling based on size of data type
-      (int64_t)kSixteen / max_input_dtype_size,
-      // Reduce unrolling if we have many inputs, start reduction at 4 inputs
-      std::max(
-          (scheduler_utils::lastPow2((int64_t)n_tensors) >> 2), (int64_t)1));
-
-  // Don't unroll at the cost of getting a full wave on the GPU
-  if (n_elems < device_multiprocessor_count * kThreadX &&
-      max_unroll_factor > 1) {
-    max_unroll_factor = std::min(
-        max_unroll_factor,
-        ceilDiv(n_elems, device_multiprocessor_count * kThreadX));
-  }
-
-  // If we use RNG don't unroll so we can do correctness testing
-  if (fusion->isStochastic() && disableRNGUnrolling()) {
-    max_unroll_factor = 1;
-  }
-
-  PointwiseParams params;
-  params.tag = "Pointwise heuristics";
-
-  // Don't try to vectorize if it's not recommended
-  params.inner_factor = 1;
-
-  // Vectorize as much as we can
-  size_t vectorize_factor = max_unroll_factor;
-
-  HeuristicCacheAccessor<std::vector<TensorView*>>
-      vectorizable_inputs_outputs_data;
-
-  // TODO: move all these boilerplate code into the accessor class
-  // (follow up)
-  if (data_cache && !data_cache->isRecording()) {
-    vectorizable_inputs_outputs_data.writeTemporary(
-        data_cache->getVectorizableInputsOutputs());
-  } else {
-    vectorizable_inputs_outputs_data.writeNew(
-        scheduler_utils::getVectorizableInputsOutputs(largest_out));
-    if (data_cache && data_cache->isRecording()) {
-      data_cache->setVectorizableInputsOutputs(
-          vectorizable_inputs_outputs_data.read());
-    }
-  }
-
-  auto& vectorizable_inputs_outputs = vectorizable_inputs_outputs_data.read();
-
-  for (auto tv : vectorizable_inputs_outputs) {
-    const auto tv_vectorize_factor = runtime_info.getVectorizableWidth(tv);
-    vectorize_factor = std::min(vectorize_factor, tv_vectorize_factor);
-  }
-
-  if (vectorize_factor == 1) {
-    params.vectorize = false;
-    params.inner_factor = max_unroll_factor;
-  } else {
-    params.vectorize = true;
-    params.inner_factor = vectorize_factor;
-  }
-  /*
-   * 2D pointwise scheduling logic. What is expected is there's some
-   * broadcasting pattern which would make scheduling as a 2D problem more
-   * efficient than scheduling simply as a 1D problem.
-   *
-   * Mapping count holds how many bytes are in each dimension for both inputs
-   * and outputs relative to the reference tensor. What we're looking for is a
-   * break point in reference_tvs dimensions which separates the outer dimension
-   * and inner dimension of the problem mapped to 2D.
-   *
-   * break_point is computed assuming no reuse, ignoring parallelization
-   * limitations, and simply figures out which point best separates broadcasted
-   * dimensions. In other words, where's the point where we isolate the most
-   * broadcasted elements to one side.
-   *
-   * Once a break point is found, simply schedule the pointwise op as 2D
-   * balancing parallelization as best as possible.
-   */
-
-  // Ideal break point location
-  int64_t break_point = 0;
-
-  // Elements on the right of break point (without break point all are on the
-  // right)
-  int64_t right_elem_count = 0;
-
-  int64_t bdimx = kThreadX;
-
-  // bdimy may be used if the right side of the break point is not large and we
-  // need to expand block level parallelism into the left side of the break
-  // point.
-  int64_t bdimy = 1;
-
-  // In 2D scheduler gdimx is used to parallelize the left side of the break
-  // point.
-  int64_t gdimx = 1;
-
-  // gdimy is used if there's too much parallelization in the right side of the
-  // break point. We will expand grid parallelization into the right side of the
-  // break point with gdimx and use gdimy for the left side of the break point.
-  int64_t gdimy = 1;
-
-  HeuristicCacheAccessor<std::vector<int64_t>> mapping_count_accessor;
-  // TODO: move all these boilerplate code into the accessor class
-  // (follow up)
-  if (data_cache && !data_cache->isRecording()) {
-    mapping_count_accessor.writeTemporary(
-        data_cache->getMappedInputOutputDims());
-  } else {
-    mapping_count_accessor.writeNew(
-        scheduler_utils::mappedInputsOutputs(largest_out));
-    if (data_cache && data_cache->isRecording()) {
-      data_cache->setMappedInputOutputDims(mapping_count_accessor.read());
-    }
-  }
-
-  auto mapping_count = mapping_count_accessor.read();
-
-  {
-    // How much would this transfer cost if it was done as a 1-D schedule
-    int64_t transfer_size_1d = 1;
-
-    auto max_dims =
-        std::max_element(mapping_count.begin(), mapping_count.end());
-
-    for (int64_t i = 0; i < (int64_t)ref_root.size(); i++) {
-      transfer_size_1d = transfer_size_1d * elem_counts[i] * (*max_dims);
-    }
-
-    // If there isn't very much parallelism available, just use 1D scheduler
-    if (true || n_elems * 2 > device_multiprocessor_count * kThreadX) {
-      int64_t min_total_transfer = std::numeric_limits<int64_t>::max();
-
-      for (int64_t break_point_i = 0; break_point_i < (int64_t)ref_root.size();
-           break_point_i++) {
-        // Number of elements in the right side of reference tv with
-        // break_point_i
-        int64_t cur_right_elem_count = 1;
-        for (int64_t right_i = break_point_i;
-             right_i < (int64_t)ref_root.size();
-             right_i++) {
-          cur_right_elem_count = cur_right_elem_count * elem_counts[right_i];
-        }
-
-        if (cur_right_elem_count <= 1) {
-          continue;
-        }
-
-        auto cur_left_elem_count = n_elems / cur_right_elem_count;
-        if (cur_left_elem_count <= 1) {
-          continue;
-        }
-
-        auto left_max_dims = std::max_element(
-            mapping_count.begin(), mapping_count.begin() + break_point_i);
-
-        auto right_max_dims = std::max_element(
-            mapping_count.begin() + break_point_i, mapping_count.end());
-
-        // Estimate transfer cost with this break point
-        int64_t cur_transfer_size = 1;
-
-        for (int64_t left_i = 0; left_i < break_point_i; left_i++) {
-          cur_transfer_size =
-              cur_transfer_size * elem_counts[left_i] * (*left_max_dims);
-        }
-
-        for (int64_t right_i = break_point_i;
-             right_i < (int64_t)ref_root.size();
-             right_i++) {
-          cur_transfer_size =
-              cur_transfer_size * elem_counts[right_i] * (*right_max_dims);
-        }
-
-        //  Continue if this break point doesn't save at least 10% of 1D
-        //  scheduling.
-        if (cur_transfer_size >= min_total_transfer ||
-            cur_transfer_size * 10 >= transfer_size_1d * 9) {
-          continue;
-        }
-
-        // Don't limit unroll factor with break point
-        if (cur_right_elem_count < max_unroll_factor) {
-          continue;
-        }
-
-        bdimx = std::min(
-            ceilDiv(cur_right_elem_count, max_unroll_factor), kThreadX);
-        bdimy = 1;
-        gdimy = 1;
-        // Put remainder in bdimy if there's at least a wave of grid level
-        // parallelism.
-        if (cur_left_elem_count > device_multiprocessor_count) {
-          bdimy = kThreadX / bdimx;
-        }
-        auto remainder_left = ceilDiv(cur_left_elem_count, bdimy);
-        auto remainder_right =
-            ceilDiv(cur_right_elem_count, bdimy * bdimx * max_unroll_factor);
-
-        // Use this break point
-        break_point = break_point_i;
-        min_total_transfer = cur_transfer_size;
-        right_elem_count = cur_right_elem_count;
-
-        gdimx = remainder_left;
-        if (remainder_right > 1 && bdimy <= 1) {
-          gdimy = remainder_right;
-        }
-      }
-    }
-  }
-
-  TORCH_INTERNAL_ASSERT(right_elem_count > 0 || params.break_point == 0);
-
-  TORCH_INTERNAL_ASSERT(!(bdimy > 1 && gdimy > 1));
-  params.break_point = break_point;
-  params.split_block = bdimy > 1;
-
-  params.lparams.bind(bdimx, ParallelType::TIDx);
-  if (params.split_block) {
-    params.lparams.bind(bdimy, ParallelType::TIDy);
-  }
-  if (gdimy > 65535) {
-    params.split_grid_y_dim = true;
-  }
-
-  if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
-    std::cerr << "\n===== Pointwise Stats ========\n"
-              << "num_elems: " << n_elems << "\n"
-              << "mapping_count: " << mapping_count << "\n"
-              << "elem_counts: " << elem_counts << "\n"
-              << "n_tensor_inputs: " << n_tensors << "\n"
-              << "max_input_dtype_size: " << max_input_dtype_size << "\n"
-              << "vectorize_factor: " << vectorize_factor << std::endl;
-    std::cerr << params.toString() << std::endl;
-  }
-
-  return params;
-}
-
-// TODO: remove or return launch parameters
-LaunchParams schedulePointwise(
-    Fusion* fusion,
-    const at::ArrayRef<c10::IValue>& runtime_inputs) {
-  FUSER_PERF_SCOPE("scheduleFusion");
-  auto params = getPointwiseHeuristics(fusion, runtime_inputs);
-  TORCH_INTERNAL_ASSERT(
-      params.has_value(), "Could not schedule pointwise operation.");
-  schedulePointwise(fusion, params.value());
-  return params.value().lparams;
-}
-
-namespace {
-// Returns number of non-reduction/non-broadcast dims in rfactor domain
-size_t nRootDims(const TensorView* tv) {
-  auto root_dom = tv->getMaybeRFactorDomain();
-  size_t tv_n_dims = 0;
-  for (auto dim : root_dom) {
-    if (!dim->isReduction() && !dim->isBroadcast()) {
-      tv_n_dims++;
-    }
-  }
-  return tv_n_dims;
-}
-} // namespace
-
-// TODO: Inline intermediate operations (avoid inlining unrolled/vectorized
-// input/output caches)
-void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
-  FusionGuard fg(fusion);
-  // fusion->printMath();
-  // Make sure we don't have global memory set on intermediate tensors from
-  // fusion segmentation
-  scheduler_utils::clearMemorySpace(fusion);
-
-  // maybe has_reduction for scheduling should be done on a per output tensor
-  // basis.
-  TORCH_INTERNAL_ASSERT(
-      !fusion->hasReduction(), "This scheduler only handles pointwise ops.");
-
-  // For intermediate outputs, apply cache_fork
-  auto outs = fusion->outputs();
-  for (const auto output : outs) {
-    if (!output->uses().empty() && output->definition() != nullptr) {
-      if (output->getValType().value() == ValType::TensorView) {
-        output->as<TensorView>()->cache_fork();
-      }
-    }
-  }
-
-  std::vector<TensorView*> input_tvs;
-  {
-    auto filtered_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
-    // Remove hanging tensor views
-    for (auto tv : filtered_tvs) {
-      if (tv->uses().empty()) {
-        continue;
-      }
-      input_tvs.push_back(tv);
-    }
-  }
-  auto output_tvs = ir_utils::filterByType<TensorView>(fusion->outputs());
-
-  size_t max_dims = 0;
-  for (auto inp : input_tvs) {
-    max_dims = std::max(nRootDims(inp), max_dims);
-  }
-
-  for (auto out : output_tvs) {
-    max_dims = std::max(nRootDims(out), max_dims);
-  }
-
-  // If everything is zero dim tensors, just return.
-  if (max_dims == 0) {
-    return;
-  }
-
-  TensorView* reference_tv = nullptr;
-  for (auto out : output_tvs) {
-    if (out->definition() == nullptr) {
-      continue;
-    }
-    if (nRootDims(out) == max_dims) {
-      reference_tv = out;
-      break;
-    }
-  }
-
-  TORCH_INTERNAL_ASSERT(
-      reference_tv != nullptr,
-      "Could not find a fully broadcasted output to reference schedule on.");
-
-  IterDomain* inner_most_id = nullptr;
-  for (auto it = reference_tv->domain()->domain().rbegin();
-       it != reference_tv->domain()->domain().rend();
-       it++) {
-    if ((*it)->isReduction()) {
-      continue;
-    }
-    if ((*it)->isBroadcast() && inner_most_id == nullptr) {
-      inner_most_id = *it;
-    }
-    inner_most_id = *it;
-    break;
-  }
-
-  TORCH_INTERNAL_ASSERT(inner_most_id != nullptr);
-  auto vectorizable_dims =
-      scheduler_utils::FindAllMappedDims::from(reference_tv, inner_most_id);
-
-  // Caches of inputs
-  std::vector<TensorView*> cached_inputs;
-
-  // Output, cache_before of output
-  std::vector<std::pair<TensorView*, TensorView*>> cached_outputs;
-
-  // Track what should be vectorized versus unrolled
-  std::unordered_set<TensorView*> vectorized_tensor;
-
-  // Figure out which inputs to cache for unrolling or vectorization
-  for (auto inp : input_tvs) {
-    if (inp->uses().empty()) {
-      continue;
-    }
-    // Need to check before caching.
-    bool vectorize = params.vectorize &&
-        scheduler_utils::shouldVectorize(inp, vectorizable_dims);
-    cached_inputs.emplace_back(inp->cache_after());
-    if (vectorize) {
-      vectorized_tensor.emplace(cached_inputs.back());
-    }
-  }
-
-  // Figure out which outputs to cache for unrolling or vectorization
-  for (auto out : output_tvs) {
-    if (out->definition() == nullptr) {
-      continue;
-    }
-    // Need to check before caching.
-    bool vectorize = params.vectorize &&
-        scheduler_utils::shouldVectorize(out, vectorizable_dims);
-    cached_outputs.emplace_back(std::make_pair(out, out->cache_before()));
-    if (vectorize) {
-      vectorized_tensor.emplace(out);
-    }
-  }
-
-  auto all_tvs = ir_utils::allTvs(fusion);
-
-  // Merge right side of break point
-  int rhs_i = -1;
-  for (int i = (int)reference_tv->nDims(); i > (int)params.break_point; i--) {
-    auto axis_i = i - 1;
-    if (reference_tv->axis(axis_i)->isBroadcast() ||
-        reference_tv->axis(axis_i)->isReduction()) {
-      continue;
-    }
-    if (rhs_i == -1) {
-      rhs_i = axis_i;
-    } else {
-      reference_tv->merge(axis_i, rhs_i);
-      rhs_i = axis_i;
-    }
-  }
-  if (rhs_i >= 0) {
-    // If there's an rhs
-    reference_tv->reorder({{rhs_i, -1}});
-  }
-
-  // Merge left side of break point
-  int lhs_i = -1;
-  for (int i = (int)params.break_point; i > 0; i--) {
-    auto axis_i = i - 1;
-    if (reference_tv->axis(axis_i)->isBroadcast() ||
-        reference_tv->axis(axis_i)->isReduction()) {
-      continue;
-    }
-    if (lhs_i == -1) {
-      lhs_i = axis_i;
-    } else {
-      reference_tv->merge(axis_i, lhs_i);
-      lhs_i = axis_i;
-    }
-  }
-
-  // Right (inner merged) dimension is at inner most position, left (outer
-  // merged) dimension is at lhs_i. Order as [lhs_i, rhs_i, unmerged...]
-  reference_tv->reorder({{lhs_i, 0}, {-1, 1}});
-
-  if (params.break_point) {
-    // 2D parallelization scheme
-    TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i >= 0);
-
-    if (params.vectorize) {
-      reference_tv->split(1, params.inner_factor);
-      reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
-      reference_tv->split(0, 1);
-      // [outer, Unswitch | i-remainder, TIDx, Vectorization]
-      reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
-      reference_tv->axis(3)->parallelize(ParallelType::TIDx);
-
-      // Aggressively mark with vectorized and cleanup later. That way we
-      // don't have to manually specify parallelization outside the reference.
-      reference_tv->axis(4)->parallelize(ParallelType::Vectorize);
-
-      // [outer, Unswitch | i-remainder, TIDx, Vectorization]
-      // To make consistent with unrolling:
-      reference_tv->reorder({{1, 2}, {2, 1}, {3, 4}, {4, 3}});
-      //[outer | i-remainder, Unswitch, Vectorization, TIDx]
-    } else {
-      reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
-      reference_tv->split(1, params.inner_factor);
-
-      reference_tv->split(0, 1);
-      // [outer, unswitch | i-remainder, unroll, TIDx ]
-      reference_tv->reorder({{1, 2}});
-      // [outer, i-remainder, unswitch, unroll, TIDx ]
-      reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
-      reference_tv->axis(4)->parallelize(ParallelType::TIDx);
-
-      //[outer | i-remainder, Unswitch, Unroll, TIDx]
-    }
-
-    // Move out of the way to furthest left point
-    reference_tv->reorder({{1, 0}});
-
-    //[i-remainder | outer | Unswitch, Unroll, TIDx]
-    if (params.split_block) {
-      reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
-      // [i-remainder | BIDx TIDy | Unswitch, Unroll, TIDx]
-      reference_tv->axis(1)->parallelize(ParallelType::BIDx);
-      reference_tv->axis(2)->parallelize(ParallelType::TIDy);
-    } else {
-      // [BIDy | BIDx | Unswitch, Unroll, TIDx]
-      reference_tv->axis(1)->parallelize(ParallelType::BIDx);
-      if (params.split_grid_y_dim) {
-        reference_tv->split(0, 65535);
-        reference_tv->axis(1)->parallelize(ParallelType::BIDy);
-      } else {
-        reference_tv->axis(0)->parallelize(ParallelType::BIDy);
-      }
-    }
-
-  } else {
-    // 1D Scheduler
-    TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i == -1);
-    // right hand side exists and is the only axis we care to schedule, move it
-    // from the inner most position to left most.
-    reference_tv->reorder({{-1, 0}});
-
-    if (params.vectorize) {
-      // Vectorize
-      reference_tv->split(0, params.inner_factor);
-      // Unswitch
-      reference_tv->split(0, 1);
-      // Threads
-      reference_tv->split(0, kThreadX);
-
-      reference_tv->axis(0)->parallelize(ParallelType::BIDx);
-      reference_tv->axis(1)->parallelize(ParallelType::TIDx);
-      reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
-      // Aggressively mark with vectorized and cleanup later. That way we don't
-      // have to manually specify parallelization outside the reference.
-      reference_tv->axis(-1)->parallelize(ParallelType::Vectorize);
-
-      //[BIDx, TIDx, Unswitch, Vectorization]
-      // To make consistent with unrolling:
-      reference_tv->reorder({{1, 3}, {2, 1}, {3, 2}});
-      //[BIDx, Unswitch, Vectorization, TIDx]
-    } else {
-      // Threads
-      reference_tv->split(0, kThreadX);
-      // Unroll
-      reference_tv->split(0, params.inner_factor);
-      // Unswitch
-      reference_tv->split(0, 1);
-
-      // [BIDx, Unswitch, Unroll, TIDx]
-      reference_tv->axis(0)->parallelize(ParallelType::BIDx);
-      reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
-      reference_tv->axis(3)->parallelize(ParallelType::TIDx);
-    }
-  }
-  TransformPropagator::from(reference_tv);
-  scheduler_utils::parallelizeAllLike(reference_tv, all_tvs);
-
-  if (params.vectorize) {
-    // Clear vectorize on tensors that shouldn't have it
-    for (auto tv : all_tvs) {
-      if (!vectorized_tensor.count(tv)) {
-        for (auto id : tv->domain()->domain()) {
-          if (id->getParallelType() == ParallelType::Vectorize) {
-            id->parallelize(ParallelType::Serial);
-          }
-        }
-      }
-    }
-  }
-
-  // Compute at into cached inputs
-  std::vector<TensorView*> consumers_of_cached_inputs;
-  // Cache of input, and one of its consumers
-  std::vector<std::pair<TensorView*, TensorView*>> input_cache_and_consumer;
-  {
-    // Avoid duplicate additions, so track what we add
-    std::unordered_set<TensorView*> added;
-    for (auto cached_input : cached_inputs) {
-      auto consumer_tvs = ir_utils::consumerTvsOf(cached_input);
-      TORCH_INTERNAL_ASSERT(
-          consumer_tvs.size(),
-          "Input was not succesfully filtered out for scheduling but wasn't used.");
-
-      // Grab a consumer which will be used for computeAt structure of cached
-      // input into a consumer
-      input_cache_and_consumer.emplace_back(
-          std::make_pair(cached_input, consumer_tvs[0]));
-
-      // Grab all consumers which will be used for inlining computeAt for the
-      // body of the computation (excluding caching inputs/outputs)
-      for (auto consumer_tv : consumer_tvs) {
-        // Don't duplicate
-        if (added.insert(consumer_tv).second) {
-          consumers_of_cached_inputs.emplace_back(consumer_tv);
-        }
-      }
-    }
-  }
-
-  for (auto entry : input_cache_and_consumer) {
-    // Compute at inside unswitch position:
-    auto input_cache = entry.first;
-    auto input_cache_consumer = entry.second;
-
-    auto unswitch_it = std::find_if(
-        input_cache_consumer->domain()->domain().begin(),
-        input_cache_consumer->domain()->domain().end(),
-        [](IterDomain* id) {
-          return id->getParallelType() == ParallelType::Unswitch;
-        });
-    auto unswitch_pos =
-        unswitch_it == input_cache_consumer->domain()->domain().end()
-        ? -1
-        : std::distance(
-              input_cache_consumer->domain()->domain().begin(), unswitch_it) +
-            1;
-
-    input_cache->computeAt(
-        input_cache_consumer, unswitch_pos, ComputeAtMode::BestEffort);
-  }
-
-  // Producers for inlined computeAt
-  std::vector<TensorView*> compute_from = consumers_of_cached_inputs;
-
-  // Consumers for inlined computeAt
-  std::vector<TensorView*> compute_to;
-  // Compute at cached outputs
-  //[BIDx, Unswitch, Vectorization, TIDx]
-  for (auto entry : cached_outputs) {
-    auto cached_output = entry.second;
-    auto output = entry.first;
-
-    auto unswitch_it = std::find_if(
-        output->domain()->domain().begin(),
-        output->domain()->domain().end(),
-        [](IterDomain* id) {
-          return id->getParallelType() == ParallelType::Unswitch;
-        });
-    auto unswitch_pos = unswitch_it == output->domain()->domain().end()
-        ? -1
-        : std::distance(output->domain()->domain().begin(), unswitch_it) + 1;
-
-    cached_output->computeAt(output, unswitch_pos, ComputeAtMode::BestEffort);
-    compute_to.push_back(cached_output);
-  }
-
-  scheduler_utils::computeAtBetween(
-      compute_from, compute_to, -1, ComputeAtMode::BestEffort);
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h
deleted file mode 100644 (file)
index cb62655..0000000
+++ /dev/null
@@ -1,37 +0,0 @@
-#pragma once
-
-#include <ATen/core/ivalue.h>
-
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class SchedulerRuntimeInfo;
-class HeuristicSummary;
-
-TORCH_CUDA_CU_API c10::optional<PointwiseParams> getPointwiseHeuristics(
-    Fusion* fusion,
-    const at::ArrayRef<c10::IValue>& runtime_inputs,
-    HeuristicSummary* data_cache = nullptr);
-
-TORCH_CUDA_CU_API c10::optional<PointwiseParams> getPointwiseHeuristics(
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    HeuristicSummary* data_cache = nullptr);
-
-TORCH_CUDA_CU_API void schedulePointwise(
-    Fusion* fusion,
-    const PointwiseParams& params);
-
-TORCH_CUDA_CU_API LaunchParams schedulePointwise(
-    Fusion* fusion,
-    const at::ArrayRef<c10::IValue>& runtime_inputs);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h b/torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h
deleted file mode 100644 (file)
index dc5d9db..0000000
+++ /dev/null
@@ -1,93 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
-
-#include <sstream>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-// Parameters the Reduction Heuristic Generates to describe the optimial
-// schedule. Warning: equal operator is intended for use in caching the kernel
-// associated with these reduction parameters. It does not check if the launch
-// parameters are equivelent!
-class PointwiseParams {
- public:
-  // vectorize if true, otherwise unroll
-  bool vectorize = false;
-
-  // Treat pointwise operation as 2-Dimensional, this is the location where we
-  // split from left side of the domain to right. i.e. 0 means problem is
-  // treated as 1-D, 1 of 3 would mean we treat the first dimension as the outer
-  // dimension, and all the others as an inner dimension.
-  int break_point = 0;
-
-  // Split block across left and right dimension
-  bool split_block = false;
-
-  // Split grid y dimension, if otherwise it would be too large
-  bool split_grid_y_dim = false;
-
-  // Unroll or vectorization factor
-  int64_t inner_factor = 1;
-
-  std::string tag = "";
-
-  LaunchParams lparams;
-
-  // Warning: Does not check launch parameters!
-  bool operator==(const PointwiseParams& other) const {
-    bool attr_equal = other.vectorize == vectorize &&
-        other.break_point == break_point && other.split_block == split_block &&
-        other.split_grid_y_dim == split_grid_y_dim &&
-        other.inner_factor == inner_factor;
-    return attr_equal;
-  }
-
-  std::string toString() const {
-    std::stringstream ss;
-    ss << "\n===== Pointwise Parameters ========\n"
-       << (tag == "" ? "" : "Tag: ") << tag << " Pointwise Characteristics:\n"
-       << " Gridx: " << lparams.gdimx() << " BlckY: " << lparams.bdimy()
-       << " BlckX: " << lparams.bdimx() << "\n";
-    if (break_point) {
-      ss << "2D Schedule\n"
-         << "  Bcast break point: " << break_point << "\n";
-      if (split_block) {
-        ss << "Split block into y-dim\n";
-      }
-      if (split_grid_y_dim) {
-        ss << "  Split y grid dim\n";
-      }
-    }
-    if (inner_factor > 1) {
-      if (vectorize) {
-        ss << "Vectorize, Factor: " << inner_factor << "\n";
-      } else {
-        ss << "Unroll, Factor: " << inner_factor << "\n";
-      }
-    }
-    ss << "====================================\n";
-    return ss.str();
-  }
-};
-
-// Warning: Hash is not based on launch parameters!
-class PointwiseParamsHash {
- public:
-  size_t operator()(const PointwiseParams& pp) const {
-    size_t attr_hash = static_cast<size_t>(pp.vectorize) ^
-        static_cast<size_t>(pp.break_point) << 4 ^
-        static_cast<size_t>(pp.split_block) << 5 ^
-        static_cast<size_t>(pp.split_grid_y_dim) << 6 ^
-        static_cast<size_t>(pp.inner_factor) << 9;
-    return attr_hash;
-  }
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp
deleted file mode 100644 (file)
index cf81f86..0000000
+++ /dev/null
@@ -1,792 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/scheduler/reduction.h>
-
-#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-
-#include <ATen/cuda/CUDAContext.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-ReductionParams innerReductionHeuristic(
-    const int64_t num_elems_in_reduction,
-    const int64_t num_outputs_for_reduction,
-    const int64_t n_tensor_inputs,
-    const int64_t max_input_dtype_size,
-    size_t vectorize_factor) {
-  // Set some targets for parallelization
-
-  const int64_t n_elems = num_elems_in_reduction * num_outputs_for_reduction;
-
-  // WARNING: Current device for codegen may not be the target device
-  const int64_t device_max_threads_per_multiprocessor =
-      (int64_t)at::cuda::getCurrentDeviceProperties()
-          ->maxThreadsPerMultiProcessor;
-
-  const int64_t device_multiprocessor_count =
-      (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
-
-  auto const max_unroll = ceilDiv(
-      // Available unrolling based on size of data type
-      (int64_t)16 / (int64_t)max_input_dtype_size,
-      // Reduce unrolling if we have many inputs, start reduction at 4 inputs
-      std::max(
-          (scheduler_utils::lastPow2((int64_t)n_tensor_inputs) >> 2),
-          (int64_t)1));
-
-  // Conservative value, could be set to larger based on arch if necessary.
-  constexpr int64_t l1_cache = 32 * 1024;
-  // Could change per generation, but for l1 we want to consider active threads,
-  // not resident
-  constexpr int64_t active_threads = 1024;
-
-  // if data fits in l2 and we need more parallelization in the reduction dim,
-  // we can use a smaller warp size. While thread local data fits in l1, and
-  // reduction dim is really small, we can use <32 threads per warp.
-  const bool fits_in_l2 = n_elems * max_input_dtype_size * n_tensor_inputs <
-      at::cuda::getCurrentDeviceProperties()->l2CacheSize;
-
-  // If it fits in l2, we just want to make sure each thread uses 32Bytes.
-  const int64_t warp_size_based_on_l2 =
-      fits_in_l2 ? (int64_t)32 / max_input_dtype_size : 32;
-
-  // Check how many elements it would take per thread to start thrashing l1
-  // set that to minimum number we want to reduce per thread.
-  const int64_t warp_size_based_on_l1 = std::min(
-      ceilDiv(
-          num_elems_in_reduction,
-          std::max(
-              l1_cache /
-                  (n_tensor_inputs * max_input_dtype_size * active_threads),
-              (int64_t)1)),
-      (int64_t)16);
-
-  // Take the smaller
-  const int64_t warp_size =
-      std::min(warp_size_based_on_l1, warp_size_based_on_l2);
-
-  // Initialization
-  int64_t target_blocks = 1;
-  int64_t target_unroll = 1;
-  int64_t target_iterations = 1;
-
-  // Try to set a minmum amount of work for each thread, as cross thread
-  // communication is slow so it shouldn't be done for every element in the
-  // reduction.
-  int64_t min_target_iterations =
-      std::max((int64_t)32 / (int64_t)max_input_dtype_size, (int64_t)1);
-
-  // Start trying to break parallelization up across threads,
-  // unrolling/iterations, and blocks.
-
-  // max_threads_in_block is the cap on a thread block, the minimum is based on
-  // warp_size
-  int64_t max_threads_in_block = std::max(
-      warp_size, ceilDiv(num_elems_in_reduction, min_target_iterations));
-
-  // If we have one warp per block, check if that's enough to saturate the SMs
-  target_blocks = ceilDiv(n_elems, warp_size);
-
-  // If we have more than a wave of blocks, put parallelism into unrolling and
-  // target iterations
-  if (target_blocks > device_multiprocessor_count) {
-    auto available_unroll = std::max(
-        n_elems / (warp_size * device_multiprocessor_count), (int64_t)1);
-
-    // Spread across unrolling and iterations, want a balance of the two so flip
-    // back and forth to alternate adding to them.
-    bool flip = true;
-
-    while (available_unroll > 1 &&
-           (target_unroll < max_unroll ||
-            // Prefer unrolling
-            target_iterations < ceilDiv(min_target_iterations, max_unroll))) {
-      if (target_unroll * 2 <= max_unroll && flip) {
-        target_unroll *= 2;
-      }
-
-      if (target_iterations * 2 <= ceilDiv(min_target_iterations, max_unroll) &&
-          !flip) {
-        target_iterations *= 2;
-      }
-
-      available_unroll = std::max(
-          n_elems /
-              (warp_size * device_multiprocessor_count * target_unroll *
-               target_iterations),
-          (int64_t)1);
-
-      flip = !flip;
-    }
-
-    // Recompute target blocks
-    target_blocks =
-        ceilDiv(n_elems, warp_size * target_unroll * target_iterations);
-  }
-
-  // Cap target blocks to 4 waves
-  target_blocks = std::min(target_blocks, device_multiprocessor_count * 4);
-
-  if (target_blocks * target_unroll * target_iterations < n_elems) {
-    // targetting 4 waves, so try to use a quarter of available threads
-    max_threads_in_block = std::min(
-        ceilDiv(n_elems, target_blocks * target_unroll),
-        ceilDiv(device_max_threads_per_multiprocessor, (int64_t)4));
-  }
-
-  // To get to target threads:
-  // Prioritize
-  // (1) x dim in reduction
-  // (2) unrolling in reduction
-  // (3) y in output
-  // To get target blocks:
-  // Prioritize
-  // (1) x dim in multiple outputs
-  // (2) y dim in multiple reductions
-
-  // Blocks for reductions
-  int64_t grdim = 1;
-  // Blocks for outputs
-  int64_t godim = 1;
-
-  // Threads for outputs
-  int64_t bdimy = 1;
-  // Threads for reduction
-  int64_t bdimx = 1;
-
-  // Should we unroll from reduction axis, or outs axis
-  bool unroll_reduction = true;
-
-  // Unroll amount
-  int64_t unroll_factor = 1;
-
-  // Grab what we can out of reduction domain, but don't go over a warp size yet
-  bdimx = std::min(num_elems_in_reduction, (int64_t)warp_size);
-  // Put everything else in bdimy for now
-  bdimy = std::max(max_threads_in_block / bdimx, (int64_t)1);
-  int64_t remainder_in_reduction = ceilDiv(num_elems_in_reduction, bdimx);
-  int64_t remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimy);
-
-  // Adjust blocking and setup unrolling
-  if (remainder_in_reduction == 1) {
-    // Small number of reduction elements, try unrolling output dimension
-    unroll_factor = std::min(target_unroll, remainder_in_output);
-    if (unroll_factor > 1) {
-      unroll_reduction = false;
-      remainder_in_output =
-          ceilDiv(num_outputs_for_reduction, unroll_factor * bdimy);
-    }
-  } else {
-    // If there are reduction elements left after unrolling a warp, re-adjust
-    // the block dims to put more threads into the reduction
-    bdimx = std::min(
-        std::max(
-            ceilDiv(num_elems_in_reduction, target_iterations * target_unroll),
-            warp_size),
-        max_threads_in_block);
-
-    // Don't exceed target.
-    bdimy = std::max(max_threads_in_block / bdimx, (int64_t)1);
-    remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimy);
-
-    remainder_in_reduction = ceilDiv(num_elems_in_reduction, bdimx);
-    unroll_factor = std::min(remainder_in_reduction, target_unroll);
-    if (unroll_factor == 1) {
-      // If we can't unroll reduction dim, unroll output dim
-      unroll_factor = std::min(remainder_in_output, target_unroll);
-      if (unroll_factor > 1) {
-        unroll_reduction = false;
-      }
-      remainder_in_output =
-          ceilDiv(num_outputs_for_reduction, bdimy * unroll_factor);
-      remainder_in_reduction =
-          ceilDiv(num_elems_in_reduction, bdimx * target_iterations);
-    } else {
-      remainder_in_reduction = ceilDiv(
-          num_elems_in_reduction,
-          bdimx * std::max(unroll_factor, target_iterations));
-    }
-  }
-
-  godim = remainder_in_output;
-
-  // Clang tidy
-  constexpr int64_t kEight = 8;
-  constexpr int64_t kThirtyTwo = 32;
-
-  // Cross grid reduction if we haven't hit our target blocks, and we have many
-  // reduction elements.
-  if ((godim < target_blocks && remainder_in_reduction > kEight &&
-       remainder_in_reduction < kThirtyTwo) ||
-      (remainder_in_reduction >= kThirtyTwo)) {
-    // Grid reductions do not support unrolling iteration dimension, revert if
-    // set.
-    if (!unroll_reduction) {
-      unroll_reduction = true;
-      unroll_factor = 1;
-      remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimy);
-      remainder_in_reduction =
-          ceilDiv(num_elems_in_reduction, bdimx * target_iterations);
-    }
-    if (remainder_in_reduction >= kThirtyTwo) {
-      // Do at least 2 iterations of unrolling per thread before we go cross
-      // grid. Limit cross grid to a multiple of the block size so cleanup on
-      // the last block doesn't take too long.
-      grdim = std::min(
-          ceilDiv(remainder_in_reduction, (int64_t)2), bdimx * bdimy * kEight);
-      // Clang tidy
-      // remainder_in_reduction = ceilDiv(remainder_in_reduction, grdim);
-    } else {
-      grdim = ceilDiv(remainder_in_reduction, (int64_t)4);
-    }
-    // Clang tidy
-    //
-    // remainder_in_reduction = ceilDiv(
-    //     num_elems_in_reduction,
-    //     bdimx *
-    //         std::max(
-    //             unroll_reduction ? unroll_factor : 1,
-    //             min_red_elems_per_thread) *
-    //         grdim);
-  }
-
-  // Try to do some cleanup of ragged waves on device
-  // godim is a remainder of a split, so can only control bdimy
-  if (
-      // If we have less than 8 waves of blocks
-      grdim * godim < device_multiprocessor_count * kEight &&
-      // And we don't have an even divisible number of blocks
-      (grdim * godim) % device_multiprocessor_count != 0 &&
-      // And we have more than one wave
-      grdim * godim > device_multiprocessor_count) {
-    // round waves down
-    auto waves =
-        std::max((godim * grdim) / device_multiprocessor_count, (int64_t)1);
-    auto new_grdim =
-        std::max((waves * device_multiprocessor_count) / godim, (int64_t)1);
-    if (
-        // If difference is less than 25% of the original grdim
-        (new_grdim - grdim) * 4 < grdim &&
-        // and difference is less than 25% of the original number of blocks
-        ((new_grdim * godim) - (grdim * godim)) * 4 < grdim * godim) {
-      grdim = new_grdim;
-    }
-  }
-
-  bool vectorize = false;
-
-  if (vectorize_factor > 1 && unroll_factor > 1 && unroll_reduction) {
-    vectorize = true;
-    unroll_factor = std::min(
-        scheduler_utils::lastPow2(unroll_factor), (int64_t)vectorize_factor);
-  }
-
-  ReductionParams rparams;
-  rparams.fastest_dim = true;
-  rparams.cross_block = true;
-  rparams.cross_grid = grdim > 1;
-  rparams.multiple_reds_per_blk = bdimy > 1;
-  rparams.loop_unroll = unroll_factor;
-  rparams.vectorize = vectorize;
-  rparams.reduction_unroll = unroll_reduction;
-
-  // If we have a cross grid case we want to have gdimy assigned to godim and
-  // gdimx assigned to grdim. Otherwise it's helpful to pull godim into gdimx in
-  // case it's larger than gdimy can hold, as not doing so can thrash the cache.
-  int64_t gdimx = LaunchParams::UNINITIALIZED_VAL;
-  int64_t gdimy = LaunchParams::UNINITIALIZED_VAL;
-
-  if (rparams.cross_grid) {
-    gdimx = grdim;
-    rparams.split_grid_dim = gdimy > scheduler_utils::y_grid_limit;
-  } else {
-    rparams.split_grid_dim = gdimx > scheduler_utils::x_grid_limit;
-  }
-
-  rparams.lparams = LaunchParams(
-      gdimx,
-      gdimy,
-      LaunchParams::UNINITIALIZED_VAL,
-      bdimx,
-      bdimy,
-      LaunchParams::UNINITIALIZED_VAL);
-  if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
-    std::cerr << "\n===== Reduction Stats ========\n"
-              << "num_elems_in_reduction: " << num_elems_in_reduction << "\n"
-              << "num_outputs_for_reduction: " << num_outputs_for_reduction
-              << "\n"
-              << "n_tensor_inputs: " << n_tensor_inputs << "\n"
-              << "max_input_dtype_size: " << max_input_dtype_size << std::endl;
-    std::cerr << rparams.toString() << std::endl;
-  }
-
-  return rparams;
-}
-
-ReductionParams OuterReductionHeuristic(
-    const int64_t num_elems_in_reduction,
-    const int64_t num_outputs_for_reduction,
-    const int64_t n_tensor_inputs,
-    const int64_t max_input_dtype_size,
-    size_t vectorize_factor) {
-  // Set some targets for parallelization
-
-  const int64_t n_elems = num_elems_in_reduction * num_outputs_for_reduction;
-  const int64_t l2_cache_size =
-      at::cuda::getCurrentDeviceProperties()->l2CacheSize;
-
-  const int64_t warp_size =
-      n_elems * max_input_dtype_size * n_tensor_inputs < l2_cache_size
-      ? (int64_t)32 / max_input_dtype_size
-      : 32;
-
-  int64_t target_blocks = 1;
-  int64_t target_unroll = 1;
-  int64_t max_threads_in_block = warp_size;
-
-  // WARNING: Current device for codegen may not be the target device
-  const int64_t device_max_threads_per_multiprocessor =
-      (int64_t)at::cuda::getCurrentDeviceProperties()
-          ->maxThreadsPerMultiProcessor;
-
-  const int64_t device_multiprocessor_count =
-      (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
-
-  auto const max_unroll = ceilDiv(
-      // Available unrolling based on size of data type
-      (int64_t)16 / (int64_t)max_input_dtype_size,
-      // Reduce unrolling if we have many inputs, start reduction at 4 inputs
-      std::max(
-          (scheduler_utils::lastPow2((int64_t)n_tensor_inputs) >> 2),
-          (int64_t)1));
-
-  // If we have one warp per block, how many blocks would that be?
-  target_blocks = ceilDiv(n_elems, (int64_t)warp_size);
-
-  // If we have more than a wave, put parallelism into unrolling
-  if (target_blocks > device_multiprocessor_count) {
-    target_unroll = std::min(
-        max_unroll, ceilDiv(target_blocks, device_multiprocessor_count));
-    target_blocks = ceilDiv(target_blocks, target_unroll);
-  }
-
-  // Cap target blocks to 4 waves
-  target_blocks = std::min(target_blocks, device_multiprocessor_count * 4);
-
-  if (target_blocks * target_unroll * max_threads_in_block < n_elems) {
-    // targetting 4 waves, so try to use a quarter of available threads
-    max_threads_in_block = std::min(
-        ceilDiv(n_elems, target_blocks * target_unroll),
-        ceilDiv(device_max_threads_per_multiprocessor, (int64_t)4));
-  }
-
-  // To get to target threads:
-  // Prioritize
-  // (1) x dim in iter domain
-  // (2) unrolling in iter domain
-  // (3) y in reduction domain
-  // To get target blocks:
-  // Prioritize
-  // (1) x dim in multiple outputs
-  // (2) y dim in multiple reductions - need to flip unrolling to reduction
-  // domain for this
-
-  // Blocks for reductions
-  int64_t gdimy = 1;
-  // Blocks for outputs
-  int64_t gdimx = 1;
-
-  // Threads for reduction
-  int64_t bdimy = 1;
-  // Threads for output
-  int64_t bdimx = 1;
-
-  // Should we unroll from reduction axis, or outs axis
-  bool unroll_reduction = false;
-
-  // Unroll amount
-  int64_t unroll_factor = 1;
-
-  int64_t remainder_in_reduction = num_elems_in_reduction;
-  int64_t remainder_in_output = num_outputs_for_reduction;
-
-  if (ceilDiv(num_outputs_for_reduction, warp_size) <
-      device_multiprocessor_count) {
-    // If we can't hit a full wave, leave bdimx as warp_size, and prioritize
-    // bdimy. TODO: Re-evaluate, should it be bdimx = warp_size?
-    bdimx = std::min(num_outputs_for_reduction, warp_size);
-  } else {
-    bdimx = std::min(
-        max_threads_in_block,
-        ceilDiv(num_outputs_for_reduction, target_blocks));
-    bdimx = std::max(bdimx, warp_size);
-  }
-
-  bdimy = std::min(
-      std::max(max_threads_in_block / bdimx, (int64_t)1),
-      num_elems_in_reduction);
-
-  // Clang tidy
-  // remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimx);
-  remainder_in_reduction = ceilDiv(remainder_in_reduction, bdimy);
-
-  if (num_outputs_for_reduction >=
-      device_multiprocessor_count * max_threads_in_block) {
-    // If we easily saturate the GPU, don't use block dim y and unroll output
-    // dimension, this could be a more gentle transition starting earlier
-    bdimx = max_threads_in_block;
-    remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimx);
-
-    bdimy = 1;
-    remainder_in_reduction = num_elems_in_reduction;
-
-    // Assume unroll in output, switch to remainder if cross grid
-    // Don't unroll if we don't have 2 full waves
-    unroll_factor = std::min(
-        ceilDiv(remainder_in_output, device_multiprocessor_count * 2),
-        target_unroll);
-
-    if (unroll_factor == 1 && remainder_in_reduction > 1) {
-      // Try unrolling in reduction dimension
-      unroll_factor = std::min(remainder_in_reduction, unroll_factor);
-      // Clang tidy
-      // remainder_in_reduction = ceilDiv(remainder_in_reduction,
-      // unroll_factor);
-      if (unroll_factor > 1) {
-        unroll_reduction = true;
-      }
-    }
-    // Clang tidy
-    // else {
-    //   remainder_in_output =
-    //       ceilDiv(num_outputs_for_reduction, bdimx * unroll_factor);
-    // }
-  } else {
-    // Not many output elements, so we want to try expand grid level parallelism
-    // first go after unrolling
-    unroll_factor = std::min(max_unroll, remainder_in_reduction);
-    if (unroll_factor > 1) {
-      unroll_reduction = true;
-    }
-
-    remainder_in_reduction =
-        ceilDiv(num_elems_in_reduction, bdimy * unroll_factor);
-
-    // Go cross grid
-    gdimy = ceilDiv(remainder_in_reduction, (int64_t)4);
-    // Clang tidy
-    // remainder_in_reduction =
-    //     ceilDiv(num_elems_in_reduction, bdimy * unroll_factor * gdimy);
-  }
-
-  // Clang tidy
-  constexpr int64_t kEight = 8;
-  constexpr int64_t kSixteen = 16;
-  constexpr int64_t kThirtyTwo = 32;
-
-  if (ceilDiv(num_elems_in_reduction, bdimy * unroll_factor) >= kThirtyTwo) {
-    // Many reduction elements, go cross grid
-    int64_t min_gdimy = 1;
-    if (gdimy > 1) {
-      // already cross grid, don't go below target or what was already set
-      min_gdimy = std::min(gdimy, ceilDiv(target_blocks, gdimx));
-    }
-    gdimy = std::max(
-        min_gdimy,
-        ceilDiv(
-            ceilDiv(num_elems_in_reduction, bdimy * unroll_factor),
-            (int64_t)kSixteen));
-    // Don't go too far above number of threads in a block since that's how many
-    // threads are available to do final reduction iteration
-    // This is good!
-    gdimy = std::min(gdimy, bdimx * bdimy * kEight);
-  }
-
-  // Try to do some cleanup of ragged waves on device
-  if (
-      // If we have less than 8 waves of blocks
-      gdimy * gdimx < device_multiprocessor_count * kEight &&
-      // And we don't have an even divisible number of blocks
-      (gdimy * gdimx) % device_multiprocessor_count != 0 &&
-      // And we have more than one wave
-      gdimy * gdimx > device_multiprocessor_count) {
-    // round waves down
-    auto waves =
-        std::max((gdimx * gdimy) / device_multiprocessor_count, (int64_t)1);
-    auto new_gdimy =
-        std::max((waves * device_multiprocessor_count) / gdimx, (int64_t)1);
-    if (
-        // If difference is less than 25% of the original gdimy
-        (new_gdimy - gdimy) * 4 < gdimy &&
-        // and difference is less than 25% of the original number of blocks
-        ((new_gdimy * gdimx) - (gdimy * gdimx)) * 4 < gdimy * gdimx) {
-      gdimy = new_gdimy;
-    }
-  }
-
-  // Cannot unroll with cross grid reductions
-  if (gdimy > 1 && !unroll_reduction) {
-    unroll_reduction = true;
-    unroll_factor = 1;
-  }
-
-  bool vectorize = false;
-
-  if (vectorize_factor > 1 && unroll_factor > 1 && !unroll_reduction) {
-    vectorize = true;
-    unroll_factor = std::min(
-        scheduler_utils::lastPow2(unroll_factor), (int64_t)vectorize_factor);
-  }
-
-  ReductionParams rparams;
-  rparams.fastest_dim = false;
-  // cross grid implies cross block
-  rparams.cross_block = bdimy > 1 || gdimy > 1;
-  rparams.cross_grid = gdimy > 1;
-  rparams.multiple_reds_per_blk = bdimx > 1;
-  rparams.loop_unroll = unroll_factor;
-  rparams.vectorize = vectorize;
-  rparams.reduction_unroll = unroll_reduction;
-
-  rparams.lparams = LaunchParams(
-      LaunchParams::UNINITIALIZED_VAL,
-      gdimy,
-      LaunchParams::UNINITIALIZED_VAL,
-      bdimx,
-      bdimy,
-      LaunchParams::UNINITIALIZED_VAL);
-
-  if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
-    std::cerr << "\n===== Reduction Stats ========\n"
-              << "num_elems_in_reduction: " << num_elems_in_reduction << "\n"
-              << "num_outputs_for_reduction: " << num_outputs_for_reduction
-              << "\n"
-              << "n_tensor_inputs: " << n_tensor_inputs << "\n"
-              << "max_input_dtype_size: " << max_input_dtype_size << std::endl;
-    std::cerr << rparams.toString() << std::endl;
-  }
-  return rparams;
-}
-
-} // namespace
-
-ReductionParams reductionHeuristic(
-    int64_t num_elems_in_reduction,
-    int64_t num_outputs_for_reduction,
-    bool fastest_dim_reduction,
-    size_t n_tensor_inputs,
-    size_t max_input_dtype_size,
-    size_t vectorize_factor) {
-  if (fastest_dim_reduction) {
-    return innerReductionHeuristic(
-        num_elems_in_reduction,
-        num_outputs_for_reduction,
-        n_tensor_inputs,
-        max_input_dtype_size,
-        vectorize_factor);
-  } else {
-    return OuterReductionHeuristic(
-        num_elems_in_reduction,
-        num_outputs_for_reduction,
-        n_tensor_inputs,
-        max_input_dtype_size,
-        vectorize_factor);
-  }
-}
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
-    Fusion* fusion,
-    const at::ArrayRef<c10::IValue>& runtime_inputs,
-    HeuristicSummary* data_cache) {
-  FUSER_PERF_SCOPE("getReductionHeuristics");
-
-  SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true);
-
-  return getReductionHeuristics(fusion, runtime_info, data_cache);
-}
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    HeuristicSummary* data_cache) {
-  FUSER_PERF_SCOPE("getReductionHeuristics");
-
-  FusionGuard fg(fusion);
-
-  HeuristicCacheAccessor<std::vector<TensorView*>> reduction_tv_data;
-  // TODO: move all these boilerplate code into the accessor class
-  // (follow up)
-  if (data_cache && !data_cache->isRecording()) {
-    reduction_tv_data.writeTemporary(data_cache->getReductionTVs());
-  } else {
-    reduction_tv_data.writeNew(scheduler_utils::getReductionTvs(fusion));
-    if (data_cache && data_cache->isRecording()) {
-      data_cache->setReductionTVs(reduction_tv_data.read());
-    }
-  }
-
-  auto& reduction_tvs = reduction_tv_data.read();
-
-  TORCH_INTERNAL_ASSERT(
-      reduction_tvs.size() == 1, "Need reduction tensor views to schedule.");
-
-  auto reduction_tv = reduction_tvs[0];
-
-  TORCH_INTERNAL_ASSERT(reduction_tv != nullptr);
-
-  auto red_root_dom = reduction_tv->getRootDomain();
-  bool fastest_dim_reduction = true;
-  for (size_t i = red_root_dom.size(); i > 0; i--) {
-    if (red_root_dom[i - 1]->isBroadcast() ||
-        red_root_dom[i - 1]->isTrivialReduction()) {
-      continue;
-    } else if (red_root_dom[i - 1]->isReduction()) {
-      fastest_dim_reduction = true;
-      break;
-    } else {
-      fastest_dim_reduction = false;
-      break;
-    }
-  }
-
-  TORCH_INTERNAL_ASSERT(
-      reduction_tv != nullptr, "Reduction TensorView wasn't found.");
-
-  TORCH_INTERNAL_ASSERT(
-      reduction_tv->hasReduction(), "TensorView doesn't have a reduction.");
-  const auto red_expr = reduction_tv->definition();
-
-  TORCH_INTERNAL_ASSERT(
-      red_expr->getExprType() != c10::nullopt &&
-          (red_expr->getExprType().value() == ExprType::ReductionOp ||
-           red_expr->getExprType().value() == ExprType::WelfordOp),
-      "TensorView doesn't have a reduction.");
-
-  int64_t num_outputs_for_reduction = 1;
-  int64_t red_elements = 1;
-
-  for (auto id : reduction_tv->getRootDomain()) {
-    auto inferred_val =
-        runtime_info.expressionEvaluator().evaluate(id->extent());
-    TORCH_INTERNAL_ASSERT(
-        inferred_val.has_value(), "Error inferring reduction size.");
-    if (id->isReduction()) {
-      red_elements *= inferred_val.value();
-    } else {
-      num_outputs_for_reduction *= inferred_val.value();
-    }
-  }
-
-  size_t max_dtype_size = 1;
-  size_t n_tensor_inputs = 0;
-  for (auto inp : fusion->inputs()) {
-    if (inp->isA<TensorView>()) {
-      max_dtype_size =
-          std::max(max_dtype_size, dataTypeSize(inp->getDataType().value()));
-      n_tensor_inputs++;
-    }
-  }
-
-  TORCH_INTERNAL_ASSERT(
-      n_tensor_inputs > 0,
-      "Tried to schedule a fusion with no tensor inputs, currently not supported.");
-
-  auto vectorizable_inputs_outputs =
-      scheduler_utils::getVectorizableInputsOutputs(reduction_tv);
-
-  // Vectorize as much as we can
-  size_t vectorize_factor = std::numeric_limits<size_t>::max();
-
-  for (auto tv : vectorizable_inputs_outputs) {
-    const auto tv_vectorize_factor = runtime_info.getVectorizableWidth(tv);
-    vectorize_factor = std::min(vectorize_factor, tv_vectorize_factor);
-  }
-
-  if (vectorize_factor == std::numeric_limits<size_t>::max()) {
-    vectorize_factor = 1;
-  }
-
-  return reductionHeuristic(
-      red_elements,
-      num_outputs_for_reduction,
-      fastest_dim_reduction,
-      n_tensor_inputs,
-      max_dtype_size,
-      vectorize_factor);
-}
-
-// fusion is the input IR that will be modified by this function
-void scheduleReduction(Fusion* fusion, const ReductionParams& rparams) {
-  FUSER_PERF_SCOPE("scheduleReduction");
-  FusionGuard fg(fusion);
-
-  // Cache inputs if unrolled
-  auto cached_inputs =
-      scheduler_utils::cacheInputs(fusion, rparams.loop_unroll > 1);
-
-  // Cache and fork  outputs
-  std::vector<std::pair<TensorView*, TensorView*>> cached_outputs =
-      scheduler_utils::cacheAndForkOutputs(fusion, rparams.loop_unroll > 1);
-
-  // Make sure we don't have global memory set on intermediate tensors from
-  // fusion segmentation
-  scheduler_utils::clearMemorySpace(fusion);
-
-  auto reduction_tvs = scheduler_utils::getReductionTvs(fusion);
-
-  TORCH_INTERNAL_ASSERT(
-      reduction_tvs.size() <= 1,
-      "Found multiple reductions sent to reduction heuristics",
-      " (and reductions are not from a multi-output expr).");
-  TORCH_INTERNAL_ASSERT(reduction_tvs.size());
-
-  auto reduction_tv = reduction_tvs[0];
-
-  auto dim_analysis =
-      scheduler_utils::canonicalDimReduction(fusion, reduction_tv);
-  bool has_iter_axis = dim_analysis.first;
-  bool has_red_axis = dim_analysis.second;
-
-  TORCH_INTERNAL_ASSERT(
-      has_red_axis,
-      "Could not find reduction axis in tensor used for reduction scheduler.");
-
-  if (!has_iter_axis) {
-    TORCH_INTERNAL_ASSERT(
-        rparams.fastest_dim,
-        "If all dims are reduction, should be sending it to fastest dim scheduler.");
-  }
-
-  TensorView* reference_tv = scheduler_utils::scheduleReductionTV(
-      rparams, reduction_tv, has_iter_axis);
-
-  // Reduction tensor views and rfactor tensor views are setup. Let's finish off
-  // the scheduling, particularly inlining and unrolling.
-  TORCH_INTERNAL_ASSERT(
-      reference_tv != nullptr && reduction_tv != nullptr,
-      "Need these two tensor views to finish the scheduling.");
-  scheduler_utils::multiReductionInliner(
-      fusion,
-      rparams,
-      reduction_tv,
-      reference_tv,
-      reduction_tvs,
-      cached_inputs,
-      cached_outputs);
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction.h b/torch/csrc/jit/codegen/cuda/scheduler/reduction.h
deleted file mode 100644 (file)
index 7e517b1..0000000
+++ /dev/null
@@ -1,32 +0,0 @@
-#pragma once
-
-#include <ATen/core/ivalue.h>
-
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class SchedulerRuntimeInfo;
-class HeuristicSummary;
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
-    Fusion* fusion,
-    const at::ArrayRef<c10::IValue>& runtime_inputs,
-    HeuristicSummary* data_cache = nullptr);
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    HeuristicSummary* data_cache = nullptr);
-
-TORCH_CUDA_CU_API void scheduleReduction(
-    Fusion* fusion,
-    const ReductionParams& rparams);
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h b/torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h
deleted file mode 100644 (file)
index 3d9402e..0000000
+++ /dev/null
@@ -1,111 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
-
-#include <sstream>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-// Parameters the Reduction Heuristic Generates to describe the optimial
-// schedule. Warning: equal operator is intended for use in caching the kernel
-// associated with these reduction parameters. It does not check if the launch
-// parameters are equivelent!
-class ReductionParams {
- public:
-  // Reducing inner most dimension?
-  bool fastest_dim = true;
-  // Reduce across the block?
-  bool cross_block = false;
-  // Reduce across the grid?
-  bool cross_grid = false;
-  // Perform multiple reductions per block?
-  bool multiple_reds_per_blk = false;
-  // Unrolling factor
-  int64_t loop_unroll = 1;
-  // Should unrolling be done on reduction dimension
-  bool reduction_unroll = true;
-  // vectorize instead of unroll
-  bool vectorize = false;
-  // Number of batches for each block
-  int64_t batches_per_block = 1;
-  // Number of warps per block
-  // TODO: Remove or repurpose
-  int64_t num_warps = 1;
-  // Store input in shared memory or registers to reduce global memory reads
-  bool persistent_kernel = false;
-
-  // Split grid dim in case it's too large for cuda
-  bool split_grid_dim = false;
-
-  std::string tag = "";
-
-  LaunchParams lparams;
-
- public:
-  // Warning: Does not check launch parameters!
-  bool operator==(const ReductionParams& other) const {
-    bool attr_equal = other.fastest_dim == fastest_dim &&
-        other.cross_block == cross_block && other.cross_grid == cross_grid &&
-        other.multiple_reds_per_blk == multiple_reds_per_blk &&
-        other.loop_unroll == loop_unroll && other.vectorize == vectorize &&
-        other.batches_per_block == batches_per_block &&
-        other.num_warps == num_warps &&
-        other.persistent_kernel == persistent_kernel &&
-        other.reduction_unroll == reduction_unroll &&
-        other.split_grid_dim == split_grid_dim;
-    return attr_equal;
-  }
-
-  std::string toString() const {
-    std::stringstream ss;
-    ss << "\n===== Reduction Parameters ========\n"
-       << (tag == "" ? "" : "Tag: ") << tag
-       << (fastest_dim ? "Red On Fastest Dim\n" : "Red On Slow Dim\n")
-       << "Reduction Characteristics:\n"
-       << (multiple_reds_per_blk ? "Multiple Reds Per Block\n" : "")
-       << (cross_block ? "Cross block reduction\n" : "")
-       << (cross_grid ? "Cross grid reduction\n" : "");
-    if (persistent_kernel) {
-      ss << "Persistent Kernel\n"
-         << "Batches per block: " << batches_per_block << "\n";
-    }
-    ss << "Blocking:\n"
-       << " GridY: " << lparams.gdimy() << " BlckY: " << lparams.bdimy()
-       << " BlckX: " << lparams.bdimx() << "\n";
-    if (loop_unroll > 1) {
-      ss << (vectorize ? "Vectorize " : "Unroll ")
-         << (reduction_unroll ? " reduction dim, " : " iter dim, ")
-         << "Factor: " << loop_unroll << "\n";
-    }
-    ss << "====================================\n";
-    return ss.str();
-  }
-};
-
-// Warning: Hash is not based on launch parameters!
-class ReductionParamsHash {
- public:
-  size_t operator()(const ReductionParams& rp) const {
-    constexpr size_t bits = sizeof(std::size_t) * 8;
-    size_t attr_hash = static_cast<size_t>(rp.fastest_dim) << (bits - 1) ^
-        static_cast<size_t>(rp.cross_block) << (bits - 2) ^
-        static_cast<size_t>(rp.cross_grid) << (bits - 3) ^
-        static_cast<size_t>(rp.multiple_reds_per_blk) << (bits - 4) ^
-        static_cast<size_t>(rp.loop_unroll) ^
-        static_cast<size_t>(rp.reduction_unroll) << (bits - 5) ^
-        static_cast<size_t>(rp.vectorize) << (bits - 6) ^
-        static_cast<size_t>(rp.batches_per_block) ^
-        static_cast<size_t>(rp.num_warps) ^
-        static_cast<size_t>(rp.persistent_kernel) << (bits - 7) ^
-        static_cast<size_t>(rp.split_grid_dim) << (bits - 8);
-    return attr_hash;
-  }
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp
deleted file mode 100644 (file)
index 9646fa2..0000000
+++ /dev/null
@@ -1,1067 +0,0 @@
-#include <c10/util/irange.h>
-#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-
-#include <limits>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-// TODO: Deduplicate from compute_at.cpp
-std::deque<std::deque<TensorView*>> tvChains(
-    std::deque<std::deque<Val*>> val_chains) {
-  std::deque<std::deque<TensorView*>> tv_chains(val_chains.size());
-  for (size_t i = 0; i < val_chains.size(); i++) {
-    auto tv_iterable = ir_utils::filterByType<TensorView>(val_chains[i]);
-    tv_chains[i] =
-        std::deque<TensorView*>(tv_iterable.begin(), tv_iterable.end());
-  }
-  return tv_chains;
-}
-
-class SchedulerTopologyChecker {
- public:
-  // Checks if any broadcasts are resolved after a reduction that don't follow
-  // the normalization pattern
-  static bool hasNonNormalizePostReductionBCast(Fusion* fusion) {
-    auto all_vals = fusion->usedMathVals();
-    std::vector<TensorView*> reduction_tvs;
-    for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
-      if (tv->hasReduction() && !fusion->hasInput(tv)) {
-        reduction_tvs.push_back(tv);
-      }
-    }
-
-    // All tensor views that are eventually consumed to produce a reduction,
-    // includes reduction tensor views.
-    std::unordered_set<TensorView*> pre_reduction_tvs;
-
-    {
-      auto pre_reduction_vals = DependencyCheck::getAllValsBetween(
-          {fusion->inputs().begin(), fusion->inputs().end()},
-          {reduction_tvs.begin(), reduction_tvs.end()});
-      auto pre_reduction_tv_vector =
-          ir_utils::filterByType<TensorView>(pre_reduction_vals);
-      pre_reduction_tvs = std::unordered_set<TensorView*>(
-          pre_reduction_tv_vector.begin(), pre_reduction_tv_vector.end());
-    }
-
-    // Track which tensor views we've validated so we don't do it again.
-    std::unordered_set<TensorView*> validated_resolved_tvs;
-
-    // Run forward (towards outputs) from reductions on any path that isn't
-    // before another reduction. Look for resolved broadcasts. If a resolved
-    // broadcast is found, start there and propagate backwards. Track the id's
-    // that were resolved and make sure there's a mapping to a TensorView before
-    // a reduction.
-    for (auto red_tv : reduction_tvs) {
-      auto forward_tv_chains =
-          tvChains(DependencyCheck::getAllUseChains(red_tv));
-      // Propagate forward from reduction through all uses of the reduction
-      for (auto forward_tv_dep_chain : forward_tv_chains) {
-        TensorView* forward_running_producer = nullptr;
-        TensorView* forward_running_consumer = forward_tv_dep_chain.front();
-        forward_tv_dep_chain.pop_front();
-        while (!forward_tv_dep_chain.empty()) {
-          forward_running_producer = forward_running_consumer;
-          forward_running_consumer = forward_tv_dep_chain.front();
-          forward_tv_dep_chain.pop_front();
-
-          if (std::none_of(
-                  forward_running_producer->getMaybeRFactorDomain().begin(),
-                  forward_running_producer->getMaybeRFactorDomain().end(),
-                  [](IterDomain* id) { return id->isBroadcast(); })) {
-            // If there's no broadcast axes in producer it doesn't need to be
-            // checked
-            continue;
-          }
-
-          // If consumer is before another reduction it doesn't need to be
-          // checked
-          if (pre_reduction_tvs.count(forward_running_consumer)) {
-            break;
-          }
-
-          // If consumer was already validated it doesn't need to be checked
-          if (validated_resolved_tvs.count(forward_running_consumer)) {
-            continue;
-          }
-
-          auto forward_pairwise_root_map = PairwiseRootDomainMap(
-              forward_running_producer, forward_running_consumer);
-          auto forward_p2c_root_map =
-              forward_pairwise_root_map.mapProducerToConsumer(
-                  forward_running_producer->domain(),
-                  forward_running_consumer->domain());
-
-          // These are the ids we will have to resolve. As we resolve them we'll
-          // remove them from this vector. If this vector ends up empty, then
-          // we've resolved everything we need to. This is a pair so as we
-          // traverse we can map the id through the traversal. The first entry
-          // in the pair will be the original id so we can reset it if it's not
-          // resolved before the next traversal. The second ID will be
-          // propagated as we map the IDs through the backward traversal.
-          std::vector<std::pair<IterDomain*, IterDomain*>> ids_to_resolve;
-
-          // Check if any TensorViews have a resolved broadcast
-          for (auto entry : forward_p2c_root_map) {
-            auto p_id = entry.first;
-            auto c_id = entry.second;
-            if (p_id->isBroadcast() &&
-                (!c_id->isBroadcast() && !c_id->isTrivialReduction())) {
-              ids_to_resolve.emplace_back(std::make_pair(c_id, c_id));
-            }
-          }
-
-          if (ids_to_resolve.empty()) {
-            continue;
-          }
-
-          // Only because of api limitations in getAllDependencyChains
-          auto inputs_of_forward_running_consumer =
-              IterVisitor::getInputsTo({forward_running_consumer});
-          auto tv_inputs_of_forward_running_consumer =
-              ir_utils::filterByType<TensorView>(
-                  inputs_of_forward_running_consumer);
-
-          for (auto input_of_forward_running_consumer :
-               tv_inputs_of_forward_running_consumer) {
-            if (pre_reduction_tvs.find(input_of_forward_running_consumer) ==
-                pre_reduction_tvs.end()) {
-              // If this input isn't an input to a reduction, no point
-              // traversing the dependency chains as we know we can't validate
-              // this broadcast through chains to this input
-              continue;
-            }
-
-            auto backward_tv_chains =
-                tvChains(DependencyCheck::getAllDependencyChains(
-                    input_of_forward_running_consumer,
-                    forward_running_consumer));
-
-            for (auto backward_tv_chain : backward_tv_chains) {
-              if (ids_to_resolve.empty()) {
-                break;
-              }
-
-              for (auto& pair : ids_to_resolve) {
-                pair.second = pair.first;
-              }
-
-              TensorView* backward_running_producer = backward_tv_chain.back();
-              TensorView* backward_running_consumer = nullptr;
-              backward_tv_chain.pop_back();
-
-              TORCH_INTERNAL_ASSERT(
-                  backward_running_producer == forward_running_consumer);
-
-              while (!backward_tv_chain.empty()) {
-                backward_running_consumer = backward_running_producer;
-                backward_running_producer = backward_tv_chain.back();
-                backward_tv_chain.pop_back();
-
-                std::vector<IterDomain*> running_resolved_ids;
-
-                auto backward_pairwise_root_map = PairwiseRootDomainMap(
-                    backward_running_producer, backward_running_consumer);
-
-                auto backward_c2p_root_map =
-                    backward_pairwise_root_map.mapConsumerToProducer(
-                        backward_running_consumer->domain(),
-                        backward_running_producer->domain());
-
-                // Mark if producer is a producer of a reduction
-                bool producer_resolves =
-                    pre_reduction_tvs.count(backward_running_producer);
-
-                bool at_leat_one_id_mapped = false;
-                for (size_t entry_i = ids_to_resolve.size(); entry_i > 0;
-                     entry_i--) {
-                  auto orig_id = ids_to_resolve[entry_i - 1].first;
-                  auto running_id = ids_to_resolve[entry_i - 1].second;
-                  if (backward_c2p_root_map.find(running_id) !=
-                      backward_c2p_root_map.end()) {
-                    at_leat_one_id_mapped = true;
-                    if (producer_resolves &&
-                        !backward_c2p_root_map.at(running_id)->isBroadcast()) {
-                      // If mapped, and producer is a producer of a reduction,
-                      // we can resolve this id
-                      ids_to_resolve.erase(
-                          ids_to_resolve.begin() + (entry_i - 1));
-                    } else {
-                      ids_to_resolve[entry_i - 1] = std::make_pair(
-                          orig_id, backward_c2p_root_map.at(running_id));
-                    }
-                  }
-                }
-                if (!at_leat_one_id_mapped) {
-                  // If no id's map any more, go to the next chain
-                  break;
-                }
-
-                if (ids_to_resolve.empty()) {
-                  break;
-                }
-              }
-            }
-          } // for(auto input_of_forward_running_consumer :
-            // tv_inputs_of_forward_running_consumer){
-
-          // if all ids were not resolved, then we've found an instance of a
-          // bad broadcast resolution after reduction
-          if (ids_to_resolve.size()) {
-            return true;
-          }
-
-        } // while (!forward_tv_dep_chain.empty()) {
-      } // for (auto forward_tv_dep_chain : forward_tv_chains) {
-    } // for (auto red_tv : reduction_tvs)
-    return false;
-  }
-
-  // Checks if any broadcasts are resolved after a reduction, this shouldn't be
-  // accepted in the single reduction or multi-reduction scheduler
-  static bool hasPostReductionBCast(Fusion* fusion) {
-    auto all_vals = fusion->usedMathVals();
-    for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
-      // Welford can have 2 outputs, so do this on all found reduction tensor
-      // views
-      if (tv->hasReduction() && !tv->isFusionInput()) {
-        auto tv_chains = tvChains(DependencyCheck::getAllUseChains(tv));
-        // Propagate forward from reduction through all uses of the reduction
-        for (auto tv_dep_chain : tv_chains) {
-          TensorView* running_producer = nullptr;
-          TensorView* running_consumer = tv_dep_chain.front();
-          tv_dep_chain.pop_front();
-          while (!tv_dep_chain.empty()) {
-            running_producer = running_consumer;
-            running_consumer = tv_dep_chain.front();
-            tv_dep_chain.pop_front();
-
-            auto pairwise_root_map =
-                PairwiseRootDomainMap(running_producer, running_consumer);
-            auto p2c_root_map = pairwise_root_map.mapProducerToConsumer(
-                running_producer->domain(), running_consumer->domain());
-
-            // Check if any TensorViews have a resolved broadcast
-            for (auto entry : p2c_root_map) {
-              auto p_id = entry.first;
-              auto c_id = entry.second;
-              if (p_id->isBroadcast() &&
-                  (!c_id->isBroadcast() && !c_id->isTrivialReduction())) {
-                return true;
-              }
-            }
-          }
-        }
-      }
-    }
-    return false;
-  }
-
-  // Checks if there's any unsupported operations post reduction. If outer
-  // reduction we can fuse some pointwise ops if they don't require
-  // broadcasting (checked in hasPostReductionBCast). For inner reductions we
-  // cannot fuse any binary like operation (includes operations like shift that
-  // we're not fusing right now) involving "new" inputs (not going through a
-  // reduction).
-  static bool supportedPostReductionFusion(
-      Fusion* fusion,
-      std::vector<TensorView*> reduction_tvs) {
-    TORCH_INTERNAL_ASSERT(reduction_tvs.size());
-    bool fastest_dim_reduction = true;
-    auto red_root_dom = reduction_tvs[0]->getRootDomain();
-    for (size_t i = red_root_dom.size(); i > 0; i--) {
-      if (red_root_dom[i - 1]->isBroadcast() ||
-          red_root_dom[i - 1]->isTrivialReduction()) {
-        continue;
-      } else if (red_root_dom[i - 1]->isReduction()) {
-        fastest_dim_reduction = true;
-        break;
-      } else {
-        fastest_dim_reduction = false;
-        break;
-      }
-    }
-
-    // If reductions are on fastest dim, don't fuse any operations (after
-    // reductions) that requires an input that is not an input to the
-    // reductions.
-    if (fastest_dim_reduction) {
-      auto post_reduction_vals = DependencyCheck::getAllValsBetween(
-          {reduction_tvs.begin(), reduction_tvs.end()},
-          {fusion->outputs().begin(), fusion->outputs().end()});
-
-      if (post_reduction_vals.empty()) {
-        return true;
-      }
-
-      auto reduction_inputs = IterVisitor::getInputsTo(
-          {reduction_tvs.begin(), reduction_tvs.end()});
-
-      for (auto tv : ir_utils::filterByType<TensorView>(
-               post_reduction_vals.begin(), post_reduction_vals.end())) {
-        if (tv->definition() == nullptr) {
-          continue;
-        }
-
-        auto tv_inputs = IterVisitor::getInputsTo({tv});
-
-        if (std::any_of(
-                tv_inputs.begin(),
-                tv_inputs.end(),
-                [&reduction_inputs](Val* inp) {
-                  return inp->isA<TensorView>() &&
-                      std::find(
-                          reduction_inputs.begin(),
-                          reduction_inputs.end(),
-                          inp) == reduction_inputs.end();
-                })) {
-          return false;
-        }
-      }
-    }
-
-    return true;
-  }
-};
-} // namespace
-
-SchedulerRuntimeInfo::SchedulerRuntimeInfo(
-    Fusion* complete_fusion,
-    const at::ArrayRef<IValue>& inputs,
-    bool create_expr_evaluator)
-    : complete_fusion_(complete_fusion) {
-  collectVectorizationInfo(inputs);
-  if (create_expr_evaluator) {
-    initializeExpressionEvaluator(inputs);
-  }
-  collectIndexModeInfo(inputs);
-}
-
-SchedulerRuntimeInfo::SchedulerRuntimeInfo(
-    const SchedulerRuntimeInfo& copy_from)
-    : complete_fusion_(copy_from.complete_fusion_),
-      alignment_map_(copy_from.alignment_map_),
-      common_alignment_size_(copy_from.common_alignment_size_) {
-  expression_evaluator_ =
-      std::make_unique<ExpressionEvaluator>(complete_fusion_);
-}
-
-size_t SchedulerRuntimeInfo::getAlignmentSize(TensorView* tv) {
-  auto alignment_entry = alignment_map_.find(tv);
-  if (alignment_entry == alignment_map_.end()) {
-    return max_alignment_size_in_byte;
-  } else {
-    return alignment_entry->second;
-  }
-}
-
-void SchedulerRuntimeInfo::initializeExpressionEvaluator(
-    const at::ArrayRef<IValue>& inputs) {
-  // TODO: refactor bindFusionInputs to better support this
-  //  use case, i.e. support construct and bind input.
-  expression_evaluator_ =
-      std::make_unique<ExpressionEvaluator>(complete_fusion_);
-  *expression_evaluator_ =
-      executor_utils::bindFusionInputs(inputs, complete_fusion_);
-}
-
-size_t SchedulerRuntimeInfo::collectAlignmentSize(
-    const at::Tensor& tensor) const {
-  const size_t address = reinterpret_cast<size_t>(tensor.data_ptr());
-  size_t alignment_size = 1;
-  size_t next_alignment_size = 2;
-
-  while (alignment_size <= max_alignment_size_in_byte &&
-         address % next_alignment_size == 0) {
-    alignment_size = next_alignment_size;
-    next_alignment_size *= 2;
-  }
-
-  return alignment_size;
-}
-
-void SchedulerRuntimeInfo::collectVectorizationInfo(
-    const at::ArrayRef<IValue>& inputs) {
-  common_alignment_size_ = max_alignment_size_in_byte;
-  size_t number_of_inputs = complete_fusion_->inputs().size();
-  std::unordered_map<TensorView*, size_t> cg_tensor_to_at_tensor_index;
-
-  for (auto input_index : c10::irange(number_of_inputs)) {
-    if (auto input_tensor = dynamic_cast<TensorView*>(
-            complete_fusion_->inputs()[input_index])) {
-      if (input_tensor->nDims() == 0) {
-        // A 0-dim tensor input would not need vectorization
-        continue;
-      }
-      if (input_tensor->domain()
-              ->domain()[input_tensor->nDims() - 1]
-              ->isBroadcast()) {
-        // skip the tensors with innermost iterdomain broadcasted,
-        //  as we will not vectorize these.
-        continue;
-      }
-
-      // Collect strides of the input tensor
-      TORCH_INTERNAL_ASSERT(inputs[input_index].isTensor());
-      const auto& at_tensor = inputs[input_index].toTensor();
-
-      cg_tensor_to_at_tensor_index.emplace(
-          std::make_pair(input_tensor, input_index));
-
-      // Collect alignment of the input tensor
-      auto alignment_size = collectAlignmentSize(at_tensor);
-      common_alignment_size_ = std::min(alignment_size, common_alignment_size_);
-      alignment_map_[input_tensor] = alignment_size;
-    }
-  }
-
-  // Compute max vector word size for each input,
-  //  tensors with inner most broadcast already
-  //  filtered out.  common_alignment_size_ is
-  //  computed up to this point.
-  for (auto it : cg_tensor_to_at_tensor_index) {
-    vectorword_map_[it.first] = collectMaxVectorizeSize(
-        inputs[it.second].toTensor(), common_alignment_size_);
-  }
-}
-
-size_t SchedulerRuntimeInfo::collectMaxVectorizeSize(
-    const at::Tensor& tensor,
-    size_t max_vector_size_in_byte) {
-  size_t vector_size = 1;
-  size_t next_vector_size = 2;
-  bool next_size_compatible = true;
-
-  while (next_size_compatible &&
-         next_vector_size * tensor.itemsize() <= max_vector_size_in_byte) {
-    // If inner most dimension size is not divisible by new word size
-    //  then we cannot vectorize with this width. But we do not
-    //  care if all dimensions of this tensor is 1, i.e.
-    //  input is actually a un-squeezed 0-dim tensor.
-    for (size_t i = tensor.ndimension(); i > 0; i--) {
-      if (tensor.size(i - 1) != 1) {
-        if (tensor.size(tensor.ndimension() - 1) % next_vector_size != 0 ||
-            tensor.stride(tensor.ndimension() - 1) != 1) {
-          next_size_compatible = false;
-        }
-        break;
-      }
-    }
-
-    if (!next_size_compatible) {
-      break;
-    }
-
-    // If any stride is not divisible by the next word size,
-    //  we cannot vectorize with this width.
-    for (auto stride : tensor.strides()) {
-      if (stride != 1 && stride % next_vector_size != 0) {
-        next_size_compatible = false;
-        break;
-      }
-    }
-
-    if (next_size_compatible) {
-      vector_size = next_vector_size;
-      next_vector_size *= 2;
-    }
-  }
-
-  return vector_size;
-}
-
-size_t SchedulerRuntimeInfo::getVectorizableWidth(TensorView* tv) {
-  auto recorded_size_it = vectorword_map_.find(tv);
-  if (recorded_size_it != vectorword_map_.end()) {
-    return recorded_size_it->second;
-  }
-
-  // If we don't have an record, either it is a tv with innermost
-  //  broadcast, or it is an intermediate tensor allocated by fuser
-  auto tv_root = TensorDomain::noReductions(tv->getRootDomain());
-  auto tv_root_size = tv_root.size();
-
-  // Filter out 0-dim tensors
-  if (tv_root_size < 1) {
-    return 1;
-  }
-
-  // Filter out mismatched contiguity info
-  if (tv_root_size != tv->domain()->contiguity().size()) {
-    return 1;
-  }
-
-  // Filter out innermost broadcast tensors
-  auto inner_dimension = tv_root[tv_root_size - 1];
-  if (inner_dimension->isBroadcast()) {
-    return 1;
-  }
-
-  // Handle intermediate or output tensors that
-  //  will be allocated by fuser
-  auto maybe_data_type = tv->getDataType();
-
-  // Do not vectorize on data with unknown type
-  if (!maybe_data_type.has_value()) {
-    return 1;
-  }
-
-  size_t item_size = dataTypeSize(maybe_data_type.value());
-  // Assume we don't have non-divisible types for now.
-  TORCH_INTERNAL_ASSERT(max_alignment_size_in_byte % item_size == 0);
-  size_t max_vector_size = max_alignment_size_in_byte / item_size;
-
-  // Assuming intermediate tensors have friendly alignment, and
-  //  all contiguity true. Determine the largest power of 2 below
-  //  innermost dimension size for the word size of vectorizaiton
-  size_t vector_size = 1;
-  size_t next_vector_size = 2;
-  auto maybe_inner_dimension_size =
-      expression_evaluator_->evaluate(inner_dimension->extent());
-  TORCH_INTERNAL_ASSERT(maybe_inner_dimension_size.has_value());
-  size_t inner_dimension_size = maybe_inner_dimension_size.value();
-
-  while (next_vector_size <= max_vector_size &&
-         next_vector_size <= inner_dimension_size &&
-         inner_dimension_size % next_vector_size == 0) {
-    vector_size = next_vector_size;
-    next_vector_size *= 2;
-  }
-
-  // save output to avoid re-compute
-  vectorword_map_[tv] = vector_size;
-
-  return vector_size;
-}
-
-void SchedulerRuntimeInfo::collectIndexModeInfo(
-    const at::ArrayRef<at::IValue>& inputs) {
-  // Save 1 more bit besides the sign bit to be conservative
-  constexpr int64_t most_positive_int32_index =
-      std::numeric_limits<int>::max() / 2;
-  constexpr int64_t most_negative_int32_index =
-      std::numeric_limits<int>::min() / 2;
-
-  // Start by setting index mode to int32
-  index_mode_ = KernelIndexMode::INT32;
-
-  // Check all runtime inputs, and if any one of
-  //  the input's index exceeds max_int32 will
-  //  fall back to int64 indexing
-  for (auto ivalue_input : inputs) {
-    if (ivalue_input.isTensor()) {
-      auto tensor_input = ivalue_input.toTensor();
-      int64_t tensor_most_positive_index = 0;
-      int64_t tensor_most_negative_index = 0;
-      for (auto dim_i = 0; dim_i < tensor_input.ndimension(); dim_i++) {
-        // Ignore broadcast dimensions
-        if (tensor_input.size(dim_i) > 1) {
-          // accumulate based on the sign of stride
-          if (tensor_input.stride(dim_i) > 0) {
-            // Acuumulate positive stride
-            tensor_most_positive_index +=
-                (tensor_input.size(dim_i) - 1) * tensor_input.stride(dim_i);
-          } else {
-            // Acuumulate negative stride
-            tensor_most_negative_index +=
-                (tensor_input.size(dim_i) - 1) * tensor_input.stride(dim_i);
-          }
-        }
-      }
-
-      // Fall back to int64 if it can be either too positive
-      //  or too negative.
-      if (tensor_most_positive_index > most_positive_int32_index ||
-          tensor_most_negative_index < most_negative_int32_index) {
-        index_mode_ = KernelIndexMode::INT64;
-        return;
-      }
-    }
-  }
-}
-
-bool SchedulerEntry::sameAs(const SchedulerEntry* other) {
-  if (heuristc_ != other->heuristc_) {
-    return false;
-  }
-  if (index_mode_ != other->index_mode_) {
-    return false;
-  }
-  // Heuristic equal should imply has_reduction_param_ equal,
-  //  need to double check if it is the case before removing
-  //  the below one.
-  if (has_reduction_param_ != other->has_reduction_param_) {
-    return false;
-  }
-  if (has_reduction_param_) {
-    return rparams_ == other->rparams_;
-  } else {
-    return pparams_ == other->pparams_;
-  }
-  return true;
-}
-
-namespace {
-template <typename REDUCTION_OP = ReductionOp>
-inline bool isTrivialReduction(REDUCTION_OP* red) {
-  auto o_tv = red->out()->template as<TensorView>();
-  // Assuming graph unscheduled at this point.
-  for (auto id : o_tv->getRootDomain()) {
-    if (id->isReduction() && !id->extent()->isOneInt()) {
-      return false;
-    }
-  }
-  return true;
-}
-
-template <typename REDUCTION_OP = ReductionOp>
-std::vector<REDUCTION_OP*> findReductionOps(Fusion* fusion) {
-  std::vector<REDUCTION_OP*> red_ops;
-  for (auto expr : fusion->exprs()) {
-    if (auto red = dynamic_cast<REDUCTION_OP*>(expr)) {
-      if (!isTrivialReduction(red)) {
-        red_ops.push_back(red);
-      }
-    }
-  }
-  return red_ops;
-}
-
-class SingleReductionScheduler : public SchedulerEntry {
- public:
-  explicit SingleReductionScheduler(
-      Fusion* fusion,
-      SchedulerRuntimeInfo& runtime_info,
-      HeuristicSummary* data_cache = nullptr)
-      : SchedulerEntry(ScheduleHeuristic::Reduction, true) {
-    computeHeuristics(fusion, runtime_info, data_cache);
-  }
-
-  //! Check if the reduction heuristics apply in given fusion
-  static bool canSchedule(
-      Fusion* fusion,
-      SchedulerRuntimeInfo& runtime_info,
-      HeuristicSummary* data_cache = nullptr) {
-    if (data_cache) {
-      return true;
-    }
-
-    auto red_ops = findReductionOps(fusion);
-    auto welford_ops = findReductionOps<WelfordOp>(fusion);
-    if (red_ops.size() + welford_ops.size() != 1) {
-      return false;
-    }
-
-    bool is_welford = welford_ops.size() > 0;
-
-    if (SchedulerTopologyChecker::hasPostReductionBCast(fusion)) {
-      return false;
-    }
-
-    auto reduction_tv = is_welford ? welford_ops[0]->out()->as<TensorView>()
-                                   : red_ops[0]->out()->as<TensorView>();
-
-    if (!SchedulerTopologyChecker::supportedPostReductionFusion(
-            fusion, {reduction_tv})) {
-      return false;
-    }
-
-    return true;
-  }
-
-  void schedule(Fusion* fusion) override {
-    FUSER_PERF_SCOPE("Schedule Single Reduction");
-    scheduleReduction(fusion, rparams_);
-  }
-
- private:
-  void computeHeuristics(
-      Fusion* fusion,
-      SchedulerRuntimeInfo& runtime_info,
-      HeuristicSummary* data_cache = nullptr) {
-    auto param = getReductionHeuristics(fusion, runtime_info, data_cache);
-    TORCH_INTERNAL_ASSERT(param.has_value());
-    rparams_ = param.value();
-  }
-};
-
-class PointWiseScheduler : public SchedulerEntry {
- public:
-  explicit PointWiseScheduler(
-      Fusion* fusion,
-      SchedulerRuntimeInfo& runtime_info,
-      HeuristicSummary* data_cache = nullptr)
-      : SchedulerEntry(ScheduleHeuristic::PointWise, false) {
-    computeHeuristics(fusion, runtime_info, data_cache);
-  }
-
-  static bool canSchedule(
-      Fusion* fusion,
-      SchedulerRuntimeInfo& runtime_info,
-      HeuristicSummary* data_cache = nullptr) {
-    if (data_cache) {
-      return true;
-    }
-    auto red_ops = findReductionOps(fusion);
-    auto welford_ops = findReductionOps<WelfordOp>(fusion);
-    return red_ops.empty() && welford_ops.empty();
-  }
-
-  void schedule(Fusion* fusion) override {
-    FUSER_PERF_SCOPE("Schedule PointWise Fusion");
-    schedulePointwise(fusion, pparams_);
-  }
-
-  void computeHeuristics(
-      Fusion* fusion,
-      SchedulerRuntimeInfo& runtime_info,
-      HeuristicSummary* data_cache = nullptr) {
-    auto pparam = getPointwiseHeuristics(fusion, runtime_info, data_cache);
-    TORCH_INTERNAL_ASSERT(pparam.has_value());
-    pparams_ = pparam.value();
-  }
-};
-
-class NormalizationScheduler : public SchedulerEntry {
- public:
-  explicit NormalizationScheduler(
-      Fusion* fusion,
-      SchedulerRuntimeInfo& runtime_info,
-      HeuristicSummary* data_cache = nullptr)
-      : SchedulerEntry(ScheduleHeuristic::Normalization, true) {
-    computeHeuristics(fusion, runtime_info, data_cache);
-  }
-
-  void schedule(Fusion* fusion) override {
-    FUSER_PERF_SCOPE("Schedule Normalization Fusion");
-    scheduleNormalization(fusion, rparams_);
-  }
-
-  static bool canSchedule(
-      Fusion* fusion,
-      SchedulerRuntimeInfo& runtime_info,
-      HeuristicSummary* data_cache = nullptr) {
-    FUSER_PERF_SCOPE("NormalizationScheduler::canSchedule");
-
-    HeuristicCacheAccessor<std::vector<TensorView*>> reduction_tv_data;
-    // TODO: move all these boilerplate code into the accessor class
-    // (follow up)
-    if (data_cache && !data_cache->isRecording()) {
-      reduction_tv_data.writeTemporary(data_cache->getReductionTVs());
-    } else {
-      reduction_tv_data.writeNew(scheduler_utils::getReductionTvs(fusion));
-      if (data_cache && data_cache->isRecording()) {
-        data_cache->setReductionTVs(reduction_tv_data.read());
-      }
-    }
-
-    auto& reduction_tvs = reduction_tv_data.read();
-
-    if (!data_cache) {
-      if (reduction_tvs.size() == 0) {
-        // Use single reduction or pointwise logic
-        return false;
-      }
-
-      if (SchedulerTopologyChecker::hasNonNormalizePostReductionBCast(fusion)) {
-        return false;
-      }
-
-      // Before examining the reduction axes want to quickly
-      //   check the reductions have the same axis width
-      //   to avoid building root domain map in easier cases
-      bool valid_axis_count = false;
-      size_t axis_count = 0;
-      auto reduction_root_size = [](TensorView* red_tv) {
-        size_t count = 0;
-        for (auto id : red_tv->getRootDomain()) {
-          if (!id->isBroadcast()) {
-            count++;
-          }
-        }
-        return count;
-      };
-
-      for (auto red : reduction_tvs) {
-        if (!valid_axis_count) {
-          valid_axis_count = true;
-          axis_count = reduction_root_size(red);
-        } else {
-          if (reduction_root_size(red) != axis_count) {
-            return false;
-          }
-        }
-      }
-
-      // Use root domain map to check the reduction ops have the same axes
-      FusionGuard fg(fusion);
-      ComputeAtRootDomainMap root_map;
-      root_map.build(true);
-
-      // red_ops.size()>1 checked before
-      for (size_t it = 1; it < reduction_tvs.size(); it++) {
-        if (!checkEquivalence(
-                reduction_tvs[it - 1], reduction_tvs[it], root_map)) {
-          return false;
-        }
-      }
-    }
-
-    // TODO: move all these boilerplate code into the accessor class
-    // (follow up)
-    // Note: this persistent buffer is actually cached from
-    //  getNormalizationHeuristics. Will need to create a separate
-    //  cache entry if they are not the same.
-    HeuristicCacheAccessor<scheduler_utils::PersistentBufferInfo>
-        persistent_buffer_data;
-
-    if (data_cache && !data_cache->isRecording()) {
-      persistent_buffer_data.writeTemporary(
-          data_cache->getPersistentBufferInfo());
-    } else {
-      persistent_buffer_data.writeNew(
-          scheduler_utils::persistentBuffers(fusion));
-      if (data_cache && data_cache->isRecording()) {
-        data_cache->setPersistentBufferInfo(persistent_buffer_data.read());
-      }
-    }
-    auto& persistent_buffers = persistent_buffer_data.read();
-
-    auto persistent_buffer_size = scheduler_utils::persistentBufferSize(
-        fusion, runtime_info, persistent_buffers, data_cache);
-    if (persistent_buffer_size * 4 > scheduler_utils::register_file_size * 3) {
-      return false;
-    }
-
-    // TODO: really need to make inserting an entry into data_cache easier to do
-    HeuristicCacheAccessor<bool> has_post_reduction_bcast_data;
-
-    if (data_cache && !data_cache->isRecording()) {
-      has_post_reduction_bcast_data.writeTemporary(
-          data_cache->getHasPostReductionBCast());
-    } else {
-      has_post_reduction_bcast_data.writeNew(
-          SchedulerTopologyChecker::hasPostReductionBCast(fusion));
-      if (data_cache && data_cache->isRecording()) {
-        data_cache->setHasPostReductionBCast(
-            has_post_reduction_bcast_data.read());
-      }
-    }
-
-    HeuristicCacheAccessor<bool> supported_post_reduction_fusion_data;
-
-    if (data_cache && !data_cache->isRecording()) {
-      supported_post_reduction_fusion_data.writeTemporary(
-          data_cache->getSupportedPostReductionFusion());
-    } else {
-      supported_post_reduction_fusion_data.writeNew(
-          SchedulerTopologyChecker::supportedPostReductionFusion(
-              fusion, reduction_tvs));
-      if (data_cache && data_cache->isRecording()) {
-        data_cache->setSupportedPostReductionFusion(
-            supported_post_reduction_fusion_data.read());
-      }
-    }
-
-    auto has_post_reduction_bcast = has_post_reduction_bcast_data.read();
-    auto supported_post_reduction_fusion =
-        supported_post_reduction_fusion_data.read();
-
-    // Multi reduction scheduler has the same limitations as single reduction
-    // scheduler here
-    if (persistent_buffer_size <= 1) {
-      if (has_post_reduction_bcast) {
-        return false;
-      }
-
-      if (!supported_post_reduction_fusion) {
-        return false;
-      }
-    }
-
-    return true;
-  }
-
- private:
-  void computeHeuristics(
-      Fusion* fusion,
-      SchedulerRuntimeInfo& runtime_info,
-      HeuristicSummary* data_cache = nullptr) {
-    auto rparams = getNormalizationHeuristics(fusion, runtime_info, data_cache);
-    TORCH_INTERNAL_ASSERT(rparams.has_value());
-    rparams_ = rparams.value();
-  }
-
-  static bool checkEquivalence(
-      TensorView* out_tv0,
-      TensorView* out_tv1,
-      const ComputeAtRootDomainMap& root_map) {
-    const auto& out_root0 = out_tv0->getRootDomain();
-    const auto& out_root1 = out_tv1->getRootDomain();
-    const auto domain0 = out_tv0->domain();
-    const auto domain1 = out_tv1->domain();
-
-    auto it0 = out_root0.begin();
-    auto it1 = out_root1.begin();
-
-    auto skip_broadcast = [&]() {
-      while (it0 != out_root0.end() && (*it0)->isBroadcast()) {
-        it0++;
-      }
-      while (it1 != out_root1.end() && (*it1)->isBroadcast()) {
-        it1++;
-      }
-    };
-
-    skip_broadcast();
-    while (it0 != out_root0.end() && it1 != out_root1.end()) {
-      if ((*it0)->isReduction() != (*it1)->isReduction()) {
-        return false;
-      }
-      if (!root_map.canMap(domain0, (*it0), domain1, (*it1))) {
-        return false;
-      }
-      it0++;
-      it1++;
-      skip_broadcast();
-    }
-
-    return it0 == out_root0.end() && it1 == out_root1.end();
-  }
-};
-
-// Schedule Table
-const std::vector<ScheduleHeuristic>& all_heuristics() {
-  static const std::vector<ScheduleHeuristic> hlist = {
-      ScheduleHeuristic::Reduction,
-      ScheduleHeuristic::PointWise,
-      ScheduleHeuristic::Normalization};
-  return hlist;
-}
-
-} // namespace
-
-// Simple dispatcher interface
-bool SchedulerEntry::canSchedule(
-    ScheduleHeuristic sh,
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    HeuristicSummary* data_cache) {
-  switch (sh) {
-    case ScheduleHeuristic::PointWise:
-      return PointWiseScheduler::canSchedule(fusion, runtime_info, data_cache);
-    case ScheduleHeuristic::Reduction:
-      return SingleReductionScheduler::canSchedule(
-          fusion, runtime_info, data_cache);
-    case ScheduleHeuristic::Normalization:
-      return NormalizationScheduler::canSchedule(
-          fusion, runtime_info, data_cache);
-    default:
-      TORCH_INTERNAL_ASSERT(false, "unreachable");
-      return false;
-  }
-  return false;
-}
-
-std::unique_ptr<SchedulerEntry> SchedulerEntry::makeEntry(
-    ScheduleHeuristic sh,
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    HeuristicSummary* data_cache) {
-  std::unique_ptr<SchedulerEntry> scheduler_entry = nullptr;
-  switch (sh) {
-    case ScheduleHeuristic::PointWise:
-      scheduler_entry = std::make_unique<PointWiseScheduler>(
-          fusion, runtime_info, data_cache);
-      break;
-    case ScheduleHeuristic::Reduction:
-      scheduler_entry = std::make_unique<SingleReductionScheduler>(
-          fusion, runtime_info, data_cache);
-      break;
-    case ScheduleHeuristic::Normalization:
-      scheduler_entry = std::make_unique<NormalizationScheduler>(
-          fusion, runtime_info, data_cache);
-      break;
-    default:
-      TORCH_INTERNAL_ASSERT(false, "unreachable");
-  }
-
-  scheduler_entry->index_mode_ = runtime_info.getIndexMode();
-  return scheduler_entry;
-}
-
-// Simply loop through the list as baseline strategy
-c10::optional<ScheduleHeuristic> SchedulerEntry::proposeHeuristics(
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info) {
-  for (auto sh : all_heuristics()) {
-    if (canSchedule(sh, fusion, runtime_info)) {
-      return sh;
-    }
-  }
-  return c10::nullopt;
-}
-
-size_t SchedulerEntryHash::operator()(const SchedulerEntry& se) const {
-  if (se.hasReductionParam()) {
-    return ReductionParamsHash()(se.reductionParams());
-  } else {
-    return PointwiseParamsHash()(se.pointwiseParams());
-  }
-}
-
-std::string toString(ScheduleHeuristic sh) {
-  switch (sh) {
-    case ScheduleHeuristic::PointWise:
-      return "pointwise";
-    case ScheduleHeuristic::Reduction:
-      return "reduction";
-    case ScheduleHeuristic::Normalization:
-      return "normalization";
-    default:
-      TORCH_INTERNAL_ASSERT(false, "undefined schedule");
-  }
-  return "";
-}
-
-HeuristicSummary::HeuristicSummary(
-    Fusion* fusion,
-    ScheduleHeuristic heuristic,
-    SchedulerRuntimeInfo& runtime_info)
-    : heuristic_(heuristic) {
-  recording_ = true;
-  switch (heuristic) {
-    case ScheduleHeuristic::PointWise:
-      getPointwiseHeuristics(fusion, runtime_info, this);
-      PointWiseScheduler::canSchedule(fusion, runtime_info, this);
-      break;
-    case ScheduleHeuristic::Reduction:
-      getReductionHeuristics(fusion, runtime_info, this);
-      SingleReductionScheduler::canSchedule(fusion, runtime_info, this);
-      break;
-    case ScheduleHeuristic::Normalization:
-      getNormalizationHeuristics(fusion, runtime_info, this);
-      NormalizationScheduler::canSchedule(fusion, runtime_info, this);
-      break;
-    default:
-      TORCH_INTERNAL_ASSERT(false, "unknown heuristic");
-  }
-  validate();
-  recording_ = false;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.h b/torch/csrc/jit/codegen/cuda/scheduler/registry.h
deleted file mode 100644 (file)
index eb353e0..0000000
+++ /dev/null
@@ -1,389 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-#include <torch/csrc/jit/codegen/cuda/utils.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class SegmentedGroup;
-class ExpressionEvaluator;
-
-//!  SchedulerRuntimeInfo is the abstraction introduced in
-//! this PR for passing runtime input dependent information
-//! to the schedulers and kernel caches.
-//!
-//! Note:
-//!  if any additional info needed,  or maybe just the inputs themselves it
-//!    could just be added to this class, and they will be distributed to the
-//!    segmenter and schedulers.
-//!  It is important that input id encoding should be up to date with any change
-//!   of this class to avoid launching compiled kernels with illegal inputs.
-class TORCH_CUDA_CU_API SchedulerRuntimeInfo {
- public:
-  // Max vector size we will consider, in bytes,
-  //  currently set to 16B = 128b
-  const size_t max_alignment_size_in_byte = 16;
-
-  //! Create runtime info for given fusion and input. Creating and binding
-  //! evaluator is optional. The evaluator is used to manage intermediate
-  //!  integers in the fusion. We need them for segmenter and schedulers,
-  //!  but we don't need them when we are just using this class to provide
-  //!  additional encoding for kernel cache lookup.
-  SchedulerRuntimeInfo(
-      Fusion* complete_fusion,
-      const at::ArrayRef<at::IValue>& inputs,
-      bool create_expr_evaluator = false);
-
-  //! Create runtime info by copying all the global
-  //! input meta data (i.e. alignment), but not the
-  //! expression evaluator.
-  SchedulerRuntimeInfo(const SchedulerRuntimeInfo& global_runtime_info);
-
-  //! Lookup for the alignment sizes of the given tv. Currently only returns
-  //!  actual alignment info for input tensors to the complete fusion,
-  //!  and for other intermediate/fuser-allocated tensors will
-  //!  return max_alignment_size_in_byte.
-  size_t getAlignmentSize(TensorView* tv);
-
-  //! Take the minimum of input tv alignment sizes. This is both information for
-  //! vectorization and
-  //!  a signature for kernel cache id lookup. May need to be updated with
-  //!  vectorization logic.
-  size_t getCommonAlignmentSize() const {
-    return common_alignment_size_;
-  }
-
-  //! Returns the max width the given tensor view can be vectorized,
-  //!  for input tensors will use the pre-computed value based on
-  //!  the given tensor alignment and strides. For intermediate tensors
-  //!  will assume it is contiguous and aligned to 128bit/16Byte
-  size_t getVectorizableWidth(TensorView* tv);
-
-  KernelIndexMode getIndexMode() {
-    return index_mode_;
-  }
-
-  Fusion* fusion() {
-    return complete_fusion_;
-  }
-
-  ExpressionEvaluator& expressionEvaluator() {
-    TORCH_INTERNAL_ASSERT(expression_evaluator_ != nullptr);
-    return *expression_evaluator_;
-  }
-
- private:
-  // Bind full fusion inputs to the internal expression evaluator
-  void initializeExpressionEvaluator(const at::ArrayRef<at::IValue>& inputs);
-
-  // Compute alignment data for all input tensors of full fusion
-  void collectVectorizationInfo(const at::ArrayRef<at::IValue>& inputs);
-
-  // Compute alignment data for given tensor
-  size_t collectAlignmentSize(const at::Tensor& tensor) const;
-
-  // Compute max vectorization word size for each an input tensor
-  size_t collectMaxVectorizeSize(
-      const at::Tensor& tensor,
-      size_t max_word_size_in_byte);
-
-  // check if input is compatible with 32b index mode
-  void collectIndexModeInfo(const at::ArrayRef<at::IValue>& inputs);
-
- private:
-  std::unique_ptr<ExpressionEvaluator> expression_evaluator_ = nullptr;
-  Fusion* complete_fusion_;
-  std::unordered_map<TensorView*, size_t> alignment_map_;
-  std::unordered_map<TensorView*, size_t> vectorword_map_;
-  size_t common_alignment_size_;
-  KernelIndexMode index_mode_ = KernelIndexMode::INT64;
-};
-
-class HeuristicSummary;
-
-//! Virtual base class for schedule heuristics
-//!   heuristic implementations derive from this
-//!   class and implement a schedule(Fusion*)
-//!   and a bool canSchedule(Fusion*) interface
-class TORCH_CUDA_CU_API SchedulerEntry {
- public:
-  //! Fusion runtime facing API,
-  //!   builds a new entry with the given heuristics
-  //!   corresponding to the given fusion
-  static std::unique_ptr<SchedulerEntry> makeEntry(
-      ScheduleHeuristic sh,
-      Fusion* fusion,
-      SchedulerRuntimeInfo& runtime_info,
-      HeuristicSummary* data_cache = nullptr);
-
-  virtual ~SchedulerEntry() = default;
-
-  //! External access for canSchedule utilities through SchedulerEntry
-  //!  to avoid exposing a single function to the namespace
-  static bool canSchedule(
-      ScheduleHeuristic sh,
-      Fusion* fusion,
-      SchedulerRuntimeInfo& runtime_info,
-      HeuristicSummary* data_cache = nullptr);
-
-  //! Fusion segmenter facing API,
-  //!   returns a schedule that applies in the given fusion, returns a nullopt
-  //!   if no schedule in the registry can handle.
-  static c10::optional<ScheduleHeuristic> proposeHeuristics(
-      Fusion* fusion,
-      SchedulerRuntimeInfo& runtime_info);
-
-  //! Fusion runtime facing API,
-  //!   schedule the given fusion with heuristics owned
-  //!   by this entry, for actual heuristics to override
-  virtual void schedule(Fusion* fusion) = 0;
-
-  //! Heuristic comparison
-  bool sameAs(const SchedulerEntry* other);
-
-  bool hasReductionParam() const {
-    return has_reduction_param_;
-  }
-
-  ScheduleHeuristic heuristc() const {
-    return heuristc_;
-  }
-
-  KernelIndexMode indexMode() const {
-    return index_mode_;
-  }
-
-  const ReductionParams& reductionParams() const {
-    TORCH_INTERNAL_ASSERT(
-        has_reduction_param_, "This schedule heuristic is not reduction.");
-    return rparams_;
-  }
-
-  const PointwiseParams& pointwiseParams() const {
-    TORCH_INTERNAL_ASSERT(
-        !has_reduction_param_, "This schedule heuristic is not pointwise.");
-    return pparams_;
-  }
-
-  void updateLaunchConstraint(const LaunchParams& launch_params) {
-    if (hasReductionParam()) {
-      rparams_.lparams = launch_params;
-    } else {
-      pparams_.lparams = launch_params;
-    }
-  }
-
- protected:
-  explicit SchedulerEntry(ScheduleHeuristic heuristic, bool has_reduction_param)
-      : heuristc_(heuristic), has_reduction_param_(has_reduction_param) {}
-
-  //! What kind of heuristics does this entry have?
-  const ScheduleHeuristic heuristc_;
-
-  //! Has reduction params if true, else has pointwise params
-  const bool has_reduction_param_;
-
-  //! Reduction parameters if applicable
-  ReductionParams rparams_;
-
-  //! Pointwise parameters if applicable
-  PointwiseParams pparams_;
-
-  //! Kernel Index Mode
-  KernelIndexMode index_mode_;
-};
-
-//! Hash function for a scheduler entry
-class TORCH_CUDA_CU_API SchedulerEntryHash {
- public:
-  size_t operator()(const SchedulerEntry& se) const;
-};
-
-//! Debug print function for heuristics
-std::string toString(ScheduleHeuristic sh);
-
-class TORCH_CUDA_CU_API HeuristicSummary {
-  using ValToFactorMap = std::unordered_map<Val*, int>;
-  using ValToFactorMapPtr = std::unique_ptr<ValToFactorMap>;
-  using ScopedPersistenceFactorMap =
-      std::unordered_map<Val*, ValToFactorMapPtr>;
-
- public:
-  HeuristicSummary(
-      Fusion* fusion,
-      ScheduleHeuristic heuristic,
-      SchedulerRuntimeInfo& runtime_info);
-  // Recording scheme:
-  bool isRecording() {
-    return recording_;
-  }
-
-  // Validate post recording:
-  //  make sure we have collected all the needed fields
-  void validate() {
-    switch (heuristic_) {
-      case ScheduleHeuristic::PointWise:
-        TORCH_INTERNAL_ASSERT(vectorizable_inputs_outputs_);
-        TORCH_INTERNAL_ASSERT(mapped_input_output_dims_);
-        break;
-      case ScheduleHeuristic::Reduction:
-        TORCH_INTERNAL_ASSERT(reduction_tvs_);
-        break;
-      case ScheduleHeuristic::Normalization:
-        TORCH_INTERNAL_ASSERT(vectorizable_inputs_outputs_);
-        TORCH_INTERNAL_ASSERT(reduction_tvs_);
-        TORCH_INTERNAL_ASSERT(persistent_buffer_info_);
-        TORCH_INTERNAL_ASSERT(has_post_reduction_bcast_);
-        TORCH_INTERNAL_ASSERT(supported_post_reduction_fusion_);
-        break;
-    }
-  }
-
-  // Accessors (un-protected for now)
-  void setVectorizableInputsOutputs(const std::vector<TensorView*>& input) {
-    TORCH_INTERNAL_ASSERT(recording_);
-
-    if (!vectorizable_inputs_outputs_) {
-      vectorizable_inputs_outputs_ =
-          std::make_unique<std::vector<TensorView*>>(input);
-    }
-  }
-
-  auto* getVectorizableInputsOutputs() {
-    return vectorizable_inputs_outputs_.get();
-  }
-
-  void setReductionTVs(const std::vector<TensorView*>& input) {
-    TORCH_INTERNAL_ASSERT(recording_);
-
-    if (!reduction_tvs_) {
-      reduction_tvs_ = std::make_unique<std::vector<TensorView*>>(input);
-    }
-  }
-
-  auto* getReductionTVs() {
-    return reduction_tvs_.get();
-  }
-
-  void setPersistentBufferInfo(
-      const scheduler_utils::PersistentBufferInfo& input) {
-    TORCH_INTERNAL_ASSERT(recording_);
-
-    if (!persistent_buffer_info_) {
-      persistent_buffer_info_ =
-          std::make_unique<scheduler_utils::PersistentBufferInfo>(input);
-    }
-  }
-
-  auto* getPersistentBufferInfo() {
-    return persistent_buffer_info_.get();
-  }
-
-  void setSupportedPostReductionFusion(bool input) {
-    TORCH_INTERNAL_ASSERT(recording_);
-
-    if (!supported_post_reduction_fusion_) {
-      supported_post_reduction_fusion_ = std::make_unique<bool>(input);
-    }
-  }
-
-  auto* getSupportedPostReductionFusion() {
-    return supported_post_reduction_fusion_.get();
-  }
-
-  void setHasPostReductionBCast(bool input) {
-    TORCH_INTERNAL_ASSERT(recording_);
-
-    if (!has_post_reduction_bcast_) {
-      has_post_reduction_bcast_ = std::make_unique<bool>(input);
-    }
-  }
-
-  auto* getHasPostReductionBCast() {
-    return has_post_reduction_bcast_.get();
-  }
-
-  void setScopedPersistenceFactorMap(const ScopedPersistenceFactorMap& input) {
-    TORCH_INTERNAL_ASSERT(recording_);
-
-    scope_persistence_factor_map_ =
-        std::make_unique<ScopedPersistenceFactorMap>();
-    for (const auto& it : input) {
-      ValToFactorMap& to_copy = *(it.second);
-      scope_persistence_factor_map_->operator[](it.first) =
-          std::make_unique<ValToFactorMap>(to_copy);
-    }
-  }
-
-  auto* getScopedPersistenceFactorMap() {
-    return scope_persistence_factor_map_.get();
-  }
-
-  void setMappedInputOutputDims(const std::vector<int64_t>& input) {
-    TORCH_INTERNAL_ASSERT(recording_);
-
-    if (!mapped_input_output_dims_) {
-      mapped_input_output_dims_ = std::make_unique<std::vector<int64_t>>(input);
-    }
-  }
-
-  auto* getMappedInputOutputDims() {
-    return mapped_input_output_dims_.get();
-  }
-
- private:
-  ScheduleHeuristic heuristic_;
-  bool recording_ = true;
-
-  // Actual data payload, could be folded into subclasses later.
-  std::unique_ptr<std::vector<TensorView*>> vectorizable_inputs_outputs_;
-  std::unique_ptr<std::vector<TensorView*>> reduction_tvs_;
-  std::unique_ptr<scheduler_utils::PersistentBufferInfo>
-      persistent_buffer_info_;
-  std::unique_ptr<bool> has_post_reduction_bcast_;
-  std::unique_ptr<bool> supported_post_reduction_fusion_;
-  std::unique_ptr<ScopedPersistenceFactorMap> scope_persistence_factor_map_;
-  std::unique_ptr<std::vector<int64_t>> mapped_input_output_dims_;
-};
-
-// A temporary utility class to save some boilerplate code when
-//  using HeuristicSummary. Can be significantly improved in a follow up.
-template <typename T>
-class HeuristicCacheAccessor {
- public:
-  HeuristicCacheAccessor() = default;
-
-  T& read() {
-    if (temporary_data_) {
-      return *temporary_data_;
-    } else {
-      return *owned_data_;
-    }
-  }
-
-  void writeNew(T data) {
-    owned_data_ = std::make_unique<T>(std::move(data));
-  }
-
-  void takeNew(std::unique_ptr<T>& data) {
-    owned_data_ = std::move(data);
-  }
-
-  void writeTemporary(T* data) {
-    temporary_data_ = data;
-  }
-
- private:
-  std::unique_ptr<T> owned_data_ = nullptr;
-  T* temporary_data_ = nullptr;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp
deleted file mode 100644 (file)
index 1faa90c..0000000
+++ /dev/null
@@ -1,1643 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-namespace scheduler_utils {
-size_t mergeReduction(
-    TensorView* tv,
-    const std::unordered_set<IterDomain*>& dont_merge) {
-  int prev_i = -1;
-  size_t num_merged = 0;
-  for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
-    if (!tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
-      continue;
-    }
-    if (prev_i == -1) {
-      prev_i = i;
-    } else {
-      tv->merge(i, prev_i);
-      prev_i = i;
-      num_merged++;
-    }
-  }
-  if (prev_i != 0) {
-    tv->reorder({{prev_i, 0}});
-  }
-
-  return prev_i == -1 ? 0 : num_merged + 1;
-}
-
-size_t mergeNonReduction(
-    TensorView* tv,
-    const std::unordered_set<IterDomain*>& dont_merge) {
-  int prev_i = -1;
-  size_t num_merged = 0;
-  if (tv->nDims() == 0) {
-    return 0;
-  }
-  for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
-    if (tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
-      continue;
-    }
-    if (prev_i == -1) {
-      prev_i = i;
-    } else {
-      tv->merge(i, prev_i);
-      prev_i = i;
-      num_merged++;
-    }
-  }
-  if (prev_i != 0) {
-    tv->reorder({{prev_i, 0}});
-  }
-
-  return prev_i == -1 ? 0 : num_merged + 1;
-}
-
-void parallelizeAllLike(
-    TensorView* reference_tv,
-    const std::vector<TensorView*>& all_tvs) {
-  FusionGuard fg(reference_tv->fusion());
-
-  auto ca_loop_map = ComputeAtMap(ComputeAtMap::MappingMode::LOOP);
-  ca_loop_map.build(FusionGuard::getCurFusion());
-  for (auto id : reference_tv->domain()->domain()) {
-    ca_loop_map.getConcreteMappedID(id)->parallelize(id->getParallelType());
-  }
-
-  for (auto tv : all_tvs) {
-    if (tv->isFusionInput()) {
-      continue;
-    }
-    for (size_t i = 0; i < tv->domain()->domain().size(); i++) {
-      tv->axis(i)->parallelize(
-          ca_loop_map.getConcreteMappedID(tv->axis(i))->getParallelType());
-    }
-  }
-}
-
-void computeAtInputs(TensorView* consumer, int pos, ComputeAtMode mode) {
-  for (auto inp_tv : ir_utils::inputTvsOf(consumer)) {
-    inp_tv->computeAt(consumer, pos, mode);
-  }
-}
-
-void computeWithOutputs(TensorView* producer, int pos, ComputeAtMode mode) {
-  for (auto out_tv : ir_utils::outputTvsOf(producer)) {
-    producer->computeWith(out_tv, pos, mode);
-  }
-}
-
-PersistentBufferInfo persistentBuffers(Fusion* fusion) {
-  FusionGuard fg(fusion);
-
-  PersistentBufferInfo info;
-
-  ComputeAtRootDomainMap root_map;
-  root_map.build();
-
-  auto all_tvs = ir_utils::allTvs(fusion);
-
-  for (auto producer : all_tvs) {
-    bool mappable = true;
-    auto consumers = ir_utils::consumerTvsOf(producer);
-    if (consumers.empty()) {
-      continue;
-    }
-
-    auto mappable_roots =
-        root_map.getMappableDims(producer->domain(), consumers[0]->domain());
-
-    auto p_root = producer->getMaybeRFactorDomain();
-
-    for (auto p_root_id : p_root) {
-      if (p_root_id->isReduction()) {
-        continue;
-      }
-      if (!mappable_roots.count(p_root_id)) {
-        mappable = false;
-        info.unmappable_dims.emplace(p_root_id);
-      }
-    }
-
-    if (!mappable) {
-      info.buffers.push_back(producer);
-    }
-  }
-  return info;
-}
-
-TvProperties getProperties(
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    TensorView* tv) {
-  TvProperties properties;
-  FusionGuard fg(fusion);
-
-  auto red_root_dom = tv->getRootDomain();
-  for (size_t i = red_root_dom.size(); i > 0; i--) {
-    if (red_root_dom[i - 1]->isBroadcast()) {
-      continue;
-    } else if (red_root_dom[i - 1]->isReduction()) {
-      break;
-    } else {
-      properties.fastest_dim_reduction = false;
-      break;
-    }
-  }
-
-  bool hit_reduction = false;
-  auto root_dom = tv->getMaybeRFactorDomain();
-  for (auto it = root_dom.rbegin(); it != root_dom.rend(); ++it) {
-    auto id = *it;
-
-    auto inferred_val =
-        runtime_info.expressionEvaluator().evaluate(id->extent());
-    TORCH_INTERNAL_ASSERT(
-        inferred_val.has_value(), "Error inferring reduction size.");
-    if (id->isReduction()) {
-      hit_reduction = true;
-      properties.reduction_numel *= inferred_val.value();
-    } else {
-      auto dim_size = inferred_val.value();
-      properties.iteration_numel *= dim_size;
-      if (hit_reduction) {
-        properties.iter_outside_red *= dim_size;
-      } else {
-        properties.iter_inside_red *= dim_size;
-      }
-    }
-  }
-
-  if (properties.reduction_numel == 1) {
-    properties.iter_outside_red =
-        properties.iter_outside_red * properties.iter_inside_red;
-    properties.iter_inside_red = 1;
-    properties.fastest_dim_reduction = true;
-  }
-
-  return properties;
-}
-
-void computeAtBetween(
-    const std::vector<TensorView*>& producers,
-    const std::vector<TensorView*>& overall_consumers,
-    int pos,
-    ComputeAtMode mode,
-    std::unordered_set<IterDomain*> mapped_to_trivial_reduction) {
-  for (auto producer : producers) {
-    // Figure out what's between producer and overall_consumers, will not give
-    // back any consumers that are not downstream from producer
-    auto all_vals_between = DependencyCheck::getAllValsBetween(
-        {producer}, {overall_consumers.begin(), overall_consumers.end()});
-
-    std::unordered_set<Val*> all_vals_between_set(
-        all_vals_between.begin(), all_vals_between.end());
-
-    for (auto consumer : overall_consumers) {
-      if (all_vals_between_set.count(consumer)) {
-        // The way we generate producers and consumers is that we inch away from
-        // inputs/outputs. There's a chance we could meet in the middle.
-        if (producer == consumer) {
-          continue;
-        }
-
-        auto pos_it = std::find_if(
-            consumer->domain()->domain().begin(),
-            consumer->domain()->domain().end(),
-            [&mapped_to_trivial_reduction](IterDomain* id) {
-              return mapped_to_trivial_reduction.count(id);
-            });
-
-        pos = pos_it == consumer->domain()->domain().end()
-            ? pos
-            : std::min(
-                  (int)std::distance(
-                      consumer->domain()->domain().begin(), pos_it) +
-                      1,
-                  (pos < 0 ? pos + (int)consumer->nDims() : pos));
-        // Assume we don't want to reset computeAt on tensors that have already
-        // performed it.
-        producer->computeAt(consumer, pos, mode);
-      }
-    }
-  }
-}
-
-int64_t persistentBufferSize(
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    PersistentBufferInfo& persistent_buffers,
-    HeuristicSummary* data_cache) {
-  FUSER_PERF_SCOPE("scheduler_utils::persistentBufferSize");
-
-  if (persistent_buffers.buffers.empty()) {
-    return 0;
-  }
-
-  int64_t persistent_buffer_size = 0;
-
-  using ValToFactorMap = std::unordered_map<Val*, int>;
-  using ValToFactorMapPtr = std::unique_ptr<ValToFactorMap>;
-  using ScopedPersistenceFactorMap =
-      std::unordered_map<Val*, ValToFactorMapPtr>;
-
-  HeuristicCacheAccessor<ScopedPersistenceFactorMap>
-      scoped_persistent_factor_data;
-  // TODO: move all these boilerplate code into the accessor class
-  // (follow up)
-
-  // Caching traversal result in this case.
-  //  This one is slightly more involving. The end result we want is all the
-  //  concrete
-  //   int values in scoped_persistence. Essentially:
-  //     scoped_persistence [val] = sum_over_all_persistent_tv (
-  //     contrubution_from_tv_to_val * persistent_size_of_tv  )
-  //  Here contrubution_from_tv_to_val can be determined at compile time.
-  //  persistent_size_of_tv is a runtime value but
-  //   doesn't require heavy graph traversal.
-  //  So in this cache entry we try to save a matrix of contribution factors,
-  //  i.e.
-  //
-  //   new_persistent_factor_map[tv][val] = contribution_from_tv_to_val, from
-  //   compile time and we combine the factor
-  //
-  //   with runtime persistent buffer sizes at runtime.
-  if (data_cache && !data_cache->isRecording()) {
-    scoped_persistent_factor_data.writeTemporary(
-        data_cache->getScopedPersistenceFactorMap());
-  } else {
-    // Compute new scoped persisitence factor:
-    auto new_persistent_factor_map_ptr =
-        std::make_unique<ScopedPersistenceFactorMap>();
-    auto& new_persistent_factor_map = *new_persistent_factor_map_ptr;
-
-    for (auto tv : persistent_buffers.buffers) {
-      auto& consumer_tv_to_factor_map_ptr = new_persistent_factor_map[tv];
-      consumer_tv_to_factor_map_ptr = std::make_unique<ValToFactorMap>();
-      auto& consumer_tv_to_factor_map = *consumer_tv_to_factor_map_ptr;
-
-      // All expressions between tv and its consumers must have tv's persistent
-      // buffer allocated. This is an optimistic view on how many registers we
-      // need allocated in the kernel, since if we ordered two persistent
-      // buffers that are completely independent to somehow overlap with
-      // eachother we would assume we wouldn't need those two buffers active at
-      // the same time, even though they would be.
-      //
-      // Unfortunately this limitation is hard to work around as we would have
-      // to actually generate the kernel before we know if it would fit
-      // persistently in registers. In practice, though, this should not happen
-      // as inlining loop structures where the persistent buffer is used should
-      // prevent muiltiple persistent buffers from being merged togther if not
-      // necessary.
-      auto consumers_of_tv = ir_utils::consumerTvsOf(tv);
-      for (auto val : DependencyCheck::getAllValsBetween(
-               {tv}, {consumers_of_tv.begin(), consumers_of_tv.end()})) {
-        // Persistent normalization kernels imply that all persistent buffers
-        // have the same dimensionality. Assume if a persistent buffer is
-        // consumed by another we can alias and reuse the memory.
-        if (val == tv) {
-          continue;
-        }
-
-        if (consumer_tv_to_factor_map.count(val)) {
-          consumer_tv_to_factor_map.at(val) += 1;
-        } else {
-          consumer_tv_to_factor_map[val] = 1;
-        }
-      }
-    }
-
-    // Caching boilerplate (TO be cleaned up in a follow up)
-    scoped_persistent_factor_data.takeNew(new_persistent_factor_map_ptr);
-    if (data_cache && data_cache->isRecording()) {
-      data_cache->setScopedPersistenceFactorMap(
-          scoped_persistent_factor_data.read());
-    }
-  }
-
-  auto& scoped_persistence_factor = scoped_persistent_factor_data.read();
-
-  // Runtime: convert the persistent factor to actual values
-  std::unordered_map<Val*, int64_t> scoped_persistence;
-
-  for (auto tv : persistent_buffers.buffers) {
-    int64_t tv_persistent_numel = -1;
-    for (auto id : tv->getMaybeRFactorDomain()) {
-      if (id->isReduction() || id->isBroadcast()) {
-        continue;
-      }
-      // Unmappable dimensions are those that we cannot inline into other
-      // tensor views. So they're the ones that need to be persistent.
-      if (!persistent_buffers.unmappable_dims.count(id)) {
-        continue;
-      }
-
-      auto id_size = runtime_info.expressionEvaluator().evaluate(id->extent());
-      TORCH_INTERNAL_ASSERT(
-          id_size.has_value(),
-          "Cannot generate heuristics if we don't have input information.");
-      if (tv_persistent_numel == -1) {
-        tv_persistent_numel = id_size.value();
-      } else {
-        tv_persistent_numel *= id_size.value();
-      }
-    }
-
-    persistent_buffer_size =
-        tv_persistent_numel * dataTypeSize(tv->getDataType().value());
-
-    // Look up the contribution part from the cached matrix:
-    auto scoped_factor_it = scoped_persistence_factor.find(tv);
-    if (scoped_factor_it != scoped_persistence_factor.end()) {
-      // now looking at scoped_persistence_factor[tv]
-      for (auto val_to_factor_it : *(scoped_factor_it->second)) {
-        // (val_to_factor_it) is (val, factor)
-        int64_t persistent_buffer_size_contribution =
-            persistent_buffer_size * val_to_factor_it.second;
-
-        //  try to write factor * persistent_buffer_size into
-        //  scoped_persistence[val]
-        auto val_it = scoped_persistence.find(val_to_factor_it.first);
-        if (val_it == scoped_persistence.end()) {
-          scoped_persistence[val_to_factor_it.first] =
-              persistent_buffer_size_contribution;
-        } else {
-          val_it->second += persistent_buffer_size_contribution;
-        }
-      }
-    }
-  }
-
-  // Find the maximum persistent buffer use
-  int64_t max_persistence_size = 0;
-  for (auto persistent_entry : scoped_persistence) {
-    max_persistence_size =
-        std::max(max_persistence_size, persistent_entry.second);
-  }
-
-  return max_persistence_size;
-}
-
-std::unordered_set<IterDomain*> getTrivialReductionMap(Fusion* fusion) {
-  auto all_tvs = ir_utils::allTvs(fusion);
-  std::unordered_set<IterDomain*> mapped_to_trivial_reduction;
-  for (auto tv : all_tvs) {
-    // root domain vs domain shouldn't matter as at this point we shouldn't have
-    // any transformations.
-    for (auto id : tv->getRootDomain()) {
-      if (id->isTrivialReduction()) {
-        mapped_to_trivial_reduction.emplace(id);
-      }
-    }
-  }
-
-  if (!mapped_to_trivial_reduction.empty()) {
-    // Shouldn't matter which compute at map we use
-    auto ca_index_map = ComputeAtMap(ComputeAtMap::MappingMode::INDEX);
-    ca_index_map.build(fusion);
-    // Make a copy we need to check mappings of all
-    auto trivial_ids = mapped_to_trivial_reduction;
-    for (auto tv : all_tvs) {
-      for (auto id : tv->getRootDomain()) {
-        if (!id->extent()->isOneInt()) {
-          continue;
-        }
-        if (std::any_of(
-                trivial_ids.begin(),
-                trivial_ids.end(),
-                [&ca_index_map, &id](IterDomain* trivial_id) {
-                  return ca_index_map.areMapped(id, trivial_id);
-                })) {
-          mapped_to_trivial_reduction.emplace(id);
-        }
-      }
-    }
-  }
-  return mapped_to_trivial_reduction;
-}
-
-std::pair<bool, bool> canonicalDimReduction(Fusion* fusion, TensorView* tv) {
-  std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
-      getTrivialReductionMap(fusion);
-
-  TORCH_INTERNAL_ASSERT(tv != nullptr);
-
-  // We coalesce all reduction axes to the right;
-  bool has_red_axis = mergeReduction(tv, mapped_to_trivial_reduction) > 0;
-
-  bool has_iter_axis = mergeNonReduction(tv, mapped_to_trivial_reduction) > 0;
-  return {has_iter_axis, has_red_axis};
-}
-
-std::vector<TensorView*> getReductionTvs(Fusion* fusion) {
-  auto all_tvs = ir_utils::allTvs(fusion);
-  std::vector<TensorView*> reduction_tvs;
-  for (auto tv : all_tvs) {
-    if (!tv->isFusionInput() &&
-        std::any_of(
-            tv->domain()->domain().begin(),
-            tv->domain()->domain().end(),
-            [](IterDomain* id) {
-              return id->isReduction() && !id->isTrivialReduction();
-            })) {
-      reduction_tvs.emplace_back(tv);
-    }
-  }
-
-  // Remove multi outputs from reduction tensor views
-  std::unordered_set<Expr*> seen_reduction_exprs;
-  reduction_tvs.erase(
-      std::remove_if(
-          reduction_tvs.begin(),
-          reduction_tvs.end(),
-          [&seen_reduction_exprs](TensorView* tv) {
-            TORCH_INTERNAL_ASSERT(
-                tv->definition() != nullptr,
-                "Somehow a tensor view without a definition but a reduction snuck into the scheduler reduction list.");
-            if (!seen_reduction_exprs.emplace(tv->definition()).second) {
-              return true;
-            }
-            return false;
-          }),
-      reduction_tvs.end());
-  return reduction_tvs;
-}
-
-TensorView* scheduleReductionTV(
-    const ReductionParams& rparams,
-    TensorView* reduction_tv,
-    bool has_iter_axis) {
-  TensorView* reference_tv = nullptr;
-  if (rparams.fastest_dim) {
-    const int iter_axis = 0;
-    const int reduce_axis = has_iter_axis ? 1 : 0;
-
-    // Do multiple reductions per block
-    if (rparams.multiple_reds_per_blk) {
-      if (rparams.reduction_unroll) {
-        // Fastest dim, multiple reductions per block
-        // Output Dimensions
-        // [x-BIDx, x-TIDy
-        //  0       1
-        //
-        //  Reduction Dimensions
-        //  rF-Remain, rf-Unswitch, rf-Unroll, X-TIDx]
-        //  2(r)          3(r+1)     4(r+2)    5(r+3)
-        //  Reduction Dimensions
-        //  rF-Remain, rf-Unswitch, X-TIDx, rf-Vectorize]
-        //  2(r)          3(r+1)     4(r+2)    5(r+3)
-
-        //  X-TIDx, rF-Remain, rf-Unswitch, rf-Unroll/Vect]
-        //   2(r)     3(r+1)       4(r+2)      5(r+3)
-
-        if (!rparams.persistent_kernel) {
-          if (rparams.vectorize) {
-            reduction_tv->split(reduce_axis, rparams.loop_unroll);
-            reduction_tv->split(
-                reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
-          } else {
-            reduction_tv->split(
-                reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
-            reduction_tv->split(reduce_axis, rparams.loop_unroll);
-          }
-          // Unswitch axis which gives us finer control on allocations with
-          // unrolling
-          reduction_tv->split(reduce_axis, 1);
-        } else {
-          if (rparams.vectorize) {
-            reduction_tv->split(reduce_axis, rparams.batches_per_block, false);
-            reduction_tv->split(reduce_axis + 1, rparams.loop_unroll);
-          } else {
-            reduction_tv->split(
-                reduce_axis,
-                rparams.batches_per_block * rparams.loop_unroll,
-                false);
-            reduction_tv->split(reduce_axis, rparams.loop_unroll);
-          }
-          // Unswitch axis which gives us finer control on allocations with
-          // unrolling
-          reduction_tv->split(reduce_axis, 1);
-        }
-
-        if (rparams.vectorize) {
-          reduction_tv->reorder(
-              {{reduce_axis, reduce_axis + 1},
-               {reduce_axis + 1, reduce_axis + 2},
-               {reduce_axis + 2, reduce_axis}});
-        } else {
-          reduction_tv->reorder(
-              {{reduce_axis + 3, reduce_axis},
-               {reduce_axis, reduce_axis + 1},
-               {reduce_axis + 1, reduce_axis + 2},
-               {reduce_axis + 2, reduce_axis + 3}});
-        }
-
-        reference_tv = ir_utils::rfactorHelper(
-            reduction_tv, {reduce_axis + 1, reduce_axis + 2, reduce_axis + 3});
-
-        reference_tv->axis(reduce_axis)->parallelize(ParallelType::TIDx);
-
-        if (rparams.vectorize) {
-          reference_tv->axis(reduce_axis + 3)
-              ->parallelize(ParallelType::Vectorize);
-        } else {
-          reference_tv->axis(reduce_axis + 3)
-              ->parallelize(ParallelType::Unroll);
-        }
-        reference_tv->axis(reduce_axis + 2)
-            ->parallelize(ParallelType::Unswitch);
-
-        if (has_iter_axis) {
-          reference_tv->split(
-              iter_axis, NamedScalar::getParallelDim(ParallelType::TIDy));
-          reference_tv->axis(iter_axis + 1)->parallelize(ParallelType::TIDy);
-          if (rparams.split_grid_dim) {
-            reference_tv->split(iter_axis, x_grid_limit);
-            reference_tv->axis(iter_axis + 1)->parallelize(ParallelType::BIDx);
-          } else {
-            reference_tv->axis(iter_axis)->parallelize(ParallelType::BIDx);
-          }
-        }
-      } else {
-        TORCH_INTERNAL_ASSERT(
-            has_iter_axis,
-            "This scheduler requires an outer dim to the reduction.");
-        // Fastest dim, Multiple reductions per block iter unroll
-        // Output Dimensions
-        // [x-BIDx, x-Unswitch, x-Unroll, x-TIDy
-        //  0       1           2         3
-        //
-        //  Reduction Dimensions
-        //  rF-Remain, r-TIDx]
-        //  4(r)     5(r+1)
-        if (!rparams.persistent_kernel) {
-          reduction_tv->split(
-              1, NamedScalar::getParallelDim(ParallelType::TIDx));
-        } else {
-          reduction_tv->split(1, rparams.batches_per_block, false);
-        }
-
-        reference_tv = ir_utils::rfactorHelper(reduction_tv, {1});
-
-        reference_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDy));
-        reference_tv->split(0, rparams.loop_unroll);
-        // Unswitch axis which gives us finer control on allocations with
-        // unrolling
-        reference_tv->split(0, 1);
-
-        // [x-BIDx, x-Unswitch, x-Unroll, x-TIDy, rF-Remain, r-TIDx]
-        //     0         1          2        3        4         5
-        // -> [x-BIDx, x-TIDy, rF-Remain, x-Unswitch, x-Unroll, r-TIDx]
-        //       0        1         2           3          4       5
-
-        reference_tv->reorder({{1, 3}, {2, 4}, {3, 1}, {4, 2}});
-
-        reference_tv->axis(1)->parallelize(ParallelType::TIDy);
-        reference_tv->axis(3)->parallelize(ParallelType::Unswitch);
-        reference_tv->axis(4)->parallelize(ParallelType::Unroll);
-        reference_tv->axis(5)->parallelize(ParallelType::TIDx);
-
-        if (rparams.split_grid_dim) {
-          reference_tv->split(0, x_grid_limit);
-          reference_tv->axis(1)->parallelize(ParallelType::BIDx);
-        } else {
-          reference_tv->axis(0)->parallelize(ParallelType::BIDx);
-        }
-      }
-    } else {
-      // Not multiple reductions per block
-      if (rparams.cross_grid) {
-        TORCH_INTERNAL_ASSERT(
-            rparams.reduction_unroll,
-            "Unrolling on iter domain not supported in this scheduler.");
-
-        TORCH_INTERNAL_ASSERT(
-            !rparams.persistent_kernel,
-            "Grid reductions not implemented yet for persistent kernels.");
-
-        // Fastest dim, cross grid, cross block
-        //      [outputs,
-        // Idx:     0
-        //   | rf-Remain, r-BIDx, r-TIDy, rf-Unswitch, rf-Unroll, r-TIDx]
-        //       1(r)     2(r+1)  3(r+2)     4(r+3)      5(r+4)   6(r+5)|
-        //   | rf-Remain, r-BIDx, r-TIDy, rf-Unswitch, r-TIDx, r-Vectorize]
-        //       1(r)     2(r+1)  3(r+2)     4(r+3)    5(r+4)    6(r+5)|
-        //                Reduction Dimensions
-
-        //   | r-BIDx, r-TIDy, r-TIDx, rf-Remain, rf-Unswitch, rf-Unroll/Vect]
-        //       1(r)  2(r+1)  3(r+2)   4(r+3)       5(r+4)     6(r+5)  |
-        //                Reduction Dimensions
-
-        if (rparams.vectorize) {
-          reduction_tv->split(reduce_axis, rparams.loop_unroll);
-          reduction_tv->split(
-              reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
-        } else {
-          reduction_tv->split(
-              reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
-          reduction_tv->split(reduce_axis, rparams.loop_unroll);
-        }
-        reduction_tv->split(reduce_axis, 1);
-        // Unswitch axis which gives us finer control on allocations with
-        // unrolling
-        reduction_tv->split(
-            reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDy));
-        reduction_tv->split(
-            reduce_axis, NamedScalar::getParallelDim(ParallelType::BIDx));
-
-        if (rparams.vectorize) {
-          reduction_tv->reorder(
-              {{reduce_axis, reduce_axis + 3},
-               {reduce_axis + 1, reduce_axis},
-               {reduce_axis + 2, reduce_axis + 1},
-               {reduce_axis + 3, reduce_axis + 4},
-               {reduce_axis + 4, reduce_axis + 2}});
-        } else {
-          reduction_tv->reorder(
-              {{reduce_axis, reduce_axis + 3},
-               {reduce_axis + 1, reduce_axis},
-               {reduce_axis + 2, reduce_axis + 1},
-               {reduce_axis + 3, reduce_axis + 4},
-               {reduce_axis + 4, reduce_axis + 5},
-               {reduce_axis + 5, reduce_axis + 2}});
-        }
-
-        reference_tv = ir_utils::rfactorHelper(
-            reduction_tv, {reduce_axis + 3, reduce_axis + 4, reduce_axis + 5});
-
-        if (rparams.vectorize) {
-          reference_tv->axis(reduce_axis + 5)
-              ->parallelize(ParallelType::Vectorize);
-        } else {
-          reference_tv->axis(reduce_axis + 5)
-              ->parallelize(ParallelType::Unroll);
-        }
-        reference_tv->axis(reduce_axis + 4)
-            ->parallelize(ParallelType::Unswitch);
-
-        reference_tv->axis(reduce_axis + 2)->parallelize(ParallelType::TIDx);
-        reference_tv->axis(reduce_axis + 1)->parallelize(ParallelType::TIDy);
-        reference_tv->axis(reduce_axis)->parallelize(ParallelType::BIDx);
-
-        if (has_iter_axis) {
-          if (rparams.split_grid_dim) {
-            reference_tv->split(iter_axis, y_grid_limit);
-            reference_tv->axis(iter_axis + 1)->parallelize(ParallelType::BIDy);
-          } else {
-            reference_tv->axis(iter_axis)->parallelize(ParallelType::BIDy);
-          }
-        }
-
-      } else {
-        // Not cross grid
-        if (rparams.reduction_unroll) {
-          // Fastest dim, Reduction unroll
-          // Output Dimensions
-          // [BIDx
-          //  0
-          //
-          // Reduction Dimensions
-          // rF-Remain, rf-Unswitch, rf-Unroll, r-TIDx]
-          // 1(r)      2(r+1)        3(r+2)      4(r+3)
-          // rF-Remain, rf-Unswitch, r-TIDx, rf-Vectorize]
-          // 1(r)      2(r+1)        3(r+2)      4(r+3)
-
-          //  r-TIDx, rF-Leftover, rf-Unswitch, rf-Unroll]
-          //  1(r)       2(r+1)      3(r+2)       4(r+3)
-
-          if (!rparams.persistent_kernel) {
-            if (rparams.vectorize) {
-              reduction_tv->split(reduce_axis, rparams.loop_unroll);
-              reduction_tv->split(
-                  reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
-            } else {
-              reduction_tv->split(
-                  reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
-              reduction_tv->split(reduce_axis, rparams.loop_unroll);
-            }
-            // Unswitch axis which gives us finer control on allocations with
-            // unrolling
-            reduction_tv->split(reduce_axis, 1);
-          } else {
-            if (rparams.vectorize) {
-              reduction_tv->split(
-                  reduce_axis, rparams.batches_per_block, false);
-              reduction_tv->split(reduce_axis + 1, rparams.loop_unroll);
-            } else {
-              reduction_tv->split(
-                  reduce_axis,
-                  rparams.batches_per_block * rparams.loop_unroll,
-                  false);
-              reduction_tv->split(reduce_axis, rparams.loop_unroll);
-            }
-            // Unswitch axis which gives us finer control on allocations with
-            // unrolling
-            reduction_tv->split(reduce_axis, 1);
-          }
-
-          if (rparams.vectorize) {
-            reduction_tv->reorder(
-                {{reduce_axis + 2, reduce_axis},
-                 {reduce_axis, reduce_axis + 1},
-                 {reduce_axis + 1, reduce_axis + 2}});
-          } else {
-            reduction_tv->reorder(
-                {{reduce_axis + 3, reduce_axis},
-                 {reduce_axis, reduce_axis + 1},
-                 {reduce_axis + 1, reduce_axis + 2},
-                 {reduce_axis + 2, reduce_axis + 3}});
-          }
-
-          reference_tv = ir_utils::rfactorHelper(
-              reduction_tv,
-              {reduce_axis + 1, reduce_axis + 2, reduce_axis + 3});
-
-          reference_tv->axis(reduce_axis)->parallelize(ParallelType::TIDx);
-          if (rparams.vectorize) {
-            reference_tv->axis(reduce_axis + 3)
-                ->parallelize(ParallelType::Vectorize);
-          } else {
-            reference_tv->axis(reduce_axis + 3)
-                ->parallelize(ParallelType::Unroll);
-          }
-          reference_tv->axis(reduce_axis + 2)
-              ->parallelize(ParallelType::Unswitch);
-
-          if (has_iter_axis) {
-            if (rparams.split_grid_dim) {
-              reference_tv->split(iter_axis, x_grid_limit);
-              reference_tv->axis(iter_axis + 1)
-                  ->parallelize(ParallelType::BIDx);
-            } else {
-              reference_tv->axis(iter_axis)->parallelize(ParallelType::BIDx);
-            }
-          }
-        } else {
-          TORCH_INTERNAL_ASSERT(
-              has_iter_axis, "Need iteration axis for iteration unroll.");
-          // Fastest dim, Reduction Splits
-          // Output Dimensions
-          // [BIDx, x-Unswitch, x-Unroll
-          //  0
-          //
-          // Reduction Dimensions
-          // rF-Remain, r-TIDx]
-          // 1(r)       2(r+1)
-
-          if (!rparams.persistent_kernel) {
-            reduction_tv->split(
-                reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
-          } else {
-            reduction_tv->split(reduce_axis, rparams.batches_per_block, false);
-          }
-
-          reduction_tv->split(iter_axis, rparams.loop_unroll);
-          // Unswitch axis which gives us finer control on allocations with
-          // unrolling
-          reduction_tv->split(iter_axis, 1);
-
-          // [x-BIDx, x-Unswitch, x-Unroll, rF-Remain, r-TIDx]
-          //     0         1          2        3        4
-          // -> [x-BIDx, rF-Remain, x-Unswitch, x-Unroll, r-TIDx]
-          //       0        1          2           3          4
-
-          reduction_tv->reorder({{1, 2}, {2, 3}, {3, 1}});
-
-          reference_tv = ir_utils::rfactorHelper(reduction_tv, {1});
-
-          reference_tv->axis(4)->parallelize(ParallelType::TIDx);
-          reference_tv->axis(3)->parallelize(ParallelType::Unroll);
-          reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
-
-          if (rparams.split_grid_dim) {
-            reference_tv->split(0, x_grid_limit);
-            reference_tv->axis(1)->parallelize(ParallelType::BIDx);
-          } else {
-            reference_tv->axis(0)->parallelize(ParallelType::BIDx);
-          }
-        }
-      }
-    }
-  } else {
-    if (rparams.cross_block) {
-      if (rparams.cross_grid) {
-        TORCH_INTERNAL_ASSERT(
-            rparams.reduction_unroll,
-            "Unrolling on iter domain not supported in this scheduler.");
-
-        TORCH_INTERNAL_ASSERT(
-            !rparams.persistent_kernel,
-            "Grid reductions not implemented yet for persistent kernels.");
-
-        // Outer Dim, cross grid, cross block
-
-        // Unrolling in this case can only be applied to the reduction
-        // dimension since currently, grid reductions cannot be called
-        // multiple times
-        //
-        // Output Dimensions
-        // [x-BIDx, x-TIDx,
-        //  0         1
-        //
-        // Reduction Dimensions
-        // rF-Leftover, r-BIDy, r-TIDy, rf-Unswitch, rf-Unroll]
-        // 2(-5)        3(-4)   4(-3)   5(-2)        6(-1)
-
-        // r-BIDy, r-TIDy, rF-Leftover, rf-Unswitch, rf-Unroll]
-        // 2(-5)    3(-4)      4(-3)       5(-2)        6(-1)
-
-        reduction_tv->split(1, rparams.loop_unroll);
-        // Unswitch axis which gives us finer control on allocations with
-        // unrolling
-        reduction_tv->split(1, 1);
-        reduction_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
-        reduction_tv->split(1, NamedScalar::getParallelDim(ParallelType::BIDy));
-
-        reduction_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
-
-        reduction_tv->reorder({{2, 4}, {3, 2}, {4, 3}});
-
-        reference_tv = ir_utils::rfactorHelper(
-            reduction_tv,
-            {4, 5, 6}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-
-        reference_tv->axis(6)->parallelize(ParallelType::Unroll);
-        reference_tv->axis(5)->parallelize(ParallelType::Unswitch);
-        reference_tv->axis(3)->parallelize(ParallelType::TIDy);
-        reference_tv->axis(2)->parallelize(ParallelType::BIDy);
-        reference_tv->axis(1)->parallelize(ParallelType::TIDx);
-        reference_tv->axis(0)->parallelize(ParallelType::BIDx);
-      } else {
-        if (rparams.reduction_unroll || rparams.loop_unroll == 1) {
-          // Outer Dim, cross block, unroll reduction dimension
-
-          // Reduction Splits
-          // Output Dimensions
-          // [x-BIDx, x-TIDx
-          //  0       1
-          //
-          // Reduction Dimensions
-          // rF-Leftover, r-TIDy, rf-Unswitch, rf-Unroll]
-          // 2(-4)        3(-3)   4(-2)       5(-1)
-
-          // r-TIDy, rF-Leftover, rf-Unswitch, rf-Unroll]
-          // 2(-4)      3(-3)       4(-2)       5(-1)
-          if (!rparams.persistent_kernel) {
-            reduction_tv->split(1, rparams.loop_unroll);
-            // Unswitch axis which gives us finer control on allocations with
-            // unrolling
-            reduction_tv->split(1, 1);
-            reduction_tv->split(
-                1, NamedScalar::getParallelDim(ParallelType::TIDy));
-          } else {
-            reduction_tv->split(1, rparams.batches_per_block, false);
-            reduction_tv->split(2, rparams.loop_unroll);
-            reduction_tv->split(2, 1);
-          }
-
-          reduction_tv->split(
-              0, NamedScalar::getParallelDim(ParallelType::TIDx));
-
-          reduction_tv->reorder({{2, 3}, {3, 2}});
-
-          reference_tv = ir_utils::rfactorHelper(
-              reduction_tv,
-              {3, 4, 5}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-
-          reference_tv->axis(5)->parallelize(ParallelType::Unroll);
-          reference_tv->axis(4)->parallelize(ParallelType::Unswitch);
-          reference_tv->axis(2)->parallelize(ParallelType::TIDy);
-          reference_tv->axis(1)->parallelize(ParallelType::TIDx);
-          reference_tv->axis(0)->parallelize(ParallelType::BIDx);
-        } else {
-          // Outer Dim, cross block, unroll iter dimension
-
-          // Output Dimensions
-          // [x-BIDx, x-Unswitch, x-Unroll, x-TIDx
-          //  0       1           2         3
-          // [x-BIDx, x-Unswitch, x-TIDx, x-Vectorize
-          //  0       1           2         3
-          //
-          // Reduction Dimensions
-          // rF-Leftover, r-TIDy]
-          // 4(-2)        5(-1)
-
-          // The unroll/unswitch dimension needs to be within the rF-Leftover
-          // dimension
-          //    [x-BIDx, x-Unswitch, x-Unroll, x-TIDx, rF-Leftover, r-TIDy]
-          //      0(-6)     1(-5)      2(-4)    3(-3)     4(-2)      5(-1)
-          //    [x-BIDx, x-Unswitch, x-TIDx, x-Vectorize, rF-Leftover, r-TIDy]
-          //      0(-6)     1(-5)      2(-4)    3(-3)     4(-2)      5(-1)
-          // -> [x-BIDx, x-TIDx, rF-Leftover, x-Unswitch, x-Unroll/Vect,
-          // r-TIDy]
-          //      0(-6)   1(-5)     2(-4)        3(-3)      4(-2)        5(-1)
-
-          if (!rparams.persistent_kernel) {
-            reduction_tv->split(
-                1, NamedScalar::getParallelDim(ParallelType::TIDy));
-          } else {
-            reduction_tv->split(1, rparams.batches_per_block, false);
-          }
-          if (rparams.vectorize) {
-            reduction_tv->split(0, rparams.loop_unroll);
-            reduction_tv->split(
-                0, NamedScalar::getParallelDim(ParallelType::TIDx));
-
-          } else {
-            reduction_tv->split(
-                0, NamedScalar::getParallelDim(ParallelType::TIDx));
-            reduction_tv->split(0, rparams.loop_unroll);
-          }
-          // Unswitch axis which gives us finer control on allocations with
-          // unrolling
-          reduction_tv->split(0, 1);
-
-          if (rparams.vectorize) {
-            reduction_tv->reorder({{1, 3}, {2, 1}, {3, 4}, {4, 2}});
-          } else {
-            reduction_tv->reorder({{1, 3}, {2, 4}, {3, 1}, {4, 2}});
-          }
-
-          reference_tv = ir_utils::rfactorHelper(
-              reduction_tv,
-              {2}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-
-          reference_tv->axis(5)->parallelize(ParallelType::TIDy);
-          reference_tv->axis(1)->parallelize(ParallelType::TIDx);
-          if (rparams.vectorize) {
-            reference_tv->axis(4)->parallelize(ParallelType::Vectorize);
-          } else {
-            reference_tv->axis(4)->parallelize(ParallelType::Unroll);
-          }
-          reference_tv->axis(3)->parallelize(ParallelType::Unswitch);
-          reference_tv->axis(0)->parallelize(ParallelType::BIDx);
-        }
-      }
-    } else {
-      if (rparams.reduction_unroll) {
-        // Outer Dim, no parallelization on reduction, unroll reduction axis
-        // Output Dimensions
-        // [x-BIDx, x-TIDx
-        //  0       1
-        //
-        // Reduction Dimensions
-        // rf-Leftover, rf-Unswitch, r-Unroll]
-        //       2            3         4
-        if (rparams.persistent_kernel) {
-          reduction_tv->split(1, rparams.batches_per_block, false);
-          reduction_tv->split(2, rparams.loop_unroll);
-          // Reduction Dimensions
-          // rf-Leftover, r-TIDy, rf-Unroll]
-          //       2         3         4
-        } else {
-          reduction_tv->split(1, rparams.loop_unroll);
-          // Unswitch axis which gives us finer control on allocations with
-          // unrolling
-          reduction_tv->split(1, 1);
-        }
-
-        reduction_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
-
-        if (rparams.persistent_kernel) {
-          // [x-BIDx, x-TIDx, rf-Leftover, r-TIDy, rf-Unroll]
-          //     0       1         2         3         4
-          reduction_tv->reorder({{3, 2}, {2, 3}});
-          // [x-BIDx, x-TIDx, r-TIDy, rf-Leftover, rf-Unroll]
-          //     0       1       2           3         4
-          reference_tv = ir_utils::rfactorHelper(
-              reduction_tv,
-              {3, 4}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-          reference_tv->axis(0)->parallelize(ParallelType::BIDx);
-          reference_tv->axis(1)->parallelize(ParallelType::TIDx);
-          reference_tv->axis(2)->parallelize(ParallelType::TIDy);
-          reference_tv->axis(3)->parallelize(ParallelType::Unswitch);
-          reference_tv->axis(4)->parallelize(ParallelType::Unroll);
-        } else {
-          reference_tv = ir_utils::rfactorHelper(
-              reduction_tv,
-              {2, 3}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-          reference_tv->axis(0)->parallelize(ParallelType::BIDx);
-          reference_tv->axis(1)->parallelize(ParallelType::TIDx);
-          reference_tv->axis(3)->parallelize(ParallelType::Unswitch);
-          reference_tv->axis(4)->parallelize(ParallelType::Unroll);
-        }
-      } else {
-        // No parallelization on reduction, unroll iter axis
-        // Output Dimensions
-        // [x-BIDx, x-Unswitch, x-Unroll, x-TIDx
-        //  0       1           2         3
-        // [x-BIDx, x-Unswitch, x-TIDx, x-Vectorize
-        //  0       1           2         3
-        //
-        // Reduction Dimensions
-        // rf-Leftover, r-{1}]
-        // 4(-1)
-        //
-        // Fake an rfactor to make scheduling more consistent.
-        //
-        // The unroll/unswitch dimension needs to be within the rF-Leftover
-        // dimension
-        if (rparams.persistent_kernel) {
-          reduction_tv->split(1, rparams.batches_per_block, false);
-        } else {
-          reduction_tv->split(1, 1);
-        }
-
-        if (rparams.vectorize) {
-          reduction_tv->split(0, rparams.loop_unroll);
-          reduction_tv->split(
-              0, NamedScalar::getParallelDim(ParallelType::TIDx));
-        } else {
-          reduction_tv->split(
-              0, NamedScalar::getParallelDim(ParallelType::TIDx));
-          reduction_tv->split(0, rparams.loop_unroll);
-        }
-
-        reduction_tv->split(0, 1);
-
-        // [x-BIDx, x-Unswitch, x-Unroll, x-TIDx, rf-Leftover, r-1]
-        //   0         1          2        3            4       5
-        // [x-BIDx, x-Unswitch, x-TIDx, x-Vectorize, rf-Leftover, r-1]
-        //   0         1          2        3              4        5
-
-        if (rparams.vectorize) {
-          reduction_tv->reorder({{1, 3}, {2, 1}, {3, 4}, {4, 2}});
-        } else {
-          reduction_tv->reorder({{1, 3}, {2, 4}, {3, 1}, {4, 2}});
-        }
-
-        // [x-BIDx, x-TIDx, rf-Leftover, x-Unswitch, x-Unroll, r-1(TIDy)]
-        //   0       1            2           3          4      5
-
-        reference_tv = ir_utils::rfactorHelper(reduction_tv, {2});
-        if (rparams.persistent_kernel) {
-          reference_tv->axis(5)->parallelize(ParallelType::TIDy);
-        }
-
-        reference_tv->axis(0)->parallelize(ParallelType::BIDx);
-        reference_tv->axis(1)->parallelize(ParallelType::TIDx);
-        reference_tv->axis(3)->parallelize(ParallelType::Unswitch);
-        if (rparams.vectorize) {
-          reference_tv->axis(4)->parallelize(ParallelType::Vectorize);
-        } else {
-          reference_tv->axis(4)->parallelize(ParallelType::Unroll);
-        }
-      }
-    }
-  }
-  return reference_tv;
-}
-
-// Reset inputs and outputs to global memory, everything else to local.
-void clearMemorySpace(Fusion* fusion) {
-  for (auto tv : ir_utils::allTvs(fusion)) {
-    if (tv->isFusionInput() || tv->isFusionOutput()) {
-      tv->setMemoryType(MemoryType::Global);
-    } else {
-      tv->setMemoryType(MemoryType::Local);
-    }
-  }
-}
-
-// Returns cached after tensors of the fusion inputs if unrolled. Otherwise
-// return empty vector.
-std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll) {
-  if (!unroll) {
-    return {};
-  }
-
-  std::vector<TensorView*> cached_inputs;
-  // If we're going to unroll, make a cache of the inputs
-  auto in_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
-  for (auto tv : in_tvs) {
-    if (tv->uses().empty()) {
-      continue;
-    }
-    auto cached_tv = tv->cache_after();
-    cached_inputs.emplace_back(cached_tv);
-  }
-  return cached_inputs;
-}
-
-// Returns the pairs of <cache of each fusion output, corresponding output> for
-// all outputs.
-std::vector<std::pair<TensorView*, TensorView*>> cacheAndForkOutputs(
-    Fusion* fusion,
-    bool unroll) {
-  std::vector<std::pair<TensorView*, TensorView*>> cached_outputs;
-  // For intermediate outputs, apply cache_fork
-  for (const auto output :
-       ir_utils::filterByType<TensorView>(fusion->outputs())) {
-    if (output->definition() == nullptr) {
-      continue;
-    }
-    if (!output->uses().empty()) {
-      auto cached_output = output->as<TensorView>()->cache_fork();
-      cached_outputs.emplace_back(std::make_pair(output, cached_output));
-    } else if (unroll) {
-      auto cached_output = output->as<TensorView>()->cache_before();
-      cached_outputs.emplace_back(std::make_pair(cached_output, output));
-    }
-  }
-  return cached_outputs;
-}
-
-void multiReductionInliner(
-    Fusion* fusion,
-    const ReductionParams& rparams,
-    TensorView* reduction_tv,
-    TensorView* reference_tv,
-    std::vector<TensorView*> reduction_tvs,
-    std::vector<TensorView*> cached_inputs,
-    std::vector<std::pair<TensorView*, TensorView*>> cached_outputs) {
-  TransformPropagator::from(reference_tv);
-
-  // Apply rfactor to all reductions if applicable
-  std::vector<TensorView*> rfactor_tvs;
-
-  if (reference_tv != reduction_tv) {
-    std::vector<int> rfactor_axes;
-    for (size_t i = 0; i < reference_tv->nDims(); i++) {
-      if (reference_tv->axis((int)i)->isReduction() &&
-          reference_tv->axis((int)i)->isRFactorProduct()) {
-        rfactor_axes.push_back((int)i);
-      }
-    }
-
-    for (auto reduction_tv_ : reduction_tvs) {
-      if (reduction_tv_ == reduction_tv) {
-        // The reduction tv
-        rfactor_tvs.push_back(reference_tv);
-        continue;
-      } else {
-        rfactor_tvs.push_back(
-            ir_utils::rfactorHelper(reduction_tv_, rfactor_axes));
-      }
-    }
-
-    TORCH_INTERNAL_ASSERT(
-        reduction_tvs.size() == rfactor_tvs.size(),
-        "Expected all reductions to contain rfactor.");
-  }
-
-  // Propagate parallelization
-  parallelizeAllLike(reference_tv, ir_utils::allTvs(fusion));
-
-  // Find iter domains that are mapped to a trivial reduction, these should
-  // never be inlined.
-  std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
-      getTrivialReductionMap(fusion);
-
-  if (rparams.loop_unroll > 1) {
-    // Inline Input caches to their consumers outside unswitched/vectorization
-    // position Inline consumers of input caches to rfactor tensors
-
-    // Mark which tensor views are actual input caches to leave vectorization on
-    // them
-    std::unordered_set<TensorView*> keep_unrolled;
-
-    std::vector<TensorView*> compute_from;
-
-    // Grab all tensor views that should be vectorized
-    auto vecotrizable_inputs_outputs =
-        getVectorizableInputsOutputs(reference_tv);
-
-    // Inputs to cache
-    for (auto cached_input : cached_inputs) {
-      auto consumers_of_input_cache = ir_utils::consumerTvsOf(cached_input);
-      for (auto consumer : consumers_of_input_cache) {
-        auto unswitch_it = std::find_if(
-            consumer->domain()->domain().begin(),
-            consumer->domain()->domain().end(),
-            [&mapped_to_trivial_reduction](IterDomain* id) {
-              return id->getParallelType() == ParallelType::Unswitch ||
-                  id->getParallelType() == ParallelType::Unroll ||
-                  id->getParallelType() == ParallelType::Vectorize ||
-                  id->getParallelType() == ParallelType::MisalignedVectorize ||
-                  mapped_to_trivial_reduction.count(id);
-            });
-        auto unswitch_pos = unswitch_it == consumer->domain()->domain().end()
-            ? -1
-            : std::distance(consumer->domain()->domain().begin(), unswitch_it) +
-                1;
-
-        cached_input->computeAt(
-            consumer, unswitch_pos, ComputeAtMode::BestEffort);
-        compute_from.push_back(consumer);
-
-        if (rparams.vectorize) {
-          auto producer_tvs = ir_utils::producerTvsOf(cached_input);
-          if (producer_tvs.size() == 1 &&
-              std::find(
-                  vecotrizable_inputs_outputs.begin(),
-                  vecotrizable_inputs_outputs.end(),
-                  producer_tvs[0]) != vecotrizable_inputs_outputs.end()) {
-            keep_unrolled.emplace(cached_input);
-          }
-        } else {
-          keep_unrolled.emplace(cached_input);
-        }
-      }
-    }
-
-    // Inline output caches into outputs
-    std::vector<TensorView*> compute_to;
-    for (auto cached_output_pair : cached_outputs) {
-      auto cached_output = cached_output_pair.first;
-      auto output = cached_output_pair.second;
-
-      // If an output has multiple consumers don't process here, we want only
-      // terminating outputs
-      if (cached_output->uses().size() > 1) {
-        continue;
-      }
-
-      auto pos_it = std::find_if(
-          output->domain()->domain().begin(),
-          output->domain()->domain().end(),
-          [&mapped_to_trivial_reduction](IterDomain* id) {
-            return id->getParallelType() == ParallelType::Unswitch ||
-                id->getParallelType() == ParallelType::Unroll ||
-                id->getParallelType() == ParallelType::Vectorize ||
-                id->getParallelType() == ParallelType::MisalignedVectorize ||
-                mapped_to_trivial_reduction.count(id);
-          });
-      auto pos = pos_it == output->domain()->domain().end()
-          ? -1
-          : std::distance(output->domain()->domain().begin(), pos_it) + 1;
-
-      cached_output->computeAt(output, pos, ComputeAtMode::BestEffort);
-
-      compute_to.push_back(cached_output);
-      if (rparams.vectorize) {
-        if (std::find(
-                vecotrizable_inputs_outputs.begin(),
-                vecotrizable_inputs_outputs.end(),
-                output) != vecotrizable_inputs_outputs.end()) {
-          keep_unrolled.emplace(output);
-        }
-      } else {
-        keep_unrolled.emplace(output);
-      }
-    }
-
-    // Before compute at-ing the internal structure, remove vectorization
-    // anywhere it doesn't belong. Otherwise it will mess up our inlining. Clear
-    // explicit unroll or vectorization when not for input or output GMEM
-    // transfers.
-    for (auto tv : ir_utils::allTvs(fusion)) {
-      if (!keep_unrolled.count(tv)) {
-        for (size_t i = 0; i < tv->nDims(); i++) {
-          auto id = tv->axis((int)i);
-          if (id->getParallelType() == ParallelType::Unroll ||
-              id->getParallelType() == ParallelType::Vectorize ||
-              id->getParallelType() == ParallelType::MisalignedVectorize) {
-            tv->axis((int)i)->parallelize(ParallelType::Serial);
-          }
-        }
-      }
-    }
-
-    // Make sure not to completely inline if there's trivial reductions in the
-    // fusion
-    auto pos_it = std::find_if(
-        reference_tv->domain()->domain().begin(),
-        reference_tv->domain()->domain().end(),
-        [&mapped_to_trivial_reduction](IterDomain* id) {
-          return mapped_to_trivial_reduction.count(id);
-        });
-
-    auto pos = pos_it == reference_tv->domain()->domain().end()
-        ? -1
-        : std::distance(reference_tv->domain()->domain().begin(), pos_it) + 1;
-
-    // Compute at inputs to rfactor dimensions
-    computeAtBetween(
-        compute_from, rfactor_tvs, pos, ComputeAtMode::MostInlined);
-
-    // Inline rfactor into reduction
-    if (reference_tv != reduction_tv) {
-      // Compute at rfactor into following reduction, keep outside first
-      // reduction iter domain in the rfactor tensor view
-      for (size_t i = 0; i < rfactor_tvs.size(); i++) {
-        if (!rparams.reduction_unroll) {
-          auto rfactor_tv = rfactor_tvs[i];
-          auto rfactor_tv_dom = rfactor_tv->domain()->domain();
-          auto reduction_it = std::find_if(
-              rfactor_tv_dom.begin(), rfactor_tv_dom.end(), [](IterDomain* id) {
-                return id->isReduction();
-              });
-          TORCH_INTERNAL_ASSERT(
-              reduction_it != rfactor_tv_dom.end(),
-              "Expected reduction axis in ",
-              rfactor_tv);
-          auto pos = std::distance(rfactor_tv_dom.begin(), reduction_it);
-          rfactor_tv->computeWith(
-              reduction_tvs[i], pos, ComputeAtMode::Standard);
-        } else {
-          rfactor_tvs[i]->computeWith(
-              reduction_tvs[i], -1, ComputeAtMode::BestEffort);
-        }
-      }
-    }
-
-    // Remove anything before a reduction from compute_from
-    {
-      auto producers_of_reductions = DependencyCheck::getAllValsBetween(
-          {fusion->inputs().begin(), fusion->inputs().end()},
-          {reduction_tvs.begin(), reduction_tvs.end()});
-
-      auto producer_tvs_of_reductions =
-          ir_utils::filterByType<TensorView>(producers_of_reductions);
-      compute_from.erase(
-          std::remove_if(
-              compute_from.begin(),
-              compute_from.end(),
-              [&producer_tvs_of_reductions](TensorView* compute_from_tv) {
-                return std::find(
-                           producer_tvs_of_reductions.begin(),
-                           producer_tvs_of_reductions.end(),
-                           compute_from_tv) != producer_tvs_of_reductions.end();
-              }),
-          compute_from.end());
-    }
-
-    // Add reduction tensor views to compute from
-    compute_from.insert(
-        compute_from.end(), reduction_tvs.begin(), reduction_tvs.end());
-
-    // Compute between reductions and output caches
-    computeAtBetween(
-        compute_from,
-        compute_to,
-        -1,
-        ComputeAtMode::BestEffort,
-        mapped_to_trivial_reduction);
-
-  } else {
-    // Want to inline, especially backwards based on reduction_tv, otherwise
-    // rfactor tv may not be inlined correctly
-    auto ref_tvs = rfactor_tvs.size() ? rfactor_tvs : reduction_tvs;
-    for (auto red_tv : ref_tvs) {
-      auto pos_it = std::find_if(
-          red_tv->domain()->domain().begin(),
-          red_tv->domain()->domain().end(),
-          [&mapped_to_trivial_reduction](IterDomain* id) {
-            return id->getParallelType() == ParallelType::Unswitch ||
-                id->getParallelType() == ParallelType::Unroll ||
-                id->getParallelType() == ParallelType::Vectorize ||
-                id->getParallelType() == ParallelType::MisalignedVectorize ||
-                mapped_to_trivial_reduction.count(id);
-          });
-      auto pos = pos_it == red_tv->domain()->domain().end()
-          ? -1
-          : std::distance(red_tv->domain()->domain().begin(), pos_it) + 1;
-
-      computeAtInputs(red_tv, pos, ComputeAtMode::MostInlined);
-      computeWithOutputs(red_tv, pos, ComputeAtMode::BestEffort);
-    }
-  }
-}
-
-FindAllMappedDims::FindAllMappedDims(TensorView* from, IterDomain* id)
-    : starting_tv(from), starting_id(id) {
-  std::deque<TensorView*> to_visit{starting_tv};
-  std::unordered_set<TensorView*> visited;
-  mapped_ids.emplace(std::make_pair(starting_tv, starting_id));
-
-  // Propagate mapping of id
-  while (!to_visit.empty()) {
-    auto tv = to_visit.front();
-    to_visit.pop_front();
-
-    if (!visited.emplace(tv).second) {
-      continue;
-    }
-
-    auto tv_id = mapped_ids.at(tv);
-
-    for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
-      if (visited.find(consumer_tv) != visited.end()) {
-        continue;
-      }
-
-      if (mapped_ids.find(consumer_tv) != mapped_ids.end()) {
-        continue;
-      }
-
-      PairwiseRootDomainMap root_map(tv, consumer_tv);
-      auto p2c_map =
-          root_map.mapProducerToConsumer(tv->domain(), consumer_tv->domain());
-
-      auto c_it = p2c_map.find(tv_id);
-      if (c_it != p2c_map.end()) {
-        mapped_ids.emplace(std::make_pair(consumer_tv, c_it->second));
-        to_visit.emplace_back(consumer_tv);
-      }
-    }
-
-    for (auto producer_tv : ir_utils::producerTvsOf(tv)) {
-      if (visited.find(producer_tv) != visited.end()) {
-        continue;
-      }
-
-      if (mapped_ids.find(producer_tv) != mapped_ids.end()) {
-        continue;
-      }
-
-      PairwiseRootDomainMap root_map(producer_tv, tv);
-      auto c2p_map =
-          root_map.mapConsumerToProducer(tv->domain(), producer_tv->domain());
-      auto p_it = c2p_map.find(tv_id);
-      if (p_it != c2p_map.end()) {
-        mapped_ids.emplace(std::make_pair(producer_tv, p_it->second));
-        to_visit.emplace_back(producer_tv);
-      }
-    }
-  }
-}
-
-std::unordered_set<IterDomain*> FindAllMappedDims::from(
-    TensorView* tv,
-    IterDomain* id) {
-  TORCH_INTERNAL_ASSERT(
-      std::find_if(
-          tv->getRootDomain().begin(),
-          tv->getRootDomain().end(),
-          [&id](IterDomain* root_id) { return root_id == id; }) !=
-          tv->getRootDomain().end(),
-      "Tried to map out ",
-      id,
-      " from TV ",
-      tv,
-      " to the rest of the fusion, but id does not belong to this tv.");
-
-  FindAllMappedDims mapped_dims(tv, id);
-
-  std::unordered_set<IterDomain*> mapped_id_set;
-  for (auto entry : mapped_dims.mapped_ids) {
-    mapped_id_set.emplace(entry.second);
-  }
-  return mapped_id_set;
-}
-
-bool shouldVectorize(
-    TensorView* tv,
-    std::unordered_set<IterDomain*> vector_dims) {
-  const auto& root_dom = TensorDomain::noBroadcasts(
-      TensorDomain::noReductions(tv->getRootDomain()));
-
-  // Don't vectorize 0-dim tensors
-  if (root_dom.size() == 0) {
-    return false;
-  }
-
-  auto inner_most_dim = root_dom[root_dom.size() - 1];
-
-  // Make sure inner most dimension is in the vector_dim set
-  if (vector_dims.count(inner_most_dim) == 0) {
-    return false;
-  }
-
-  auto root_pos_it = std::find_if(
-      tv->getRootDomain().begin(),
-      tv->getRootDomain().end(),
-      [&inner_most_dim](IterDomain* id) { return inner_most_dim == id; });
-
-  TORCH_INTERNAL_ASSERT(root_pos_it != tv->getRootDomain().end());
-  auto inner_most_dim_pos =
-      std::distance(tv->getRootDomain().begin(), root_pos_it);
-
-  const auto& contiguity = tv->domain()->contiguity();
-
-  TORCH_INTERNAL_ASSERT(contiguity.size() == tv->getRootDomain().size());
-
-  // Don't vectorize if inner most dimension is not contiguous
-  if (!contiguity[inner_most_dim_pos]) {
-    return false;
-  }
-
-  return true;
-}
-
-std::vector<TensorView*> getVectorizableInputsOutputs(
-    TensorView* reference_tv) {
-  if (reference_tv->nDims() == 0) {
-    return {};
-  }
-
-  IterDomain* inner_most_id = nullptr;
-  for (auto it = reference_tv->getRootDomain().rbegin();
-       it != reference_tv->getRootDomain().rend();
-       it++) {
-    if ((*it)->isReduction() && reference_tv->isFusionInput()) {
-      continue;
-    }
-    if ((*it)->isBroadcast() && inner_most_id == nullptr) {
-      inner_most_id = *it;
-    }
-    inner_most_id = *it;
-    break;
-  }
-
-  if (inner_most_id == nullptr) {
-    return {};
-  }
-
-  auto vectorizable_dims = FindAllMappedDims::from(reference_tv, inner_most_id);
-
-  std::vector<TensorView*> vectorizable_tensors;
-
-  for (auto input_tv :
-       ir_utils::filterByType<TensorView>(reference_tv->fusion()->inputs())) {
-    if (shouldVectorize(input_tv, vectorizable_dims)) {
-      vectorizable_tensors.push_back(input_tv);
-    }
-  }
-
-  for (auto output_tv :
-       ir_utils::filterByType<TensorView>(reference_tv->fusion()->outputs())) {
-    if (shouldVectorize(output_tv, vectorizable_dims)) {
-      vectorizable_tensors.push_back(output_tv);
-    }
-  }
-
-  return vectorizable_tensors;
-}
-
-std::vector<int64_t> mappedInputsOutputs(TensorView* reference_tv) {
-  auto fusion = reference_tv->fusion();
-  FusionGuard fg(fusion);
-
-  // All input or output tensor views
-  std::vector<TensorView*> in_out_tvs;
-  {
-    auto inp_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
-    in_out_tvs.insert(in_out_tvs.end(), inp_tvs.begin(), inp_tvs.end());
-    auto out_tvs = ir_utils::filterByType<TensorView>(fusion->outputs());
-    in_out_tvs.insert(in_out_tvs.end(), out_tvs.begin(), out_tvs.end());
-  }
-
-  // Shouldn't matter which compute at map we use
-  auto ca_index_map = ComputeAtMap(ComputeAtMap::MappingMode::INDEX);
-  ca_index_map.build(fusion);
-
-  auto ref_root_domain = reference_tv->getMaybeRFactorDomain();
-  std::vector<int64_t> mapping_count(ref_root_domain.size(), 0);
-
-  // Map all inputs and output domains to reference tv domains
-  for (auto in_out_tv : in_out_tvs) {
-    auto in_out_tv_domain = in_out_tv->getRootDomain();
-    auto in_out_tv_domain_list = std::list<IterDomain*>(
-        in_out_tv_domain.begin(), in_out_tv_domain.end());
-    auto in_out_dtype_size = dataTypeSize(in_out_tv->getDataType().value());
-
-    for (size_t ref_i = 0; ref_i < ref_root_domain.size(); ref_i++) {
-      auto ref_id = ref_root_domain[ref_i];
-
-      // If reference id is broadcast or reduction
-      if (ref_id->isBroadcast() || ref_id->isReduction()) {
-        continue;
-      }
-      auto map_it = std::find_if(
-          in_out_tv_domain_list.begin(),
-          in_out_tv_domain_list.end(),
-          [&ref_id, &ca_index_map](IterDomain* in_out_tv_id) {
-            return ca_index_map.areMapped(in_out_tv_id, ref_id);
-          });
-
-      if (map_it == in_out_tv_domain_list.end()) {
-        continue;
-      }
-
-      // If input/output id is broadcast or reduction
-      if ((*map_it)->isBroadcast() || (*map_it)->isReduction()) {
-        continue;
-      }
-
-      mapping_count[ref_i] = mapping_count[ref_i] + (int64_t)in_out_dtype_size;
-      in_out_tv_domain_list.erase(map_it);
-    }
-  }
-  return mapping_count;
-}
-
-} // namespace scheduler_utils
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.h b/torch/csrc/jit/codegen/cuda/scheduler/utils.h
deleted file mode 100644 (file)
index 37599ef..0000000
+++ /dev/null
@@ -1,198 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class SchedulerRuntimeInfo;
-
-namespace scheduler_utils {
-
-constexpr int64_t register_file_size = 256 * 1024;
-constexpr int64_t x_grid_limit = ((int64_t)1 << (int64_t)31) - (int64_t)1;
-constexpr int64_t y_grid_limit = 65535;
-
-// Largest Power of 2 less-than n
-constexpr int64_t lastPow2(int64_t n) {
-  TORCH_INTERNAL_ASSERT(n >= 0);
-  n |= (n >> 1);
-  n |= (n >> 2);
-  n |= (n >> 4);
-  n |= (n >> 8); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-  n |= (n >> 16); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-  n |= (n >> 32); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-  return std::max((int64_t)1, n - (n >> 1));
-}
-
-// Merge all reduction to the right side and returns total number of
-// reduction axes. Don't merge is typically used for trivial reductions.
-size_t mergeReduction(
-    TensorView* tv,
-    const std::unordered_set<IterDomain*>& dont_merge = {});
-
-// merge all non-reduction axes to the left side and returns total number of
-// iteration axes. Don't merge is typically used for trivial reductions.
-size_t mergeNonReduction(
-    TensorView* tv,
-    const std::unordered_set<IterDomain*>& dont_merge = {});
-
-TORCH_CUDA_CU_API void parallelizeAllLike(
-    TensorView* reference_tv,
-    const std::vector<TensorView*>& all_tvs);
-
-TORCH_CUDA_CU_API void computeAtInputs(
-    TensorView* consumer,
-    int pos,
-    ComputeAtMode mode = ComputeAtMode::Standard);
-
-TORCH_CUDA_CU_API void computeWithOutputs(
-    TensorView* producer,
-    int pos,
-    ComputeAtMode mode = ComputeAtMode::Standard);
-
-struct PersistentBufferInfo {
-  std::vector<TensorView*> buffers;
-  std::unordered_set<IterDomain*> unmappable_dims;
-};
-
-// Buffers whos roots can't map to all producer roots based on compute at. These
-// are the buffers we would make persistent in a persistent kerenl or would have
-// to recompute if we can't make a persistent kernel. This function will also
-// return inputs as being marked persistent if they follow this pattern. It is
-// important to note however inputs don't strictly have to be persistent as they
-// can simply be read multiple times from GMEM in the same kernel.
-PersistentBufferInfo persistentBuffers(Fusion* fusion);
-
-struct TvProperties {
-  // How many elements in tensor view are there to reduce
-  int64_t reduction_numel = 1;
-  // How many reductions do we need to perform, i.e. how many iter dimension
-  // elements are there
-  int64_t iteration_numel = 1;
-  // Do we reduce the fastest dimension, if no reduction mark true
-  bool fastest_dim_reduction = true;
-  // What's the iter numel to the left of the reduction (if there is one)
-  int64_t iter_outside_red = 1;
-  // What's the iter numel to the right of the reduction (if this is or isn't
-  // one)
-  int64_t iter_inside_red = 1;
-};
-
-// Fill TvProperties structure about tv
-TvProperties getProperties(
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    TensorView* tv);
-
-// Will call computeAt once on each producer, with the first consumer found that
-// is a consumer of the individual producer
-void computeAtBetween(
-    const std::vector<TensorView*>& producers,
-    const std::vector<TensorView*>& consumers,
-    int pos,
-    ComputeAtMode mode,
-    std::unordered_set<IterDomain*> mapped_to_trivial_reduction = {});
-
-// Compute the amount of register space would be needed to perform this kernel
-// persistently, only based on buffers that must be persistent, and based on the
-// maximum of all minimum size requirement. i.e. if must be persistent, only
-// hold persistent dimension.
-int64_t persistentBufferSize(
-    Fusion* fusion,
-    SchedulerRuntimeInfo& runtime_info,
-    PersistentBufferInfo& persistent_buffers,
-    HeuristicSummary* data_cache = nullptr);
-
-// Returns a set of all iteration domains (in roots of tensors) that map to a
-// trivial reduction
-std::unordered_set<IterDomain*> getTrivialReductionMap(Fusion* fusion);
-
-// Merges tensor view to the form:
-// [IterationDomain, ReductionDomain, TrivialReductionDim0,
-// TrivialReductionDim1, ...] Returns if <iteration dimensions, reduction
-// dimensions>
-std::pair<bool, bool> canonicalDimReduction(Fusion* fusion, TensorView* tv);
-
-// Return a list of tensor views that are outputs of reduction operations. If
-// multiple outputs of an expression are found, only include one in the list
-// (WelfordOp)
-std::vector<TensorView*> getReductionTvs(Fusion* fusion);
-
-// Consistent parallelization based on provided reduction parameters. Provided
-// tensor is expected to be reduced by canonicalDimReduction before sending
-// here. reduction_tv should be provided as the tensorview to reduce.
-// RFactor of reduction_tv will be returned if applicable otherwise reduction_tv
-// is returned
-TensorView* scheduleReductionTV(
-    const ReductionParams& rparams,
-    TensorView* reduction_tv,
-    bool has_iter_axis);
-
-// Reset inputs and outputs to global memory, everything else to local.
-void clearMemorySpace(Fusion* fusion);
-
-// Returns cached after tensors of the fusion inputs if unrolled. Otherwise
-// return empty vector.
-std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll);
-
-// Returns the pairs of <cache of each fusion output, corresponding output> for
-// all outputs.
-std::vector<std::pair<TensorView*, TensorView*>> cacheAndForkOutputs(
-    Fusion* fusion,
-    bool unroll);
-
-// Inlining function intended for single or multi reduction fusions.
-void multiReductionInliner(
-    Fusion* fusion,
-    const ReductionParams& rparams,
-    TensorView* reduction_tv,
-    TensorView* reference_tv,
-    std::vector<TensorView*> reduction_tvs,
-    std::vector<TensorView*> cached_inputs,
-    std::vector<std::pair<TensorView*, TensorView*>> cached_outputs);
-
-// Uses a lot of logic from TransformPropagator in the implementation
-class FindAllMappedDims {
- private:
-  FindAllMappedDims(TensorView* from, IterDomain* starting_id);
-
- private:
-  std::unordered_map<TensorView*, IterDomain*> mapped_ids;
-  TensorView* starting_tv = nullptr;
-  IterDomain* starting_id = nullptr;
-
- public:
-  // Looks through fusion and finds all dims that match to the one provided in
-  // the tensorview provided. Iter domain must be a root domain.
-  static std::unordered_set<IterDomain*> from(TensorView* tv, IterDomain* id);
-};
-
-// Checks if tensor view has an iteration domain in vector dims in its inner
-// most root position (excluding broadcast and reduction), and checks if it is a
-// contiguous dimension
-bool shouldVectorize(
-    TensorView* tv,
-    std::unordered_set<IterDomain*> vector_dims);
-
-// Returns all inputs and outputs that share the inner most dimension of the
-// provided reference. If reference is an input it ignores reduction axes, will
-// ignore all broadcast axes.
-std::vector<TensorView*> getVectorizableInputsOutputs(TensorView* reference_tv);
-
-// Returns a vector of counts, size = reference_tv->getRootDomain().size(), each
-// entry [i] is the number of inputs/outputs that have a non-broadcast dimension
-// mapped to the corresponding dimension in reference_tv. Count includes
-// reference_tv if reference_tv is an input or output. Count is multiplied by
-// data type size.
-std::vector<int64_t> mappedInputsOutputs(TensorView* reference_tv);
-
-} // namespace scheduler_utils
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
index 4a64636..5430d0e 100644 (file)
@@ -15,8 +15,8 @@ namespace cuda {
 
 namespace {
 
-bool hasTypeAndDevice(const TensorTypePtr& op) {
-  return op->device().has_value() && op->scalarType().has_value();
+bool hasTypeAndDim(const TensorTypePtr& op) {
+  return op->sizes().size().has_value() && op->scalarType().has_value();
 }
 
 /* NaiveTypePropagator
@@ -51,7 +51,6 @@ class NaiveTypePropagator {
       }
       // unary operations that forward meta info:
       case aten::neg:
-      case aten::bitwise_not:
       case aten::abs:
       case aten::log:
       case aten::log10:
@@ -81,23 +80,20 @@ class NaiveTypePropagator {
       case aten::relu:
       case aten::sigmoid:
       case aten::threshold:
-      case aten::softplus:
       case aten::clamp:
       case aten::gelu:
-      case aten::gelu_backward:
-      case aten::silu:
       case aten::tanh: {
         TORCH_CHECK(
-            hasTypeAndDevice(node->input(0)->type()->cast<TensorType>()),
-            "Type and device propagation has failed, or was not provided enough information.");
+            hasTypeAndDim(node->input(0)->type()->cast<TensorType>()),
+            "Type, device, and dimensionality propagation has failed, or was not provided enough information.");
         node->output()->setType(node->input(0)->type()->cast<TensorType>());
         break;
       }
       // TODO: rand_like should support cast.
       case aten::rand_like: {
         TORCH_CHECK(
-            hasTypeAndDevice(node->input(0)->type()->cast<TensorType>()),
-            "Type and device propagation has failed, or was not provided enough information.");
+            hasTypeAndDim(node->input(0)->type()->cast<TensorType>()),
+            "Type, device, and dimensionality propagation has failed, or was not provided enough information.");
         node->output()->setType(node->input(0)->type()->cast<TensorType>());
         break;
       }
@@ -122,30 +118,7 @@ class NaiveTypePropagator {
         node->output()->setType(promoted_type);
         break;
       }
-      // Type can be int or bool for "and" and "or", if both are bool should be
-      // bool, if both int should be int, otherwise would have errored
-      case aten::__and__:
-      case aten::__or__: {
-        const auto promoted_type = binary_broadcast_type(
-            node->input(0)->type()->cast<TensorType>(),
-            node->input(1)->type()->cast<TensorType>(),
-            node->input(0)->type()->cast<TensorType>()->scalarType() ==
-                    at::ScalarType::Bool
-                ? at::ScalarType::Bool
-                : at::ScalarType::Int);
-        break;
-      }
-      // Real int ops
-      case aten::__xor__:
-      case aten::__lshift__:
-      case aten::__rshift__: {
-        const auto promoted_type = binary_broadcast_type(
-            node->input(0)->type()->cast<TensorType>(),
-            node->input(1)->type()->cast<TensorType>(),
-            at::ScalarType::Int);
-        node->output()->setType(promoted_type);
-        break;
-      }
+      // TODO: double check type casting logic for operations commented out.
       case aten::lt:
       case aten::le:
       case aten::gt:
@@ -175,180 +148,6 @@ class NaiveTypePropagator {
         node->output()->setType(promoted_type);
         break;
       }
-      case aten::dropout: {
-        auto out_type = node->input(0)->type()->cast<TensorType>();
-        node->output()->setType(out_type);
-        break;
-      }
-      case aten::instance_norm:
-      case aten::batch_norm: {
-        auto out_type = node->input(0)->type()->cast<TensorType>();
-        node->output()->setType(out_type);
-        break;
-      }
-      case aten::_batch_norm_impl_index_backward: {
-        auto grad_input_type = node->input(1)->type()->cast<TensorType>();
-        TORCH_CHECK(
-            hasTypeAndDevice(grad_input_type),
-            "Type and device propagation has failed, or was not provided enough information.");
-        node->output(0)->setType(grad_input_type);
-
-        // TODO: double check with type promotion
-        auto mean_rstd_type = TensorType::create(
-            *grad_input_type->scalarType(),
-            *grad_input_type->device(),
-            c10::nullopt,
-            c10::nullopt);
-
-        node->output(1)->setType(mean_rstd_type);
-        node->output(2)->setType(mean_rstd_type);
-
-        break;
-      }
-      case aten::_batch_norm_impl_index: {
-        auto out_type = node->input(0)->type()->cast<TensorType>();
-        TORCH_CHECK(
-            hasTypeAndDevice(out_type),
-            "Type and device propagation has failed, or was not provided enough information.");
-        node->output(0)->setType(out_type);
-
-        auto mean_rstd_type = TensorType::create(
-            *out_type->scalarType(),
-            *out_type->device(),
-            c10::nullopt,
-            c10::nullopt);
-
-        node->output(1)->setType(mean_rstd_type);
-        node->output(2)->setType(mean_rstd_type);
-        // TODO: not that it matters, but mark the right type here;
-        // node->output(3)->setType(out_type->withScalarType());
-        node->output(3)->setType(out_type);
-        node->output(4)->setType(IntType::get());
-
-        break;
-      }
-      case aten::native_batch_norm: {
-        auto out_type = node->input(0)->type()->cast<TensorType>();
-        TORCH_CHECK(
-            hasTypeAndDevice(out_type),
-            "Type and device propagation has failed, or was not provided enough information.");
-        node->output(0)->setType(out_type);
-
-        auto mean_rstd_type = TensorType::create(
-            *out_type->scalarType(),
-            *out_type->device(),
-            c10::nullopt,
-            c10::nullopt);
-
-        node->output(1)->setType(mean_rstd_type);
-        node->output(2)->setType(mean_rstd_type);
-
-        break;
-      }
-      case aten::native_batch_norm_backward: {
-        // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-        auto out_mask_list = constant_as<c10::List<bool>>(node->input(9));
-        TORCH_INTERNAL_ASSERT(
-            out_mask_list.has_value(), "output mask for batch_norm_backward");
-        std::vector<int> output_mask;
-        for (const auto value : out_mask_list->vec()) {
-          output_mask.emplace_back(static_cast<int>(value));
-        }
-
-        if (output_mask[0]) {
-          auto in_type = node->input(1)->type()->cast<TensorType>();
-          node->output(0)->setType(in_type);
-        }
-
-        if (output_mask[1]) {
-          auto weight_type = node->input(2)->type()->cast<TensorType>();
-          node->output(1)->setType(weight_type);
-        }
-
-        if (output_mask[2]) {
-          auto weight_type = node->input(2)->type()->cast<TensorType>();
-          auto bias_type = TensorType::create(
-              *weight_type->scalarType(),
-              *weight_type->device(),
-              *weight_type->dim(),
-              output_mask[2]);
-          node->output(2)->setType(bias_type);
-        }
-        break;
-      }
-      case aten::layer_norm: {
-        auto out_type = node->input(0)->type()->cast<TensorType>();
-        node->output()->setType(out_type);
-        break;
-      }
-      case aten::native_layer_norm: {
-        auto out_type = node->input(0)->type()->cast<TensorType>();
-        TORCH_CHECK(
-            hasTypeAndDevice(out_type),
-            "Type and device propagation has failed, or was not provided enough information.");
-        node->output(0)->setType(out_type);
-
-        auto mean_rstd_type = TensorType::create(
-            *out_type->scalarType(), *out_type->device(), c10::nullopt, false);
-
-        node->output(1)->setType(mean_rstd_type);
-        node->output(2)->setType(mean_rstd_type);
-
-        break;
-      }
-      case aten::native_layer_norm_backward: {
-        // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-        auto out_mask_list = constant_as<c10::List<bool>>(node->input(7));
-        TORCH_INTERNAL_ASSERT(
-            out_mask_list.has_value(), "output mask for layer_norm_backward");
-        std::vector<int> output_mask;
-        for (const auto value : out_mask_list->vec()) {
-          output_mask.emplace_back(static_cast<int>(value));
-        }
-
-        if (output_mask[0]) {
-          auto out_type = node->input(0)->type()->cast<TensorType>();
-          node->output(0)->setType(out_type);
-        }
-
-        if (output_mask[1] &&
-            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-            !node->input(5)->type()->isSubtypeOf(
-                static_cast<c10::TypePtr>(NoneType::get()))) {
-          // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-          auto weight_type = node->input(5)->type()->cast<TensorType>();
-          node->output(1)->setType(weight_type);
-        }
-
-        if (output_mask[2] &&
-            // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-            !node->input(6)->type()->isSubtypeOf(
-                static_cast<c10::TypePtr>(NoneType::get()))) {
-          // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
-          auto bias_type = node->input(6)->type()->cast<TensorType>();
-          node->output(2)->setType(bias_type);
-        }
-        break;
-      }
-      case aten::softmax: {
-        auto out_type = node->input(0)->type()->cast<TensorType>();
-
-        // accept dtype input to `aten::softmax` node
-        if (!node->input(2)->type()->isSubtypeOf(
-                static_cast<c10::TypePtr>(NoneType::get()))) {
-          if (auto opt_ivalue = toIValue(node->input(2))) {
-            out_type = out_type->withScalarType(opt_ivalue->toScalarType());
-          }
-        }
-        node->output()->setType(out_type);
-        break;
-      }
-      case aten::_softmax_backward_data: {
-        auto out_type = node->input(0)->type()->cast<TensorType>();
-        node->output()->setType(out_type);
-        break;
-      }
-      case aten::mean:
       case aten::sum: {
         auto out_type = node->input(0)->type()->cast<TensorType>();
 
@@ -362,53 +161,16 @@ class NaiveTypePropagator {
         const auto dims = constant_as<c10::List<int64_t>>(node->input(1));
         const auto keepdim = constant_as<bool>(node->input(2));
         TORCH_CHECK(
-            dims.has_value() && keepdim.has_value(),
+            dims.has_value() && keepdim.has_value() && !keepdim.value(),
             "Shape inference cannot handle options.");
         node->output()->setType(
             unary_reduce_type(out_type, dims->vec(), keepdim.value()));
         break;
       }
-      case aten::sum_to_size:
-      case aten::_grad_sum_to_size: {
-        auto out_type = node->input(0)->type()->cast<TensorType>();
-        node->output()->setType(out_type->withDim(c10::nullopt));
-        break;
-      }
-      case aten::type_as: {
-        const auto type0 = node->input(0)->type()->cast<TensorType>();
-        const auto type1 = node->input(1)->type()->cast<TensorType>();
-        TORCH_CHECK(
-            type0 != nullptr && type1 != nullptr &&
-                type1->scalarType().has_value(),
-            "input to type_as needs to be a tensor");
-        node->output()->setType(type0->withScalarType(type1->scalarType()));
-        break;
-      }
-      case aten::to: {
-        const auto type0 = node->input(0)->type()->cast<TensorType>();
-        const auto out_dtype = toIValue(node->input(1));
-        TORCH_CHECK(out_dtype, "No output type specified");
-        node->output()->setType(
-            type0->withScalarType(out_dtype->toScalarType()));
-        break;
-      }
-      case prim::add_optional: {
-        const auto type0 = node->input(0)->type()->cast<TensorType>();
-        const auto type1 = node->input(1)->type()->cast<TensorType>();
-        TORCH_CHECK(type0 != nullptr);
-        if (type1 != nullptr) {
-          node->output()->setType(type0);
-        } else {
-          const auto promoted_type = binary_broadcast_type(type0, type1);
-          node->output()->setType(promoted_type);
-        }
-        break;
-      }
       default:
         TORCH_CHECK(
             false,
-            "type inference failed, unrecognized operation encountered:",
-            node->kind().toDisplayString());
+            "type inference failed, unrecognized operation encountered.");
         // TODO: generate a proper error log, as this probably means something
         //       went unexpected.
         break;
@@ -424,11 +186,18 @@ class NaiveTypePropagator {
       const TensorTypePtr& op,
       const std::vector<int64_t>& dims,
       bool keepdim) {
-    TORCH_CHECK(
-        hasTypeAndDevice(op),
-        "Type and device propagation has failed, or was not provided enough information.");
+    TORCH_CHECK(hasTypeAndDim(op), "requires complete shape on input");
+    auto input_size = op->sizes();
+    int64_t ndims = keepdim ? input_size.size().value() : 0;
+    if (!keepdim) {
+      for (size_t i = 0; i < input_size.size(); i++) {
+        if (std::find(dims.begin(), dims.end(), i) == dims.end()) {
+          ndims++;
+        }
+      }
+    }
     return TensorType::create(
-        *op->scalarType(), *op->device(), c10::nullopt, c10::nullopt);
+        *op->scalarType(), *op->device(), ndims, c10::nullopt);
   }
 
   // TODO: we should comply to codegen type promotion.
@@ -442,23 +211,27 @@ class NaiveTypePropagator {
 
     if (op0 != nullptr && op1 != nullptr) {
       TORCH_CHECK(
-          hasTypeAndDevice(op0) && hasTypeAndDevice(op1),
-          "Type and device propagation has failed, or was not provided enough information.");
+          op0->sizes().size().has_value() && op1->sizes().size().has_value(),
+          "Cannot process input tensor without concrete number of dimensions.");
+      int64_t ndims = *op0->sizes().size() > *op1->sizes().size()
+          ? *op0->sizes().size()
+          : *op1->sizes().size();
+
       auto promoted_scalar_type = scalar_type.has_value()
           ? *scalar_type
           : c10::promoteTypes(*op0->scalarType(), *op1->scalarType());
 
       return TensorType::create(
-          promoted_scalar_type, *op0->device(), c10::nullopt, c10::nullopt);
+          promoted_scalar_type, *op0->device(), ndims, c10::nullopt);
     } else {
       auto ptr = (op0 != nullptr) ? op0 : op1;
       TORCH_CHECK(
-          hasTypeAndDevice(ptr),
-          "Type and device propagation has failed, or was not provided enough information.");
+          hasTypeAndDim(ptr),
+          "Type, device, and dimensionality propagation has failed, or was not provided enough information.");
       return TensorType::create(
           scalar_type.has_value() ? *scalar_type : *ptr->scalarType(),
           *ptr->device(),
-          c10::nullopt,
+          *ptr->sizes().size(),
           c10::nullopt);
     }
   }
index 5c72fab..cbf31ff 100644 (file)
@@ -24,19 +24,8 @@ DataType aten_opt_type_map(const c10::optional<at::ScalarType>& scalar_type) {
 }
 } // namespace
 
-TensorView::TensorView(TensorDomain* domain, DataType dtype, MemoryType mtype)
-    : Val(ValType::TensorView, dtype), domain_(domain), memory_type_(mtype) {
-  // Don't do this after transforms
-  if (domain_->domain() == domain_->getRootDomain()) {
-    // Mark the size-1 axes as broadcast to support implicit broadcast semantic
-    for (auto* id : domain_->domain()) {
-      if (!id->isBroadcast() && !id->isReduction() && !id->isGather() &&
-          id->extent()->isOneInt()) {
-        id->convertToBroadcast();
-      }
-    }
-  }
-}
+TensorView::TensorView(TensorDomain* _domain, DataType dtype, MemoryType mtype)
+    : Val(ValType::TensorView, dtype), domain_(_domain), memory_type_(mtype) {}
 
 TensorView::TensorView(const std::shared_ptr<c10::TensorType>& tensor_type)
     : Val(ValType::TensorView,
@@ -99,18 +88,10 @@ TensorView::TensorView(const std::shared_ptr<c10::TensorType>& tensor_type)
 TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner)
     : Val(src, ir_cloner),
       domain_(ir_cloner->clone(src->domain_)),
-      compute_at_pos_(src->compute_at_pos_),
-      max_producer_pos_(src->max_producer_pos_),
-      memory_type_(src->memory_type_),
-      swizzle_type_(src->swizzle_type_) {
-  for (const auto id : src->axesToSwizzle()) {
-    axes_to_swizzle_.push_back(ir_cloner->clone(id));
-  }
-}
-
-bool TensorView::hasAnyReduction() const {
-  return domain()->noReductions().size() != domain()->domain().size();
-}
+      compute_at_view_(ir_cloner->clone(src->compute_at_view_)),
+      relative_compute_at_axis_(src->relative_compute_at_axis_),
+      this_compute_at_axis_(src->this_compute_at_axis_),
+      memory_type_(src->memory_type_) {}
 
 bool TensorView::hasReduction() const {
   return domain()->hasReduction();
@@ -124,6 +105,10 @@ bool TensorView::hasGridReduction() const {
   return domain()->hasGridReduction();
 }
 
+bool TensorView::hasBlockBroadcast() const {
+  return domain()->hasBlockBroadcast();
+}
+
 bool TensorView::hasBroadcast() const {
   return domain()->hasBroadcast();
 }
@@ -166,162 +151,184 @@ IterDomain* TensorView::axis(int pos) const {
   return domain()->axis(pos);
 }
 
-void TensorView::setComputeAt(unsigned int pos, bool decrease) {
-  if (pos <= compute_at_pos_ && !decrease) {
-    return;
-  }
+TensorView* TensorView::unsafeClone() const {
+  TensorView* new_view = new TensorView(domain_, getDataType().value());
+  new_view->compute_at_view_ = compute_at_view_;
+  new_view->relative_compute_at_axis_ = relative_compute_at_axis_;
+  new_view->this_compute_at_axis_ = this_compute_at_axis_;
+  new_view->memory_type_ = memory_type_;
+  new_view->name_ = name();
+  return new_view;
+}
+
+void TensorView::setComputeAt(TensorView* computeAtView, int axis) {
+  compute_at_view_ = computeAtView;
+  relative_compute_at_axis_ = axis;
+  setThisComputeAtAxis();
 
   TORCH_INTERNAL_ASSERT(
-      (unsigned)pos <= nDims(),
-      "Invalid this computeAt position for T",
-      name(),
-      ": ",
-      pos);
+      getThisComputeAtAxis() >= 0 &&
+          (unsigned int)getThisComputeAtAxis() <= nDims(),
+      "Invalid computeAt on ",
+      this,
+      " tried to set to local axis ",
+      getThisComputeAtAxis());
 
-  compute_at_pos_ = pos;
+  TORCH_INTERNAL_ASSERT(
+      std::none_of(
+          domain()->domain().begin(),
+          domain()->domain().begin() + getThisComputeAtAxis(),
+          [](IterDomain* id) { return id->isReduction(); }),
+      "Invalid computeAt, reduction domain inside computeAt axis.");
+}
+
+void TensorView::setComputeAt(
+    TensorView* computeAtView,
+    int thisPos,
+    int relPos) {
+  compute_at_view_ = computeAtView;
+  relative_compute_at_axis_ = relPos;
+  this_compute_at_axis_ = thisPos;
+  TORCH_INTERNAL_ASSERT(
+      this_compute_at_axis_ <= nDims(), "Manually set an invalid computeAt.");
 }
 
-void TensorView::setMaxProducer(unsigned int pos, bool decrease) {
-  if (pos <= max_producer_pos_ && !decrease) {
-    return;
+// Where in compute_at_view does this->axis(pos) match up?
+// TODO: This doesn't seem like the safest function as a fusion output can ref
+// another fusion output,  we may want to check that there is a direct
+// consumer/producer relationship between this and compute_at view before using
+// this function, and creating another pass to handle relative outputs.
+int TensorView::getComputeAtRelPos(int pos) {
+  if (!hasComputeAt()) {
+    return pos;
   }
 
-  TORCH_INTERNAL_ASSERT(
-      (unsigned)pos <= nDims(),
-      "Invalid max producer position for T",
-      name(),
-      ": ",
-      pos);
+  if (!compute_at_view_->hasBroadcast()) {
+    return pos;
+  }
 
-  max_producer_pos_ = pos;
-}
+  size_t pos_cav = 0, pos_this = 0;
 
-TensorView* TensorView::computeAt(
-    TensorView* consumer,
-    int position,
-    ComputeAtMode mode) {
-  // Make sure this and consumer are not the same tensor, that's illegal
-  TORCH_CHECK(!sameAs(consumer), "Cannot call this->computeAt(this, ...)");
+  // We could be in an instance where pos == 0, but consumer[0] is bcast and
+  // this[0] is not
 
-  // We support negative axes, so increment it by consumer->nDims() + 1 and make
-  // sure the result is within consumer->nDims() + 1. being at consumer->nDims()
-  // means producer will be computed inline with consumer, hence the +1.
-  if (position < 0)
-    position += int(consumer->nDims()) + 1;
+  while (compute_at_view_->axis(pos_cav)->isBroadcast() &&
+         !(axis(pos_this)->isBroadcast())) {
+    pos_cav++;
+  }
 
-  TORCH_CHECK(
-      (position >= 0 && (unsigned int)position < consumer->nDims() + 1) ||
-          mode == ComputeAtMode::BestEffort,
-      "Compute at called on an position outside valid range.");
+  while ((int)pos_this < pos) {
+    TORCH_INTERNAL_ASSERT(
+        pos_cav < compute_at_view_->nDims(),
+        "Error computing relative position in computeAt.");
 
-  if (mode == ComputeAtMode::BestEffort) {
-    position = std::max(-1, position);
-    position = std::min((int)consumer->nDims(), position);
+    if (compute_at_view_->axis(pos_cav)->isBroadcast() &&
+        !(axis(pos_this)->isBroadcast())) {
+      pos_cav++;
+    } else {
+      pos_cav++;
+      pos_this++;
+    }
   }
 
-  ComputeAt::runAt(this, consumer, (unsigned int)position, mode);
+  return pos_cav;
+}
 
-  return this;
+void TensorView::setThisComputeAtAxis() {
+  if (compute_at_view_ == nullptr) {
+    relative_compute_at_axis_ = 0;
+    this_compute_at_axis_ = 0;
+    return;
+  }
+
+  // this[is{i1}, is{i2},] -> compute at compute_at_view[bS{i0}, iS{i1}, iS{i2}]
+  // axis = 2 this compute at axis = 1
+
+  // pos in compute at view
+  size_t pos_cav = 0, pos_this = 0;
+  while (pos_cav < relative_compute_at_axis_ && pos_this < nDims()) {
+    if (compute_at_view_->axis(pos_cav)->isBroadcast() &&
+        !(axis(pos_this)->isBroadcast())) {
+      pos_cav++;
+    } else {
+      pos_cav++;
+      pos_this++;
+    }
+  }
+
+  TORCH_INTERNAL_ASSERT(
+      pos_cav == relative_compute_at_axis_ ||
+          (pos_cav < compute_at_view_->nDims() &&
+           compute_at_view_->axis(pos_cav)->isBroadcast()),
+      "Error seting up relative position between this and what we view into.");
+
+  this_compute_at_axis_ = pos_this;
 }
 
-TensorView* TensorView::computeWith(
-    TensorView* consumer,
-    int position,
-    ComputeAtMode mode) {
+TensorView* TensorView::computeAt(TensorView* consumer, int axis) {
   // Make sure this and consumer are not the same tensor, that's illegal
   TORCH_CHECK(!sameAs(consumer), "Cannot call this->computeAt(this, ...)");
 
-  // We support negative axes, so increment it by this->nDims() + 1 and make
-  // sure the result is within this->nDims() + 1. being at this->nDims()
-  // means producer will be computed inline with this, hence the +1.
-  if (position < 0)
-    position += int(this->nDims()) + 1;
+  // We support negative axes, so increment it by consumer->nDims() + 1 and make
+  // sure the result is within consumer->nDims() + 1. being at consumer->nDims()
+  // means producer will be computed inline with consumer, hence the +1.
+  if (axis < 0)
+    axis += int(consumer->nDims()) + 1;
   TORCH_CHECK(
-      position >= 0 && (unsigned int)position < this->nDims() + 1,
-      "Compute at called on an position outside valid range.");
+      axis >= 0 && (unsigned int)axis < consumer->nDims() + 1,
+      "Compute at called on an axis outside valid range.");
 
-  ComputeAt::runWith(this, consumer, (unsigned int)position, mode);
+  ComputeAt::run(this, consumer, (unsigned int)axis);
 
   return this;
 }
 
-TensorView* TensorView::split(int axis_, Val* factor, bool inner_split) {
+TensorView* TensorView::split(int axis, Val* factor) {
   // Only check things associated with axis, factor will be validated in
   // IterDomain
   TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim TensorView");
 
-  if (axis_ < 0)
-    axis_ += domain()->nDims();
-
-  TORCH_INTERNAL_ASSERT(
-      axis_ >= 0,
-      "Split axis is less than 0 even after adjusting for nDims: ",
-      axis_);
+  if (axis < 0)
+    axis += domain()->nDims();
 
-  TORCH_CHECK(
-      axis_ >= (int)getComputeAtPosition(),
-      "Cannot split axis within compute at position. Axis = ",
-      axis_,
-      " computeAtPosition = ",
-      getComputeAtPosition());
-
-  TORCH_CHECK(
-      axis_ >= (int)getMaxProducerPosition(),
-      "Cannot split axis within max producer position. Axis = ",
-      axis_,
-      " maxProducerPosition = ",
-      getMaxProducerPosition());
-
-  TORCH_CHECK(
-      axis(axis_)->getParallelType() == ParallelType::Serial,
-      "Splitting an axis of non-Serial parallel type is not supported at this time."
-      " Parallelization strategy must be set after calling split.");
+  if (getComputeAtView() != nullptr)
+    if (axis < (int)getThisComputeAtAxis())
+      TORCH_CHECK(
+          false,
+          "Cannot split axis within compute at range. Axis = ",
+          axis,
+          " thisComputeAtAxis = ",
+          getThisComputeAtAxis());
 
-  domain()->split(axis_, factor, inner_split);
+  domain()->split(axis, factor);
   return this;
 }
 
-TensorView* TensorView::split(int axis, unsigned int factor, bool inner_split) {
-  split(axis, new Int(factor), inner_split);
+TensorView* TensorView::split(int axis, unsigned int factor) {
+  domain()->split(axis, new Int(factor));
   return this;
 }
 
 // Merge "axis" and "axis+1" into 1 dimension
 TensorView* TensorView::merge(int axis_o, int axis_i) {
   TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim TensorView");
-
   if (axis_o < 0)
     axis_o += domain()->nDims();
 
   if (axis_i < 0)
     axis_i += domain()->nDims();
 
-  TORCH_CHECK(
-      axis_o >= (int)getComputeAtPosition() &&
-          axis_i >= (int)getComputeAtPosition(),
-      false,
-      "Cannot merge axes within compute at position. Either axis ",
-      axis_o,
-      " or ",
-      axis_i,
-      " are within computeAtPosition = ",
-      getComputeAtPosition());
-
-  TORCH_CHECK(
-      axis_o >= (int)getMaxProducerPosition() &&
-          axis_i >= (int)getMaxProducerPosition(),
-      "Cannot merge axes within max producer position. Either axis ",
-      axis_o,
-      " or ",
-      axis_i,
-      " are within maxProducerPosition = ",
-      getMaxProducerPosition());
-
-  TORCH_CHECK(
-      axis(axis_o)->getParallelType() == ParallelType::Serial ||
-          axis(axis_i)->getParallelType() == ParallelType::Serial,
-      "Merging axes of non-Serial parallel type is not supported at this time."
-      " Parallelization strategy must be set after calling split.");
+  if (getComputeAtView() != nullptr)
+    if (axis_o + 1 < (int)getThisComputeAtAxis() ||
+        axis_i + 1 < (int)getThisComputeAtAxis())
+      TORCH_CHECK(
+          false,
+          "Cannot merge axis within compute at range. Either axis ",
+          axis_o,
+          " or ",
+          axis_i,
+          " are within thisComputeAtAxis = ",
+          getThisComputeAtAxis());
 
   domain()->merge(axis_o, axis_i);
   return this;
@@ -331,119 +338,24 @@ TensorView* TensorView::reorder(const std::unordered_map<int, int>& old2new_) {
   TORCH_INTERNAL_ASSERT(
       !(nDims() == 0 && old2new_.size() > 0),
       "Tried to reorder a 0-dim TensorView");
-
-  for (auto entry : old2new_) {
-    auto old_pos = entry.first < 0 ? entry.first + (int)nDims() : entry.first;
-    auto new_pos =
-        entry.second < 0 ? entry.second + (int)nDims() : entry.second;
-    if (old_pos == new_pos) {
-      continue;
-    }
-    TORCH_INTERNAL_ASSERT(
-        old_pos >= 0,
-        "Found \"old\" position that's less than 0 even though already adjusted by nDims: ",
-        old_pos);
-    TORCH_INTERNAL_ASSERT(
-        new_pos >= 0,
-        "Found \"new\" position that's less than 0 even though already adjusted by nDims: ",
-        new_pos);
-    TORCH_CHECK(
-        old_pos >= (int)getComputeAtPosition() &&
-            new_pos >= (int)getComputeAtPosition(),
-        "Cannot reorder axes within compute at position. Either axis ",
-        old_pos,
-        " or ",
-        new_pos,
-        " are within computeAtPosition = ",
-        getComputeAtPosition());
-
-    TORCH_CHECK(
-        old_pos >= (int)getMaxProducerPosition() &&
-            new_pos >= (int)getMaxProducerPosition(),
-        "Cannot reorder axes within max producer position. Either axis ",
-        old_pos,
-        " or ",
-        new_pos,
-        " are within maxProducerPosition = ",
-        getMaxProducerPosition());
-  }
-
   domain()->reorder(old2new_);
   return this;
 }
 
-TensorView* TensorView::swizzle(
-    SwizzleType type,
-    const std::vector<int>& axes) {
-  swizzle_type_ = type;
-
-  // Clear previously set swizzle axes if any
-  if (axes_to_swizzle_.size()) {
-    axes_to_swizzle_.clear();
-  }
-
-  if (swizzle_type_ == SwizzleType::Transpose) {
-    TORCH_CHECK(
-        axes.size() == 2,
-        "Invalid axis list: ",
-        axes,
-        ". Number of axes must be two.");
-    TORCH_CHECK(
-        axes[0] != axes[1],
-        "Invalid axis list: ",
-        axes,
-        ". Two distinctive axes must be given.");
-    TORCH_CHECK(
-        getMemoryType() == MemoryType::Shared,
-        "Transpose swizzle is meant for tensors on shared memory.");
-    for (auto pos : axes) {
-      if (pos < 0) {
-        pos += nDims();
-      }
-      TORCH_CHECK(pos >= 0 && pos < (int)nDims(), "Invalid axis: ", pos);
-      TORCH_CHECK(
-          pos >= (int)getComputeAtPosition(),
-          "Invalid axis: ",
-          pos,
-          ". Axis outside computeAt position is not allocated.");
-      TORCH_CHECK(
-          !axis(pos)->isReduction(),
-          "Invalid axis: ",
-          pos,
-          ". Swizzling a reduction axis is not supported");
-      TORCH_CHECK(
-          !axis(pos)->isBroadcast(),
-          "Invalid axis: ",
-          pos,
-          ". Swizzling a broadcast axis is not supported");
-      axes_to_swizzle_.push_back(axis(pos));
-    }
-  }
-
-  return this;
-}
-
 TensorView* TensorView::rFactor(const std::vector<int>& axes) {
-  // TODO: I think we should do this but
-  // NVFuserTest.FusionSmemBlockGemmCache_CUDA prevents it from going in at the
-  // moment.
-
-  // TORCH_INTERNAL_ASSERT(
-  //     !hasComputeAt(), "Cannot rfactor tensors after compute at has been
-  //     set.");
   TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim TensorView");
-  TORCH_INTERNAL_ASSERT(definition()->isA<ReductionOp>());
   FusionGuard fg(fusion());
+  Expr* origin_expr = fusion()->origin(this);
   TORCH_CHECK(
-      definition() != nullptr &&
-          definition()->getExprType() == ExprType::ReductionOp,
+      origin_expr != nullptr &&
+          origin_expr->getExprType() == ExprType::ReductionOp,
       "Error rfactoring ",
       this,
-      " its definition is either a nullptr or not a reduction.");
+      " its origin is either a nullptr or not a reduction.");
   TORCH_CHECK(
       !domain()->hasRFactor(), "Cannot call rfactor on the same view twice.");
 
-  ReductionOp* this_definition = definition()->as<ReductionOp>();
+  ReductionOp* this_origin = origin_expr->as<ReductionOp>();
 
   // Split tensor view into 2 parts
   auto domain_pair = domain()->rFactor(axes);
@@ -461,279 +373,107 @@ TensorView* TensorView::rFactor(const std::vector<int>& axes) {
   TensorView* consumer = this;
 
   // Setup dependency chain, inserting producer before this op.
-  // Expr* producer_definition =
+  // Expr* producer_origin =
   new ReductionOp(
-      this_definition->getReductionOpType(),
-      this_definition->init(),
+      this_origin->getReductionOpType(),
+      this_origin->init(),
       producer,
-      this_definition->in());
+      this_origin->in());
 
-  // Expr* consumer_definition =
+  // Expr* consumer_origin =
   new ReductionOp(
-      this_definition->getReductionOpType(),
-      this_definition->init(),
+      this_origin->getReductionOpType(),
+      this_origin->init(),
       consumer,
       producer);
 
   return producer;
 }
 
-TensorView* TensorView::welfordRfactorHelper(
-    TensorView* tv,
-    const std::vector<int>& axes) {
-  // Hack:
-  // Semantically we should always keep the outputs of welfordOp scheduled
-  // the same but the user end cannot guarantee that.
-  // In order to guarantee that the rFactor is defined meaningfully the
-  // scheduling of the output TV that got the rfactor call is force replayed
-  // towards the other two
-
-  if (!sameAs(tv)) {
-    auto root = tv->getRootDomain();
-    auto this_root = getRootDomain();
-
-    // construct a trivial root domain map
-    std::unordered_map<IterDomain*, IterDomain*> id_map;
-    for (size_t i = 0; i < root.size(); i++) {
-      id_map[this_root[i]] = root[i];
-    }
-
-    // replay on the target tv
-    ReplayTransformations replay(domain()->domain(), id_map);
-
-    // construct the new tensor domain
-    std::vector<IterDomain*> new_id;
-    for (auto id : domain()->domain()) {
-      TORCH_INTERNAL_ASSERT(
-          replay.getReplay().count(id), "Welford Replay Failed");
-      new_id.push_back(replay.getReplay().at(id));
-    }
-
-    std::vector<bool> new_contig(
-        tv->domain()->contiguity().begin(), tv->domain()->contiguity().end());
-    // replace tensor domain of target tv
-    tv->setDomain(new TensorDomain(tv->getRootDomain(), new_id, new_contig));
-  }
-
-  // Split tensor view into 2 parts
-  auto domain_pair = tv->domain()->rFactor(axes);
-  // Producer in the pair
-  auto producer_domain = domain_pair.first;
-  // Consumer in the pair
-  auto consumer_domain = domain_pair.second;
-
-  // This domain will be the consumer, so create the producer
-  TensorView* producer =
-      new TensorView(producer_domain, tv->getDataType().value());
-
-  // Set domain of consumer
-  tv->setDomain(consumer_domain);
-
-  return producer;
-}
-
-WelfordResult TensorView::rFactor(
-    const std::vector<int>& axes,
-    TensorView* avg,
-    TensorView* var,
-    TensorView* n) {
-  TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim TensorView");
-  FusionGuard fg(fusion());
-  TORCH_CHECK(
-      definition() != nullptr &&
-          definition()->getExprType() == ExprType::WelfordOp,
-      "Error rfactoring welford ",
-      this,
-      " its definition is either a nullptr or not a welford.");
-  TORCH_CHECK(
-      !domain()->hasRFactor(), "Cannot call rfactor on the same view twice.");
-
-  WelfordOp* wop = definition()->as<WelfordOp>();
-
-  TORCH_INTERNAL_ASSERT(
-      avg->sameAs(wop->outAvg()), "Welford rfactor not used correctly");
-  TORCH_INTERNAL_ASSERT(
-      var->sameAs(wop->outVar()), "Welford rfactor not used correctly");
-  TORCH_INTERNAL_ASSERT(
-      n->sameAs(wop->outN()), "Welford rfactor not used correctly");
-
-  std::vector<std::pair<TensorView*, TensorView*>> tv2rf{
-      {avg, nullptr}, {var, nullptr}, {n, nullptr}};
-
-  // Make sure this gets rfactored last so everybody gets
-  //  replayed correctly
-  for (auto& it : tv2rf) {
-    if (!sameAs(it.first)) {
-      it.second = welfordRfactorHelper(it.first, axes);
-    }
-  }
-
-  for (auto& it : tv2rf) {
-    if (sameAs(it.first)) {
-      it.second = welfordRfactorHelper(it.first, axes);
-    }
-  }
-
-  TensorView* producer_avg = tv2rf[0].second;
-  TensorView* producer_var = tv2rf[1].second;
-  TensorView* producer_n = tv2rf[2].second;
-
-  // Setup dependency chain, inserting producer before this op.
-  // Expr* producer_definition =
-  new WelfordOp(
-      producer_avg,
-      producer_var,
-      producer_n, /*out var/avg/count */
-      wop->initAvg(),
-      wop->initVar(),
-      wop->initN(), /*init var/avg/count */
-      wop->inAvg(),
-      wop->inVar(),
-      wop->inN());
-
-  // Expr* consumer_definition =
-  new WelfordOp(
-      avg,
-      var,
-      n,
-      wop->initAvg(),
-      wop->initVar(),
-      wop->initN(),
-      producer_avg,
-      producer_var,
-      producer_n);
-
-  return WelfordResult(producer_avg, producer_var, producer_n);
-}
-
 TensorView* TensorView::cache_before() {
   FusionGuard fg(fusion());
 
+  Expr* origin_expr = fusion()->origin(this);
   TORCH_CHECK(
-      definition() != nullptr && !isFusionInput(),
+      origin_expr != nullptr && !fusion()->hasInput(this),
       "Error adding cache_before ",
       this,
-      " its definition is a nullptr and we restrict using cache_before on an input.");
+      " its origin is a nullptr and we restrict using cache_before on an input.");
 
   TORCH_CHECK(
-      isFusionOutput() ||
-          definition()->getExprType() != ExprType::ReductionOp ||
-          definition()->getExprType() != ExprType::WelfordOp,
+      fusion()->hasOutput(this) ||
+          origin_expr->getExprType() != ExprType::ReductionOp,
       "Error adding cache_before ",
       this,
-      " its definition is a reduction and it is not an output, instead please use cache_after.");
-
-  // Previously, caching computed-at tensors was allowed but was never
-  // really robust. Make it an error unless it is really needed.
-  TORCH_CHECK(
-      !hasComputeAt(),
-      "Caching computed-at tensors is not allowed. Apply caching before computeAt");
-
-  // It also did additional transformation when a producer tensor has computeAt.
-  // Make sure we no longer rely on that behavior.
-  if (definition() != nullptr) {
-    for (TensorView* producer_of_producer :
-         ir_utils::filterByType<TensorView>(definition()->inputs())) {
-      TORCH_CHECK(
-          !producer_of_producer->hasComputeAt(),
-          "Potentially invalid computeAt and caching detected. Apply caching before computeAt.");
-    }
-  }
+      " its origin is a reduction and it is not an output, instead please use cache_after.");
 
   // Create Producer Domain
-  // This domain will be the consumer which needs a new domain, so replace the
-  // producers domain with this domain.
+  // This domain will be the consumer, so create the producer
   auto root_domain = getRootDomain();
-
   TensorView* producer = new TensorView(
       new TensorDomain(
-          domain()->getRootDomain(),
-          domain()->domain(),
-          domain()->contiguity()),
+          root_domain, std::vector<bool>(root_domain.size(), true)),
       getDataType().value());
 
   // Set domain of consumer
   TensorView* consumer = this;
 
-  size_t i = 0;
-  auto no_reduction_root_domain = TensorDomain::noReductions(getRootDomain());
-  std::vector<IterDomain*> new_root_domain(no_reduction_root_domain.size());
-  for (const auto& dom : no_reduction_root_domain) {
-    new_root_domain[i++] = dom->clone();
+  // this TV is an output and its origin is a reduction
+  // remove reduction axis from this tv
+  if (origin_expr->getExprType() == ExprType::ReductionOp) {
+    size_t i = 0;
+    auto no_reduction_root_domain = TensorDomain::noReductions(getRootDomain());
+    std::vector<IterDomain*> new_root_domain(no_reduction_root_domain.size());
+    for (const auto& dom : no_reduction_root_domain) {
+      new_root_domain[i++] = dom->clone();
+    }
+    consumer->setDomain(new TensorDomain(
+        new_root_domain, std::vector<bool>(new_root_domain.size(), true)));
   }
 
-  consumer->setDomain(new TensorDomain(
-      new_root_domain, std::vector<bool>(new_root_domain.size(), true)));
-
   // Insert producer - Cache_Before (CB) - before this TV.
-  // Before: Prev TV -> [Definition Op] -> This TV
-  // After:  Prev TV -> [Definition Op] -> New CB TV -> [Set Op] -> This TV
+  // Before: Prev TV -> [Origin Op] -> This TV
+  // After:  Prev TV -> [Origin Op] -> New CB TV -> [Set Op] -> This TV
 
   // Get inputs for origin expression
-  auto expr_inputs = definition()->inputs();
-  // Expr* producer_definition =
-  ir_utils::replaceValInExpr(definition(), this, producer);
+  auto expr_inputs = origin_expr->inputs();
+
+  // Expr* producer_origin =
+  createExprConsumer(origin_expr, producer);
 
   // Expr* producer_uses =
   new UnaryOp(UnaryOpType::Set, consumer, producer);
 
-  // definition_ is no longer valid
-  // setDefinition(nullptr);
-
-  auto replayed_consumer_pair =
-      TransformReplay::replayCasP(consumer, producer, -1);
-  consumer->setDomain(replayed_consumer_pair.first);
+  // Before: This TV -> Next TV
+  // After:  New TV (CB) -> This TV -> Next TV
+  if (hasComputeAt()) {
+    TransformReplay::replayPasC(producer, consumer, -1);
+    auto this_ca_pos = getThisComputeAtAxis();
+    producer->computeAt(consumer, this_ca_pos);
+  } else {
+    // Before: Prev TV -> This TV
+    // After:  Prev TV -> New TV (CB) -> This TV
+    // Iterate over origin expression inputs for cache_before on outputs
+    for (TensorView* origin_input :
+         ir_utils::filterByType<TensorView>(expr_inputs)) {
+      if (origin_input->hasComputeAt() &&
+          origin_input->getComputeAtView() == this) {
+        TransformReplay::replayPasC(producer, consumer, -1);
+
+        auto origin_ca_pos = origin_input->getThisComputeAtAxis();
+        auto origin_rel_ca_pos = origin_input->getRelativeComputeAtAxis();
+        origin_input->computeAt(producer, origin_ca_pos);
+        producer->setComputeAt(consumer, origin_rel_ca_pos);
+      }
+    }
+  }
 
   return producer;
 }
 
-TensorView* TensorView::cache_fork() {
-  FusionGuard fg(fusion());
-
-  // Before: [Expr] -> This TV (Global Output) -> [Usage Expr]
-  // After:  [Expr] -> This TV (Local) -> [Usage Expr] > Next TV
-  //                            (Fork) -> [Set Expr]   -> New TV (Global Output)
-
-  TORCH_CHECK(
-      fusion()->hasOutput(this) && !this->uses().empty(),
-      "Error adding cache_fork ",
-      this,
-      " this TensorView must be an output with subsequent uses");
-
-  // Previously, caching computed-at tensors was allowed but was never
-  // really robust. Make it an error unless it is really needed.
-  TORCH_CHECK(
-      !hasComputeAt(),
-      "Caching computed-at tensors is not allowed. Apply caching before computeAt");
-
-  // This domain will be the producer, so create the consumer
-  auto root_domain = TensorDomain::noReductions(getRootDomain());
-  TensorView* new_output = new TensorView(
-      new TensorDomain(
-          IterDomain::clone(root_domain),
-          std::vector<bool>(root_domain.size(), true)),
-      getDataType().value());
-
-  // Create write operation from this TV to new output
-  new UnaryOp(UnaryOpType::Set, new_output, this);
-
-  // The new TV becomes an output.
-  // New TV has global memory type.
-  // This TV has local memory type.
-  fusion()->replaceOutput(this, new_output);
-
-  // Transform new output according to this TV
-  auto replayed_output_pair = TransformReplay::replayCasP(new_output, this, -1);
-  new_output->setDomain(replayed_output_pair.first);
-
-  return new_output;
-}
-
 TensorView* TensorView::cache_after() {
   FusionGuard fg(fusion());
 
-  const bool kIsFusionInput = fusion()->hasInput(this);
-
   // Get all the uses for this Tensorview
   TORCH_CHECK(
       !fusion()->hasOutput(this),
@@ -741,26 +481,6 @@ TensorView* TensorView::cache_after() {
       this,
       " we restrict using cache_after on an output.");
 
-  // Previously, caching computed-at tensors was allowed but was never
-  // really robust. Make it an error unless it is really needed.
-  TORCH_CHECK(
-      !hasComputeAt(),
-      "Caching computed-at tensors is not allowed. Apply caching before computeAt.");
-
-  // It also did additional transformation when this tensor is an
-  // input and the outputs of its consumers have computeAt. Make sure
-  // we no longer rely on that behavior.
-  if (kIsFusionInput) {
-    for (const auto& expr : uses()) {
-      for (TensorView* output :
-           ir_utils::filterByType<TensorView>(expr->outputs())) {
-        TORCH_CHECK(
-            !output->hasComputeAt(),
-            "Potentially invalid computeAt and caching detected. Apply caching before computeAt.");
-      }
-    }
-  }
-
   // Create Consumer Domain
   // Keep Broadcast Axis (Permanent)
   // Remove Reduction Axis
@@ -786,12 +506,37 @@ TensorView* TensorView::cache_after() {
 
   // Expr* consumer_uses =
   for (auto expr : fusion()->unordered_uses(this)) {
-    ir_utils::replaceValInExpr(expr, this, consumer);
+    createExprProducer(expr, this, consumer);
   }
 
-  // Expr* consumer_definition =
+  // Expr* consumer_origin =
   new UnaryOp(UnaryOpType::Set, consumer, producer);
 
+  // Before: This TV -> Next TV
+  // After:  This TV -> New TV (After) -> Next TV
+  if (hasComputeAt()) {
+    TransformReplay::replayCasP(consumer, producer, -1);
+
+    auto rel_ca_pos = getRelativeComputeAtAxis();
+    auto this_ca_pos = getThisComputeAtAxis();
+    auto this_ca_view = getComputeAtView();
+
+    computeAt(consumer, this_ca_pos);
+    consumer->setComputeAt(this_ca_view, rel_ca_pos);
+  } else {
+    // Check users of this TV for computeAt for cache_after on inputs
+    for (const auto& expr : fusion()->unordered_uses(consumer)) {
+      for (TensorView* output :
+           ir_utils::filterByType<TensorView>(expr->outputs())) {
+        if (output->hasComputeAt()) {
+          TransformReplay::replayPasC(consumer, output, -1);
+          auto output_ca_pos = output->getThisComputeAtAxis();
+          consumer->setComputeAt(output, output_ca_pos);
+        }
+      }
+    }
+  }
+
   return consumer;
 }
 
@@ -804,76 +549,157 @@ void TensorView::setMemoryType(MemoryType mt) {
   }
 }
 
-void TensorView::clearReductionIterDomains() {
-  TORCH_INTERNAL_ASSERT(
-      !domain()->hasRFactor(),
-      "should not call clearReductionIterDomains on rfactor tv");
+namespace {
 
-  TORCH_INTERNAL_ASSERT(
-      domain()->domain() == getRootDomain(),
-      "should not call clearReductionIterDomains on already transformed TensorDomains");
-
-  std::vector<IterDomain*> new_root;
-  std::vector<bool> new_contig;
-  for (size_t i = 0; i < getRootDomain().size(); i++) {
-    if (!getRootDomain()[i]->isReduction()) {
-      new_root.push_back(getRootDomain()[i]);
-      new_contig.push_back(domain()->contiguity()[i]);
-    }
+// Create New Expr given consumer - [output of the expression]
+struct CreateExprConsumer : public OptInDispatch {
+ public:
+  static void create(Expr* expr, TensorView* consumer) {
+    CreateExprConsumer cec(consumer);
+    cec.handle(expr);
   }
 
-  setDomain(new TensorDomain(new_root, new_contig));
-}
+ private:
+  explicit CreateExprConsumer(TensorView* consumer) : consumer_(consumer) {}
 
-TensorViewBuilder& TensorViewBuilder::ndims(size_t ndims) {
-  TORCH_CHECK(shape_.empty() || shape_.size() == ndims);
-  TORCH_CHECK(contiguity_.empty() || contiguity_.size() == ndims);
-  ndims_ = ndims;
-  return *this;
-}
+  void handle(Expr* expr) final {
+    OptInDispatch::handle(expr);
+  }
 
-TensorViewBuilder& TensorViewBuilder::dtype(DataType dtype) {
-  dtype_ = dtype;
-  return *this;
-}
+  void handle(UnaryOp* unary_expr) final {
+    new UnaryOp(unary_expr->getUnaryOpType(), consumer_, unary_expr->in());
+  }
 
-TensorViewBuilder& TensorViewBuilder::contiguity(std::vector<bool> contiguity) {
-  TORCH_CHECK(contiguity_.empty(), "Attempting to reset contiguity");
-  if (!contiguity.empty()) {
-    TORCH_CHECK(ndims_ == 0 || ndims_ == contiguity.size());
-    ndims_ = contiguity.size();
+  void handle(BinaryOp* binary_expr) final {
+    new BinaryOp(
+        binary_expr->getBinaryOpType(),
+        consumer_,
+        binary_expr->lhs(),
+        binary_expr->rhs());
   }
-  contiguity_ = std::move(contiguity);
-  return *this;
-}
 
-TensorViewBuilder& TensorViewBuilder::shape(std::vector<int64_t> shape) {
-  TORCH_CHECK(shape_.empty(), "Attempting to reset shape");
-  if (!shape.empty()) {
-    TORCH_CHECK(ndims_ == 0 || ndims_ == shape.size());
-    ndims_ = shape.size();
+  void handle(TernaryOp* ternary_expr) final {
+    new TernaryOp(
+        ternary_expr->getTernaryOpType(),
+        consumer_,
+        ternary_expr->in1(),
+        ternary_expr->in2(),
+        ternary_expr->in3());
   }
-  shape_ = std::move(shape);
-  return *this;
-}
 
-TensorView* TensorViewBuilder::build() const {
-  // Build the domain
-  std::vector<IterDomain*> domain(ndims_, nullptr);
-  for (size_t i = 0; i < ndims_; i++) {
-    if (shape_.empty() || shape_[i] == -1) {
-      domain[i] = new IterDomain(new Int(0), new Int());
+  void handle(ReductionOp* reduction_expr) final {
+    new ReductionOp(
+        reduction_expr->getReductionOpType(),
+        reduction_expr->init(),
+        consumer_,
+        reduction_expr->in());
+  }
+
+  void handle(BroadcastOp* broadcast_expr) final {
+    new BroadcastOp(consumer_, broadcast_expr->in());
+  }
+
+ private:
+  TensorView* consumer_ = nullptr;
+};
+
+// Create New Expr given producer - [an input for the expression]
+struct CreateExprProducer : public OptInDispatch {
+ public:
+  static void create(Expr* expr, TensorView* current, TensorView* producer) {
+    CreateExprProducer cep(current, producer);
+    cep.handle(expr);
+  }
+
+ private:
+  explicit CreateExprProducer(TensorView* current, TensorView* producer)
+      : current_(current), producer_(producer) {}
+
+  void handle(Expr* expr) final {
+    OptInDispatch::handle(expr);
+  }
+
+  void handle(UnaryOp* unary_expr) final {
+    new UnaryOp(unary_expr->getUnaryOpType(), unary_expr->out(), producer_);
+  }
+
+  void handle(BinaryOp* binary_expr) final {
+    if (binary_expr->lhs()->sameAs(current_)) {
+      new BinaryOp(
+          binary_expr->getBinaryOpType(),
+          binary_expr->out(),
+          producer_,
+          binary_expr->rhs());
     } else {
-      TORCH_CHECK(
-          shape_[i] >= 0,
-          "Invalid extent value. ",
-          "For a tensor representing a single scalar use ndims = 0 with no sizes set.");
-      domain[i] = new IterDomain(new Int(0), new Int(shape_[i]));
+      new BinaryOp(
+          binary_expr->getBinaryOpType(),
+          binary_expr->out(),
+          binary_expr->lhs(),
+          producer_);
     }
   }
 
-  // Create the final TensorView
-  return new TensorView(new TensorDomain(domain, contiguity_), dtype_);
+  void handle(TernaryOp* ternary_expr) final {
+    if (ternary_expr->in1()->sameAs(current_)) {
+      new TernaryOp(
+          ternary_expr->getTernaryOpType(),
+          ternary_expr->out(),
+          producer_,
+          ternary_expr->in2(),
+          ternary_expr->in3());
+    } else if (ternary_expr->in2()->sameAs(current_)) {
+      new TernaryOp(
+          ternary_expr->getTernaryOpType(),
+          ternary_expr->out(),
+          ternary_expr->in1(),
+          producer_,
+          ternary_expr->in3());
+    } else {
+      new TernaryOp(
+          ternary_expr->getTernaryOpType(),
+          ternary_expr->out(),
+          ternary_expr->in1(),
+          ternary_expr->in2(),
+          producer_);
+    }
+  }
+
+  void handle(ReductionOp* reduction_expr) final {
+    new ReductionOp(
+        reduction_expr->getReductionOpType(),
+        reduction_expr->init(),
+        reduction_expr->out(),
+        producer_);
+  }
+
+  void handle(BroadcastOp* broadcast_expr) final {
+    new BroadcastOp(broadcast_expr->out(), producer_);
+  }
+
+ private:
+  TensorView* current_ = nullptr;
+  TensorView* producer_ = nullptr;
+};
+
+} // namespace
+
+// In Cache Before, for the origin expr of the original tensor,
+// we create a new operation where the original tensor is replaced
+// with the new cache tensor. This function creates a new expr
+// given the consumer, the output of the expression.
+void TensorView::createExprConsumer(Expr* expr, TensorView* consumer) {
+  CreateExprConsumer::create(expr, consumer);
+}
+
+// In Cache After, for all the uses of the original tensor, we create
+// a new operation where the original tensor is replaced with the new
+// cache tensor. This function creates a new expr given a producer,
+// an input for the expression.
+void TensorView::createExprProducer(
+    Expr* expr,
+    TensorView* current,
+    TensorView* producer) {
+  CreateExprProducer::create(expr, current, producer);
 }
 
 } // namespace cuda
index 7e41caf..1ea703d 100644 (file)
@@ -45,7 +45,7 @@ void ReplayTransformations::handle(Split* s) {
       "Transform traversal failed, modified a node but it was not a leaf node.");
 
   // Replay the split onto mapped
-  auto outs = IterDomain::split(mapped, s->factor(), s->innerSplit());
+  auto outs = IterDomain::split(mapped, s->factor());
   // Remove mapped from the leaf IDs
   leaf_ids_.erase(mapped);
 
@@ -218,16 +218,14 @@ void ReplayTransformations::runReplay() {
 BestEffortReplay::BestEffortReplay(
     const std::vector<IterDomain*>& replay_domain,
     const std::vector<IterDomain*>& target_domain,
-    std::unordered_map<IterDomain*, IterDomain*> target2replay_map,
-    std::unordered_map<IterDomain*, IterDomain*> forward_id_map)
-    : target2replay_id_map_(std::move(target2replay_map)),
-      forward_id_map_(std::move(forward_id_map)) {
-  for (auto entry : target2replay_id_map_) {
+    std::unordered_map<IterDomain*, IterDomain*> replay_map,
+    bool forward_bcast_mismatch)
+    : id_map_(std::move(replay_map)) {
+  for (auto entry : id_map_)
     leaf_ids_[entry.second] = counter++;
-  }
 
   // Grab expr history of iter domains in target_domain
-  std::vector<Expr*> target_exprs = ExprSort::getExprs(
+  std::vector<Expr*> t_exprs = ExprSort::getExprs(
       FusionGuard::getCurFusion(),
       std::vector<Val*>(target_domain.begin(), target_domain.end()));
 
@@ -237,198 +235,160 @@ BestEffortReplay::BestEffortReplay(
   // replay_domain domain. This will be used to propagate the target_domain to
   // replay_domain map.
 
-  // Map replay domain's IterDomains to the Exprs they're used in
-  std::vector<Expr*> replay_exprs = ExprSort::getExprs(
+  // Maps replay domain's IterDomains to the Exprs they're used in
+  std::vector<Expr*> r_exprs = ExprSort::getExprs(
       FusionGuard::getCurFusion(),
       std::vector<Val*>(replay_domain.begin(), replay_domain.end()));
-
-  std::unordered_map<IterDomain*, Expr*> replay_id2expr_map;
-  for (auto replay_expr : replay_exprs) {
-    for (auto id : ir_utils::filterByType<IterDomain>(replay_expr->inputs())) {
+  std::unordered_map<IterDomain*, Expr*> replay_expr_map;
+  for (auto r_expr : r_exprs) {
+    for (auto id : ir_utils::filterByType<IterDomain>(r_expr->inputs())) {
       TORCH_INTERNAL_ASSERT(
-          replay_id2expr_map.find(id) == replay_id2expr_map.end(),
-          "Error trying to map rfactor root domain during replay.",
-          " An IterDomain was found to be used in more than one expression.");
+          replay_expr_map.find(id) == replay_expr_map.end(),
+          "Error trying to map rfactor root domain during replay. IterDomain's shouldn't have more than one use.");
       // Only want to forward rfactor in map
-      replay_id2expr_map[id] = replay_expr;
+      replay_expr_map[id] = r_expr;
     }
   }
 
   std::string err_str(
-      "Error during replay, a transformation was called that conflicts with an rfactor call.");
+      "Error during replay, a computeAt was called that conflicts with an rfactor call.");
 
   // Iterate through target IterDomains' history and compare with what we
   // recorded from replay_domain
-  for (auto target_expr : target_exprs) {
-    auto target_inps_filtered =
-        ir_utils::filterByType<IterDomain>(target_expr->inputs());
-
-    // If any input argument in target expression is in the forward map then
-    // forward the mapped IterDomains in replay and continue to the next
-    // expression as target_expr cannot match a replay_expr
-    if (std::any_of(
-            target_inps_filtered.begin(),
-            target_inps_filtered.end(),
-            [&](IterDomain* target_inp) {
-              return this->inForwardMap(target_inp);
-            })) {
-      for (auto target_inp : target_inps_filtered) {
-        if (inForwardMap(target_inp)) {
-          auto target2replay_it = target2replay_id_map_.find(target_inp);
-          if (target2replay_it != target2replay_id_map_.end()) {
-            // Replace target_inp entry in target2replay_id_map_ with forwarded
-            // id
-            target2replay_id_map_[getForwardedId(target_inp)] =
-                target2replay_it->second;
-            target2replay_id_map_.erase(target_inp);
-          }
-        }
-      }
-      // Continue to next target_expr
-      continue;
-    }
-
-    std::vector<IterDomain*> target_id_inps(
-        target_inps_filtered.begin(), target_inps_filtered.end());
-
-    std::vector<IterDomain*> replay_inps =
-        std::vector<IterDomain*>(target_id_inps.size(), nullptr);
-
-    bool missing_replay_input = false;
-
-    // Map target_expr inputs to replay domain directly
-    for (const auto t_i : c10::irange(target_id_inps.size())) {
-      // There might not be a mapping, that could be okay (depends on rfactor
-      // checking).
-      auto it = target2replay_id_map_.find(target_id_inps[t_i]);
-      if (it != target2replay_id_map_.end()) {
-        replay_inps[t_i] = getForwardedId(it->second);
-      } else {
-        missing_replay_input = true;
-      }
+  for (auto t_expr : t_exprs) {
+    auto t_inps_filtered = ir_utils::filterByType<IterDomain>(t_expr->inputs());
+    std::vector<IterDomain*> t_inps(
+        t_inps_filtered.begin(), t_inps_filtered.end());
+
+    std::vector<IterDomain*> r_inps =
+        std::vector<IterDomain*>(t_inps.size(), nullptr);
+
+    // Map t_expr inputs to replay domain directly
+    for (const auto t_i : c10::irange(t_inps.size())) {
+      // There might not be a mapping, that could be okay.
+      auto it = id_map_.find(t_inps[t_i]);
+      if (it != id_map_.end())
+        r_inps[t_i] = it->second;
     }
 
-    // Check if any of the associated replay id's are part of an rfactor domain
-    bool replay_has_rfactor_inp =
-        std::any_of(replay_inps.begin(), replay_inps.end(), [](IterDomain* id) {
+    bool has_rfactor =
+        std::any_of(r_inps.begin(), r_inps.end(), [](IterDomain* id) {
           return id == nullptr ? false : id->isRFactorProduct();
         });
 
-    // If some replay id inputs are part of rfactor, make sure all target
-    // expression inputs map to a replay input
-    if (replay_has_rfactor_inp) {
+    if (has_rfactor) {
       bool no_missing_exprs = std::none_of(
-          replay_inps.begin(),
-          replay_inps.end(),
-          [&replay_id2expr_map](IterDomain* id) {
+          r_inps.begin(), r_inps.end(), [&replay_expr_map](IterDomain* id) {
             if (id == nullptr) {
               return true;
             } else {
-              return replay_id2expr_map.find(id) == replay_id2expr_map.end();
+              return replay_expr_map.find(id) == replay_expr_map.end();
             }
           });
       TORCH_INTERNAL_ASSERT(no_missing_exprs, err_str);
     }
 
-    // If any inputs are missing, continue as this expr doesn't match.
-    if (missing_replay_input) {
-      TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
-      continue;
+    // I would like to have this more generic or have this whole function go
+    // through dispatch, but trying to make quick forward progress on
+    // https://github.com/csarofeen/pytorch/issues/286 This mapping reflects
+    // more closely what is done in ReplayTransform with mismatched
+    // broadcast/merge
+    if (forward_bcast_mismatch && !has_rfactor &&
+        t_expr->getExprType().value() == ExprType::Merge) {
+      auto t_merge = t_expr->as<Merge>();
+      auto t_outer = t_merge->outer();
+      auto t_inner = t_merge->inner();
+      IterDomain* r_outer = id_map_.find(t_outer) != id_map_.end()
+          ? id_map_.at(t_outer)
+          : nullptr;
+      IterDomain* r_inner = id_map_.find(t_inner) != id_map_.end()
+          ? id_map_.at(t_inner)
+          : nullptr;
+      if (r_outer != nullptr && r_inner == nullptr && t_inner->isBroadcast()) {
+        id_map_[t_merge->out()] = r_outer;
+      } else if (
+          r_inner != nullptr && r_outer == nullptr && t_outer->isBroadcast()) {
+        id_map_[t_merge->out()] = r_inner;
+      }
     }
 
-    // Find which replay_expr maps to the target_expr
-    Expr* replay_expr = nullptr;
-    // Check if all inputs have the same expression
-    bool mismatched_replay_exprs = false;
-    for (auto replay_inp : replay_inps) {
-      auto it = replay_id2expr_map.find(replay_inp);
-      if (it != replay_id2expr_map.end()) {
-        if (replay_expr == nullptr) {
-          replay_expr = it->second;
-        } else {
-          mismatched_replay_exprs =
-              mismatched_replay_exprs || replay_expr != it->second;
+    Expr* r_expr = nullptr;
+    for (auto r_inp : r_inps) {
+      if (r_inp != nullptr) {
+        auto it = replay_expr_map.find(r_inp);
+        if (it != replay_expr_map.end()) {
+          r_expr = it->second;
+          break;
         }
-      } else {
-        // If no expr is mapped then set mismatched epxrs to go to continue to
-        // the next target expr
-        mismatched_replay_exprs = true;
       }
     }
 
-    // If expressions of mapped inputs don't match, then continue to next target
-    // expr
-    if (mismatched_replay_exprs || replay_expr == nullptr) {
-      TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
+    if (r_expr == nullptr) {
+      TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
       continue;
     }
 
-    bool mismatched_inputs = replay_inps.size() != replay_expr->inputs().size();
-    for (size_t i = 0; i < replay_inps.size() && !mismatched_inputs; i++) {
-      mismatched_inputs =
-          mismatched_inputs || replay_expr->inputs()[i] != replay_inps[i];
+    bool mismatched_inputs = r_inps.size() != r_expr->inputs().size();
+    for (size_t i = 0; i < r_inps.size() && !mismatched_inputs; i++) {
+      if (r_inps[i] == nullptr) {
+        mismatched_inputs = true;
+      } else {
+        mismatched_inputs =
+            mismatched_inputs || r_expr->inputs()[i] != r_inps[i];
+      }
     }
 
-    // If there isn't an rfactor id in the replay's inputs and there's a
-    // mismatched input, continue
     if (mismatched_inputs) {
-      TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
+      TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
       continue;
     }
 
-    // If there isn't an rfactor id in the replay's inputs and there's a
-    // mismatch in replay_expr's and target_expr's outputs, continue
-    if (target_expr->outputs().size() != replay_expr->outputs().size()) {
-      TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
+    if (t_expr->outputs().size() != r_expr->outputs().size()) {
+      TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
       continue;
     }
 
-    // If there isn't an rfactor id in the replay's inputs and there's a
-    // mismatch in replay_expr's and target_expr's expression type, continue
-    if (replay_expr->getExprType().value() !=
-        target_expr->getExprType().value()) {
-      TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
+    if (r_expr->getExprType().value() != t_expr->getExprType().value()) {
+      TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
       continue;
     }
 
-    // If there isn't an rfactor id in the replay's inputs and there's a
-    // mismatch in replay_expr's and target_expr's split factor (if a split
-    // expr), continue
-    if (replay_expr->getExprType().value() == ExprType::Split) {
-      auto r_split = replay_expr->as<Split>();
-      auto t_split = target_expr->as<Split>();
-      if (!r_split->factor()->sameAs(t_split->factor()) ||
-          r_split->innerSplit() != t_split->innerSplit()) {
-        TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
+    // If the expression is a split, make sure it's split by the same ammount.
+    if (r_expr->getExprType().value() == ExprType::Split) {
+      if (!r_expr->as<Split>()->factor()->sameAs(
+              r_expr->as<Split>()->factor())) {
+        TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
         continue;
       }
     }
 
-    // Take replay expr inputs out of map:
-    for (size_t t_i = 0; t_i < target_id_inps.size(); t_i++) {
-      auto t_inp = target_id_inps[t_i];
-      auto r_orig_inp = target2replay_id_map_.at(t_inp);
-      auto r_maybe_forwarded_inp = replay_inps[t_i];
-
-      // Remove original target2replay_it->second if it's in leaf_ids
-      if (leaf_ids_.find(r_orig_inp) != leaf_ids_.end()) {
-        leaf_ids_.erase(r_orig_inp);
-      }
+    bool missing_input = std::any_of(
+        t_expr->inputs().begin(), t_expr->inputs().end(), [this](Val* inp) {
+          if (inp->getValType() == ValType::IterDomain) {
+            return id_map_.find(inp->as<IterDomain>()) == id_map_.end();
+          }
+          return false;
+        });
 
-      // Check if we used a forwarded id, if so add forwarded id's to tracking.
-      if (r_orig_inp != r_maybe_forwarded_inp) {
-        forwarded_ids_.emplace_back(r_orig_inp);
+    if (missing_input) {
+      TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
+      continue;
+    }
+    // Take target_domain inputs out of map:
+    for (auto t_inp : ir_utils::filterByType<IterDomain>(t_expr->inputs())) {
+      auto it = id_map_.find(t_inp);
+      if (leaf_ids_.find(it->second) != leaf_ids_.end()) {
+        leaf_ids_.erase(it->second);
       }
     }
 
     // Add outputs to map.
-    for (const auto i : c10::irange(target_expr->outputs().size())) {
-      auto t_out = target_expr->output(i);
-      auto r_out = replay_expr->output(i);
+    for (const auto i : c10::irange(t_expr->outputs().size())) {
+      auto t_out = t_expr->output(i);
+      auto r_out = r_expr->output(i);
       if (t_out->getValType() == ValType::IterDomain &&
           r_out->getValType() == ValType::IterDomain) {
-        target2replay_id_map_[t_out->as<IterDomain>()] =
-            r_out->as<IterDomain>();
+        id_map_[t_out->as<IterDomain>()] = r_out->as<IterDomain>();
         leaf_ids_[r_out->as<IterDomain>()] = counter++;
       }
     }
@@ -460,8 +420,8 @@ int BestEffortReplay::findFirstMismatchedID(
   }
 
   BestEffortReplay ber(td2->domain(), td1->domain(), id_map);
-  for (const auto i :
-       c10::irange(std::max(td1->domain().size(), td2->domain().size()))) {
+
+  for (const auto i : c10::irange(td1->domain().size())) {
     if (ber.getReplay().find(td1->axis(i)) == ber.getReplay().end()) {
       return i;
     }
@@ -471,281 +431,7 @@ int BestEffortReplay::findFirstMismatchedID(
       return i;
     }
   }
-  return std::min(td1->nDims(), td2->nDims());
-}
-
-namespace {
-
-// Maps that track information relevant to best effort replay about broadcast
-// axes in consumer that are not in producer
-//
-// For example if we have consumer: T0[i0, b1, b2, i3] and producer:
-// T1[i0, i3]
-//
-// If consumer transformations are:
-// -> T[i0, b1o, b1i, b2o, b2i, i3]
-// -> T[i0*b1i, b1o, b2o, b2i, i3]
-// -> T[i0*b1i*b2o, b1o, b2i, i3]
-// -> T[i0*b1i*b2o*i3, b1o, b2i]
-//
-// forwarding_map would forward i0->i0*b1i and i0*b1i->i0*b1i*b2o
-// compliment_map would have the entry i0->b1i and i0*b1i->b2o
-//
-// The first is to fast forward transformations in consumer involving broadcast
-// axes not in producer. The compliment map is to use later to compute what leaf
-// nodes we may have after the forwarding process is finished. Leaf nodes are
-// only important for replayCasP, so look there to see how this is done. Forward
-// map is used for replayCasP and replayPasC.
-struct ConsumerForwardingInfo {
- public:
-  // Map IterDomain* axes that can safely be forwarded to their output.
-  std::unordered_map<IterDomain*, IterDomain*> forwarding_map;
-
-  // Given a forward id map id_input -> id_forwarded
-  // Track the other inputs in the expr that id_input is an input to. These will
-  // be used to adjust the replay's leaf tracking. Don't need to track one to
-  // many as currently transformations on IterDomains can only have maximum 2
-  // inputs, but maybe in the future we'll have more.
-  std::unordered_map<IterDomain*, std::vector<IterDomain*>> compliment_map;
-
-  ConsumerForwardingInfo(
-      const TensorView* producer,
-      const TensorView* consumer) {
-    // Collect which root axes are in consumer that are not in producer because
-    // of broadcasting
-    std::unordered_set<IterDomain*> consumer_bcast_roots_not_in_producer;
-
-    const auto c2p_root_map =
-        PairwiseRootDomainMap(producer, consumer)
-            .mapConsumerToProducer(consumer->domain(), producer->domain());
-
-    for (auto consumer_root_id : consumer->getRootDomain()) {
-      if (consumer_root_id->isBroadcast()) {
-        if (c2p_root_map.find(consumer_root_id) == c2p_root_map.end()) {
-          consumer_bcast_roots_not_in_producer.emplace(consumer_root_id);
-        }
-      }
-    }
-
-    // We have root axes in consumer that don't exist in producer, now forward
-    // those to include all id's in consumer comprised of only axes not in
-    // producer.
-    auto consumer_bcast_ids_not_in_producer =
-        consumer_bcast_roots_not_in_producer;
-
-    std::vector<Expr*> consumer_history = ExprSort::getExprs(
-        FusionGuard::getCurFusion(),
-        std::vector<Val*>(
-            consumer->domain()->domain().begin(),
-            consumer->domain()->domain().end()));
-
-    auto isIdOnlyInConsumer =
-        [&consumer_bcast_ids_not_in_producer](IterDomain* input_id) {
-          return consumer_bcast_ids_not_in_producer.find(input_id) !=
-              consumer_bcast_ids_not_in_producer.end();
-        };
-
-    for (auto expr : consumer_history) {
-      auto input_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
-      // If expr inputs are all in consumer_bcast_ids_not_in_producer, than so
-      // are all outputs
-      if (std::all_of(input_ids.begin(), input_ids.end(), isIdOnlyInConsumer)) {
-        // add all outputs to not being in producer
-        for (auto output_ids :
-             ir_utils::filterByType<IterDomain>(expr->outputs())) {
-          consumer_bcast_ids_not_in_producer.emplace(output_ids);
-        }
-      } else if (
-          expr->isA<Merge>() &&
-          std::any_of(input_ids.begin(), input_ids.end(), isIdOnlyInConsumer)) {
-        auto merge_expr = expr->as<Merge>();
-        // If
-        // - one of the inputs is made of id's in consumer that don't map to
-        // producer (bcast axes),
-        // - && the other input maps to an id in both consumer and producer
-        // - && this is a merge
-        //   for the sake of BestEffortReplay we can forward the input mapping
-        //   to both consumer and producer to the output of the expression
-        std::vector<IterDomain*> forwarded_ids;
-        std::vector<IterDomain*> compliment_ids;
-
-        for (auto input_id : input_ids) {
-          if (!isIdOnlyInConsumer(input_id)) {
-            forwarded_ids.emplace_back(input_id);
-            forwarding_map.emplace(std::make_pair(input_id, merge_expr->out()));
-          } else {
-            compliment_ids.push_back(input_id);
-          }
-        }
-
-        // Set up compliment map
-        for (auto forwarded_id : forwarded_ids) {
-          compliment_map.emplace(std::make_pair(forwarded_id, compliment_ids));
-        }
-      }
-    }
-  }
-};
-
-} // namespace
-
-BestEffortReplay BestEffortReplay::replayCasP(
-    const TensorView* consumer,
-    const TensorView* producer,
-    int producer_compute_at_axis,
-    const RootDomainMap& root_map) {
-  if (producer_compute_at_axis < 0)
-    producer_compute_at_axis += (int)producer->nDims() + 1;
-
-  TORCH_INTERNAL_ASSERT(
-      producer_compute_at_axis >= 0 &&
-          (unsigned int)producer_compute_at_axis <= producer->nDims(),
-      "Invalid axis provided to BestEffortReplay::replayCasP.");
-
-  // producer ids we need to match in consumer
-  std::vector<IterDomain*> producer_CA_ids(
-      producer->domain()->domain().begin(),
-      producer->domain()->domain().begin() + producer_compute_at_axis);
-  producer_CA_ids = TensorDomain::noReductions(producer_CA_ids);
-
-  // If producer has an rfactor root, that's what will match the consumer
-  std::vector<IterDomain*> producer_root = producer->getMaybeRFactorDomain();
-
-  // Figure out all inputs required to generate the compute_at dimensions. We
-  // need all deps because inputs on producer may be in getRootDomain, but we
-  // may need in rFactorDomain
-  auto all_CA_id_deps = DependencyCheck::getAllValsBetween(
-      {producer_root.begin(), producer_root.end()},
-      {producer_CA_ids.begin(), producer_CA_ids.end()});
-
-  // Figure out minimal set of root IDs needed to produce producer_CA_ids:
-  std::unordered_set<IterDomain*> producer_CA_root_ids;
-  for (IterDomain* id : producer_root) {
-    if (std::find(all_CA_id_deps.begin(), all_CA_id_deps.end(), id) !=
-        all_CA_id_deps.end()) {
-      producer_CA_root_ids.emplace(id);
-    }
-  }
-
-  const auto p2c_root_map = root_map.mapProducerToConsumer(
-      producer->domain(), consumer->domain(), producer_CA_root_ids);
-
-  // See FusionAdvancedComputeAt7 for an example of the forwarding logic
-  ConsumerForwardingInfo consumer_forwarding_info(producer, consumer);
-
-  auto consumer_replay = BestEffortReplay(
-      consumer->domain()->domain(),
-      producer_CA_ids,
-      p2c_root_map,
-      consumer_forwarding_info.forwarding_map);
-
-  // Need to adjust leaf map based on forwarding before returning.
-
-  // ID's could go through more than one forward iteration in the map before it
-  // terminates. Grab every id between the forwarded id, and what it was
-  // forwarded to
-  std::function<void(IterDomain*, std::vector<IterDomain*>&)>
-      collectForwardedIds =
-          [&consumer_forwarding_info, &collectForwardedIds](
-              IterDomain* forward_id,
-              std::vector<IterDomain*>& forwarded_ids) -> void {
-    if (consumer_forwarding_info.forwarding_map.find(forward_id) !=
-        consumer_forwarding_info.forwarding_map.end()) {
-      forwarded_ids.emplace_back(forward_id);
-      collectForwardedIds(
-          consumer_forwarding_info.forwarding_map.at(forward_id),
-          forwarded_ids);
-    }
-  };
-
-  std::vector<IterDomain*> expanded_forwarded_ids;
-  for (auto forwarded_id : consumer_replay.forwarded_ids_) {
-    collectForwardedIds(forwarded_id, expanded_forwarded_ids);
-  }
-
-  // Grab all compliments of forwarded ids.
-  std::vector<IterDomain*> compliments;
-  for (auto forwarded_id : expanded_forwarded_ids) {
-    auto compliment_map_it =
-        consumer_forwarding_info.compliment_map.find(forwarded_id);
-    TORCH_INTERNAL_ASSERT(
-        compliment_map_it != consumer_forwarding_info.compliment_map.end(),
-        "Issue tracking forwarded broadcast merges in best effort replay consumer as producer.");
-    compliments.insert(
-        compliments.end(),
-        compliment_map_it->second.begin(),
-        compliment_map_it->second.end());
-  }
-
-  // Grab all exprs used to make the forwarded compliments
-  auto compliment_exprs = ExprSort::getExprs(
-      FusionGuard::getCurFusion(), {compliments.begin(), compliments.end()});
-
-  // Figure out if there are any leaves in compliment_exprs that aren't
-  // the forwarded id
-  std::unordered_map<IterDomain*, size_t> leaf_ids;
-
-  for (auto expr : compliment_exprs) {
-    for (auto inp : ir_utils::filterByType<IterDomain>(expr->inputs())) {
-      leaf_ids.erase(inp);
-    }
-    for (auto out : ir_utils::filterByType<IterDomain>(expr->outputs())) {
-      // If we used the comliment for forwarded don't add to leaf nodes.
-      if (std::find(compliments.begin(), compliments.end(), out) ==
-          compliments.end()) {
-        leaf_ids.emplace(std::make_pair(out, consumer_replay.counter++));
-      }
-    }
-  }
-
-  consumer_replay.leaf_ids_.insert(leaf_ids.begin(), leaf_ids.end());
-
-  return consumer_replay;
-}
-
-// Runs a best effort replay that ignores broadcast axes that appear in
-// consumer that are not mapped to producer in root_map.
-BestEffortReplay BestEffortReplay::replayPasC(
-    const TensorView* producer,
-    const TensorView* consumer,
-    int consumer_compute_at_axis,
-    const RootDomainMap& root_map) {
-  if (consumer_compute_at_axis < 0)
-    consumer_compute_at_axis += (int)consumer->nDims() + 1;
-  TORCH_INTERNAL_ASSERT(
-      consumer_compute_at_axis >= 0 &&
-          (unsigned int)consumer_compute_at_axis <= consumer->nDims(),
-      "Invalid axis provided to BestEffortReplay::replayPasC.");
-
-  // consumer ids we need to match in producer
-  std::vector<IterDomain*> consumer_CA_ids(
-      consumer->domain()->domain().begin(),
-      consumer->domain()->domain().begin() + consumer_compute_at_axis);
-
-  // Figure out all inputs required to generate the compute_at dimensions
-  auto consumer_CA_root_vals = IterVisitor::getInputsTo(
-      std::vector<Val*>(consumer_CA_ids.begin(), consumer_CA_ids.end()));
-
-  std::unordered_set<IterDomain*> consumer_CA_root_ids;
-  for (auto val : consumer_CA_root_vals) {
-    if (val->getValType().value() == ValType::IterDomain) {
-      consumer_CA_root_ids.emplace(val->as<IterDomain>());
-    }
-  }
-
-  const auto c2p_root_map = root_map.mapConsumerToProducer(
-      consumer->domain(), producer->domain(), consumer_CA_root_ids);
-
-  ConsumerForwardingInfo consumer_forwarding_info(producer, consumer);
-
-  // Instead of replaying from the root, lets try to play forward the history
-  // of producer if they match ops on consumer. Enforce if we modify an
-  // rfactor axis that those ops must match.
-  return BestEffortReplay(
-      producer->domain()->domain(),
-      consumer_CA_ids,
-      c2p_root_map,
-      consumer_forwarding_info.forwarding_map);
+  return td1->nDims();
 }
 
 } // namespace cuda
index 9ea2a9b..c7b4bce 100644 (file)
@@ -5,7 +5,6 @@
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
 #include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
 #include <unordered_map>
 #include <vector>
 
@@ -158,50 +157,23 @@ class TORCH_CUDA_CU_API ReplayTransformations : public IterVisitor {
 
 class TORCH_CUDA_CU_API BestEffortReplay {
  private:
-  std::unordered_map<IterDomain*, IterDomain*> target2replay_id_map_;
-  std::unordered_map<IterDomain*, IterDomain*> forward_id_map_;
+  std::unordered_map<IterDomain*, IterDomain*> id_map_;
   std::unordered_map<IterDomain*, size_t> leaf_ids_;
-  std::vector<IterDomain*> forwarded_ids_;
-
-  // Need to track which id's have been forwarded. Later need to make sure leaf
-  // nodes to produce compliment axes are properly tracked. i.e.
-  // T[i0, b1, b2, i3]
-  // -> T[i0, b1o, b1i, b2o, b2i, i3]
-  // -> T[i0*b1i*b2o, b1o, b2i, i3]
-  // -> T[i0*b1i*b2o*i3, b1o, b2i]
-  // If we forwarded i0 -> i0*b1i*b2o*i3, we need to know that b1o and b2i
-  // are leaf nodes even though their split wasn't part of targets replay.
-
-  // Counter to make sure best effort replay leaf_ids can be grabbed
-  // deterministicly
   size_t counter = 0;
 
-  bool inForwardMap(IterDomain* id) const {
-    return forward_id_map_.find(id) != forward_id_map_.end();
-  }
-
-  IterDomain* getForwardedId(IterDomain* id) const {
-    auto forwarded_id_it = forward_id_map_.find(id);
-    if (forwarded_id_it == forward_id_map_.end()) {
-      return id;
-    } else {
-      return getForwardedId(forwarded_id_it->second);
-    }
-  }
-
  public:
-  // Highly duplicated from the constructor above.
-  // TODO: Remove other constructor
+  // replay_map: mapping of target root domains to corresponding
+  // replay root domains
   BestEffortReplay(
       const std::vector<IterDomain*>& replay_domain,
       const std::vector<IterDomain*>& target_domain,
-      std::unordered_map<IterDomain*, IterDomain*> target2replay_map,
-      std::unordered_map<IterDomain*, IterDomain*> forward_id_map = {});
+      std::unordered_map<IterDomain*, IterDomain*> replay_map,
+      bool forward_bcast_mismatch = false);
 
   // Return iter domain map from target_domain IDs to their "replayed"
   // replay_domain IDs. If not in map, was not replayed.
   const std::unordered_map<IterDomain*, IterDomain*>& getReplay() const {
-    return target2replay_id_map_;
+    return id_map_;
   }
 
   // ids in replay that did not have matching transforms in target_domain
@@ -225,26 +197,8 @@ class TORCH_CUDA_CU_API BestEffortReplay {
     return leaf_vec_;
   }
 
-  // Runs a best effort replay that ignores broadcast axes that appear in
-  // consumer that are not mapped to producer in root_map.
-  static BestEffortReplay replayCasP(
-      const TensorView* consumer,
-      const TensorView* producer,
-      int producer_compute_at_axis,
-      const RootDomainMap& root_map);
-
-  // Runs a best effort replay that ignores broadcast axes that appear in
-  // consumer that are not mapped to producer in root_map.
-  static BestEffortReplay replayPasC(
-      const TensorView* producer,
-      const TensorView* consumer,
-      int consumer_compute_at_axis,
-      const RootDomainMap& root_map);
-
   // Find the first position i where td1[i] is not the same as td2[i]. "Same"
   // means the DAG and input IDs to generate td1[i] and td2[i] are the same.
-  // td1 and td2 are assumed to have some matching iter domains, as this is a
-  // strict same-ness check.
   static int findFirstMismatchedID(
       const TensorDomain* td1,
       const TensorDomain* td2);
index 570f669..120904a 100644 (file)
@@ -5,11 +5,9 @@
 #include <torch/csrc/jit/codegen/cuda/instrumentation.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
 #include <torch/csrc/jit/codegen/cuda/transform_iter.h>
 
-#include <deque>
+#include <vector>
 
 namespace torch {
 namespace jit {
@@ -43,13 +41,13 @@ class ReplaySelf : public ReplayTransformations {
         "Transform traversal failed, modified a node but it was not a leaf node.");
 
     // outer loop size
-    Val* remainder = ceilDiv(mapped->extent(), s->factor());
+    Val* oe = ceilDiv(mapped->extent(), s->factor());
 
     // Manually replay the split, following the output of the operations.
     // This is so rfactor ops are replayed correctly.
     IterDomain* ido = new IterDomain(
         new Int(0),
-        s->innerSplit() ? remainder->as<Int>() : s->factor(),
+        oe->as<Int>(),
         s->outer()->getParallelType(),
         s->outer()->getIterType(),
         s->outer()->isRFactorProduct());
@@ -57,13 +55,13 @@ class ReplaySelf : public ReplayTransformations {
     // inner IterDomain
     IterDomain* idi = new IterDomain(
         new Int(0),
-        s->innerSplit() ? s->factor() : remainder->as<Int>(),
+        s->factor(),
         s->inner()->getParallelType(),
-        s->inner()->getIterType(),
+        s->outer()->getIterType(),
         s->inner()->isRFactorProduct());
 
     // Generate the split node
-    new Split(ido, idi, mapped, s->factor(), s->innerSplit());
+    new Split(ido, idi, mapped, s->factor());
 
     // Remove mapped id from leaf IDs
     leaf_ids_.erase(mapped);
@@ -133,10 +131,10 @@ class ReplaySelf : public ReplayTransformations {
 TensorDomain* TransformReplay::fullSelfReplay(
     const TensorDomain* new_self_root,
     const TensorDomain* self) {
-  FUSER_PERF_SCOPE("TransformReplay::fullSelfReplay");
+  FUSER_PERF_SCOPE("fullSelfReplay");
 
   TORCH_INTERNAL_ASSERT(
-      new_self_root->getRootDomain().size() == self->getRootDomain().size(),
+      new_self_root->nDims() == self->getRootDomain().size(),
       "Invalid number of IterDomains provided.");
 
   // Map for replay, should be pretty simple.
@@ -145,28 +143,17 @@ TensorDomain* TransformReplay::fullSelfReplay(
     size_t i = 0;
     for (auto id : self->getRootDomain()) {
       TORCH_INTERNAL_ASSERT(
-          new_self_root->getRootDomain()[i]->start()->isZeroInt() &&
-              id->start()->isZeroInt(),
-          "Replay does not support IterDomains that do not start at 0, received: ",
-          new_self_root->getRootDomain()[i]->start(),
-          " and ",
-          id->start()->isZeroInt());
+          new_self_root->axis(i)->start() == id->start(),
+          "Replay does not support IterDomains that do not start at 0.");
 
       TORCH_INTERNAL_ASSERT(
-          new_self_root->getRootDomain()[i]->getParallelType() ==
-                  id->getParallelType() &&
-              new_self_root->getRootDomain()[i]->isReduction() ==
-                  id->isReduction() &&
-              new_self_root->getRootDomain()[i]->isRFactorProduct() ==
+          new_self_root->axis(i)->getParallelType() == id->getParallelType() &&
+              new_self_root->axis(i)->isReduction() == id->isReduction() &&
+              new_self_root->axis(i)->isRFactorProduct() ==
                   id->isRFactorProduct() &&
-              new_self_root->getRootDomain()[i]->isBroadcast() ==
-                  id->isBroadcast(),
-          "Axes ",
-          id,
-          " and ",
-          new_self_root->getRootDomain()[i],
-          " do not match for self replay.");
-      axis_map[id] = new_self_root->getRootDomain()[i];
+              new_self_root->axis(i)->isBroadcast() == id->isBroadcast(),
+          "Axes do not match for self replay.");
+      axis_map[id] = new_self_root->axis(i);
       i++;
     }
   }
@@ -184,28 +171,10 @@ TensorDomain* TransformReplay::fullSelfReplay(
           "Error during replay, didn't replay an axis.");
       new_domain[i++] = it->second;
     }
-
-    if (self->hasRFactor()) {
-      std::vector<IterDomain*> new_rfactor_domain(
-          self->getMaybeRFactorDomain().size(), nullptr);
-      size_t i = 0;
-      for (auto id : self->getMaybeRFactorDomain()) {
-        auto it = replay.getReplay().find(id);
-        TORCH_INTERNAL_ASSERT(
-            it != replay.getReplay().end(),
-            "Error during replay, didn't replay an axis.");
-        new_rfactor_domain[i++] = it->second;
-      }
-      return new TensorDomain(
-          new_self_root->getRootDomain(),
-          new_rfactor_domain,
-          new_domain,
-          new_self_root->contiguity());
-    }
   }
 
   return new TensorDomain(
-      new_self_root->getRootDomain(), new_domain, new_self_root->contiguity());
+      new_self_root->domain(), new_domain, self->contiguity());
 }
 
 // Producer could have rfactor axes which consumer may want replayed. We can
@@ -214,16 +183,10 @@ TensorDomain* TransformReplay::fullSelfReplay(
 // mapped to in the consumer the operations would all be the same. then we want
 // to start the replay of the producer from the rfactor root axes, not the root.
 std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
-    const TensorView* producer,
-    const TensorView* consumer,
-    int consumer_compute_at_axis,
-    const RootDomainMap& root_map) {
-  FUSER_PERF_SCOPE("TransformReplay::replayPasC");
-
-  // If this is a reduction operation, we may call transform_replay on the
-  // tensor view. When this happens, just return thet target view.
-  if (producer == consumer)
-    return {producer->domain(), producer->nDims()};
+    const TensorDomain* producer,
+    const TensorDomain* consumer,
+    int consumer_compute_at_axis) {
+  FUSER_PERF_SCOPE("replayPasC");
 
   if (consumer_compute_at_axis < 0)
     consumer_compute_at_axis += (int)consumer->nDims() + 1;
@@ -234,14 +197,35 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
 
   // consumer ids we need to match in producer
   std::vector<IterDomain*> consumer_CA_ids(
-      consumer->domain()->domain().begin(),
-      consumer->domain()->domain().begin() + consumer_compute_at_axis);
+      consumer->domain().begin(),
+      consumer->domain().begin() + consumer_compute_at_axis);
+
+  // Figure out all inputs required to generate the compute_at dimensions
+  std::unordered_set<Val*> consumer_CA_root_vals = IterVisitor::getInputsTo(
+      std::vector<Val*>(consumer_CA_ids.begin(), consumer_CA_ids.end()));
+
+  std::unordered_set<IterDomain*> consumer_CA_root_ids;
+  for (auto val : consumer_CA_root_vals) {
+    if (val->getValType().value() == ValType::IterDomain) {
+      consumer_CA_root_ids.emplace(val->as<IterDomain>());
+    }
+  }
+
+  // Map of consumer_CA_root_ids to related producer_CA_ids
+  auto replay_root_map =
+      TensorDomain::mapRootCtoP(consumer, producer, consumer_CA_root_ids);
+
+  // Track which root axes in producer we will send to replay
+  std::unordered_set<IterDomain*> producer_roots4replay;
+  for (auto entry : replay_root_map) {
+    producer_roots4replay.emplace(entry.second);
+  }
 
   // Instead of replaying from the root, lets try to play forward the history of
   // producer if they match ops on consumer. Enforce if we modify an rfactor
   // axis that those ops must match.
-  auto forward_replay = BestEffortReplay::replayPasC(
-      producer, consumer, consumer_compute_at_axis, root_map);
+  BestEffortReplay forward_replay(
+      producer->domain(), consumer_CA_ids, replay_root_map);
 
   // Make a new map based on all the leaves resulting from best effort replay
   id_map forwarded_replay_map;
@@ -258,67 +242,41 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
   auto leaf_ids(replay_PasC.getUnorderedLeafIDs());
 
   // Remove all ids that map to the compute at axis, we're going to replay the
-  // rest, track all dims needed to match consumer CA dims
-  std::vector<IterDomain*> needed_dims;
+  // rest
   for (auto c_id : consumer_CA_ids) {
     auto it = replay_PasC.getReplay().find(c_id);
     if (it == replay_PasC.getReplay().end()) {
       TORCH_INTERNAL_ASSERT(
-          c_id->isBroadcast() || c_id->isGather(),
+          c_id->isBroadcast(),
           "Could not find axis, ",
           c_id,
           ", requested in replay.");
       continue;
     }
-    TORCH_INTERNAL_ASSERT(
-        leaf_ids.find(it->second) != leaf_ids.end(),
-        "Replayed id to match consumer id ",
-        c_id,
-        " should be a leaf in replay map.");
-    leaf_ids.erase(it->second);
-    needed_dims.push_back(it->second);
+    if (leaf_ids.find(it->second) != leaf_ids.end())
+      leaf_ids.erase(it->second);
   }
 
   // leaf_ids now contains all producer ID products that are not used to satisfy
   // the computeAt Turn into a  map so we can play forward these IDs in producer
   // (if possible):
   id_map producer_self_replay_map;
-  for (auto entry : leaf_ids) {
+  for (auto entry : leaf_ids)
     producer_self_replay_map[entry.first] = entry.first;
-  }
-
-  // Check which root domains were used to produce the leaf_ids. We may have
-  // picked up extra roots in consumer because of broadcast forwarding.
-  std::vector<Val*> unordered_non_root_leaf_vals;
-  for (auto leaf_id : replay_PasC.getUnorderedLeafIDs()) {
-    if (leaf_id.first->definition() == nullptr) {
-      continue;
-    } else {
-      unordered_non_root_leaf_vals.emplace_back(leaf_id.first);
-    }
-  }
-
-  auto processed_roots = IterVisitor::getInputsTo(unordered_non_root_leaf_vals);
 
   auto producer_root = producer->getMaybeRFactorDomain();
 
   // Any root domain that was not used to generate computeIDs we can also put in
   // the map to forward their transformations.
-  for (auto producer_root_id : producer_root) {
-    if (std::find(
-            processed_roots.begin(), processed_roots.end(), producer_root_id) ==
-            processed_roots.end() &&
-        std::find(needed_dims.begin(), needed_dims.end(), producer_root_id) ==
-            needed_dims.end()) {
+  for (auto producer_root_id : producer_root)
+    if (producer_roots4replay.find(producer_root_id) ==
+        producer_roots4replay.end()) {
       producer_self_replay_map[producer_root_id] = producer_root_id;
     }
-  }
 
   // Play forward transformations all producer IDs we can
   auto producer_replayed_leaves = BestEffortReplay(
-      producer->domain()->domain(),
-      producer->domain()->domain(),
-      producer_self_replay_map);
+      producer->domain(), producer->domain(), producer_self_replay_map);
 
   /*
    * Accumulate axes in to the new domain in the following order, making sure to
@@ -349,7 +307,7 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
     auto it = replay_PasC.getReplay().find(c_id);
     if (it == replay_PasC.getReplay().end()) {
       TORCH_INTERNAL_ASSERT(
-          c_id->isBroadcast() || c_id->isGather(),
+          c_id->isBroadcast(),
           "Could not find axis, ",
           c_id,
           ", requested in replay.");
@@ -360,9 +318,8 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
   }
 
   unsigned int producer_compute_at_axis = new_IDs.size();
-
   // Add axes in (2)
-  for (auto c_id : consumer->domain()->domain()) {
+  for (auto c_id : consumer->domain()) {
     auto it = replay_PasC.getReplay().find(c_id);
     if (it != replay_PasC.getReplay().end()) {
       auto id = it->second;
@@ -380,7 +337,7 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
   }
 
   // Add axes in (3)
-  for (auto id : producer->domain()->domain()) {
+  for (auto id : producer->domain()) {
     if (producer_replayed_leaves.getUnorderedLeafIDs().find(id) !=
         producer_replayed_leaves.getUnorderedLeafIDs().end()) {
       if (used_IDs.find(id) == used_IDs.end()) {
@@ -399,22 +356,15 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
       producer->getRootDomain(),
       producer->getRFactorDomain(),
       new_IDs,
-      producer->domain()->contiguity());
-
+      producer->contiguity());
   return {replayed, producer_compute_at_axis};
 }
 
 std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
-    const TensorView* consumer,
-    const TensorView* producer,
-    int producer_compute_at_axis,
-    const RootDomainMap& root_map) {
-  FUSER_PERF_SCOPE("TransformReplay::replayCasP");
-
-  // If this is a reduction operation, we may call transform_replay on the same
-  // tensor view. When this happens, just return thet target view.
-  if (consumer == producer)
-    return {consumer->domain(), consumer->nDims()};
+    const TensorDomain* consumer,
+    const TensorDomain* producer,
+    int producer_compute_at_axis) {
+  FUSER_PERF_SCOPE("replayCasP");
 
   if (producer_compute_at_axis < 0)
     producer_compute_at_axis += (int)producer->nDims() + 1;
@@ -426,28 +376,50 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
 
   // producer ids we need to match in consumer
   std::vector<IterDomain*> producer_CA_ids(
-      producer->domain()->domain().begin(),
-      producer->domain()->domain().begin() + producer_compute_at_axis);
+      producer->domain().begin(),
+      producer->domain().begin() + producer_compute_at_axis);
   producer_CA_ids = TensorDomain::noReductions(producer_CA_ids);
 
+  // Grab root domains of producer and consumer
+  std::vector<IterDomain*> consumer_root = consumer->getRootDomain();
+
+  // If producer has an rfactor root, that's what will match the consumer
+  std::vector<IterDomain*> producer_root = producer->getMaybeRFactorDomain();
+
+  // Figure out all inputs required to generate the compute_at dimensions. We
+  // need all deps because inputs on producer may be in getRootDomain, but we
+  // may need in rFactorDomain
+  std::unordered_set<Val*> all_CA_id_deps = DependencyCheck::getAllValsBetween(
+      {producer_root.begin(), producer_root.end()},
+      {producer_CA_ids.begin(), producer_CA_ids.end()});
+
+  // Figure out which root IDs we need:
+  std::unordered_set<IterDomain*> producer_CA_root_ids;
+  for (IterDomain* id : producer_root) {
+    if (all_CA_id_deps.find(id) != all_CA_id_deps.end())
+      producer_CA_root_ids.emplace(id);
+  }
+
+  auto replay_root_map =
+      TensorDomain::mapRootPtoC(producer, consumer, producer_CA_root_ids);
+
+  // Track which root axes in producer we will send to replay
+  std::unordered_set<IterDomain*> consumer_roots4replay;
+  for (auto entry : replay_root_map) {
+    consumer_roots4replay.emplace(entry.second);
+  }
+
   // Instead of replaying from the root, lets try to forward the history of
   // consumer if they match ops on producer. Enforce if we modify an rfactor
   // axis that those ops match.
-  BestEffortReplay forward_replay = BestEffortReplay::replayCasP(
-      consumer, producer, producer_compute_at_axis, root_map);
+  BestEffortReplay forward_replay(
+      consumer->domain(), producer_CA_ids, replay_root_map);
 
-  // Track dangling leaves which can be produced in
-  // BestEffortReplay::replayCasP these don't have any equivalent in producer
-  // so they're not in the map. We will simply map them to themselves so we
-  // don't lose them.
   id_map forwarded_replay_map;
-  auto forward_dangling_leaves = forward_replay.getUnorderedLeafIDs();
   for (auto entry : forward_replay.getReplay()) {
-    if (forward_dangling_leaves.find(entry.second) !=
-        forward_dangling_leaves.end()) {
+    if (forward_replay.getUnorderedLeafIDs().find(entry.second) !=
+        forward_replay.getUnorderedLeafIDs().end())
       forwarded_replay_map[entry.first] = entry.second;
-      forward_dangling_leaves.erase(entry.second);
-    }
   }
 
   // Replay producer dimensions.
@@ -457,8 +429,7 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
   auto leaf_ids(replay_CasP.getUnorderedLeafIDs());
 
   // Remove all ids that map to the compute at axis, we're going to replay the
-  // rest, track all dims that are needed to match producer CA dims
-  std::vector<IterDomain*> needed_dims;
+  // rest
   for (auto p_id : producer_CA_ids) {
     auto it = replay_CasP.getReplay().find(p_id);
     TORCH_INTERNAL_ASSERT(
@@ -466,60 +437,27 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
         "Could not find axis, ",
         p_id,
         ", requested in replay.");
-    TORCH_INTERNAL_ASSERT(
-        leaf_ids.find(it->second) != leaf_ids.end(),
-        "Replayed id to match producer id ",
-        p_id,
-        " should be a leaf in replay map.");
-    leaf_ids.erase(it->second);
-    needed_dims.push_back(it->second);
+    if (leaf_ids.find(it->second) != leaf_ids.end())
+      leaf_ids.erase(it->second);
   }
 
   // leaf_ids now contains all consumer ID products that are not used to satisfy
-  // the computeAt. Turn into a  map so we can play forward these IDs in
-  // consumer (if possible):
+  // the computeAt Turn into a  map so we can play forward these IDs in consumer
+  // (if possible):
   id_map consumer_self_replay_map;
-  for (auto entry : leaf_ids) {
-    consumer_self_replay_map[entry.first] = entry.first;
-  }
-
-  for (auto entry : forward_dangling_leaves) {
+  for (auto entry : leaf_ids)
     consumer_self_replay_map[entry.first] = entry.first;
-  }
-
-  // Check which root domains were used to produce the leaf_ids. We may have
-  // picked up extra roots in consumer because of broadcast forwarding.
-  std::vector<Val*> unordered_non_root_leaf_vals;
-  for (auto leaf_id : replay_CasP.getUnorderedLeafIDs()) {
-    if (leaf_id.first->definition() == nullptr) {
-      continue;
-    } else {
-      unordered_non_root_leaf_vals.emplace_back(leaf_id.first);
-    }
-  }
-
-  auto processed_roots = IterVisitor::getInputsTo(unordered_non_root_leaf_vals);
-
-  std::vector<IterDomain*> consumer_root = consumer->getRootDomain();
 
   // Any root domain that was not used to generate computeIDs we can also put in
   // the map to forward their transformations.
-  for (auto consumer_root_id : consumer_root) {
-    if (std::find(
-            processed_roots.begin(), processed_roots.end(), consumer_root_id) ==
-            processed_roots.end() &&
-        // Don't re-add roots that may have directly mapped in the replay
-        std::find(needed_dims.begin(), needed_dims.end(), consumer_root_id) ==
-            needed_dims.end()) {
+  for (auto consumer_root_id : consumer_root)
+    if (consumer_roots4replay.find(consumer_root_id) ==
+        consumer_roots4replay.end())
       consumer_self_replay_map[consumer_root_id] = consumer_root_id;
-    }
-  }
 
   // Play forward transformations all consumer IDs we can
   auto consumer_replayed_leaves = BestEffortReplay(
-      consumer->domain()->domain(),
-      consumer->domain()->domain(),
-      consumer_self_replay_map);
+      consumer->domain(), consumer->domain(), consumer_self_replay_map);
 
   /*
    * Accumulate axes in to the new domain in the following order, making sure to
@@ -559,7 +497,7 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
   }
 
   // Add axes in (2)
-  for (auto p_id : producer->domain()->domain()) {
+  for (auto p_id : producer->domain()) {
     auto it = replay_CasP.getReplay().find(p_id);
     if (it != replay_CasP.getReplay().end()) {
       auto id = it->second;
@@ -577,7 +515,7 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
   }
 
   // Add axes in (3)
-  for (auto id : consumer->domain()->domain()) {
+  for (auto id : consumer->domain()) {
     if (consumer_replayed_leaves.getUnorderedLeafIDs().find(id) !=
         consumer_replayed_leaves.getUnorderedLeafIDs().end()) {
       if (used_IDs.find(id) == used_IDs.end()) {
@@ -596,182 +534,40 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
       consumer->getRootDomain(),
       consumer->getRFactorDomain(),
       new_IDs,
-      consumer->domain()->contiguity());
+      consumer->contiguity());
 
   return {replayed, producer_CA_ids.size()};
 }
 
 // replay Producer as Consumer
-std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
-    const TensorView* producer,
-    const TensorView* consumer,
+std::pair<TensorView*, unsigned int> TransformReplay::replayPasC(
+    TensorView* producer,
+    TensorView* consumer,
     int compute_at_axis) {
-  // Use the pairwise root map as a default mapper
-  PairwiseRootDomainMap root_map(producer, consumer);
-  return replayPasC(producer, consumer, compute_at_axis, root_map);
-}
-
-std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
-    const TensorView* consumer,
-    const TensorView* producer,
-    int compute_at_axis) {
-  // Use the pairwise root map as a default mapper
-  PairwiseRootDomainMap root_map(producer, consumer);
-  return replayCasP(consumer, producer, compute_at_axis, root_map);
-}
-
-namespace {
-
-std::deque<TensorView*> deduplicate(const std::deque<TensorView*>& tv_deuqe) {
-  std::deque<TensorView*> deduplicated;
-  std::unordered_set<TensorView*> inserted;
-  for (auto tv_entry : tv_deuqe) {
-    if (inserted.find(tv_entry) == inserted.end()) {
-      deduplicated.emplace_back(tv_entry);
-      inserted.emplace(tv_entry);
-    }
-  }
-  return deduplicated;
-}
-
-std::deque<TensorView*> tvInputs(Expr* expr) {
-  auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs());
-  return std::deque<TensorView*>(tv_inputs.begin(), tv_inputs.end());
-}
-
-std::deque<TensorView*> tvOutputs(Expr* expr) {
-  auto tv_outputs = ir_utils::filterByType<TensorView>(expr->outputs());
-  return std::deque<TensorView*>(tv_outputs.begin(), tv_outputs.end());
-}
-
-std::deque<TensorView*> consumersOf(TensorView* tv) {
-  std::deque<TensorView*> consumer_tvs;
-  for (auto def : tv->uses()) {
-    auto outs = tvOutputs(def);
-    consumer_tvs.insert(consumer_tvs.end(), outs.begin(), outs.end());
-  }
-  return deduplicate(consumer_tvs);
-}
-
-std::deque<TensorView*> producersFor(TensorView* tv) {
-  auto def = tv->definition();
-  if (def == nullptr) {
-    return {};
-  }
-
-  return deduplicate(tvInputs(def));
-}
-
-}; // namespace
-
-bool TransformPropagator::replayPasC(
-    TensorView* producer_tv,
-    TensorView* consumer_tv) {
-  if (producer_tv == starting_tv) {
-    return false;
-  }
-
-  auto consumer_pos_it = replayed_pos.find(consumer_tv);
-  if (consumer_pos_it == replayed_pos.end()) {
-    return false;
-  }
-
-  auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv);
-  auto producerAsC = TransformReplay::replayPasC(
-      producer_tv, consumer_tv, consumer_pos_it->second, pairwiseMap);
-
-  if (replayed_pos.find(producer_tv) != replayed_pos.end()) {
-    if (producerAsC.second <= replayed_pos.at(producer_tv)) {
-      return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks)
-    }
-  }
-
-  producer_tv->setDomain(producerAsC.first);
-  replayed_pos[producer_tv] = producerAsC.second;
-
-  return true;
-}
-
-bool TransformPropagator::replayCasP(
-    TensorView* consumer_tv,
-    TensorView* producer_tv) {
-  if (consumer_tv == starting_tv) {
-    return false;
-  }
-
-  auto producer_pos_it = replayed_pos.find(producer_tv);
-  if (producer_pos_it == replayed_pos.end()) {
-    return false;
-  }
-
-  auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv);
-  auto consumerAsP = TransformReplay::replayCasP(
-      consumer_tv, producer_tv, producer_pos_it->second, pairwiseMap);
-
-  if (replayed_pos.find(consumer_tv) != replayed_pos.end()) {
-    if (consumerAsP.second <= replayed_pos.at(consumer_tv)) {
-      return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks)
-    }
-  }
-
-  consumer_tv->setDomain(consumerAsP.first);
-  replayed_pos[consumer_tv] = consumerAsP.second;
-
-  return true;
-}
+  // If this is a reduction operation, we may call transform_replay on the
 
-TransformPropagator::TransformPropagator(TensorView* from) : starting_tv(from) {
-  // Tensors we should try to propagate in the consumer direction
-  std::deque<TensorView*> consumer_propagation{starting_tv};
-
-  // Tensors we should try to propagate in the producer direction
-  std::deque<TensorView*> producer_propagation{starting_tv};
-
-  // Seed position with local tv
-  replayed_pos[from] = from->nDims();
-
-  // While tensor views are being replayed, if they're modified, make sure we
-  // propagate back to all producers as well as consumers. This is definitely
-  // not the most efficient implementation as what we do is any time a tv is
-  // changed we propagate both forward and backward. If a forward pass touches
-  // every node, the backward pass will try to replay every node, potentially
-  // multiple times.
-  while (!consumer_propagation.empty() || !producer_propagation.empty()) {
-    while (!consumer_propagation.empty()) {
-      // Tensor view we will replay onto consumers
-      auto tv = consumer_propagation.front();
-      consumer_propagation.pop_front();
-
-      // Replay tv forward to its consumers.
-      for (auto consumer_tv : consumersOf(tv)) {
-        auto replayed = replayCasP(consumer_tv, tv);
-        // If consumer has changed, mark we should propagate its consumers
-
-        if (replayed) {
-          consumer_propagation.emplace_back(consumer_tv);
-          producer_propagation.emplace_back(consumer_tv);
-        }
-      }
-    }
+  // tensor view. When this happens, just return thet target view.
+  if (producer == consumer)
+    return {producer, 0};
 
-    while (!producer_propagation.empty()) {
-      // Tensor view we will replay onto producers
-      auto tv = producer_propagation.front();
-      producer_propagation.pop_front();
-      // Replay tv backward to its producers
-      for (auto producer_tv : producersFor(tv)) {
-        auto replayed = replayPasC(producer_tv, tv);
-        if (replayed) {
-          producer_propagation.emplace_back(producer_tv);
-          consumer_propagation.emplace_back(producer_tv);
-        }
-      }
-    }
-  }
+  std::pair<TensorDomain*, unsigned int> replay =
+      replayPasC(producer->domain(), consumer->domain(), compute_at_axis);
+  producer->setDomain(replay.first);
+  return {producer, replay.second};
 }
 
-void TransformPropagator::from(TensorView* tv) {
-  TransformPropagator propagate(tv);
+std::pair<TensorView*, unsigned int> TransformReplay::replayCasP(
+    TensorView* consumer,
+    TensorView* producer,
+    int compute_at_axis) {
+  // If this is a reduction operation, we may call transform_replay on the same
+  // tensor view. When this happens, just return thet target view.
+  if (consumer == producer)
+    return {consumer, 0};
+  std::pair<TensorDomain*, unsigned int> replay =
+      replayCasP(consumer->domain(), producer->domain(), compute_at_axis);
+  consumer->setDomain(replay.first);
+  return {consumer, replay.second};
 }
 
 } // namespace cuda
index 7264afa..22a5cec 100644 (file)
@@ -4,7 +4,6 @@
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <algorithm>
-#include <unordered_map>
 #include <vector>
 
 namespace torch {
@@ -120,32 +119,32 @@ namespace cuda {
 
 class TensorDomain;
 class TensorView;
-class RootDomainMap;
 
 class TORCH_CUDA_CU_API TransformReplay {
  public:
   // Replay producer as consumer, returns {producer, producer_compute_at_axis}.
   static std::pair<TensorDomain*, unsigned int> replayPasC(
-      const TensorView* producer,
-      const TensorView* consumer,
+      const TensorDomain* producer,
+      const TensorDomain* consumer,
+      int consumer_compute_at_axis);
+
+  // Replay producer as consumer, returns {producer, producer_compute_at_axis}.
+  static std::pair<TensorView*, unsigned int> replayPasC(
+      TensorView* producer,
+      TensorView* consumer,
       int consumer_compute_at_axis);
-  static std::pair<TensorDomain*, unsigned int> replayPasC(
-      const TensorView* producer,
-      const TensorView* consumer,
-      int consumer_compute_at_axis,
-      const RootDomainMap& root_map);
 
-  // Replay producer as consumer, returns {replayed_consumer_domain,
-  // consumer_compute_at_axis}.
+  // Replay producer as consumer, returns {consumer, consumer_compute_at_axis}.
   static std::pair<TensorDomain*, unsigned int> replayCasP(
-      const TensorView* consumer,
-      const TensorView* producer,
+      const TensorDomain* consumer,
+      const TensorDomain* producer,
+      int producer_compute_at_axis);
+
+  // Replay producer as consumer, returns {consumer, consumer_compute_at_axis}.
+  static std::pair<TensorView*, unsigned int> replayCasP(
+      TensorView* consumer,
+      TensorView* producer,
       int producer_compute_at_axis);
-  static std::pair<TensorDomain*, unsigned int> replayCasP(
-      const TensorView* consumer,
-      const TensorView* producer,
-      int producer_compute_at_axis,
-      const RootDomainMap& root_map);
 
   // Self replay.
   static TensorDomain* fullSelfReplay(
@@ -153,21 +152,6 @@ class TORCH_CUDA_CU_API TransformReplay {
       const TensorDomain* self);
 };
 
-class TORCH_CUDA_CU_API TransformPropagator {
- private:
-  bool replayPasC(TensorView* producer_tv, TensorView* consumer_tv = nullptr);
-  bool replayCasP(TensorView* consumer_tv, TensorView* producer_tv = nullptr);
-
-  TransformPropagator(TensorView* from);
-
- private:
-  std::unordered_map<TensorView*, unsigned int> replayed_pos;
-  TensorView* starting_tv = nullptr;
-
- public:
-  static void from(TensorView* tv);
-};
-
 } // namespace cuda
 } // namespace fuser
 } // namespace jit
index 0c05606..b43ec54 100644 (file)
@@ -48,13 +48,13 @@ class ReplayRFactor : public ReplayTransformations {
       return ReplayTransformations::handle(s);
 
     // outer loop size
-    Val* remainder = ceilDiv(mapped->extent(), s->factor());
+    Val* oe = ceilDiv(mapped->extent(), s->factor());
 
     // Manually replay the split, making reduction = false and rfactor = true
     // outer IterDomain
     IterDomain* ido = new IterDomain(
         new Int(0),
-        s->innerSplit() ? remainder->as<Int>() : s->factor(),
+        oe->as<Int>(),
         mapped->getParallelType(),
         rfactor_outer ? IterType::Reduction : IterType::Iteration,
         true); // broadcast
@@ -62,13 +62,13 @@ class ReplayRFactor : public ReplayTransformations {
     // inner IterDomain
     IterDomain* idi = new IterDomain(
         new Int(0),
-        s->innerSplit() ? s->factor() : remainder->as<Int>(),
+        s->factor(),
         mapped->getParallelType(),
         rfactor_inner ? IterType::Reduction : IterType::Iteration,
         true);
 
     // Generate the split node
-    new Split(ido, idi, mapped, s->factor(), s->innerSplit());
+    new Split(ido, idi, mapped, s->factor());
 
     // Remove mapped id from leaf IDs
     leaf_ids_.erase(mapped);
@@ -153,7 +153,7 @@ class ReplayRFactor : public ReplayTransformations {
 TensorDomain* TransformRFactor::runReplay(
     TensorDomain* orig_td,
     std::vector<int> axes) {
-  FUSER_PERF_SCOPE("TransformRFactor::runReplay");
+  FUSER_PERF_SCOPE("runReplay");
 
   TORCH_CHECK(!axes.empty(), "No axes provided to rfactor replay.");
 
@@ -304,7 +304,7 @@ TensorDomain* TransformRFactor::runReplay(
 TensorDomain* TransformRFactor::runReplay2(
     TensorDomain* orig_td,
     std::vector<int> axes) {
-  FUSER_PERF_SCOPE("TransformRFactor::runReplay2");
+  FUSER_PERF_SCOPE("runReplay2");
 
   int ndims = (int)orig_td->nDims();
 
index 1ea5ddc..802ccc4 100644 (file)
@@ -8,64 +8,6 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-bool isFloatingPointType(DataType dtype) {
-  switch (dtype) {
-    case DataType::Bool:
-      return false;
-    case DataType::Double:
-    case DataType::Float:
-    case DataType::Half:
-      return true;
-    case DataType::Int:
-    case DataType::Int32:
-      return false;
-    case DataType::Null:
-      TORCH_CHECK(
-          false, "Null type is not a valid argument to isFloatingPoint");
-    default:
-      TORCH_CHECK(false, "Type not supported in isFloatingPoint");
-  }
-}
-
-bool isIntegralType(DataType dtype) {
-  switch (dtype) {
-    case DataType::Bool:
-    case DataType::Double:
-    case DataType::Float:
-    case DataType::Half:
-      return false;
-    case DataType::Int:
-    case DataType::Int32:
-      return true;
-    case DataType::Null:
-      TORCH_CHECK(
-          false, "Null type is not a valid argument to isFloatingPoint");
-    default:
-      TORCH_CHECK(false, "Type not supported in isFloatingPoint");
-  }
-}
-
-bool isIntegerOp(const BinaryOpType bopt) {
-  return bopt >= BinaryOpType::Mod && bopt <= BinaryOpType::Rshift;
-}
-
-bool isLogicalOp(const BinaryOpType bopt) {
-  return bopt >= BinaryOpType::Eq && bopt <= BinaryOpType::NE;
-}
-
-bool alsoBooleanOperator(const BinaryOpType bopt) {
-  return bopt >= BinaryOpType::And && bopt <= BinaryOpType::Xor;
-}
-
-bool alsoBooleanOperator(const UnaryOpType uopt) {
-  return uopt >= UnaryOpType::Not && uopt <= UnaryOpType::Not;
-}
-
-bool noFullIntegerSupport(const BinaryOpType bopt) {
-  return bopt == BinaryOpType::Div || bopt == BinaryOpType::Pow ||
-      bopt == BinaryOpType::Fmod;
-}
-
 // Return highest on list (smallest enum val)
 DataType promote_type(const DataType& t1, const DataType& t2) {
   TORCH_CHECK(
@@ -79,34 +21,27 @@ DataType promote_type(const DataType& t1, const DataType& t2) {
 
 // Return highest on list (smallest enum val)
 ValType promote_type(const ValType& t1, const ValType& t2) {
-  if (t1 == ValType::TensorView || t2 == ValType::TensorView) {
-    return ValType::TensorView;
-  }
-  if (t1 == ValType::Scalar &&
-      (t2 == ValType::Scalar || t2 == ValType::NamedScalar)) {
-    return ValType::Scalar;
-  }
-  if (t2 == ValType::Scalar &&
-      (t1 == ValType::Scalar || t1 == ValType::NamedScalar)) {
-    return ValType::Scalar;
-  }
-  TORCH_CHECK(false, "Expected promotable ValTypes but got: ", t1, " and ", t2);
+  TORCH_CHECK(
+      t1 >= ValType::TensorView && t2 >= ValType::TensorView,
+      "Expected promotable ValTypes but got: ",
+      t1,
+      " and ",
+      t2);
+  // Check that it's a promotable type (with dtype)
+  // static_assert??
+  return t1 < t2 ? t1 : t2;
 }
 
 static const char* data_type2string(DataType t) {
   switch (t) {
     case DataType::Bool:
       return "bool";
-    case DataType::Double:
-      return "double";
     case DataType::Float:
       return "float";
     case DataType::Half:
       return "__half";
     case DataType::Int:
       return "int64_t";
-    case DataType::Int32:
-      return "int";
     case DataType::Null:
       return "nullptr";
     default:
@@ -118,6 +53,8 @@ static const char* data_type2string(DataType t) {
 
 static const char* val_type2string(ValType t) {
   switch (t) {
+    case ValType::TensorIndex:
+      return "TensorIndex";
     case ValType::TensorView:
       return "TensorView";
     case ValType::TensorDomain:
@@ -128,9 +65,21 @@ static const char* val_type2string(ValType t) {
       return "Scalar";
     case ValType::NamedScalar:
       return "NamedScalar";
+    case ValType::KirIterDomain:
+      return "KirIterDomain";
+    case ValType::KirNamedScalar:
+      return "KirNamedScalar";
+    case ValType::KirScalar:
+      return "KirScalar";
+    case ValType::KirTensorDomain:
+      return "KirTensorDomain";
+    case ValType::KirTensorView:
+      return "KirTensorView";
     default:
-      TORCH_INTERNAL_ASSERT(false, "No string found for val type.");
+      break;
   }
+  TORCH_INTERNAL_ASSERT(false, "No string found for val type.");
+  return nullptr;
 }
 
 static const char* expr_type2string(ExprType t) {
@@ -143,87 +92,85 @@ static const char* expr_type2string(ExprType t) {
       return "TernaryOp";
     case ExprType::ReductionOp:
       return "ReductionOp";
+    case ExprType::GridReduction:
+      return "GridReduction";
     case ExprType::BroadcastOp:
       return "BroadcastOp";
-    case ExprType::ShiftOp:
-      return "ShiftOp";
+    case ExprType::ForLoop:
+      return "ForLoop";
+    case ExprType::IfThenElse:
+      return "IfThenElse";
+    case ExprType::Allocate:
+      return "Allocate";
+    case ExprType::Sync:
+      return "SyncThreads";
     case ExprType::Split:
       return "Split";
     case ExprType::Merge:
       return "Merge";
+    case ExprType::KirUnaryOp:
+      return "KirUnaryOp";
+    case ExprType::KirBinaryOp:
+      return "KirBinaryOp";
+    case ExprType::KirTernaryOp:
+      return "KirTernaryOp";
+    case ExprType::KirReductionOp:
+      return "KirReductionOp";
+    case ExprType::KirBroadcastOp:
+      return "KirBroadcastOp";
     default:
-      TORCH_INTERNAL_ASSERT(false, "No string found for expr type.");
-  }
-}
-
-bool needFloatSuffix(UnaryOpType t) {
-  switch (t) {
-    case UnaryOpType::Abs:
-    case UnaryOpType::Cast:
-    case UnaryOpType::Frac:
-    case UnaryOpType::Gelu:
-    case UnaryOpType::Silu:
-    case UnaryOpType::Neg:
-    case UnaryOpType::Relu:
-    case UnaryOpType::Reciprocal:
-    case UnaryOpType::Set:
-    case UnaryOpType::Sigmoid:
-      return false;
-    default:
-      return true;
+      break;
   }
+  TORCH_INTERNAL_ASSERT(false, "No string found for expr type.");
+  return nullptr;
 }
 
 static const char* unary_op_type2string(UnaryOpType t) {
   switch (t) {
     case UnaryOpType::Abs:
-      return "abs";
+      return "fabs";
     case UnaryOpType::Acos:
-      return "acos";
+      return "acosf";
     case UnaryOpType::Asin:
-      return "asin";
+      return "asinf";
     case UnaryOpType::Atan:
-      return "atan";
+      return "atanf";
     case UnaryOpType::Atanh:
-      return "atanh";
+      return "atanhf";
     case UnaryOpType::Cast:
       return "cast";
     case UnaryOpType::Ceil:
-      return "ceil";
+      return "ceilf";
     case UnaryOpType::Cos:
-      return "cos";
+      return "cosf";
     case UnaryOpType::Cosh:
-      return "cosh";
+      return "coshf";
     case UnaryOpType::Exp:
-      return "exp";
+      return "expf";
     case UnaryOpType::Expm1:
-      return "expm1";
+      return "expm1f";
     case UnaryOpType::Erf:
-      return "erf";
+      return "erff";
     case UnaryOpType::Erfc:
-      return "erfc";
+      return "erfcf";
     case UnaryOpType::Floor:
-      return "floor";
+      return "floorf";
     case UnaryOpType::Frac:
       return "frac";
     case UnaryOpType::Gelu:
       return "gelu";
-    case UnaryOpType::Silu:
-      return "silu";
     case UnaryOpType::Lgamma:
-      return "lgamma";
+      return "lgammaf";
     case UnaryOpType::Log:
-      return "log";
+      return "logf";
     case UnaryOpType::Log10:
-      return "log10";
+      return "log10f";
     case UnaryOpType::Log1p:
-      return "log1p";
+      return "log1pf";
     case UnaryOpType::Log2:
-      return "log2";
+      return "log2f";
     case UnaryOpType::Neg:
       return "neg";
-    case UnaryOpType::Not:
-      return "not";
     case UnaryOpType::RandLike:
       return "randLike";
     case UnaryOpType::Reciprocal:
@@ -231,84 +178,62 @@ static const char* unary_op_type2string(UnaryOpType t) {
     case UnaryOpType::Relu:
       return "relu";
     case UnaryOpType::Rsqrt:
-      return "rsqrt";
+      return "rsqrtf";
     case UnaryOpType::Round:
-      return "nearbyint";
+      return "roundf";
     case UnaryOpType::Set:
       return "set";
     case UnaryOpType::Sigmoid:
       return "sigmoid";
     case UnaryOpType::Sin:
-      return "sin";
+      return "sinf";
     case UnaryOpType::Sinh:
-      return "sinh";
+      return "sinhf";
     case UnaryOpType::Sqrt:
-      return "sqrt";
+      return "sqrtf";
     case UnaryOpType::Tan:
-      return "tan";
+      return "tanf";
     case UnaryOpType::Tanh:
-      return "tanh";
+      return "tanhf";
     case UnaryOpType::Trunc:
-      return "trunc";
+      return "truncf";
     default:
-      TORCH_INTERNAL_ASSERT(false, "No string found for unary op type.");
+      break;
   }
-}
-
-std::string stringifyBooleanOp(const UnaryOpType uopt) {
-  TORCH_INTERNAL_ASSERT(
-      uopt == UnaryOpType::Not, uopt, " is not a boolean operator.");
-  return "!";
+  TORCH_INTERNAL_ASSERT(false, "No string found for unary op type.");
+  return nullptr;
 }
 
 static const char* unary_op_type_inline_op2string(UnaryOpType t) {
   switch (t) {
     case UnaryOpType::Neg:
       return "-";
-    case UnaryOpType::Not:
-      return "~";
     case UnaryOpType::Set:
       return "";
-    case UnaryOpType::Address:
-      return "(int64_t) &";
     default:
       break;
   }
   return nullptr;
 }
 
-bool needFloatSuffix(BinaryOpType t) {
-  switch (t) {
-    case BinaryOpType::Atan2:
-    case BinaryOpType::Div:
-    case BinaryOpType::Fmod:
-    case BinaryOpType::Max:
-    case BinaryOpType::Min:
-    case BinaryOpType::Pow:
-      return true;
-    default:
-      return false;
-  }
-}
-
 static const char* binary_op_type2string(BinaryOpType t) {
   switch (t) {
     case BinaryOpType::Add:
       return "add";
     case BinaryOpType::Atan2:
-      return "atan2";
+      return "atan2f";
     case BinaryOpType::Div:
       return "div";
     case BinaryOpType::Fmod:
-      return "fmod";
+      return "fmodf";
     case BinaryOpType::Max:
-      return "fmax";
+      return "fmaxf";
     case BinaryOpType::Min:
-      return "fmin";
+      return "fminf";
     case BinaryOpType::Mul:
       return "mul";
     case BinaryOpType::Pow:
-      return "pow";
+      return "powf";
     case BinaryOpType::Remainder:
       return "remainder";
     case BinaryOpType::Sub:
@@ -334,19 +259,9 @@ static const char* binary_op_type2string(BinaryOpType t) {
     case BinaryOpType::NE:
       return "notEqual";
     default:
-      TORCH_INTERNAL_ASSERT(false, "No string found for binary op type.");
-  }
-}
-
-static const char* binary_op_integer_op2string(BinaryOpType t) {
-  switch (t) {
-    case BinaryOpType::Max:
-      return "max";
-    case BinaryOpType::Min:
-      return "min";
-    default:
       break;
   }
+  TORCH_INTERNAL_ASSERT(false, "No string found for binary op type.");
   return nullptr;
 }
 
@@ -356,19 +271,16 @@ static const char* binary_op_type_inline_op2string(BinaryOpType t) {
       return "+";
     case BinaryOpType::Div:
       return "/";
+    case BinaryOpType::Mod:
+      return "%";
     case BinaryOpType::Mul:
       return "*";
     case BinaryOpType::Sub:
       return "-";
 
-    // Integer ops
-    case BinaryOpType::Mod:
-      return "%";
-    case BinaryOpType::Lshift:
-      return "<<";
-    case BinaryOpType::Rshift:
-      return ">>";
     // Logical Ops
+    case BinaryOpType::And:
+      return "&&";
     case BinaryOpType::Eq:
       return "==";
     case BinaryOpType::GE:
@@ -381,32 +293,12 @@ static const char* binary_op_type_inline_op2string(BinaryOpType t) {
       return "<";
     case BinaryOpType::NE:
       return "!=";
-    // Assume bitwise, otherwise use stringifyBooleanOp
-    case BinaryOpType::And:
-      return "&";
-    case BinaryOpType::Or:
-      return "|";
-    case BinaryOpType::Xor:
-      return "^";
     default:
       break;
   }
   return nullptr;
 }
 
-std::string stringifyBooleanOp(const BinaryOpType bopt) {
-  switch (bopt) {
-    case BinaryOpType::And:
-      return "&&";
-    case BinaryOpType::Or:
-      return "||";
-    case BinaryOpType::Xor:
-      return "!=";
-    default:
-      TORCH_INTERNAL_ASSERT(false, bopt, " is not a boolean operator.")
-  }
-}
-
 static const char* ternary_op_type2string(TernaryOpType t) {
   switch (t) {
     case TernaryOpType::Clamp:
@@ -416,8 +308,10 @@ static const char* ternary_op_type2string(TernaryOpType t) {
     case TernaryOpType::Where:
       return "where";
     default:
-      TORCH_INTERNAL_ASSERT(false, "Unexpected TernaryOpType", t);
+      break;
   }
+  TORCH_INTERNAL_ASSERT(false, "No string found for ternary op type.");
+  return nullptr;
 }
 
 static const char* parallel_type2string(ParallelType t) {
@@ -436,17 +330,15 @@ static const char* parallel_type2string(ParallelType t) {
       return "threadIdx.x";
     case ParallelType::Vectorize:
       return "V";
-    case ParallelType::MisalignedVectorize:
-      return "MV";
     case ParallelType::Unroll:
-      return "UR";
-    case ParallelType::Unswitch:
-      return "US";
+      return "U";
     case ParallelType::Serial:
       return "S";
     default:
-      TORCH_INTERNAL_ASSERT(false, "Unexpected ParallelType", t);
+      break;
   }
+  TORCH_INTERNAL_ASSERT(false, "No string found for parallel type.");
+  return nullptr;
 }
 
 static const char* memory_type2string(MemoryType t) {
@@ -458,8 +350,10 @@ static const char* memory_type2string(MemoryType t) {
     case MemoryType::Global:
       return "global";
     default:
-      TORCH_INTERNAL_ASSERT(false, "Unexpected MemoryType", t);
+      break;
   }
+  TORCH_INTERNAL_ASSERT(false, "No string found for memory type.");
+  return nullptr;
 }
 
 static const char* iter_type2string(IterType t) {
@@ -472,11 +366,9 @@ static const char* iter_type2string(IterType t) {
       return "sb";
     case IterType::BroadcastWithoutStride:
       return "b";
-    case IterType::Gather:
-      return "g";
     default:
-      // Don't try to print t as it would recursively call this function
-      TORCH_INTERNAL_ASSERT(false, "Unexpected IterType");
+      TORCH_INTERNAL_ASSERT(false, "No string found for IterDomain type.");
+      return nullptr;
   }
 }
 
@@ -495,8 +387,10 @@ static const char* thread_size2string(ParallelType t) {
     case ParallelType::TIDx:
       return "blockDim.x";
     default:
-      TORCH_INTERNAL_ASSERT(false, "Unexpected parallel type", t);
+      break;
   }
+  TORCH_INTERNAL_ASSERT(false, "Could not find size of the thread type ", t);
+  return nullptr;
 }
 
 const unsigned int _WORD_SHIFT = 16;
@@ -506,24 +400,28 @@ constexpr unsigned int supported_switch_pair(DataType t1, DataType t2) {
 static const char* supported_casts2string(
     const std::pair<DataType, DataType>& t) {
   switch (supported_switch_pair(std::get<0>(t), std::get<1>(t))) {
-    case supported_switch_pair(DataType::Double, DataType::Float):
-      return "(float)";
-    case supported_switch_pair(DataType::Float, DataType::Double):
-      return "(double)";
-    case supported_switch_pair(DataType::Int32, DataType::Float):
-      return "(float)";
-    case supported_switch_pair(DataType::Int, DataType::Float):
-      return "(double)";
-    case supported_switch_pair(DataType::Int32, DataType::Int):
-      return "(int64_t)";
     case supported_switch_pair(DataType::Float, DataType::Half):
       return "__float2half";
     case supported_switch_pair(DataType::Half, DataType::Float):
       return "__half2float";
-    case supported_switch_pair(DataType::Bool, DataType::Float):
-      return "float";
     default:
-      return nullptr;
+      break;
+  }
+  return nullptr;
+}
+
+bool is_logical_op(const BinaryOpType& bot) {
+  switch (bot) {
+    case BinaryOpType::And:
+    case BinaryOpType::Eq:
+    case BinaryOpType::GE:
+    case BinaryOpType::GT:
+    case BinaryOpType::LE:
+    case BinaryOpType::LT:
+    case BinaryOpType::NE:
+      return true;
+    default:
+      return false;
   }
 }
 
@@ -531,17 +429,14 @@ DataType aten_to_data_type(const at::ScalarType& scalar_type) {
   switch (scalar_type) {
     case at::ScalarType::Bool:
       return DataType::Bool;
-    case at::ScalarType::Double:
-      return DataType::Double;
     case at::ScalarType::Float:
       return DataType::Float;
     case at::ScalarType::Half:
       return DataType::Half;
     case at::ScalarType::Long:
       return DataType::Int;
-    case at::ScalarType::Int:
-      return DataType::Int32;
     default:
+      TORCH_INTERNAL_ASSERT(false, "No data type found for scalar type.");
       return DataType::Null;
   }
 }
@@ -550,18 +445,15 @@ at::ScalarType data_type_to_aten(const DataType& data_type) {
   switch (data_type) {
     case DataType::Bool:
       return at::ScalarType::Bool;
-    case DataType::Double:
-      return at::ScalarType::Double;
     case DataType::Float:
       return at::ScalarType::Float;
     case DataType::Half:
       return at::ScalarType::Half;
     case DataType::Int:
       return at::ScalarType::Long;
-    case DataType::Int32:
-      return at::ScalarType::Int;
     default:
       TORCH_INTERNAL_ASSERT(false, "No data type found for scalar type.");
+      return at::ScalarType::Undefined;
   }
 }
 
@@ -616,12 +508,6 @@ c10::optional<std::string> inline_op_str(const BinaryOpType botype) {
                         : c10::nullopt;
 }
 
-c10::optional<std::string> integer_op_str(const BinaryOpType botype) {
-  const char* str = binary_op_integer_op2string(botype);
-  return str != nullptr ? c10::optional<std::string>(std::string(str))
-                        : c10::nullopt;
-}
-
 std::string stringifyThreadSize(const ParallelType ptype) {
   return thread_size2string(ptype);
 }
@@ -630,42 +516,6 @@ std::string stringifyThread(const ParallelType ptype) {
   return parallel_type2string(ptype);
 }
 
-std::string typePrefix(const DataType data_type) {
-  switch (data_type) {
-    case DataType::Bool:
-      return "b";
-    case DataType::Double:
-      return "d";
-    case DataType::Float:
-    case DataType::Half:
-      return "f";
-    case DataType::Int:
-    case DataType::Int32:
-      return "i";
-    default:
-      TORCH_INTERNAL_ASSERT(false, "No data type found for scalar type.");
-  }
-}
-
-bool isParallelTypeThreadDim(ParallelType ptype) {
-  return ptype == ParallelType::TIDx || ptype == ParallelType::TIDy ||
-      ptype == ParallelType::TIDz;
-}
-
-bool isParallelTypeBlockDim(ParallelType ptype) {
-  return ptype == ParallelType::BIDx || ptype == ParallelType::BIDy ||
-      ptype == ParallelType::BIDz;
-}
-
-bool isParallelTypeThread(ParallelType ptype) {
-  return isParallelTypeBlockDim(ptype) || isParallelTypeThreadDim(ptype);
-}
-
-bool isParallelTypeVectorize(ParallelType ptype) {
-  return ptype == ParallelType::Vectorize ||
-      ptype == ParallelType::MisalignedVectorize;
-}
-
 c10::optional<std::string> cast_func_str(
     const std::pair<DataType, DataType>& cast) {
   const char* str = supported_casts2string(cast);
@@ -677,16 +527,12 @@ size_t dataTypeSize(DataType type) {
   switch (type) {
     case DataType::Bool:
       return sizeof(bool);
-    case DataType::Double:
-      return sizeof(double);
     case DataType::Float:
-      return sizeof(float);
+      return 4;
     case DataType::Half:
-      return sizeof(at::Half);
+      return 2;
     case DataType::Int:
-      return sizeof(uint64_t);
-    case DataType::Int32:
-      return sizeof(uint32_t);
+      return 4;
     default:
       TORCH_INTERNAL_ASSERT(false, "Size undefined for data type, ", type);
   }
index 739c97f..715c06b 100644 (file)
@@ -14,8 +14,6 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-enum class KernelIndexMode { INT32, INT64 };
-
 // https://stackoverflow.com/questions/18837857/cant-use-enum-class-as-unordered-map-key
 struct TypeHash {
   template <typename T>
@@ -31,32 +29,17 @@ enum class ValType {
   TensorView,
   Scalar,
   NamedScalar,
-};
 
-// Manual - The user provides the Bool value. Predicate generation is bypassed.
-// Inline corresponds with PredicateCompute::getInlinePredicate
-// Unswitch corresponds with UnswitchPredicate::get
-// Misaligned - PredicateCompute::getInlinePredicate + Misaligned flag
-// Shift - ShiftPredicateInserter::getShiftPredicate
-// Padding - ShiftPredicateInserter::getPaddingPredicate
-// ReductionWrite - Same as Inline but without reduction axes
-enum class PredicateType {
-  Manual,
-  Inline,
-  Unswitch,
-  Vectorize,
-  Misaligned,
-  Shift,
-  Padding,
-  ReductionWrite
+  // Temporary: Kernel IR nodes
+  TensorIndex,
+  KirNamedScalar,
+  KirScalar,
+  KirTensorDomain,
+  KirIterDomain,
+  KirTensorView,
 };
 
-enum class DataType { Double, Float, Half, Int, Int32, Bool, Null };
-
-// Returns if the datatype is a floating point type
-bool isFloatingPointType(DataType dtype);
-// Returns if the datatype is an integer type
-bool isIntegralType(DataType dtype);
+enum class DataType { Bool, Float, Half, Int, Null };
 
 enum class ExprType {
   Invalid,
@@ -65,18 +48,25 @@ enum class ExprType {
   TernaryOp,
   ReductionOp,
   BroadcastOp,
-  WelfordOp,
-  TransposeOp,
-  ShiftOp,
-  GatherOp,
   Split,
   Merge,
+
+  // Temporary: Kernel IR nodes
+  GridReduction,
+  ForLoop,
+  IfThenElse,
+  Allocate,
+  Sync,
+  KirUnaryOp,
+  KirBinaryOp,
+  KirTernaryOp,
+  KirReductionOp,
+  KirBroadcastOp,
 };
 
 enum class UnaryOpType {
   Abs,
   Acos,
-  Address,
   Asin,
   Atan,
   Atanh,
@@ -91,7 +81,6 @@ enum class UnaryOpType {
   Floor,
   Frac,
   Gelu,
-  Silu,
   Lgamma,
   Log,
   Log10,
@@ -110,15 +99,9 @@ enum class UnaryOpType {
   Sqrt,
   Tan,
   Tanh,
-  Trunc,
-
-  // Might be a bitwise operator or boolean operator.
-  Not
+  Trunc
 };
 
-// Primarily for Not, which could be Not a boolean, or a bitwise not.
-bool alsoBooleanOperator(const UnaryOpType uopt);
-
 // TODO: Order of this list is important as it affects type promotion. it's not
 // in the right order now.
 enum class BinaryOpType {
@@ -135,43 +118,19 @@ enum class BinaryOpType {
   Sub,
   // TypeAs,
 
-  // Integer output ops. If changing modify isIntegerOp
+  // Logical Ops
+  // Int operations, leave position of Mod we depend on its location of first
   Mod,
   CeilDiv,
-  Lshift,
-  Rshift,
-
-  // Logical Ops
-  // Int operations, leave position of Mod as first logical op see
-  // isLogicalOp(BinaryOpType bopt)
+  And,
   Eq,
   GE,
   GT,
   LE,
   LT,
-  NE,
-
-  // Maybe bitwise or boolean op, leave position of and as first bool/int
-  // op. These are ops that have different operators based on output type. See
-  // is boolean op. These ops also don't work on floating point inputs.
-  And,
-  Or,
-  Xor
+  NE
 };
 
-// Return if output of operator should be a boolean
-bool isIntegerOp(const BinaryOpType bopt);
-
-// Return if output of operator should be a boolean
-bool isLogicalOp(const BinaryOpType bopt);
-
-// Operations that could be a bitwise operation or a boolean operation depending
-// on input, for example bitwise_and is also used for boolean and in the jit
-bool alsoBooleanOperator(const BinaryOpType bopt);
-
-//! Operations that have tricky behaviors with all integer inputs
-bool noFullIntegerSupport(const BinaryOpType bopt);
-
 enum class TernaryOpType { Clamp, Threshold, Where };
 
 enum class ParallelType {
@@ -182,9 +141,7 @@ enum class ParallelType {
   TIDy,
   TIDx,
   Vectorize,
-  MisalignedVectorize,
   Unroll,
-  Unswitch,
   Serial
 };
 
@@ -202,24 +159,15 @@ enum class IterType {
   Iteration,
   Reduction,
   BroadcastWithStride,
-  BroadcastWithoutStride,
-  Gather
+  BroadcastWithoutStride
 };
 
-enum class SwizzleType { NoSwizzle, Transpose };
-
-// Returns if function needs an f suffix on the operator when operating on a
-// float value i.e. sin->sinf
-bool needFloatSuffix(UnaryOpType t);
-bool needFloatSuffix(BinaryOpType t);
-
 ValType promote_type(const ValType& t1, const ValType& t2);
 DataType promote_type(const DataType& t1, const DataType& t2);
+bool is_logical_op(const BinaryOpType& bot);
 
-// If type cannot be found (i.e. codegen does not support provided type) returns
-// DataType::Null
-TORCH_CUDA_CU_API DataType aten_to_data_type(const at::ScalarType& scalar_type);
-TORCH_CUDA_CU_API at::ScalarType data_type_to_aten(const DataType& data_type);
+DataType aten_to_data_type(const at::ScalarType& scalar_type);
+at::ScalarType data_type_to_aten(const DataType& data_type);
 
 TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const ValType);
 TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const DataType);
@@ -231,28 +179,16 @@ TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const ParallelType);
 TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const MemoryType);
 TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const IterType);
 
-std::string stringifyBooleanOp(const UnaryOpType);
-std::string stringifyBooleanOp(const BinaryOpType);
-
 std::string stringifyThreadSize(const ParallelType);
 std::string stringifyThread(const ParallelType);
-std::string typePrefix(const DataType);
-
-// TODO: ThreadDim should be BlockDim and BlockDim should be GridDim
-TORCH_CUDA_CU_API bool isParallelTypeThreadDim(ParallelType);
-TORCH_CUDA_CU_API bool isParallelTypeBlockDim(ParallelType);
-TORCH_CUDA_CU_API bool isParallelTypeThread(ParallelType);
-
-TORCH_CUDA_CU_API bool isParallelTypeVectorize(ParallelType);
 
 TORCH_CUDA_CU_API c10::optional<std::string> inline_op_str(const UnaryOpType);
 TORCH_CUDA_CU_API c10::optional<std::string> inline_op_str(const BinaryOpType);
-TORCH_CUDA_CU_API c10::optional<std::string> integer_op_str(const BinaryOpType);
 
 TORCH_CUDA_CU_API c10::optional<std::string> cast_func_str(
     const std::pair<DataType, DataType>&);
 
-TORCH_CUDA_CU_API size_t dataTypeSize(DataType type);
+size_t dataTypeSize(DataType type);
 
 enum class LaunchConfigType {
   Compatible,
diff --git a/torch/csrc/jit/codegen/cuda/utils.cpp b/torch/csrc/jit/codegen/cuda/utils.cpp
deleted file mode 100644 (file)
index db25fce..0000000
+++ /dev/null
@@ -1,106 +0,0 @@
-
-#include <torch/csrc/jit/codegen/cuda/utils.h>
-
-#include <c10/util/string_view.h>
-
-#include <cstdlib>
-#include <unordered_map>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-auto parseDebugDumpOptions() {
-  std::unordered_map<DebugDumpOption, bool> options_map = {
-      {DebugDumpOption::FusionIr, false},
-      {DebugDumpOption::FusionIrMath, false},
-      {DebugDumpOption::KernelIr, false},
-      {DebugDumpOption::CudaKernel, false},
-      {DebugDumpOption::CudaFull, false},
-      {DebugDumpOption::CudaToFile, false},
-      {DebugDumpOption::LaunchParam, false},
-      {DebugDumpOption::FusionSegments, false},
-      {DebugDumpOption::PrintRuntimeArgs, false},
-      {DebugDumpOption::EffectiveBandwidth, false},
-      {DebugDumpOption::FusionSegmentsDrawing, false},
-      {DebugDumpOption::PrintPtxasLog, false},
-      {DebugDumpOption::SchedulerDebug, false},
-      {DebugDumpOption::ParallelDimensions, false}};
-
-  if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_DUMP")) {
-    c10::string_view options_view(dump_options);
-    while (!options_view.empty()) {
-      const auto end_pos = options_view.find_first_of(',');
-      const auto token = options_view.substr(0, end_pos);
-      if (token == "fusion_ir") {
-        options_map[DebugDumpOption::FusionIr] = true;
-      } else if (token == "fusion_ir_math") {
-        options_map[DebugDumpOption::FusionIrMath] = true;
-      } else if (token == "kernel_ir") {
-        options_map[DebugDumpOption::KernelIr] = true;
-      } else if (token == "cuda_kernel") {
-        options_map[DebugDumpOption::CudaKernel] = true;
-      } else if (token == "cuda_full") {
-        options_map[DebugDumpOption::CudaFull] = true;
-      } else if (token == "cuda_to_file") {
-        options_map[DebugDumpOption::CudaToFile] = true;
-      } else if (token == "launch_param") {
-        options_map[DebugDumpOption::LaunchParam] = true;
-      } else if (token == "segmented_fusion") {
-        options_map[DebugDumpOption::FusionSegments] = true;
-      } else if (token == "print_args") {
-        options_map[DebugDumpOption::PrintRuntimeArgs] = true;
-      } else if (token == "dump_eff_bandwidth") {
-        options_map[DebugDumpOption::EffectiveBandwidth] = true;
-      } else if (token == "draw_segmented_fusion") {
-        options_map[DebugDumpOption::FusionSegmentsDrawing] = true;
-      } else if (token == "ptxas_verbose") {
-        options_map[DebugDumpOption::PrintPtxasLog] = true;
-      } else if (token == "scheduler_params") {
-        options_map[DebugDumpOption::SchedulerDebug] = true;
-      } else if (token == "parallel_dimensions") {
-        options_map[DebugDumpOption::ParallelDimensions] = true;
-      } else {
-        TORCH_CHECK(
-            false,
-            "Invalid debug dump option: '",
-            token,
-            "'\nAvailable options:\n",
-            "\tfusion_ir, fusion_ir_math, kernel_ir, cuda_kernel, cuda_full,\n",
-            "\tcuda_to_file, launch_param, segmented_fusion, print_args,\n",
-            "\tdump_eff_bandwidth, draw_segmented_fusion, scheduler_params\n",
-            "\tparallel_dimensions,\n");
-      }
-      options_view = (end_pos != c10::string_view::npos)
-          ? options_view.substr(end_pos + 1)
-          : "";
-    }
-  }
-
-  return options_map;
-}
-
-} // namespace
-
-bool isDebugDumpEnabled(DebugDumpOption option) {
-  const static auto dump_options = parseDebugDumpOptions();
-  return dump_options.at(option);
-}
-
-bool useFallback() {
-  const char* disable_fb_env = getenv("PYTORCH_NVFUSER_DISABLE_FALLBACK");
-  return !(disable_fb_env ? atoi(disable_fb_env) : 0);
-}
-
-bool disableRNGUnrolling() {
-  const char* disable_rng_unroll = getenv("PYTORCH_NVFUSER_DISABLE_RNG_UNROLL");
-  return disable_rng_unroll ? atoi(disable_rng_unroll) : 0;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
index e7de6fe..f47c944 100644 (file)
@@ -7,48 +7,17 @@ namespace jit {
 namespace fuser {
 namespace cuda {
 
-//! Types of debug print-outs
-//!
-//! These can be set through the `PYTORCH_NVFUSER_DUMP` environment variable
-//!
-enum class DebugDumpOption {
-  FusionIr, //!< Dump the Fusion IR before lowering
-  FusionIrMath, //!< Dump just the compute (math) part of the Fusion IR
-  KernelIr, //!< Dump the compiler Kernel IR
-  CudaKernel, //!< Dump the generated CUDA C++ kernel code
-  CudaFull, //!< Dump the complete CUDA C++ code
-  CudaToFile, //!< Dump CUDA Strings to File
-  LaunchParam, //!< Dump the Launch parameters of kernel
-  FusionSegments, //!< Dump Segmented Fusion Graph
-  PrintRuntimeArgs, //!< Print the runtime arguments when launching kernels
-  EffectiveBandwidth, //! Measure kernel performance and print effective
-                      //! bandwidth
-  FusionSegmentsDrawing, //!< Dump Segmented Fusion Graph
-  PrintPtxasLog, //!< Print the ptxas verbose log including register usage
-  SchedulerDebug, //! Dump scheduler heuristic parameters
-  ParallelDimensions //!< Dump known parallel dimensions
-};
-
-bool isDebugDumpEnabled(DebugDumpOption option);
-
-// Check if fallback path should be used which will dispatch to eagermode if any
-// errors are encountered. Helpful for debugging.
-bool useFallback();
-
-// Returns if unrolling should not be used for kernels with RNG in them.
-bool disableRNGUnrolling();
-
-//! Ceil integer division
+// Common Functions
 constexpr int64_t ceilDiv(int64_t a, int64_t b) {
   return (a + b - 1) / b;
 }
 
-//! Simple mixin for suppressing copy & move operations, ex:
-//!
-//!  class Foo : public NonCopyable {
-//!   ...
-//!  };
-//!
+// Simple mixin for suppressing copy & move operations, ex:
+//
+//  class Foo : public NonCopyable {
+//   ...
+//  };
+//
 class NonCopyable {
  public:
   NonCopyable() = default;
@@ -58,9 +27,9 @@ class NonCopyable {
   NonCopyable& operator=(const NonCopyable&) = delete;
 };
 
-//! A generic root for a hierarchy of polymorphic classes:
-//! - It ensures virtual destructors
-//! - Provides the base->as<Derived>() and node->isA<T>() notation
+// A generic root for a hierarchy of polymorphic classes:
+// - It ensures virtual destructors
+// - Provides the base->as<Derived>() and node->isA<T>() notation
 class PolymorphicBase {
  public:
   virtual ~PolymorphicBase() = default;
@@ -89,16 +58,16 @@ class PolymorphicBase {
     return downcast_ptr;
   }
 
-  //! Check if the runtime time is T (or derived from T)
-  //!
-  //! \note Don't use this for conditional casts. Instead, use:
-  //!
-  //!  if (auto t = dynamic_cast<T>(p)) { ... }
-  //!
-  //! instead of:
-  //!
-  //!  if (p->isA<T>()) { auto t = p->as<T>(); ... }
-  //!
+  // Check if the runtime time is T (or derived from T)
+  //
+  // NOTE: Don't use this for conditional casts. Use:
+  //
+  //  if (auto t = dynamic_cast<T>(p)) { ... }
+  //
+  // instead of:
+  //
+  //  if (p->isA<T>()) { auto t = p->as<T>(); ... }
+  //
   template <class T>
   bool isA() const {
     return dynamic_cast<const T*>(this) != nullptr;
index 12a1873..0c54f46 100644 (file)
@@ -60,7 +60,7 @@ bool isDifferentiable(const Node* n) {
 
   if (n->kind() == prim::Constant || n->kind() == prim::AutogradZero ||
       n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk ||
-      n->kind() == prim::profile || n->kind() == prim::profile_ivalue)
+      n->kind() == prim::profile)
     return true;
 
   if (n->isMemberOf(differentiable_ops))
index b5c6ce9..b099db1 100644 (file)
@@ -668,13 +668,6 @@ const ExecutionPlan& ProfilingGraphExecutorImpl::getOptimizedPlanFor(
     // before any other pass that could insert `prim::iprofile_value` node on
     // `aten::_grad_sum_to_size` input.
     InsertProfileNodesForSpecializeAutogradZero(pr_.get());
-    // `InsertProfileNodesForCUDAFuser` inserts profile node for non-tensor
-    // value
-#ifndef C10_MOBILE
-    if (RegisterCudaFuseGraph::isRegistered()) {
-      torch::jit::fuser::cuda::InsertProfileNodesForCUDAFuser(pr_.get());
-    }
-#endif
     GRAPH_DUMP("Profiled Graph: ", pr_->graph());
     profiling_plan_ = ExecutionPlan(pr_->graph(), function_name_);
     // fall-through
index 400b54e..6c7c65b 100644 (file)
@@ -5,15 +5,11 @@
 #include <torch/csrc/jit/jit_log.h>
 #include <torch/csrc/jit/passes/clear_profiling.h>
 #include <torch/csrc/jit/passes/constant_propagation.h>
-#include <torch/csrc/jit/passes/cuda_graph_fuser.h>
 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
 #include <torch/csrc/jit/runtime/autodiff.h>
 #include <torch/csrc/jit/runtime/graph_executor.h>
 #include <torch/csrc/jit/runtime/interpreter.h>
 
-#include <torch/csrc/jit/codegen/cuda/interface.h>
-#include <torch/csrc/jit/ir/ir.h>
-
 namespace torch {
 namespace jit {
 
@@ -205,13 +201,7 @@ void ProfilingRecord::insertShapeProfile(Node* n, size_t offset) {
 }
 
 bool needsProfiledInputs(Node* n) {
-  if (tensorexpr::isSupported(n) ||
-#ifndef C10_MOBILE
-      (RegisterCudaFuseGraph::isRegistered() && fuser::cuda::canFuseNode(n))
-#else
-      false
-#endif
-  ) {
+  if (tensorexpr::isSupported(n)) {
     return true;
   }
 
@@ -242,13 +232,7 @@ bool needsProfiledInputs(Node* n) {
 }
 
 bool needsProfiledOutput(Node* n) {
-  if (tensorexpr::isSupported(n) ||
-#ifndef C10_MOBILE
-      (RegisterCudaFuseGraph::isRegistered() && fuser::cuda::canFuseNode(n))
-#else
-      false
-#endif
-  ) {
+  if (tensorexpr::isSupported(n)) {
     return true;
   }
 
index 534b0af..29dac32 100644 (file)
@@ -289,25 +289,8 @@ class JitTestCase(JitCommonTestCase):
         result.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
         return result
 
-    def assertGraphContains(self, graph, kind, consider_subgraphs=False):
-
-        if consider_subgraphs:
-            strgraph = str(graph)
-            count = strgraph.count(kind) - strgraph.count('with {}'.format(kind))
-            self.assertTrue(count > 0)
-            return
-
-        def nodes(block):
-            out = []
-            for node in block.nodes():
-                if node.kind() == kind:
-                    out.append(node)
-                for block in node.blocks():
-                    out += nodes(block)
-            return out
-
-        out_nodes = nodes(graph)
-        self.assertTrue(len(out_nodes) > 0)
+    def assertGraphContains(self, graph, kind):
+        self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
 
     def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
         def perform_assert(graph, kind, actual, expected, consider_subgraphs):