USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
"USE_CUDNN" OFF)
cmake_dependent_option(
- BUILD_NVFUSER_BENCHMARK "Build C++ binaries for nvfuser benchmarks" ON
- "USE_CUDA;BUILD_TEST" OFF)
-cmake_dependent_option(
USE_WHOLE_CUDNN "Use whole-library linking for cuDNN" OFF
"USE_STATIC_CUDNN" OFF)
cmake_dependent_option(
_(aten, baddbmm) \
_(aten, bartlett_window) \
_(aten, batch_norm) \
-_(aten, _batch_norm_impl_index) \
-_(aten, _batch_norm_impl_index_backward) \
_(aten, bernoulli) \
_(aten, bilinear) \
_(aten, binary_cross_entropy) \
_(aten, gather) \
_(aten, gcd) \
_(aten, gelu) \
-_(aten, gelu_backward) \
_(aten, geometric) \
_(aten, geqrf) \
_(aten, get_device) \
_(aten, narrow_copy) \
_(aten, native_batch_norm) \
_(aten, native_batch_norm_backward) \
-_(aten, native_layer_norm) \
-_(aten, native_layer_norm_backward) \
_(aten, native_clone) \
_(aten, native_get_device) \
_(aten, native_norm) \
_(prim, CudaFusionGroup) \
_(prim, CudaFusionGuard) \
_(prim, FunctionalGraph) \
- _(prim, add_optional) \
_(prim, DifferentiableGraph) \
_(prim, TensorExprGroup) \
_(prim, StaticSubgraph) \
+++ /dev/null
-if(USE_CUDA)
- add_executable(nvfuser_bench
- batch_norm.cpp
- bert.cpp
- broadcast.cpp
- gelu_backward.cpp
- heuristic_lookup.cpp
- instance_norm.cpp
- layer_norm.cpp
- lstm_cell.cpp
- reduction.cpp
- softmax.cpp
- scale_bias_relu.cpp
- utils.cpp
- main.cpp)
-
- target_link_libraries(nvfuser_bench PRIVATE torch_library benchmark)
-endif()
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-//------------------------------------------------------------------------------
-
-static void setupBatchNorm(Fusion* fusion, DataType dtype) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- FusionGuard fg(fusion);
-
- const bool kTraining = true;
- const float kMomentum = 0.1;
- const float kEps = 1e-5;
-
- // setup fusion
- auto input = makeContigTensor(4, dtype);
- auto weight = makeContigTensor(1, dtype);
- auto bias = makeContigTensor(1, dtype);
- auto running_mean = makeContigTensor(1, DataType::Float);
- auto running_var = makeContigTensor(1, DataType::Float);
-
- fusion->addInput(input);
- fusion->addInput(weight);
- fusion->addInput(bias);
- fusion->addInput(running_mean);
- fusion->addInput(running_var);
-
- if (dtype == DataType::Half) {
- input = castOp(DataType::Float, input);
- weight = castOp(DataType::Float, weight);
- bias = castOp(DataType::Float, bias);
- }
-
- auto momentum_ptr = new Double(kMomentum);
- auto eps_ptr = new Double(kEps);
-
- auto result = batch_norm(
- input,
- weight,
- bias,
- running_mean,
- running_var,
- kTraining,
- momentum_ptr,
- eps_ptr);
-
- auto output = result.output;
-
- if (dtype == DataType::Half) {
- output = castOp(DataType::Half, output);
- }
-
- fusion->addOutput(output);
-}
-
-static void NvFuserScheduler_BatchNorm(
- benchmark::State& benchmark_state,
- FusionExecutorCache* fusion_executor_cache,
- DataType dtype) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- const bool kTraining = true;
- const float kMomentum = 0.1;
- const float kEps = 1e-5;
-
- std::vector<int64_t> input_shape{
- benchmark_state.range(0),
- benchmark_state.range(1),
- benchmark_state.range(2),
- benchmark_state.range(2)};
-
- // inputs
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- auto fp32_options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor at_x = at::randn(input_shape, options);
- at::Tensor at_weight = at::ones({input_shape[1]}, options);
- at::Tensor at_bias = at::zeros({input_shape[1]}, options);
- at::Tensor at_run_mean = at::zeros({input_shape[1]}, fp32_options);
- at::Tensor at_run_var = at::ones({input_shape[1]}, fp32_options);
- std::vector<c10::IValue> aten_inputs(
- {at_x, at_weight, at_bias, at_run_mean, at_run_var});
-
- runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
-
- benchmark_state.SetBytesProcessed(
- (int64_t(benchmark_state.iterations()) *
- (2 * (at_x.numel() + at_weight.numel() + at_bias.numel())) *
- int64_t(dataTypeSize(dtype))) +
- (2 * (at_run_mean.numel() + at_run_var.numel()) *
- int64_t(dataTypeSize(DataType::Float))));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_BatchNorm(
- benchmark::State& benchmark_state,
- DataType dtype) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- const float kMomentum = 0.1;
- const float kEps = 1e-5;
- std::vector<int64_t> input_shape{
- benchmark_state.range(0),
- benchmark_state.range(1),
- benchmark_state.range(2),
- benchmark_state.range(2)};
-
- // inputs
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- auto fp32_options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor at_x = at::randn(input_shape, options);
- at::Tensor at_weight = at::ones({input_shape[1]}, options);
- at::Tensor at_bias = at::zeros({input_shape[1]}, options);
- at::Tensor at_running_mean = at::zeros({input_shape[1]}, fp32_options);
- at::Tensor at_running_var = at::ones({input_shape[1]}, fp32_options);
-
- auto ato_weight = c10::optional<at::Tensor>(at_weight);
- auto ato_bias = c10::optional<at::Tensor>(at_bias);
- auto ato_running_mean = c10::optional<at::Tensor>(at_running_mean);
- auto ato_running_var = c10::optional<at::Tensor>(at_running_var);
-
- auto output = at::batch_norm(
- at_x,
- ato_weight,
- ato_bias,
- ato_running_mean,
- ato_running_var,
- true,
- kMomentum,
- kEps,
- true);
- cudaDeviceSynchronize();
-
- for (auto _ : benchmark_state) {
- CudaKernelTimer timer;
- auto output = at::batch_norm(
- at_x,
- ato_weight,
- ato_bias,
- ato_running_mean,
- ato_running_var,
- true,
- kMomentum,
- kEps,
- true);
- benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
- cudaDeviceSynchronize();
- }
- benchmark_state.SetBytesProcessed(
- (int64_t(benchmark_state.iterations()) *
- (2 * (at_x.numel() + at_weight.numel() + at_bias.numel())) *
- int64_t(dataTypeSize(dtype))) +
- (2 * (at_running_mean.numel() + at_running_var.numel()) *
- int64_t(dataTypeSize(DataType::Float))));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_BatchNorm_fp32(benchmark::State& benchmark_state) {
- Baseline_BatchNorm(benchmark_state, DataType::Float);
-}
-
-static void Baseline_BatchNorm_fp16(benchmark::State& benchmark_state) {
- Baseline_BatchNorm(benchmark_state, DataType::Half);
-}
-
-//------------------------------------------------------------------------------
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_BatchNorm_fp32,
- setupBatchNorm,
- NvFuserScheduler_BatchNorm,
- DataType::Float);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{32, 32}, {64, 512}, {8, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{64, 128}, {64, 128}, {8, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{128, 128}, {128, 512}, {8, 128}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{16, 64}, {2, 4}, {128, 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_BatchNorm_fp16,
- setupBatchNorm,
- NvFuserScheduler_BatchNorm,
- DataType::Half);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{32, 32}, {64, 512}, {8, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{64, 128}, {64, 128}, {8, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{128, 128}, {128, 512}, {8, 128}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_BatchNorm_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{16, 64}, {2, 4}, {128, 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-BENCHMARK(Baseline_BatchNorm_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{32, 32}, {64, 512}, {8, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{64, 128}, {64, 128}, {8, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{128, 128}, {128, 512}, {8, 128}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{16, 64}, {2, 4}, {128, 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{32, 32}, {64, 512}, {8, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{64, 128}, {64, 128}, {8, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{128, 128}, {128, 512}, {8, 128}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_BatchNorm_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{16, 64}, {2, 4}, {128, 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include <sstream>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-// Return reduction tensor view and output of reduction
-static void setupDivMaxSoftmaxDropoutForward(Fusion* fusion, DataType dtype) {
- FusionGuard fg(fusion);
-
- bool is_fp16 = dtype == DataType::Half;
-
- TensorView* tv0 = TensorViewBuilder()
- .ndims(4)
- .dtype(dtype)
- .contiguity({true, false, false, true})
- .shape({-1, 1, 1, -1})
- .build();
- TensorView* tv1 = makeContigTensor(4, dtype);
-
- fusion->addInput(tv0);
- fusion->addInput(tv1);
-
- // TODO: should be input
- auto d16 = new Double(1.0);
-
- if (is_fp16) {
- tv0 = castOp(DataType::Float, tv0);
- tv1 = castOp(DataType::Float, tv1);
- }
-
- auto tv2 = div(tv1, d16);
- auto tv3 = add(tv2, tv0);
-
- auto tv10 = softmax(tv3, 3);
- auto dropout_tvs = dropout(tv10, new Double(0.9));
- auto tv12 = dropout_tvs.mask;
- auto tv14 = dropout_tvs.output;
-
- if (is_fp16) {
- tv14 = castOp(DataType::Half, tv14);
- tv10 = castOp(DataType::Half, tv10);
- tv3 = castOp(DataType::Half, tv3);
- }
-
- fusion->addOutput(tv14);
- fusion->addOutput(tv12);
- fusion->addOutput(tv10);
- fusion->addOutput(tv3);
-}
-
-static void setupDivMaxSoftmaxDropoutBackward(Fusion* fusion, DataType dtype) {
- TensorView* tv0 = makeContigTensor(4, dtype);
- // Strangely tv1 isn't used anywhere, need to come back to that...
- TensorView* tv1 = makeContigTensor(4, dtype);
- TensorView* tv2 = makeContigTensor(4, dtype);
- TensorView* tv3 = makeContigTensor(4, DataType::Bool);
-
- fusion->addInput(tv0);
- fusion->addInput(tv1);
- fusion->addInput(tv2);
- fusion->addInput(tv3);
-
- bool is_fp16 = dtype == DataType::Half;
- if (is_fp16) {
- tv0 = castOp(DataType::Float, tv0);
- tv1 = castOp(DataType::Float, tv1);
- tv2 = castOp(DataType::Float, tv2);
- }
-
- // TODO: should be inputs
- auto d32 = new Double(1.0);
- // fusion->addInput(d32);
- auto d33 = new Double(2.0);
- // fusion->addInput(d33);
-
- auto tv4 = mul(tv2, tv3);
- auto tv5 = mul(tv4, d33);
- auto tv6 = mul(tv5, tv0);
- auto tv7 = sum(tv6, {-1});
- auto tv8 = broadcast(tv7, {false, false, false, true});
- auto tv9 = mul(tv0, tv8);
- auto tv10 = sub(tv6, tv9);
- auto tv11 = div(tv10, d32);
-
- if (is_fp16) {
- tv10 = castOp(DataType::Half, tv10);
- tv11 = castOp(DataType::Half, tv11);
- }
-
- fusion->addOutput(tv11);
- fusion->addOutput(tv10);
-}
-
-static void MagicScheduler_DivMaxSoftDropFwd(
- benchmark::State& benchmark_state,
- DataType dtype) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto w = benchmark_state.range(0);
- auto x = benchmark_state.range(1);
- auto y = benchmark_state.range(2);
- auto z = benchmark_state.range(3);
-
- setupDivMaxSoftmaxDropoutForward(&fusion, dtype);
-
- auto tvs = ir_utils::allTvs(&fusion);
-
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({w, 1, 1, z}, options);
- at::Tensor t1 = at::randn({w, x, y, z}, options);
-
- std::vector<c10::IValue> at_inputs = {t0, t1};
- std::vector<at::Tensor> cg_outputs;
-
- auto norm_params = getNormalizationHeuristics(&fusion, at_inputs);
- TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
- scheduleNormalization(&fusion, norm_params.value());
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.setMeasureKernelTimeFlag(true);
- // Sync everything up before we start
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- CudaKernelTimer timer;
- cg_outputs = fe.runFusion({t0, t1}, norm_params.value().lparams);
- benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
- }
- // Sync everything up before we're finished, don't want to run ahead on the
- // cpu while benchmarking.
- cudaDeviceSynchronize();
-
- int64_t bytes = 0;
- for (auto tensor : std::vector<at::Tensor>({t0, t1})) {
- bytes += tensor.numel() *
- (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
- }
-
- for (auto tensor : cg_outputs) {
- bytes += tensor.numel() *
- (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
- }
-
- benchmark_state.SetBytesProcessed(
- bytes * int64_t(benchmark_state.iterations()));
-}
-
-static void MagicScheduler_DivMaxSoftDropBwd(
- benchmark::State& benchmark_state,
- DataType dtype) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto w = benchmark_state.range(0);
- auto x = benchmark_state.range(1);
- auto y = benchmark_state.range(2);
- auto z = benchmark_state.range(3);
-
- setupDivMaxSoftmaxDropoutBackward(&fusion, dtype);
-
- auto tvs = ir_utils::allTvs(&fusion);
-
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({w, x, y, z}, options);
- at::Tensor t1 = at::randn({w, x, y, z}, options);
- at::Tensor t2 = at::randn({w, x, y, z}, options);
- at::Tensor t3 = at::randn({w, x, y, z}, options).round().to(at::kBool);
-
- std::vector<c10::IValue> at_inputs = {t0, t1, t2, t3};
- std::vector<at::Tensor> cg_outputs;
-
- auto norm_params = getNormalizationHeuristics(&fusion, at_inputs);
- TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
- scheduleNormalization(&fusion, norm_params.value());
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.setMeasureKernelTimeFlag(true);
- // Sync everything up before we start
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- CudaKernelTimer timer;
- cg_outputs = fe.runFusion({t0, t1, t2, t3}, norm_params.value().lparams);
- benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
- }
- // Sync everything up before we're finished, don't want to run ahead on the
- // cpu while benchmarking.
- cudaDeviceSynchronize();
-
- int64_t bytes = 0;
- // Some reason t1 isn't used, ignore it.
- for (auto tensor : std::vector<at::Tensor>({t0, t2, t3})) {
- bytes += tensor.numel() *
- (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
- }
-
- for (auto tensor : cg_outputs) {
- bytes += tensor.numel() *
- (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
- }
-
- benchmark_state.SetBytesProcessed(
- bytes * int64_t(benchmark_state.iterations()));
-}
-
-static void setupBiasDropoutAddLayernormFwd(Fusion* fusion, DataType dtype) {
- FusionGuard fg(fusion);
-
- bool is_fp16 = dtype == DataType::Half;
-
- TensorView* tv0 = makeContigTensor(1, dtype);
- TensorView* tv1 = makeContigTensor(1, dtype);
- TensorView* tv2 = makeContigTensor(3, dtype);
- TensorView* tv3 = makeContigTensor(3, dtype);
- TensorView* tv4 = makeContigTensor(1, dtype);
-
- fusion->addInput(tv0);
- fusion->addInput(tv1);
- fusion->addInput(tv2);
- fusion->addInput(tv3);
- fusion->addInput(tv4);
-
- if (is_fp16) {
- tv0 = castOp(DataType::Float, tv0);
- tv1 = castOp(DataType::Float, tv1);
- tv2 = castOp(DataType::Float, tv2);
- tv3 = castOp(DataType::Float, tv3);
- tv4 = castOp(DataType::Float, tv4);
- }
-
- auto tv5 = broadcast(tv4, {true, true, false});
- auto tv6 = add(tv3, tv5);
- auto dropout_outs = dropout(tv6, new Double(0.9));
-
- auto tv8 = dropout_outs.output;
- auto tv10 = dropout_outs.mask;
-
- auto tv11 = add(tv10, tv2);
-
- auto layer_norm_outs = layer_norm(tv11, 1, tv0, tv1, new Double(1e-5));
- auto tv14 = layer_norm_outs.output;
- auto tv21 = layer_norm_outs.mean;
- auto tv26 = layer_norm_outs.invstd;
-
- if (is_fp16) {
- tv11 = castOp(DataType::Half, tv11);
- tv14 = castOp(DataType::Half, tv14);
- tv21 = castOp(DataType::Half, tv21);
- tv26 = castOp(DataType::Half, tv26);
- }
-
- fusion->addOutput(tv8);
- fusion->addOutput(tv11);
- fusion->addOutput(tv14);
- fusion->addOutput(tv21);
- fusion->addOutput(tv26);
-}
-
-static void MagicScheduler_BiasDropoutAddLayernormFwd(
- benchmark::State& benchmark_state,
- DataType dtype) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto x = benchmark_state.range(0);
- auto y = benchmark_state.range(1);
- auto z = benchmark_state.range(2);
-
- setupBiasDropoutAddLayernormFwd(&fusion, dtype);
-
- auto tvs = ir_utils::allTvs(&fusion);
-
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({z}, options);
- at::Tensor t1 = at::randn({z}, options);
- at::Tensor t2 = at::randn({x, y, z}, options);
- at::Tensor t3 = at::randn({x, y, z}, options);
- at::Tensor t4 = at::randn({z}, options);
-
- std::vector<c10::IValue> at_inputs = {t0, t1, t2, t3, t4};
- std::vector<at::Tensor> cg_outputs;
-
- auto norm_params = getNormalizationHeuristics(&fusion, at_inputs);
- TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
- scheduleNormalization(&fusion, norm_params.value());
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.setMeasureKernelTimeFlag(true);
- // Sync everything up before we start
-
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- CudaKernelTimer timer;
- cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
- benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
- }
- // Sync everything up before we're finished, don't want to run ahead on the
- // cpu while benchmarking.
- cudaDeviceSynchronize();
-
- int64_t bytes = 0;
- for (auto inp : at_inputs) {
- auto tensor = inp.toTensor();
- bytes += tensor.numel() *
- (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
- }
-
- for (auto tensor : cg_outputs) {
- bytes += tensor.numel() *
- (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
- }
-
- benchmark_state.SetBytesProcessed(
- bytes * int64_t(benchmark_state.iterations()));
-}
-
-static void MagicScheduler_fp32_BiasDropoutAddLayernormFwd(
- benchmark::State& benchmark_state) {
- MagicScheduler_BiasDropoutAddLayernormFwd(benchmark_state, DataType::Float);
-}
-
-static void setupBiasDropoutAddLayernormBwd1(Fusion* fusion, DataType dtype) {
- FusionGuard fg(fusion);
-
- bool is_fp16 = dtype == DataType::Half;
-
- TensorView* tv1 = makeContigTensor(3, dtype);
- TensorView* tv2 = makeContigTensor(3, dtype);
- TensorView* tv3 = TensorViewBuilder()
- .ndims(3)
- .dtype(dtype)
- .contiguity({true, true, true})
- .shape({-1, -1, 1})
- .build();
- TensorView* tv4 = TensorViewBuilder()
- .ndims(3)
- .dtype(dtype)
- .contiguity({true, true, true})
- .shape({-1, -1, 1})
- .build();
-
- fusion->addInput(tv1);
- fusion->addInput(tv2);
- fusion->addInput(tv3);
- fusion->addInput(tv4);
-
- if (is_fp16) {
- tv1 = castOp(DataType::Float, tv1);
- tv2 = castOp(DataType::Float, tv2);
- tv3 = castOp(DataType::Float, tv3);
- tv4 = castOp(DataType::Float, tv4);
- }
-
- auto tv7 = sub(tv2, tv3);
- auto tv8 = mul(tv7, tv4);
- auto tv24 = sum(tv1, {0, 1});
- auto tv22 = mul(tv1, tv8);
- auto tv23 = sum(tv22, {0, 1});
-
- if (is_fp16) {
- tv24 = castOp(DataType::Half, tv24);
- tv23 = castOp(DataType::Half, tv23);
- tv8 = castOp(DataType::Half, tv8);
- }
-
- fusion->addOutput(tv24);
- fusion->addOutput(tv23);
- fusion->addOutput(tv8);
-}
-
-static void MagicScheduler_BiasDropoutAddLayernormBwd1(
- benchmark::State& benchmark_state,
- DataType dtype) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto x = benchmark_state.range(0);
- auto y = benchmark_state.range(1);
- auto z = benchmark_state.range(2);
-
- setupBiasDropoutAddLayernormBwd1(&fusion, dtype);
-
- auto tvs = ir_utils::allTvs(&fusion);
-
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({x, y, z}, options);
- at::Tensor t1 = at::randn({x, y, z}, options);
- at::Tensor t2 = at::randn({x, y, 1}, options);
- at::Tensor t3 = at::randn({x, y, 1}, options);
-
- std::vector<c10::IValue> at_inputs = {t0, t1, t2, t3};
- std::vector<at::Tensor> cg_outputs;
-
- auto norm_params = getNormalizationHeuristics(&fusion, at_inputs);
- TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
- scheduleNormalization(&fusion, norm_params.value());
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.setMeasureKernelTimeFlag(true);
- // Sync everything up before we start
-
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- clearL2Cache();
- cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
- benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
- }
- // Sync everything up before we're finished, don't want to run ahead on the
- // cpu while benchmarking.
- cudaDeviceSynchronize();
-
- int64_t bytes = 0;
- for (auto inp : at_inputs) {
- auto tensor = inp.toTensor();
- bytes += tensor.numel() *
- (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
- }
-
- for (auto tensor : cg_outputs) {
- bytes += tensor.numel() *
- (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
- }
-
- benchmark_state.SetBytesProcessed(
- bytes * int64_t(benchmark_state.iterations()));
-}
-
-static void setupBiasDropoutAddLayernormBwd2(Fusion* fusion, DataType dtype) {
- FusionGuard fg(fusion);
-
- bool is_fp16 = dtype == DataType::Half;
-
- TensorView* tv4 = TensorViewBuilder()
- .ndims(3)
- .dtype(dtype)
- .contiguity({true, true, true})
- .shape({-1, -1, 1})
- .build();
- TensorView* tv5 = makeContigTensor(1, dtype);
- TensorView* tv1 = makeContigTensor(3, dtype);
- TensorView* tv8 = makeContigTensor(3, dtype);
-
- fusion->addInput(tv4);
- fusion->addInput(tv5);
- fusion->addInput(tv1);
- fusion->addInput(tv8);
-
- if (is_fp16) {
- tv4 = castOp(DataType::Float, tv4);
- tv5 = castOp(DataType::Float, tv5);
- tv1 = castOp(DataType::Float, tv1);
- tv8 = castOp(DataType::Float, tv8);
- }
- auto d36 = mul(new Double(1.0), tv1->axis(2)->extent());
- auto d47 = unaryOp(UnaryOpType::Reciprocal, d36);
-
- auto tv9 = broadcast(tv5, {true, true, false});
- auto tv10 = mul(tv1, tv9);
- auto tv14 = mul(tv10, tv8);
- auto tv15 = sum(tv14, {2});
- auto tv16 = broadcast(tv15, {false, false, true});
- auto tv17 = mul(tv8, tv16);
- auto tv12 = sum(tv10, {2});
- auto tv13 = broadcast(tv12, {false, false, true});
- auto tv11 = mul(d36, tv10);
- auto tv18 = sub(tv11, tv13);
- auto tv20 = mul(d47, tv4);
- auto tv19 = sub(tv18, tv17);
- auto tv21 = mul(tv20, tv19);
-
- if (is_fp16) {
- tv21 = castOp(DataType::Half, tv21);
- }
-
- fusion->addOutput(tv21);
-}
-
-static void MagicScheduler_BiasDropoutAddLayernormBwd2(
- benchmark::State& benchmark_state,
- DataType dtype) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto x = benchmark_state.range(0);
- auto y = benchmark_state.range(1);
- auto z = benchmark_state.range(2);
-
- setupBiasDropoutAddLayernormBwd2(&fusion, dtype);
-
- auto tvs = ir_utils::allTvs(&fusion);
-
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
- at::Tensor t4 = at::randn({x, y, 1}, options);
- at::Tensor t5 = at::randn({z}, options);
- at::Tensor t1 = at::randn({x, y, z}, options);
- at::Tensor t8 = at::randn({x, y, z}, options);
-
- std::vector<c10::IValue> at_inputs = {t4, t5, t1, t8};
- std::vector<at::Tensor> cg_outputs;
-
- auto norm_params = getNormalizationHeuristics(&fusion, at_inputs);
- TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
- scheduleNormalization(&fusion, norm_params.value());
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.setMeasureKernelTimeFlag(true);
- // Sync everything up before we start
-
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- CudaKernelTimer timer;
- cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
- benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
- }
- // Sync everything up before we're finished, don't want to run ahead on the
- // cpu while benchmarking.
- cudaDeviceSynchronize();
-
- int64_t bytes = 0;
- for (auto inp : at_inputs) {
- auto tensor = inp.toTensor();
- bytes += tensor.numel() *
- (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
- }
-
- for (auto tensor : cg_outputs) {
- bytes += tensor.numel() *
- (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
- }
-
- benchmark_state.SetBytesProcessed(
- bytes * int64_t(benchmark_state.iterations()));
-}
-
-static void setupBiasDropoutAddLayernormBwd3(Fusion* fusion, DataType dtype) {
- FusionGuard fg(fusion);
-
- bool is_fp16 = dtype == DataType::Half;
-
- TensorView* tv0 = makeContigTensor(3, dtype);
- TensorView* tv21 = makeContigTensor(3, dtype);
-
- fusion->addInput(tv0);
- fusion->addInput(tv21);
-
- if (is_fp16) {
- tv0 = castOp(DataType::Float, tv0);
- tv21 = castOp(DataType::Float, tv21);
- }
-
- // Uncertain this is the right value, but going for it anyways
- auto d34 = div(new Double(1.0), tv0->axis(2)->extent());
-
- auto tv25 = mul(tv21, tv0);
- auto tv26 = mul(tv25, d34);
- auto tv27 = sum(tv26, {0, 1});
-
- if (is_fp16) {
- tv26 = castOp(DataType::Half, tv27);
- tv27 = castOp(DataType::Half, tv27);
- }
-
- fusion->addOutput(tv26);
- fusion->addOutput(tv27);
-}
-
-static void MagicScheduler_BiasDropoutAddLayernormBwd3(
- benchmark::State& benchmark_state,
- DataType dtype) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto x = benchmark_state.range(0);
- auto y = benchmark_state.range(1);
- auto z = benchmark_state.range(2);
-
- setupBiasDropoutAddLayernormBwd3(&fusion, dtype);
-
- auto tvs = ir_utils::allTvs(&fusion);
-
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({x, y, z}, options);
- at::Tensor t21 = at::randn({x, y, z}, options);
-
- std::vector<c10::IValue> at_inputs = {t0, t21};
- std::vector<at::Tensor> cg_outputs;
-
- auto norm_params = getNormalizationHeuristics(&fusion, at_inputs);
- TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
- scheduleNormalization(&fusion, norm_params.value());
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.setMeasureKernelTimeFlag(true);
- // Sync everything up before we start
-
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- CudaKernelTimer timer;
- cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
- benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
- }
- // Sync everything up before we're finished, don't want to run ahead on the
- // cpu while benchmarking.
- cudaDeviceSynchronize();
-
- int64_t bytes = 0;
- for (auto inp : at_inputs) {
- auto tensor = inp.toTensor();
- bytes += tensor.numel() *
- (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
- }
-
- for (auto tensor : cg_outputs) {
- bytes += tensor.numel() *
- (int64_t)dataTypeSize(aten_to_data_type(tensor.scalar_type()));
- }
-
- benchmark_state.SetBytesProcessed(
- bytes * int64_t(benchmark_state.iterations()));
-}
-
-//------------------------------------------------------------------------------
-
-static void DivMaxSoftDropFwd_fp32(benchmark::State& benchmark_state) {
- MagicScheduler_DivMaxSoftDropFwd(benchmark_state, DataType::Float);
-}
-
-static void DivMaxSoftDropBwd_fp32(benchmark::State& benchmark_state) {
- MagicScheduler_DivMaxSoftDropBwd(benchmark_state, DataType::Float);
-}
-
-static void DivMaxSoftDropFwd_fp16(benchmark::State& benchmark_state) {
- MagicScheduler_DivMaxSoftDropFwd(benchmark_state, DataType::Half);
-}
-
-static void DivMaxSoftDropBwd_fp16(benchmark::State& benchmark_state) {
- MagicScheduler_DivMaxSoftDropBwd(benchmark_state, DataType::Half);
-}
-
-static void BiasDropoutAddLayernormBwd1_fp32(
- benchmark::State& benchmark_state) {
- MagicScheduler_BiasDropoutAddLayernormBwd1(benchmark_state, DataType::Float);
-}
-
-// Use full ampere wave here
-static void BiasDropoutAddLayernormBwd1_tf32(
- benchmark::State& benchmark_state) {
- MagicScheduler_BiasDropoutAddLayernormBwd1(benchmark_state, DataType::Float);
-}
-
-static void BiasDropoutAddLayernormBwd2_fp32(
- benchmark::State& benchmark_state) {
- MagicScheduler_BiasDropoutAddLayernormBwd2(benchmark_state, DataType::Float);
-}
-
-static void BiasDropoutAddLayernormBwd3_fp32(
- benchmark::State& benchmark_state) {
- MagicScheduler_BiasDropoutAddLayernormBwd3(benchmark_state, DataType::Float);
-}
-
-//------------------------------------------------------------------------------
-
-BENCHMARK(DivMaxSoftDropFwd_fp32)
- ->RangeMultiplier(8)
- ->Ranges({{8, 8}, {16, 16}, {128, 128}, {128, 128}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(DivMaxSoftDropBwd_fp32)
- ->RangeMultiplier(8)
- ->Ranges({{8, 8}, {16, 16}, {128, 128}, {128, 128}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(DivMaxSoftDropFwd_fp16)
- ->RangeMultiplier(8)
- ->Ranges({{8, 8}, {16, 16}, {128, 128}, {128, 128}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(DivMaxSoftDropBwd_fp16)
- ->RangeMultiplier(8)
- ->Ranges({{8, 8}, {16, 16}, {128, 128}, {128, 128}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(BiasDropoutAddLayernormBwd1_fp32)
- ->RangeMultiplier(2)
- ->Ranges({{32, 1024}, {128, 128}, {1024, 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-// Use full ampere wave here
-BENCHMARK(BiasDropoutAddLayernormBwd1_tf32)
- ->RangeMultiplier(2)
- ->Ranges({{32, 1024}, {128, 128}, {864, 864}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(BiasDropoutAddLayernormBwd2_fp32)
- ->Ranges({{32, 1024}, {128, 128}, {1024, 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(BiasDropoutAddLayernormBwd3_fp32)
- ->Ranges({{32, 1024}, {128, 128}, {1024, 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include <sstream>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-// Return broadcast tensor view and output of broadcast
-static void setupBroadcast(Fusion* fusion, DataType dtype, int bcast_axis) {
- FusionGuard fg(fusion);
-
- bool is_fp16 = dtype == DataType::Half;
-
- TensorView* tv0 = makeContigTensor(2, dtype);
- TensorView* tv1 = makeContigTensor(1, dtype);
-
- fusion->addInput(tv0);
- fusion->addInput(tv1);
-
- std::vector<bool> bcast_pattern(2, false);
- bcast_pattern[bcast_axis] = true;
-
- if (is_fp16) {
- tv0 = castOp(DataType::Float, tv0);
- tv1 = castOp(DataType::Float, tv1);
- }
-
- TensorView* tv2 = broadcast(tv1, bcast_pattern);
- TensorView* tv3 = add(tv0, tv2);
-
- if (is_fp16) {
- tv3 = castOp(DataType::Half, tv3);
- }
-
- fusion->addOutput(tv3);
-}
-
-static void NvFuserScheduler_Broadcast(
- benchmark::State& benchmark_state,
- FusionExecutorCache* fusion_executor_cache,
- DataType dtype,
- int bcast_dim) {
- auto bcast_size = benchmark_state.range(0);
- auto iter_size = benchmark_state.range(1);
-
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
-
- at::Tensor t0 =
- (bcast_dim ? at::randn({iter_size, bcast_size}, options)
- : at::randn({bcast_size, iter_size}, options));
-
- at::Tensor t1 = at::randn({iter_size}, options);
-
- fusion_executor_cache->profile(true);
- fusion_executor_cache->runFusionWithInputs({t0, t1});
-
- auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
- auto executor_instance = compile_log.fusion_executor;
- TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
- TORCH_INTERNAL_ASSERT(compile_log.launch_constraints.has_value());
- auto params = toString(compile_log.pointwise_params.value());
- auto lparams = toString(compile_log.launch_constraints.value());
-
- benchmark_state.SetLabel(params + lparams);
-
- fusion_executor_cache->profile(false);
- executor_instance->setMeasureKernelTimeFlag(true);
- // Sync everything up before we start
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- auto cg_outputs = fusion_executor_cache->runFusionWithInputs({t0, t1});
- benchmark_state.SetIterationTime(
- executor_instance->kernelTimeMs() / 1000.0);
- clearL2Cache();
- }
- // Sync everything up before we're finished, don't want to run ahead on the
- // cpu while benchmarking.
- cudaDeviceSynchronize();
-
- benchmark_state.SetBytesProcessed(
- int64_t(benchmark_state.iterations()) *
- (iter_size * bcast_size * 2 + iter_size) * int64_t(dataTypeSize(dtype)));
-}
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Broadcast_Outer_fp32,
- setupBroadcast,
- NvFuserScheduler_Broadcast,
- DataType::Float,
- 0);
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Broadcast_Outer_fp16,
- setupBroadcast,
- NvFuserScheduler_Broadcast,
- DataType::Half,
- 0);
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Broadcast_Inner_fp32,
- setupBroadcast,
- NvFuserScheduler_Broadcast,
- DataType::Float,
- 1);
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Broadcast_Inner_fp16,
- setupBroadcast,
- NvFuserScheduler_Broadcast,
- DataType::Half,
- 1);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp32)
- ->RangeMultiplier(8)
- ->Ranges({{1, 1024 * 1024}, {160, 320}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp32)
- ->RangeMultiplier(8)
- ->Ranges({{32768, 64 * 1024 * 1024}, {2, 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp32)
- ->RangeMultiplier(8)
- ->Ranges({{2, 16}, {32768, 64 * 1024 * 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp16)
- ->RangeMultiplier(8)
- ->Ranges({{1, 1024 * 1024}, {160, 320}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp16)
- ->RangeMultiplier(8)
- ->Ranges({{32768, 64 * 1024 * 1024}, {2, 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp16)
- ->RangeMultiplier(8)
- ->Ranges({{2, 16}, {32768, 64 * 1024 * 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Outer_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp32)
- ->RangeMultiplier(8)
- ->Ranges({{1, 1024 * 1024}, {160, 320}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp32)
- ->RangeMultiplier(8)
- ->Ranges({{32768, 64 * 1024 * 1024}, {2, 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp32)
- ->RangeMultiplier(8)
- ->Ranges({{2, 16}, {32768, 64 * 1024 * 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp16)
- ->RangeMultiplier(8)
- ->Ranges({{1, 1024 * 1024}, {160, 320}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp16)
- ->RangeMultiplier(8)
- ->Ranges({{32768, 64 * 1024 * 1024}, {2, 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp16)
- ->RangeMultiplier(8)
- ->Ranges({{2, 16}, {32768, 64 * 1024 * 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Broadcast_Inner_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
+++ /dev/null
-
-// Based on NVFuserTest.FusionBiasGeluBwd_CUDA
-
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-static void setupFusion(Fusion* fusion) {
- FusionGuard fg(fusion);
-
- const float k_079 = 0.79788456;
- const float k_004 = 0.044715;
- const float k_010 = 0.1070322243;
-
- // gradient tensor
- auto t0 = makeContigTensor(3, DataType::Half);
- fusion->addInput(t0);
-
- auto t1 = castOp(DataType::Float, t0);
-
- // bias tensor
- auto t2 = makeContigTensor(1, DataType::Half);
- fusion->addInput(t2);
-
- auto t3 = castOp(DataType::Float, t2);
-
- // input tensor
- auto t4 = makeContigTensor(3, DataType::Half);
- fusion->addInput(t4);
-
- auto t5 = castOp(DataType::Float, t4);
- auto t6 = broadcast(t3, {true, true, false});
- auto t7 = add(t6, t5);
- auto t8 = mul(t7, new Double(k_079));
- auto t9 = mul(t7, new Double(k_004));
- auto t10 = mul(t9, t7);
- auto t11 = add(t10, new Int(1));
- auto t12 = mul(t8, t11);
- auto t13 = unaryOp(UnaryOpType::Tanh, t12);
- auto t14 = mul(t7, new Double(0.5));
- auto t15 = mul(t13, t13);
- auto t16 = unaryOp(UnaryOpType::Neg, t15);
- auto t17 = add(t16, new Int(1));
- auto t18 = mul(t7, new Double(k_010));
- auto t19 = mul(t18, t7);
- auto t20 = add(t19, new Double(k_079));
- auto t21 = mul(t17, t20);
- auto t22 = mul(t14, t21);
- auto t23 = add(t13, new Int(1));
- auto t24 = mul(t23, new Double(0.5));
- auto t25 = add(t22, t24);
- auto t26 = mul(t25, t1);
-
- // Save float output for validation
- fusion->addOutput(t26);
- auto t27 = castOp(DataType::Half, t26);
- fusion->addOutput(t27);
-}
-
-static std::vector<c10::IValue> setupInputs() {
- at::manual_seed(0);
-
- auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
- std::vector<int64_t> input_shape{6, 512, 4096};
- std::vector<int64_t> bias_shape{4096};
- auto at_input = at::randn(input_shape, options);
- auto at_bias = at::randn(bias_shape, options);
- auto at_grad = at::randn(input_shape, options);
-
- return {at_grad, at_bias, at_input};
-}
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_SetupFusion(benchmark::State& benchmark_state) {
- for (auto _ : benchmark_state) {
- Fusion fusion;
- setupFusion(&fusion);
- }
-}
-
-BENCHMARK(GeluBackward_SetupFusion)->Unit(benchmark::kMicrosecond);
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_AutoSchedule(benchmark::State& benchmark_state) {
- for (auto _ : benchmark_state) {
- // Setup (not included in the measurement)
- benchmark_state.PauseTiming();
- Fusion fusion;
- setupFusion(&fusion);
- std::vector<c10::IValue> inputs = setupInputs();
- benchmark_state.ResumeTiming();
-
- // Auto-schedule
- schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
- }
-}
-
-BENCHMARK(GeluBackward_AutoSchedule)->Unit(benchmark::kMicrosecond);
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_Lower(benchmark::State& benchmark_state) {
- constexpr int kHiddenFeatures = 512;
- constexpr int kBatchSize = 64;
-
- Fusion fusion;
-
- // setup fusion
- setupFusion(&fusion);
-
- // inputs
- std::vector<c10::IValue> inputs = setupInputs();
-
- schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
- for (auto _ : benchmark_state) {
- GpuLower gpu_lower(&fusion);
- }
-}
-
-BENCHMARK(GeluBackward_Lower)->Unit(benchmark::kMillisecond);
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_Compile(benchmark::State& benchmark_state) {
- Fusion fusion;
-
- // setup fusion
- setupFusion(&fusion);
-
- // inputs
- std::vector<c10::IValue> inputs = setupInputs();
-
- schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
- for (auto _ : benchmark_state) {
- FusionExecutor executor;
- executor.compileFusion(&fusion);
- }
-}
-
-BENCHMARK(GeluBackward_Compile)->Unit(benchmark::kMillisecond);
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_RunFusion(benchmark::State& benchmark_state) {
- Fusion fusion;
-
- // setup fusion
- setupFusion(&fusion);
-
- // inputs
- std::vector<c10::IValue> inputs = setupInputs();
-
- // outputs
- std::vector<at::Tensor> outputs;
-
- auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
- FusionExecutor executor;
- executor.compileFusion(&fusion);
-
- cudaDeviceSynchronize();
-
- for (auto _ : benchmark_state) {
- outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
- cudaDeviceSynchronize();
- clearL2Cache();
- }
-}
-
-BENCHMARK(GeluBackward_RunFusion)->Unit(benchmark::kMicrosecond);
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_RunFusion_GpuOnly(benchmark::State& benchmark_state) {
- Fusion fusion;
-
- // setup fusion
- setupFusion(&fusion);
-
- // inputs
- std::vector<c10::IValue> inputs = setupInputs();
-
- // outputs
- std::vector<at::Tensor> outputs;
-
- auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
- FusionExecutor executor;
- executor.setMeasureKernelTimeFlag(true);
- executor.compileFusion(&fusion);
-
- cudaDeviceSynchronize();
-
- for (auto _ : benchmark_state) {
- outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
- benchmark_state.SetIterationTime(executor.kernelTimeMs() / 1000.0);
- clearL2Cache();
- }
-}
-
-BENCHMARK(GeluBackward_RunFusion_GpuOnly)
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-static void GeluBackward_RunFusion_CpuOnly(benchmark::State& benchmark_state) {
- Fusion fusion;
-
- // setup fusion
- setupFusion(&fusion);
-
- // inputs
- std::vector<c10::IValue> inputs = setupInputs();
-
- // outputs
- std::vector<at::Tensor> outputs;
-
- auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
- FusionExecutor executor;
- executor.setExecuteKernelFlag(false);
- executor.compileFusion(&fusion);
-
- for (auto _ : benchmark_state) {
- outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
- }
-}
-
-BENCHMARK(GeluBackward_RunFusion_CpuOnly)->Unit(benchmark::kMicrosecond);
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
-// but unknown sizes
-TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
- return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
-}
-
-// Make a non-contiguous tensor of compile-time known sizes
-TensorView* makeConcreteTensor(
- std::vector<int64_t> shape,
- DataType dtype = DataType::Float) {
- return TensorViewBuilder().shape(shape).dtype(dtype).build();
-}
-
-static auto getLayerBackwardNormRuntime(
- std::unique_ptr<Fusion> fusion_ptr,
- std::unique_ptr<FusionExecutorCache>& fec,
- std::vector<at::IValue>& aten_inputs,
- std::vector<int64_t>& shape,
- std::vector<int64_t>& norm_shape) {
- Fusion& fusion = *fusion_ptr.get();
-
- const size_t kM = shape.size();
- const size_t kN = norm_shape.size();
- const size_t kOuterNumDims = kM - kN;
-
- std::vector<int64_t> outer_shape;
- for (size_t idx = 0; idx < kOuterNumDims; ++idx) {
- outer_shape.push_back(shape[idx]);
- }
- for (size_t idx = kOuterNumDims; idx < kM; ++idx) {
- outer_shape.push_back(1);
- }
-
- auto grad_out = makeSymbolicTensor(shape.size());
- auto input = makeSymbolicTensor(shape.size());
- auto mean = makeConcreteTensor(outer_shape);
- auto rstd = makeConcreteTensor(outer_shape);
- auto weight = makeSymbolicTensor(norm_shape.size());
- auto bias = makeSymbolicTensor(norm_shape.size());
- fusion.addInput(grad_out);
- fusion.addInput(input);
- fusion.addInput(mean);
- fusion.addInput(rstd);
- fusion.addInput(weight);
- fusion.addInput(bias);
-
- auto grads = layer_norm_backward(
- grad_out,
- input,
- norm_shape,
- mean,
- rstd,
- weight,
- bias,
- {true, true, true});
-
- fusion.addOutput(grads.grad_input);
- fusion.addOutput(grads.grad_weight);
- fusion.addOutput(grads.grad_bias);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_grad_out = at::randn(shape, options);
- at::Tensor aten_input = at::randn(shape, options);
- at::Tensor aten_weight = at::randn(norm_shape, options);
- at::Tensor aten_bias = at::randn(norm_shape, options);
- auto at_weight = c10::optional<at::Tensor>(aten_weight);
- auto at_bias = c10::optional<at::Tensor>(aten_bias);
-
- const float kEps = 1e-5;
- auto aten_results =
- at::native_layer_norm(aten_input, norm_shape, at_weight, at_bias, kEps);
- auto aten_output = std::get<0>(aten_results);
- auto aten_mean = std::get<1>(aten_results);
- auto aten_rstd = std::get<2>(aten_results);
-
- fec = std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
- aten_inputs = {
- aten_grad_out, aten_input, aten_mean, aten_rstd, aten_weight, aten_bias};
- auto cg_outputs = fec->runFusionWithInputs(aten_inputs);
-
- return fec->getMostRecentKernelRuntime();
-}
-
-static void LayerNormBackward_HeuristicLookup(
- benchmark::State& benchmark_state) {
- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
- FusionGuard fg(fusion_ptr.get());
-
- // PreAllocate
- std::unique_ptr<FusionExecutorCache> fec;
- std::vector<at::IValue> aten_inputs;
-
- std::vector<int64_t> shape{20, 100, 35, 67};
- std::vector<int64_t> norm_shape{67};
-
- auto runtime = getLayerBackwardNormRuntime(
- std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);
- TORCH_INTERNAL_ASSERT(
- runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
-
- for (auto _ : benchmark_state) {
- // Setup (not included in the measurement)
- runtime->getMaybeHeuristicsFor(aten_inputs);
- }
-}
-
-static auto getLayerForwardNormRuntime(
- std::unique_ptr<Fusion> fusion_ptr,
- std::unique_ptr<FusionExecutorCache>& fec,
- std::vector<at::IValue>& aten_inputs,
- std::vector<int64_t>& shape,
- std::vector<int64_t>& norm_shape) {
- Fusion& fusion = *fusion_ptr.get();
-
- const float kEps = 1e-5;
- Double* eps_ptr = new Double(kEps);
-
- auto input = makeSymbolicTensor(shape.size());
- fusion.addInput(input);
-
- auto result = layer_norm(input, norm_shape, nullptr, nullptr, eps_ptr);
-
- fusion.addOutput(result.output);
- fusion.addOutput(result.mean);
- fusion.addOutput(result.invstd);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn(shape, options);
-
- fec = std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
- aten_inputs = {aten_input};
- auto cg_outputs = fec->runFusionWithInputs(aten_inputs);
-
- return fec->getMostRecentKernelRuntime();
-}
-
-static void LayerNormForward_HeuristicLookup(
- benchmark::State& benchmark_state) {
- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
- FusionGuard fg(fusion_ptr.get());
-
- // PreAllocate
- std::unique_ptr<FusionExecutorCache> fec;
- std::vector<at::IValue> aten_inputs;
-
- std::vector<int64_t> shape{20, 100, 35, 67};
- std::vector<int64_t> norm_shape{67};
-
- auto runtime = getLayerForwardNormRuntime(
- std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);
- TORCH_INTERNAL_ASSERT(
- runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
-
- for (auto _ : benchmark_state) {
- // Setup (not included in the measurement)
- runtime->getMaybeHeuristicsFor(aten_inputs);
- }
-}
-
-BENCHMARK(LayerNormBackward_HeuristicLookup)->Unit(benchmark::kMicrosecond);
-BENCHMARK(LayerNormForward_HeuristicLookup)->Unit(benchmark::kMicrosecond);
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
-// but unknown sizes
-TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
- return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
-}
-
-// Make a non-contiguous tensor of compile-time known sizes
-TensorView* makeConcreteTensor(
- std::vector<int64_t> shape,
- DataType dtype = DataType::Float) {
- return TensorViewBuilder().shape(shape).dtype(dtype).build();
-}
-
-static auto getLayerBackwardNormRuntime(
- std::unique_ptr<Fusion> fusion_ptr,
- std::unique_ptr<FusionExecutorCache>& fec,
- std::vector<at::IValue>& aten_inputs,
- std::vector<int64_t>& shape,
- std::vector<int64_t>& norm_shape) {
- Fusion& fusion = *fusion_ptr.get();
-
- const size_t kM = shape.size();
- const size_t kN = norm_shape.size();
- const size_t kOuterNumDims = kM - kN;
-
- std::vector<int64_t> outer_shape;
- for (size_t idx = 0; idx < kOuterNumDims; ++idx) {
- outer_shape.push_back(shape[idx]);
- }
- for (size_t idx = kOuterNumDims; idx < kM; ++idx) {
- outer_shape.push_back(1);
- }
-
- auto grad_out = makeSymbolicTensor(shape.size());
- auto input = makeSymbolicTensor(shape.size());
- auto mean = makeConcreteTensor(outer_shape);
- auto rstd = makeConcreteTensor(outer_shape);
- auto weight = makeSymbolicTensor(norm_shape.size());
- auto bias = makeSymbolicTensor(norm_shape.size());
- fusion.addInput(grad_out);
- fusion.addInput(input);
- fusion.addInput(mean);
- fusion.addInput(rstd);
- fusion.addInput(weight);
- fusion.addInput(bias);
-
- auto grads = layer_norm_backward(
- grad_out,
- input,
- norm_shape,
- mean,
- rstd,
- weight,
- bias,
- {true, true, true});
-
- fusion.addOutput(grads.grad_input);
- fusion.addOutput(grads.grad_weight);
- fusion.addOutput(grads.grad_bias);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_grad_out = at::randn(shape, options);
- at::Tensor aten_input = at::randn(shape, options);
- at::Tensor aten_weight = at::randn(norm_shape, options);
- at::Tensor aten_bias = at::randn(norm_shape, options);
- auto at_weight = c10::optional<at::Tensor>(aten_weight);
- auto at_bias = c10::optional<at::Tensor>(aten_bias);
-
- const float kEps = 1e-5;
- auto aten_results =
- at::native_layer_norm(aten_input, norm_shape, at_weight, at_bias, kEps);
- auto aten_output = std::get<0>(aten_results);
- auto aten_mean = std::get<1>(aten_results);
- auto aten_rstd = std::get<2>(aten_results);
-
- fec = std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
- aten_inputs = {
- aten_grad_out, aten_input, aten_mean, aten_rstd, aten_weight, aten_bias};
- auto cg_outputs = fec->runFusionWithInputs(aten_inputs);
-
- return fec->getMostRecentKernelRuntime();
-}
-
-static void LayerNormBackward_HeuristicLookup(
- benchmark::State& benchmark_state) {
- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
- FusionGuard fg(fusion_ptr.get());
-
- // PreAllocate
- std::unique_ptr<FusionExecutorCache> fec;
- std::vector<at::IValue> aten_inputs;
-
- std::vector<int64_t> shape{20, 100, 35, 67};
- std::vector<int64_t> norm_shape{67};
-
- auto runtime = getLayerBackwardNormRuntime(
- std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);
- TORCH_INTERNAL_ASSERT(
- runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
-
- for (auto _ : benchmark_state) {
- // Setup (not included in the measurement)
- runtime->getMaybeHeuristicsFor(aten_inputs);
- }
-}
-
-static auto getLayerForwardNormRuntime(
- std::unique_ptr<Fusion> fusion_ptr,
- std::unique_ptr<FusionExecutorCache>& fec,
- std::vector<at::IValue>& aten_inputs,
- std::vector<int64_t>& shape,
- std::vector<int64_t>& norm_shape) {
- Fusion& fusion = *fusion_ptr.get();
-
- const float kEps = 1e-5;
- Double* eps_ptr = new Double(kEps);
-
- auto input = makeSymbolicTensor(shape.size());
- fusion.addInput(input);
-
- auto result = layer_norm(input, norm_shape, nullptr, nullptr, eps_ptr);
-
- fusion.addOutput(result.output);
- fusion.addOutput(result.mean);
- fusion.addOutput(result.invstd);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn(shape, options);
-
- fec = std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
- aten_inputs = {aten_input};
- auto cg_outputs = fec->runFusionWithInputs(aten_inputs);
-
- return fec->getMostRecentKernelRuntime();
-}
-
-static void LayerNormForward_HeuristicLookup(
- benchmark::State& benchmark_state) {
- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
- FusionGuard fg(fusion_ptr.get());
-
- // PreAllocate
- std::unique_ptr<FusionExecutorCache> fec;
- std::vector<at::IValue> aten_inputs;
-
- std::vector<int64_t> shape{20, 100, 35, 67};
- std::vector<int64_t> norm_shape{67};
-
- auto runtime = getLayerForwardNormRuntime(
- std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);
- TORCH_INTERNAL_ASSERT(
- runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
-
- for (auto _ : benchmark_state) {
- // Setup (not included in the measurement)
- runtime->getMaybeHeuristicsFor(aten_inputs);
- }
-}
-
-BENCHMARK(LayerNormBackward_HeuristicLookup)->Unit(benchmark::kMicrosecond);
-BENCHMARK(LayerNormForward_HeuristicLookup)->Unit(benchmark::kMicrosecond);
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-static void setupInstanceNorm(Fusion* fusion, DataType dtype) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- FusionGuard fg(fusion);
-
- auto input = makeContigTensor(4, dtype);
- auto weight = makeContigTensor(1, dtype);
- auto bias = makeContigTensor(1, dtype);
- auto running_mean = makeContigTensor(1, DataType::Float);
- auto running_var = makeContigTensor(1, DataType::Float);
-
- fusion->addInput(input);
- fusion->addInput(weight);
- fusion->addInput(bias);
- fusion->addInput(running_mean);
- fusion->addInput(running_var);
-
- if (dtype == DataType::Half) {
- input = castOp(DataType::Float, input);
- weight = castOp(DataType::Float, weight);
- bias = castOp(DataType::Float, bias);
- }
-
- const bool kTraining = true;
- const float kMomentum = 0.1;
- const float kEps = 1e-5;
- auto momentum_ptr = new Double(kMomentum);
- auto eps_ptr = new Double(kEps);
-
- auto norm = instance_norm(
- input,
- weight,
- bias,
- running_mean,
- running_var,
- kTraining,
- momentum_ptr,
- eps_ptr);
-
- auto output = unaryOp(UnaryOpType::Relu, norm.output);
-
- if (dtype == DataType::Half) {
- output = castOp(DataType::Half, output);
- }
-
- fusion->addOutput(output);
-}
-
-//------------------------------------------------------------------------------
-
-static void NvFuserScheduler_InstanceNorm(
- benchmark::State& benchmark_state,
- FusionExecutorCache* fusion_executor_cache,
- DataType dtype) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- std::vector<int64_t> input_shape{
- benchmark_state.range(0),
- benchmark_state.range(2),
- benchmark_state.range(1),
- benchmark_state.range(1)};
-
- // inputs
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- auto fp32_options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor at_x = at::randn(input_shape, options);
- at::Tensor at_weight = at::ones({input_shape[1]}, options);
- at::Tensor at_bias = at::zeros({input_shape[1]}, options);
- at::Tensor at_mean = at::zeros({input_shape[1]}, fp32_options);
- at::Tensor at_var = at::ones({input_shape[1]}, fp32_options);
-
- std::vector<c10::IValue> aten_inputs = {
- at_x, at_weight, at_bias, at_mean, at_var};
- std::vector<at::Tensor> outputs;
-
- runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
-
- const size_t kSize =
- input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
- const size_t kChannels = input_shape[1];
-
- // Read: x, weight, bias, running_mean, running_var
- // Write: y, running_mean, running_var
- benchmark_state.SetBytesProcessed(
- benchmark_state.iterations() *
- ((kChannels * 2 + kSize * 2) * dataTypeSize(dtype) +
- (kChannels * 2 * 2) * dataTypeSize(DataType::Float)));
-}
-
-static void Baseline_InstanceNorm(
- benchmark::State& benchmark_state,
- DataType dtype) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- std::vector<int64_t> input_shape{
- benchmark_state.range(0),
- benchmark_state.range(2),
- benchmark_state.range(1),
- benchmark_state.range(1)};
- const float kMomentum = 0.1;
- const float kEps = 1e-5;
- const auto aten_dtype = data_type_to_aten(dtype);
-
- at::manual_seed(0);
- auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
- auto fp32_options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor at_x = at::randn(input_shape, options);
- at::Tensor at_weight = at::ones({input_shape[1]}, options);
- at::Tensor at_bias = at::zeros({input_shape[1]}, options);
- at::Tensor at_mean = at::zeros({input_shape[1]}, fp32_options);
- at::Tensor at_var = at::ones({input_shape[1]}, fp32_options);
-
- auto ato_weight = c10::optional<at::Tensor>(at_weight);
- auto ato_bias = c10::optional<at::Tensor>(at_bias);
- auto ato_running_mean = c10::optional<at::Tensor>(at_mean);
- auto ato_running_var = c10::optional<at::Tensor>(at_var);
-
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- CudaKernelTimer timer;
-
- auto norm = at::instance_norm(
- at_x,
- ato_weight,
- ato_bias,
- ato_running_mean,
- ato_running_var,
- true,
- kMomentum,
- kEps,
- false);
- auto output = at::relu(norm);
-
- benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
- cudaDeviceSynchronize();
- clearL2Cache();
- cudaDeviceSynchronize();
- }
-
- const size_t kSize =
- input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
- const size_t kChannels = input_shape[1];
-
- // Read: x, weight, bias, running_mean, running_var
- // Write: y, running_mean, running_var
- benchmark_state.SetBytesProcessed(
- benchmark_state.iterations() *
- ((kChannels * 2 + kSize * 2) * dataTypeSize(dtype) +
- (kChannels * 2 * 2) * dataTypeSize(DataType::Float)));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_InstanceNorm_fp32(benchmark::State& benchmark_state) {
- Baseline_InstanceNorm(benchmark_state, DataType::Float);
-}
-
-static void Baseline_InstanceNorm_fp16(benchmark::State& benchmark_state) {
- Baseline_InstanceNorm(benchmark_state, DataType::Half);
-}
-
-//------------------------------------------------------------------------------
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_fp32_InstanceNorm,
- setupInstanceNorm,
- NvFuserScheduler_InstanceNorm,
- DataType::Float);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_fp32_InstanceNorm)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8}, {640, 640}, {64, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_fp16_InstanceNorm,
- setupInstanceNorm,
- NvFuserScheduler_InstanceNorm,
- DataType::Half);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_fp16_InstanceNorm)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8}, {640, 640}, {64, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-//------------------------------------------------------------------------------
-
-BENCHMARK(Baseline_InstanceNorm_fp32)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8}, {640, 640}, {64, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_InstanceNorm_fp16)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8}, {640, 640}, {64, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-//------------------------------------------------------------------------------
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-//------------------------------------------------------------------------------
-
-static void setupLayerNorm(Fusion* fusion, DataType dtype) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- FusionGuard fg(fusion);
-
- const int kReductionAxis = 1;
- const float kEps = 1e-5;
-
- Double* eps_ptr = new Double(kEps);
-
- // setup fusion
- auto input = makeContigTensor(2, dtype);
- auto weight = makeContigTensor(1, dtype);
- auto bias = makeContigTensor(1, dtype);
-
- fusion->addInput(input);
- fusion->addInput(weight);
- fusion->addInput(bias);
-
- if (dtype == DataType::Half) {
- input = castOp(DataType::Float, input);
- weight = castOp(DataType::Float, weight);
- bias = castOp(DataType::Float, bias);
- }
-
- auto layer_norm_results = layer_norm(input, 1, weight, bias, eps_ptr);
-
- auto output = layer_norm_results.output;
-
- if (dtype == DataType::Half) {
- output = castOp(DataType::Half, output);
- }
-
- fusion->addOutput(output);
-}
-
-static void NvFuserScheduler_LayerNorm(
- benchmark::State& benchmark_state,
- FusionExecutorCache* fusion_executor_cache,
- DataType dtype) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- std::vector<int64_t> input_shape{656, benchmark_state.range(0)};
- const float kEps = 1e-5;
-
- // inputs
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- at::Tensor input = at::randn(input_shape, options);
- at::Tensor weight = at::randn({input_shape[1]}, options);
- at::Tensor bias = at::randn({input_shape[1]}, options);
-
- std::vector<c10::IValue> aten_inputs({input, weight, bias});
-
- runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
-
- benchmark_state.SetBytesProcessed(
- int64_t(benchmark_state.iterations()) *
- (2 * input.numel() + weight.numel() + bias.numel()) *
- int64_t(dataTypeSize(dtype)));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_LayerNorm(
- benchmark::State& benchmark_state,
- DataType dtype) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- std::vector<int64_t> input_shape{656, benchmark_state.range(0)};
- const int kReductionAxis = 1;
- std::vector<int64_t> norm_shape;
- for (int idx = kReductionAxis; idx < input_shape.size(); ++idx) {
- norm_shape.push_back(input_shape[idx]);
- }
-
- // inputs
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- at::Tensor input = at::randn(input_shape, options);
- at::Tensor weight = at::randn({input_shape[1]}, options);
- at::Tensor bias = at::randn({input_shape[1]}, options);
-
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- CudaKernelTimer timer;
- auto output = at::layer_norm(input, norm_shape, weight, bias);
- benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
- cudaDeviceSynchronize();
- clearL2Cache();
- cudaDeviceSynchronize();
- }
-}
-
-static void Baseline_LayerNorm_fp32(benchmark::State& benchmark_state) {
- Baseline_LayerNorm(benchmark_state, DataType::Float);
-}
-
-static void Baseline_LayerNorm_fp16(benchmark::State& benchmark_state) {
- Baseline_LayerNorm(benchmark_state, DataType::Half);
-}
-
-//------------------------------------------------------------------------------
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_fp32_LayerNorm,
- setupLayerNorm,
- NvFuserScheduler_LayerNorm,
- DataType::Float);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_fp32_LayerNorm)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8 << 12}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_fp16_LayerNorm,
- setupLayerNorm,
- NvFuserScheduler_LayerNorm,
- DataType::Half);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_fp16_LayerNorm)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8 << 12}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-BENCHMARK(Baseline_LayerNorm_fp32)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8 << 12}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_LayerNorm_fp16)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8 << 12}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-// TODO: add LSTM function to composite operations
-// Function Signature: cy, hy = lstm(x, cx)
-static void setupFusion(Fusion* fusion) {
- FusionGuard fg(fusion);
-
- TensorView* tvs[16];
- for (size_t i = 0; i < 16; i++) {
- tvs[i] = makeContigTensor(2, DataType::Float);
- fusion->addInput(tvs[i]);
- }
-
- const auto cx = makeContigTensor(2, DataType::Float);
- fusion->addInput(cx);
-
- const auto in_x = add(add(add(tvs[0], tvs[1]), tvs[2]), tvs[3]);
- const auto forget_x = add(add(add(tvs[4], tvs[5]), tvs[6]), tvs[7]);
- const auto cell_x = add(add(add(tvs[8], tvs[9]), tvs[10]), tvs[11]);
- const auto out_x = add(add(add(tvs[12], tvs[13]), tvs[14]), tvs[15]);
- auto lstm_result = lstm(cx, in_x, forget_x, cell_x, out_x);
-
- fusion->addOutput(lstm_result.cell);
- fusion->addOutput(lstm_result.hidden);
-}
-
-static std::vector<c10::IValue> setupInputs(
- int hidden_features,
- int batch_size) {
- at::manual_seed(0);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const at::Tensor large_tensor0 =
- at::randn({batch_size, hidden_features * 4}, options);
- const at::Tensor large_tensor1 =
- at::randn({batch_size, hidden_features * 4}, options);
- const at::Tensor large_tensor2 =
- at::randn({batch_size, hidden_features * 4}, options);
- const at::Tensor large_tensor3 =
- at::randn({batch_size, hidden_features * 4}, options);
-
- const auto chunked0 = large_tensor0.chunk(4, 1);
- const auto chunked1 = large_tensor1.chunk(4, 1);
- const auto chunked2 = large_tensor2.chunk(4, 1);
- const auto chunked3 = large_tensor3.chunk(4, 1);
-
- std::vector<c10::IValue> inputs;
- inputs.insert(inputs.end(), chunked0.begin(), chunked0.end());
- inputs.insert(inputs.end(), chunked1.begin(), chunked1.end());
- inputs.insert(inputs.end(), chunked2.begin(), chunked2.end());
- inputs.insert(inputs.end(), chunked3.begin(), chunked3.end());
-
- const auto at_cx = at::randn({batch_size, hidden_features}, options);
- inputs.push_back(at_cx);
-
- return inputs;
-}
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_SetupFusion(benchmark::State& benchmark_state) {
- for (auto _ : benchmark_state) {
- Fusion fusion;
- setupFusion(&fusion);
- }
-}
-
-BENCHMARK(LstmCell_SetupFusion)->Unit(benchmark::kMicrosecond);
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_AutoSchedule(benchmark::State& benchmark_state) {
- constexpr int kHiddenFeatures = 512;
- constexpr int kBatchSize = 64;
-
- for (auto _ : benchmark_state) {
- // Setup (not included in the measurement)
- benchmark_state.PauseTiming();
- Fusion fusion;
- setupFusion(&fusion);
- std::vector<c10::IValue> inputs = setupInputs(kHiddenFeatures, kBatchSize);
- benchmark_state.ResumeTiming();
-
- // Auto-schedule
- schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
- }
-}
-
-BENCHMARK(LstmCell_AutoSchedule)->Unit(benchmark::kMicrosecond);
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_Lower(benchmark::State& benchmark_state) {
- constexpr int kHiddenFeatures = 512;
- constexpr int kBatchSize = 64;
-
- Fusion fusion;
-
- // setup fusion
- setupFusion(&fusion);
-
- // inputs
- std::vector<c10::IValue> inputs = setupInputs(kHiddenFeatures, kBatchSize);
-
- schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
- for (auto _ : benchmark_state) {
- GpuLower gpu_lower(&fusion);
- }
-}
-
-BENCHMARK(LstmCell_Lower)->Unit(benchmark::kMillisecond);
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_Compile(benchmark::State& benchmark_state) {
- constexpr int kHiddenFeatures = 512;
- constexpr int kBatchSize = 64;
-
- Fusion fusion;
-
- // setup fusion
- setupFusion(&fusion);
-
- // inputs
- std::vector<c10::IValue> inputs = setupInputs(kHiddenFeatures, kBatchSize);
-
- schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
- for (auto _ : benchmark_state) {
- FusionExecutor executor;
- executor.compileFusion(&fusion);
- }
-}
-
-BENCHMARK(LstmCell_Compile)->Unit(benchmark::kMillisecond);
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_RunFusion(
- benchmark::State& benchmark_state,
- int hidden_features,
- int batch_size) {
- Fusion fusion;
-
- // setup fusion
- setupFusion(&fusion);
-
- // inputs
- std::vector<c10::IValue> inputs = setupInputs(hidden_features, batch_size);
-
- // outputs
- std::vector<at::Tensor> outputs;
-
- auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
- FusionExecutor executor;
- executor.compileFusion(&fusion);
-
- cudaDeviceSynchronize();
-
- for (auto _ : benchmark_state) {
- outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
- cudaDeviceSynchronize();
- }
-}
-
-BENCHMARK_CAPTURE(LstmCell_RunFusion, Small, 512, 64)
- ->Unit(benchmark::kMicrosecond);
-
-BENCHMARK_CAPTURE(LstmCell_RunFusion, Medium, 1024, 128)
- ->Unit(benchmark::kMicrosecond);
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_RunFusion_GpuOnly(
- benchmark::State& benchmark_state,
- int hidden_features,
- int batch_size) {
- Fusion fusion;
-
- // setup fusion
- setupFusion(&fusion);
-
- // inputs
- std::vector<c10::IValue> inputs = setupInputs(hidden_features, batch_size);
-
- // outputs
- std::vector<at::Tensor> outputs;
-
- auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
- FusionExecutor executor;
- executor.setMeasureKernelTimeFlag(true);
- executor.compileFusion(&fusion);
-
- cudaDeviceSynchronize();
-
- for (auto _ : benchmark_state) {
- outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
- benchmark_state.SetIterationTime(executor.kernelTimeMs() / 1000.0);
- cudaDeviceSynchronize();
- clearL2Cache();
- cudaDeviceSynchronize();
- }
-}
-
-BENCHMARK_CAPTURE(LstmCell_RunFusion_GpuOnly, Small, 512, 64)
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK_CAPTURE(LstmCell_RunFusion_GpuOnly, Medium, 1024, 128)
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-static void LstmCell_RunFusion_CpuOnly(
- benchmark::State& benchmark_state,
- int hidden_features,
- int batch_size) {
- Fusion fusion;
-
- // setup fusion
- setupFusion(&fusion);
-
- // inputs
- std::vector<c10::IValue> inputs = setupInputs(hidden_features, batch_size);
-
- // outputs
- std::vector<at::Tensor> outputs;
-
- auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
-
- FusionExecutor executor;
- executor.setExecuteKernelFlag(false);
- executor.compileFusion(&fusion);
-
- for (auto _ : benchmark_state) {
- outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
- }
-}
-
-BENCHMARK_CAPTURE(LstmCell_RunFusion_CpuOnly, Small, 512, 64)
- ->Unit(benchmark::kMicrosecond);
-
-BENCHMARK_CAPTURE(LstmCell_RunFusion_CpuOnly, Medium, 1024, 128)
- ->Unit(benchmark::kMicrosecond);
+++ /dev/null
-#include <benchmark/benchmark.h>
-
-BENCHMARK_MAIN();
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include <sstream>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-// Return reduction tensor view and output of reduction
-static void setupReduction(Fusion* fusion, DataType dtype, int red_axis) {
- FusionGuard fg(fusion);
-
- bool is_fp16 = dtype == DataType::Half;
-
- TensorView* tv0 = makeContigTensor(2, dtype);
- fusion->addInput(tv0);
-
- TensorView* tv0_cast = tv0;
- if (is_fp16) {
- tv0_cast = castOp(DataType::Float, tv0);
- }
-
- TensorView* tv1 = sum(tv0_cast, {red_axis});
-
- TensorView* tv1_cast = tv1;
- if (is_fp16) {
- tv1_cast = castOp(DataType::Half, tv1);
- }
-
- fusion->addOutput(tv1_cast);
-
- TensorView* output_of_reduction = nullptr;
- if (is_fp16) {
- output_of_reduction = tv1_cast;
- }
-}
-
-static void NvFuserScheduler_Reduction(
- benchmark::State& benchmark_state,
- FusionExecutorCache* fusion_executor_cache,
- DataType dtype,
- int reduction_dim) {
- auto reduction_size = benchmark_state.range(0);
- auto iter_size = benchmark_state.range(1);
-
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- at::Tensor aten_input =
- (reduction_dim ? at::randn({iter_size, reduction_size}, options)
- : at::randn({reduction_size, iter_size}, options));
-
- fusion_executor_cache->profile(true);
- fusion_executor_cache->runFusionWithInputs({aten_input});
-
- auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
- auto executor_instance = compile_log.fusion_executor;
- TORCH_INTERNAL_ASSERT(compile_log.reduction_params.has_value());
- TORCH_INTERNAL_ASSERT(compile_log.launch_constraints.has_value());
- auto rparams = toString(compile_log.reduction_params.value());
- auto lparams = toString(compile_log.launch_constraints.value());
-
- benchmark_state.SetLabel(rparams + lparams);
-
- fusion_executor_cache->profile(false);
- executor_instance->setMeasureKernelTimeFlag(true);
- // Sync everything up before we start
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- auto cg_outputs = fusion_executor_cache->runFusionWithInputs({aten_input});
- benchmark_state.SetIterationTime(
- executor_instance->kernelTimeMs() / 1000.0);
- clearL2Cache();
- }
- // Sync everything up before we're finished, don't want to run ahead on the
- // cpu while benchmarking.
- cudaDeviceSynchronize();
-
- benchmark_state.SetBytesProcessed(
- int64_t(benchmark_state.iterations()) *
- (iter_size * reduction_size + iter_size) * int64_t(dataTypeSize(dtype)));
-}
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Reduction_Outer_fp32,
- setupReduction,
- NvFuserScheduler_Reduction,
- DataType::Float,
- 0);
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Reduction_Outer_fp16,
- setupReduction,
- NvFuserScheduler_Reduction,
- DataType::Half,
- 0);
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Reduction_Inner_fp32,
- setupReduction,
- NvFuserScheduler_Reduction,
- DataType::Float,
- 1);
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Reduction_Inner_fp16,
- setupReduction,
- NvFuserScheduler_Reduction,
- DataType::Half,
- 1);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp32)
- ->RangeMultiplier(8)
- ->Ranges({{1, 1024 * 1024}, {160, 320}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{32768, 128 * 1024 * 1024}, {2, 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{2, 16}, {32768, 128 * 1024 * 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp32)
- ->RangeMultiplier(2)
- ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp16)
- ->RangeMultiplier(8)
- ->Ranges({{1, 1024 * 1024}, {160, 320}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{32768, 128 * 1024 * 1024}, {2, 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{2, 16}, {32768, 128 * 1024 * 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Outer_fp16)
- ->RangeMultiplier(2)
- ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp32)
- ->RangeMultiplier(8)
- ->Ranges({{1, 1024 * 1024}, {160, 320}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{32768, 128 * 1024 * 1024}, {2, 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp32)
- ->RangeMultiplier(4)
- ->Ranges({{2, 16}, {32768, 128 * 1024 * 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp32)
- ->RangeMultiplier(2)
- ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp16)
- ->RangeMultiplier(8)
- ->Ranges({{1, 1024 * 1024}, {160, 320}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{32768, 128 * 1024 * 1024}, {2, 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp16)
- ->RangeMultiplier(4)
- ->Ranges({{2, 16}, {32768, 128 * 1024 * 1024}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Reduction_Inner_fp16)
- ->RangeMultiplier(2)
- ->Ranges({{128, 1024 * 16}, {128, 1024 * 16}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-static void setupSBR(Fusion* fusion, DataType dtype) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- FusionGuard fg(fusion);
-
- const size_t kNumberOfDims = 4;
-
- std::vector<int64_t> bcast_shape(kNumberOfDims, 1);
- bcast_shape[bcast_shape.size() - 1] = -1;
-
- std::vector<bool> bcast_contig(kNumberOfDims, false);
- bcast_contig[bcast_contig.size() - 1] = true;
-
- auto x = makeContigTensor(kNumberOfDims, dtype);
-
- auto scale = TensorViewBuilder()
- .contiguity(bcast_contig)
- .shape(bcast_shape)
- .dtype(dtype)
- .build();
-
- auto bias = TensorViewBuilder()
- .contiguity(bcast_contig)
- .shape(bcast_shape)
- .dtype(dtype)
- .build();
-
- fusion->addInput(x);
- fusion->addInput(scale);
- fusion->addInput(bias);
-
- if (dtype == DataType::Half) {
- x = castOp(DataType::Float, x);
- scale = castOp(DataType::Float, scale);
- bias = castOp(DataType::Float, bias);
- }
-
- auto scale_bias = add(mul(x, scale), bias);
- auto scale_bias_relu = unaryOp(UnaryOpType::Relu, scale_bias);
-
- if (dtype == DataType::Half) {
- scale_bias_relu = castOp(DataType::Half, scale_bias_relu);
- }
- fusion->addOutput(scale_bias_relu);
-}
-
-static void setupSBRNorm(Fusion* fusion, DataType dtype) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
- FusionGuard fg(fusion);
-
- const size_t kNumberOfDims = 4;
-
- auto x = makeContigTensor(kNumberOfDims, dtype);
- auto weight = makeContigTensor(1, dtype);
- auto bias = makeContigTensor(1, dtype);
- auto mean = makeContigTensor(1, dtype);
- auto var = makeContigTensor(1, dtype);
-
- fusion->addInput(x);
- fusion->addInput(weight);
- fusion->addInput(bias);
- fusion->addInput(mean);
- fusion->addInput(var);
-
- std::vector<bool> broadcast_mask(kNumberOfDims, true);
- broadcast_mask[broadcast_mask.size() - 1] = false;
-
- if (dtype == DataType::Half) {
- x = castOp(DataType::Float, x);
- weight = castOp(DataType::Float, weight);
- bias = castOp(DataType::Float, bias);
- mean = castOp(DataType::Float, mean);
- var = castOp(DataType::Float, var);
- }
-
- auto rsqrt = unaryOp(UnaryOpType::Rsqrt, var);
- auto this_scale = mul(weight, rsqrt);
- auto this_bias = mul(sub(bias, mean), this_scale);
-
- auto bcast_scale = broadcast(this_scale, broadcast_mask);
- auto bcast_bias = broadcast(this_bias, broadcast_mask);
-
- auto scale_bias = add(mul(x, bcast_scale), bcast_bias);
- auto scale_bias_relu = unaryOp(UnaryOpType::Relu, scale_bias);
-
- if (dtype == DataType::Half) {
- scale_bias_relu = castOp(DataType::Half, scale_bias_relu);
- }
-
- fusion->addOutput(scale_bias_relu);
-}
-
-//------------------------------------------------------------------------------
-
-static void NvFuserScheduler_SBR(
- benchmark::State& benchmark_state,
- FusionExecutorCache* fusion_executor_cache,
- DataType dtype) {
- // N, H, W, C format
- std::vector<int64_t> input_shape{
- benchmark_state.range(0),
- benchmark_state.range(1),
- benchmark_state.range(1),
- benchmark_state.range(2)};
- std::vector<int64_t> bcast_shape{1, 1, 1, -1};
-
- // inputs
- at::manual_seed(0);
- std::vector<int64_t> static_bcast_shape{1, 1, 1, benchmark_state.range(2)};
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- at::Tensor at_x = at::randn(input_shape, options);
- at::Tensor at_scale = at::ones(static_bcast_shape, options);
- at::Tensor at_bias = at::zeros(static_bcast_shape, options);
-
- // inputs
- std::vector<c10::IValue> aten_inputs = {at_x, at_scale, at_bias};
-
- fusion_executor_cache->profile(true);
- fusion_executor_cache->runFusionWithInputs(aten_inputs);
-
- auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
- auto executor_instance = compile_log.fusion_executor;
- TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
- TORCH_INTERNAL_ASSERT(compile_log.launch_constraints.has_value());
- auto params = toString(compile_log.pointwise_params.value());
- auto lparams = toString(compile_log.launch_constraints.value());
-
- benchmark_state.SetLabel(params + lparams);
- benchmark_state.SetLabel(lparams);
-
- fusion_executor_cache->profile(false);
- executor_instance->setMeasureKernelTimeFlag(true);
- // Sync everything up before we start
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
- benchmark_state.SetIterationTime(
- executor_instance->kernelTimeMs() / 1000.0);
- clearL2Cache();
- }
-
- // Sync everything up before we're finished, don't want to run ahead on the
- // cpu while benchmarking.
- cudaDeviceSynchronize();
-
- const size_t size =
- input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
- const size_t channels = input_shape[3];
- benchmark_state.SetBytesProcessed(
- int64_t(benchmark_state.iterations()) * (channels * 2 + size * 2) *
- int64_t(dataTypeSize(dtype)));
-}
-
-static void Baseline_SBR(benchmark::State& benchmark_state, DataType dtype) {
- // N, H, W, C format
- std::vector<int64_t> input_shape{
- benchmark_state.range(0),
- benchmark_state.range(1),
- benchmark_state.range(1),
- benchmark_state.range(2)};
- std::vector<int64_t> bcast_shape{benchmark_state.range(2)};
-
- // inputs
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- at::Tensor at_x = at::randn(input_shape, options);
- at::Tensor at_y = at::randn(input_shape, options);
- at::Tensor at_scale = at::ones(bcast_shape, options);
- at::Tensor at_bias = at::zeros(bcast_shape, options);
-
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- CudaKernelTimer timer;
-
- auto scale = at::mul(at_x, at_scale);
- auto bias = at::add(scale, at_bias);
- auto output = at::relu(bias);
-
- benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
- cudaDeviceSynchronize();
- clearL2Cache();
- cudaDeviceSynchronize();
- }
-
- const size_t size =
- input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
- const size_t channels = input_shape[3];
- benchmark_state.SetBytesProcessed(
- int64_t(benchmark_state.iterations()) * (channels * 2 + size * 2) *
- int64_t(dataTypeSize(dtype)));
-}
-
-//------------------------------------------------------------------------------
-
-static void NvFuserScheduler_SBR_Norm(
- benchmark::State& benchmark_state,
- FusionExecutorCache* fusion_executor_cache,
- DataType dtype) {
- // N, H, W, C format
- std::vector<int64_t> input_shape{
- benchmark_state.range(0),
- benchmark_state.range(1),
- benchmark_state.range(1),
- benchmark_state.range(2)};
- std::vector<int64_t> bcast_shape{benchmark_state.range(2)};
-
- // inputs
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- at::Tensor at_x = at::randn(input_shape, options);
- at::Tensor at_weight = at::ones(bcast_shape, options);
- at::Tensor at_bias = at::zeros(bcast_shape, options);
- at::Tensor at_mean = at::zeros(bcast_shape, options);
- at::Tensor at_var = at::ones(bcast_shape, options);
-
- // inputs
- std::vector<c10::IValue> aten_inputs = {
- at_x, at_weight, at_bias, at_mean, at_var};
-
- fusion_executor_cache->profile(true);
- fusion_executor_cache->runFusionWithInputs(aten_inputs);
-
- auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
- auto executor_instance = compile_log.fusion_executor;
- TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
- TORCH_INTERNAL_ASSERT(compile_log.launch_constraints.has_value());
- auto params = toString(compile_log.pointwise_params.value());
- auto lparams = toString(compile_log.launch_constraints.value());
-
- benchmark_state.SetLabel(params + lparams);
-
- fusion_executor_cache->profile(false);
- executor_instance->setMeasureKernelTimeFlag(true);
- // Sync everything up before we start
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
- benchmark_state.SetIterationTime(
- executor_instance->kernelTimeMs() / 1000.0);
- clearL2Cache();
- }
-
- // Sync everything up before we're finished, don't want to run ahead on the
- // cpu while benchmarking.
- cudaDeviceSynchronize();
-
- const size_t size =
- input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
- const size_t channels = input_shape[3];
- benchmark_state.SetBytesProcessed(
- int64_t(benchmark_state.iterations()) * (channels * 4 + size * 2) *
- int64_t(dataTypeSize(dtype)));
-}
-
-static void Baseline_SBR_Norm(
- benchmark::State& benchmark_state,
- DataType dtype) {
- // N, H, W, C format
- std::vector<int64_t> input_shape{
- benchmark_state.range(0),
- benchmark_state.range(1),
- benchmark_state.range(1),
- benchmark_state.range(2)};
- std::vector<int64_t> bcast_shape{1, 1, 1, benchmark_state.range(2)};
-
- // inputs
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- at::Tensor at_x = at::randn(input_shape, options);
- at::Tensor at_weight = at::ones(bcast_shape, options);
- at::Tensor at_bias = at::zeros(bcast_shape, options);
- at::Tensor at_mean = at::zeros(bcast_shape, options);
- at::Tensor at_var = at::ones(bcast_shape, options);
-
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- CudaKernelTimer timer;
-
- auto this_scale = at::mul(at_weight, at::rsqrt(at_var));
- auto this_bias = at::mul(at::sub(at_bias, at_mean), this_scale);
-
- auto scale = at::mul(at_x, this_scale);
- auto bias = at::add(scale, this_bias);
- auto output = at::relu(bias);
-
- benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
- cudaDeviceSynchronize();
- }
-
- const size_t size =
- input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3];
- const size_t channels = input_shape[3];
- benchmark_state.SetBytesProcessed(
- int64_t(benchmark_state.iterations()) * (channels * 4 + size * 2) *
- int64_t(dataTypeSize(dtype)));
-}
-
-//------------------------------------------------------------------------------
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_SBR_fp32,
- setupSBR,
- NvFuserScheduler_SBR,
- DataType::Float);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_SBR_fp32)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8}, {640, 640}, {64, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_SBR_fp16,
- setupSBR,
- NvFuserScheduler_SBR,
- DataType::Half);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_SBR_fp16)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8}, {640, 640}, {64, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_SBR_Norm_fp32,
- setupSBRNorm,
- NvFuserScheduler_SBR_Norm,
- DataType::Float);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_SBR_Norm_fp32)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8}, {640, 640}, {64, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_SBR_Norm_fp16,
- setupSBRNorm,
- NvFuserScheduler_SBR_Norm,
- DataType::Half);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_SBR_Norm_fp16)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8}, {640, 640}, {64, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-static void Baseline_SBR_fp32(benchmark::State& benchmark_state) {
- Baseline_SBR(benchmark_state, DataType::Float);
-}
-
-BENCHMARK(Baseline_SBR_fp32)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8}, {640, 640}, {64, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-static void Baseline_SBR_fp16(benchmark::State& benchmark_state) {
- Baseline_SBR(benchmark_state, DataType::Half);
-}
-
-BENCHMARK(Baseline_SBR_fp16)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8}, {640, 640}, {64, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-static void Baseline_SBR_Norm_fp32(benchmark::State& benchmark_state) {
- Baseline_SBR_Norm(benchmark_state, DataType::Float);
-}
-
-BENCHMARK(Baseline_SBR_Norm_fp32)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8}, {640, 640}, {64, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-static void Baseline_SBR_Norm_fp16(benchmark::State& benchmark_state) {
- Baseline_SBR_Norm(benchmark_state, DataType::Half);
-}
-
-BENCHMARK(Baseline_SBR_Norm_fp16)
- ->RangeMultiplier(2)
- ->Ranges({{8, 8}, {640, 640}, {64, 256}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <cuda_runtime.h>
-
-#include "utils.h"
-
-using namespace torch::jit::fuser::cuda;
-
-//------------------------------------------------------------------------------
-
-static void setupSoftmax(
- Fusion* fusion,
- DataType dtype,
- const int reduction_axis) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- FusionGuard fg(fusion);
- // setup fusion
- auto input = makeContigTensor(2, dtype);
- fusion->addInput(input);
-
- if (dtype == DataType::Half) {
- input = castOp(DataType::Float, input);
- }
-
- auto output = softmax(input, reduction_axis);
-
- if (dtype == DataType::Half) {
- output = castOp(DataType::Half, output);
- }
-
- fusion->addOutput(output);
-}
-
-static void NvFuserScheduler_Softmax(
- benchmark::State& benchmark_state,
- FusionExecutorCache* fusion_executor_cache,
- DataType dtype,
- const int reduction_axis) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- std::vector<int64_t> input_shape{
- benchmark_state.range(1), benchmark_state.range(0)};
-
- // inputs
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn(input_shape, options);
- std::vector<c10::IValue> aten_inputs({aten_input});
-
- runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
-
- benchmark_state.SetBytesProcessed(
- int64_t(benchmark_state.iterations()) *
- (2 * aten_input.numel() * int64_t(dataTypeSize(dtype))));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_Softmax(
- benchmark::State& benchmark_state,
- DataType dtype) {
- std::vector<int64_t> input_shape{
- benchmark_state.range(1), benchmark_state.range(0)};
- const int kReductionAxis = benchmark_state.range(2);
-
- // inputs
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn(input_shape, options);
-
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- CudaKernelTimer timer;
- auto output = at::_softmax(aten_input, kReductionAxis, false);
- benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
- cudaDeviceSynchronize();
- clearL2Cache();
- cudaDeviceSynchronize();
- }
-
- benchmark_state.SetBytesProcessed(
- int64_t(benchmark_state.iterations()) *
- (2 * aten_input.numel() * int64_t(dataTypeSize(dtype))));
-}
-
-static void Baseline_Softmax_fp32(benchmark::State& benchmark_state) {
- Baseline_Softmax(benchmark_state, DataType::Float);
-}
-
-static void Baseline_Softmax_fp16(benchmark::State& benchmark_state) {
- Baseline_Softmax(benchmark_state, DataType::Half);
-}
-
-//------------------------------------------------------------------------------
-
-static void setupSoftmaxDropout(
- Fusion* fusion,
- DataType dtype,
- const int kReductionAxis) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- FusionGuard fg(fusion);
-
- constexpr int kHiddenSize = 768;
- constexpr int kNumAttentionHeads = 12;
- constexpr int kAttentionHeadSize = kHiddenSize / kNumAttentionHeads;
- constexpr float kDropoutProbability = 0.9;
- constexpr float kScale = 1.0f / kDropoutProbability;
-
- // setup fusion
- auto attention_scores = makeContigTensor(4, dtype);
- auto attention_mask = makeContigTensor(4, dtype);
-
- Double* divisor = new Double();
-
- fusion->addInput(attention_scores);
- fusion->addInput(attention_mask);
- fusion->addInput(divisor);
-
- if (dtype == DataType::Half) {
- attention_scores = castOp(DataType::Float, attention_scores);
- attention_mask = castOp(DataType::Float, attention_mask);
- }
-
- attention_scores = div(attention_scores, divisor);
- attention_scores = add(attention_scores, attention_mask);
- auto attention_probs = softmax(attention_scores, kReductionAxis);
- auto prob = new Double(kDropoutProbability);
- auto scale = new Double(kScale);
- auto dropout_results = dropout(attention_probs, prob, scale);
- auto output = dropout_results.output;
-
- if (dtype == DataType::Half) {
- attention_scores = castOp(DataType::Half, attention_scores);
- attention_probs = castOp(DataType::Half, attention_probs);
- output = castOp(DataType::Half, output);
- }
-
- fusion->addOutput(attention_scores);
- fusion->addOutput(attention_probs);
- fusion->addOutput(output);
-
- fusion->addOutput(dropout_results.mask);
-}
-
-static void NvFuserScheduler_SoftmaxDropout(
- benchmark::State& benchmark_state,
- FusionExecutorCache* fusion_executor_cache,
- DataType dtype,
- const int kReductionAxis) {
- TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);
-
- // reduce across 1, [256, 12, 100, 8]
- std::vector<int64_t> input_shape{256, 12, 100, benchmark_state.range(0)};
-
- constexpr int kHiddenSize = 768;
- constexpr int kNumAttentionHeads = 12;
- constexpr int kAttentionHeadSize = kHiddenSize / kNumAttentionHeads;
- constexpr float kDropoutProbability = 0.9;
- constexpr float kScale = 1.0f / kDropoutProbability;
-
- // inputs
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- at::Tensor at_scores = at::randn(input_shape, options);
- at::Tensor at_mask = at::randn(input_shape, options);
- std::vector<c10::IValue> aten_inputs(
- {at_scores, at_mask, sqrt(kAttentionHeadSize)});
-
- runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
-
- // 5 dtype: attention_scores + attention_mask + attention_scores_out +
- // attention_probs_out + output
- // 1 bool: dropout_results.mask
- // All the same size
- benchmark_state.SetBytesProcessed(
- int64_t(benchmark_state.iterations()) * 5 * at_scores.numel() *
- int64_t(dataTypeSize(dtype)) +
- // bool mask
- int64_t(benchmark_state.iterations()) * at_scores.numel() *
- int64_t(dataTypeSize(DataType::Bool)));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_Softmax_Dropout(
- benchmark::State& benchmark_state,
- const int kReductionAxis,
- DataType dtype) {
- std::vector<int64_t> input_shape{256, 12, 100, benchmark_state.range(0)};
-
- constexpr int kHiddenSize = 768;
- constexpr int kNumAttentionHeads = 12;
- constexpr float kDropoutProbability = 0.1;
- constexpr int kAttentionHeadSize = kHiddenSize / kNumAttentionHeads;
-
- // inputs
- at::manual_seed(0);
- auto options =
- at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
- at::Tensor attention_scores = at::randn(input_shape, options);
- at::Tensor at_y = at::randn(input_shape, options);
-
- cudaDeviceSynchronize();
-
- for (auto _ : benchmark_state) {
- // Create
- CudaKernelTimer timer;
-
- // Run
- attention_scores = attention_scores / sqrt(kAttentionHeadSize);
- attention_scores = attention_scores + at_y;
- auto attention_probs =
- at::_softmax(attention_scores, kReductionAxis, false);
- attention_probs = at::dropout(attention_probs, kDropoutProbability, true);
-
- // Record
- benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
- cudaDeviceSynchronize();
- clearL2Cache();
- cudaDeviceSynchronize();
- }
-
- // 5 dtype: attention_scores + attention_mask + attention_scores_out +
- // attention_probs_out + output
- // 1 bool: dropout_results.mask
- // All the same size
- benchmark_state.SetBytesProcessed(
- int64_t(benchmark_state.iterations()) * 5 * attention_scores.numel() *
- int64_t(dataTypeSize(dtype)) +
- // bool mask
- int64_t(benchmark_state.iterations()) * attention_scores.numel() *
- int64_t(dataTypeSize(DataType::Bool)));
-}
-
-//------------------------------------------------------------------------------
-
-static void Baseline_Softmax_Dropout_Inner_fp32(
- benchmark::State& benchmark_state) {
- Baseline_Softmax_Dropout(benchmark_state, 3, DataType::Float);
-}
-
-static void Baseline_Softmax_Dropout_Outer_fp32(
- benchmark::State& benchmark_state) {
- Baseline_Softmax_Dropout(benchmark_state, 1, DataType::Float);
-}
-
-static void Baseline_Softmax_Dropout_Inner_fp16(
- benchmark::State& benchmark_state) {
- Baseline_Softmax_Dropout(benchmark_state, 3, DataType::Half);
-}
-
-static void Baseline_Softmax_Dropout_Outer_fp16(
- benchmark::State& benchmark_state) {
- Baseline_Softmax_Dropout(benchmark_state, 1, DataType::Half);
-}
-
-//------------------------------------------------------------------------------
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Softmax_Outer_fp32,
- setupSoftmax,
- NvFuserScheduler_Softmax,
- DataType::Float,
- 0);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Outer_fp32)
- ->RangeMultiplier(2)
- ->Ranges({{656, 656}, {8, 8 << 12}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Softmax_Inner_fp32,
- setupSoftmax,
- NvFuserScheduler_Softmax,
- DataType::Float,
- 1);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Inner_fp32)
- ->RangeMultiplier(2)
- ->Ranges({{656, 656}, {8, 8 << 12}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Softmax_Outer_fp16,
- setupSoftmax,
- NvFuserScheduler_Softmax,
- DataType::Half,
- 0);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Outer_fp16)
- ->RangeMultiplier(2)
- ->Ranges({{656, 656}, {8, 8 << 12}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Softmax_Inner_fp16,
- setupSoftmax,
- NvFuserScheduler_Softmax,
- DataType::Half,
- 1);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Inner_fp16)
- ->RangeMultiplier(2)
- ->Ranges({{656, 656}, {8, 8 << 12}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Softmax_Dropout_Inner_fp32,
- setupSoftmaxDropout,
- NvFuserScheduler_SoftmaxDropout,
- DataType::Float,
- 3);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Dropout_Inner_fp32)
- ->Arg(8)
- ->Arg(16)
- ->Arg(24)
- ->Arg(32)
- ->Arg(40)
- ->Arg(48)
- ->Arg(56)
- ->Arg(64)
- ->Arg(72)
- ->Arg(80)
- ->Arg(88)
- ->Arg(96)
- ->Arg(104)
- ->Arg(112)
- ->Arg(120)
- ->Arg(128)
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Softmax_Dropout_Outer_fp32,
- setupSoftmaxDropout,
- NvFuserScheduler_SoftmaxDropout,
- DataType::Float,
- 1);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Dropout_Outer_fp32)
- ->Arg(8)
- ->Arg(16)
- ->Arg(24)
- ->Arg(32)
- ->Arg(40)
- ->Arg(48)
- ->Arg(56)
- ->Arg(64)
- ->Arg(72)
- ->Arg(80)
- ->Arg(88)
- ->Arg(96)
- ->Arg(104)
- ->Arg(112)
- ->Arg(120)
- ->Arg(128)
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Softmax_Dropout_Inner_fp16,
- setupSoftmaxDropout,
- NvFuserScheduler_SoftmaxDropout,
- DataType::Half,
- 3);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Dropout_Inner_fp16)
- ->Arg(8)
- ->Arg(16)
- ->Arg(24)
- ->Arg(32)
- ->Arg(40)
- ->Arg(48)
- ->Arg(56)
- ->Arg(64)
- ->Arg(72)
- ->Arg(80)
- ->Arg(88)
- ->Arg(96)
- ->Arg(104)
- ->Arg(112)
- ->Arg(120)
- ->Arg(128)
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-NVFUSER_BENCHMARK_DEFINE(
- NvFuserScheduler_Softmax_Dropout_Outer_fp16,
- setupSoftmaxDropout,
- NvFuserScheduler_SoftmaxDropout,
- DataType::Half,
- 1);
-
-NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_Dropout_Outer_fp16)
- ->Arg(8)
- ->Arg(16)
- ->Arg(24)
- ->Arg(32)
- ->Arg(40)
- ->Arg(48)
- ->Arg(56)
- ->Arg(64)
- ->Arg(72)
- ->Arg(80)
- ->Arg(88)
- ->Arg(96)
- ->Arg(104)
- ->Arg(112)
- ->Arg(120)
- ->Arg(128)
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-//------------------------------------------------------------------------------
-
-BENCHMARK(Baseline_Softmax_fp32)
- ->RangeMultiplier(2)
- ->Ranges({{656, 656}, {8, 8 << 12}, {0, 1}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_Softmax_fp16)
- ->RangeMultiplier(2)
- ->Ranges({{656, 656}, {8, 8 << 12}, {0, 1}})
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_Softmax_Dropout_Inner_fp32)
- ->Arg(8)
- ->Arg(16)
- ->Arg(24)
- ->Arg(32)
- ->Arg(40)
- ->Arg(48)
- ->Arg(56)
- ->Arg(64)
- ->Arg(72)
- ->Arg(80)
- ->Arg(88)
- ->Arg(96)
- ->Arg(104)
- ->Arg(112)
- ->Arg(120)
- ->Arg(128)
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_Softmax_Dropout_Outer_fp32)
- ->Arg(8)
- ->Arg(16)
- ->Arg(24)
- ->Arg(32)
- ->Arg(40)
- ->Arg(48)
- ->Arg(56)
- ->Arg(64)
- ->Arg(72)
- ->Arg(80)
- ->Arg(88)
- ->Arg(96)
- ->Arg(104)
- ->Arg(112)
- ->Arg(120)
- ->Arg(128)
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_Softmax_Dropout_Inner_fp16)
- ->Arg(8)
- ->Arg(16)
- ->Arg(24)
- ->Arg(32)
- ->Arg(40)
- ->Arg(48)
- ->Arg(56)
- ->Arg(64)
- ->Arg(72)
- ->Arg(80)
- ->Arg(88)
- ->Arg(96)
- ->Arg(104)
- ->Arg(112)
- ->Arg(120)
- ->Arg(128)
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
-
-BENCHMARK(Baseline_Softmax_Dropout_Outer_fp16)
- ->Arg(8)
- ->Arg(16)
- ->Arg(24)
- ->Arg(32)
- ->Arg(40)
- ->Arg(48)
- ->Arg(56)
- ->Arg(64)
- ->Arg(72)
- ->Arg(80)
- ->Arg(88)
- ->Arg(96)
- ->Arg(104)
- ->Arg(112)
- ->Arg(120)
- ->Arg(128)
- ->Unit(benchmark::kMicrosecond)
- ->UseManualTime();
+++ /dev/null
-#include "utils.h"
-
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <sstream>
-
-using namespace torch::jit::fuser::cuda;
-
-std::string toString(ReductionParams rparams) {
- std::stringstream ss;
- if (rparams.fastest_dim) {
- ss << "/Fastest dim";
- } else {
- ss << "/Slow dim";
- }
- if (rparams.cross_grid) {
- ss << "/cross grid";
- }
- if (rparams.cross_block) {
- ss << "/cross block";
- }
- if (rparams.multiple_reds_per_blk) {
- ss << "/multiple reductions per block ";
- }
- if (rparams.loop_unroll > 1) {
- ss << (rparams.vectorize ? "/Vectorize " : "/Unroll ")
- << (rparams.reduction_unroll ? "reduction dim " : "iter dim ")
- << rparams.loop_unroll;
- }
- if (rparams.batches_per_block > 1) {
- ss << "/batches per block " << rparams.batches_per_block << " ";
- }
- if (rparams.persistent_kernel) {
- ss << "/persistent";
- }
-
- if (rparams.split_grid_dim) {
- ss << "/split grid dim";
- }
- return ss.str();
-}
-
-std::string toString(PointwiseParams params) {
- std::stringstream ss;
- if (params.break_point) {
- ss << "2D Schedule at " << params.break_point << "/";
- if (params.split_block) {
- ss << " Split block into y-dim/";
- }
- if (params.split_grid_y_dim) {
- ss << " Split y grid dim/";
- }
- } else {
- ss << "1D"
- << "/";
- }
- if (params.inner_factor > 1) {
- if (params.vectorize) {
- ss << "Vectorize, Factor: " << params.inner_factor;
- } else {
- ss << "Unroll, Factor: " << params.inner_factor;
- }
- }
- return ss.str();
-}
-
-std::string toString(LaunchParams lparams) {
- std::stringstream ss;
- lparams.toString();
- ss << "/Launch_Parameters["
- << "block(" << lparams.bdimz() << "/" << lparams.bdimy() << "/"
- << lparams.bdimx() << ")/grid(" << lparams.gdimz() << "/"
- << lparams.gdimy() << "/" << lparams.gdimx() << ")/" << lparams.smem()
- << "]";
- return ss.str();
-}
-
-void clearL2Cache() {
- torch::NoGradGuard no_grad;
- auto l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize;
- auto options =
- torch::TensorOptions().dtype(torch::kFloat32).device(at::kCUDA, 0);
-
- auto l2_elems = l2_cache_size / 4;
- torch::Tensor t0 = torch::empty(l2_elems, options);
- torch::Tensor t1 = torch::clone(t0);
-};
-
-TensorView* makeContigTensor(size_t ndims, DataType dtype) {
- return TensorViewBuilder()
- .ndims(ndims)
- .dtype(dtype)
- .contiguity(std::vector<bool>(ndims, true))
- .build();
-}
-
-void runBenchmarkIterations(
- benchmark::State& benchmark_state,
- FusionExecutorCache* fusion_executor_cache,
- std::vector<c10::IValue>& aten_inputs) {
- fusion_executor_cache->runFusionWithInputs(aten_inputs);
- bool segmented =
- fusion_executor_cache->getMostRecentKernelRuntime()->isSegmented();
-
- if (!segmented) {
- fusion_executor_cache->profile(true);
- fusion_executor_cache->runFusionWithInputs(aten_inputs);
- auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
- auto executor_instance = compile_log.fusion_executor;
- TORCH_INTERNAL_ASSERT(compile_log.reduction_params.has_value());
- TORCH_INTERNAL_ASSERT(compile_log.launch_constraints.has_value());
- auto rparams = toString(compile_log.reduction_params.value());
- auto lparams = toString(compile_log.launch_constraints.value());
- benchmark_state.SetLabel(rparams + lparams);
- executor_instance->setMeasureKernelTimeFlag(true);
-
- // Sync everything up before we start
- cudaDeviceSynchronize();
- for (auto _ : benchmark_state) {
- auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
- benchmark_state.SetIterationTime(
- executor_instance->kernelTimeMs() / 1000.0);
- clearL2Cache();
- }
- // Sync everything up before we're finished, don't want to run ahead on the
- // cpu while benchmarking.
- cudaDeviceSynchronize();
- } else {
- // Segmented
- // Sync everything up before we start
- {
- // Compile/warmup
- auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
- }
- cudaDeviceSynchronize();
- CudaKernelTimer timer;
- for (auto _ : benchmark_state) {
- timer.restart();
- auto cg_outputs = fusion_executor_cache->runFusionWithInputs(aten_inputs);
- benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
- clearL2Cache();
- }
- // Sync everything up before we're finished, don't want to run ahead on the
- // cpu while benchmarking.
- cudaDeviceSynchronize();
- }
-}
-
-namespace executorCache {
-thread_local ExecutorMap executor_map_;
-ExecutorMap& getGlobalMap() {
- return executor_map_;
-}
-} // namespace executorCache
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-
-#include <benchmark/benchmark.h>
-
-#include <ATen/cuda/CUDAContext.h>
-#include <torch/torch.h>
-
-#include <cuda_runtime.h>
-
-using namespace torch::jit::fuser::cuda;
-
-std::string toString(ReductionParams rparams);
-std::string toString(PointwiseParams params);
-std::string toString(LaunchParams lparams);
-
-// Run benchmark iterations with provided inputs. If not segmented, report
-// kernel time from the runtime, as well as heuristic parameters. If segmented
-// use timers. Make sure to clear L2 between iterations.
-void runBenchmarkIterations(
- benchmark::State& benchmark_state,
- FusionExecutorCache* fusion_executor_cache,
- std::vector<c10::IValue>& aten_inputs);
-
-void clearL2Cache();
-
-// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
-// but unknown sizes. Taken from test_gpu.cpp
-TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float);
-
-class CudaKernelTimer {
- public:
- CudaKernelTimer() {
- // Setup
- cudaEventCreate(&start_event);
- cudaEventCreate(&finish_event);
- cudaEventRecord(start_event);
- }
-
- ~CudaKernelTimer() {
- cudaEventDestroy(start_event);
- cudaEventDestroy(finish_event);
- }
-
- void restart() {
- cudaEventRecord(start_event);
- }
-
- float elapsed() {
- // Record
- cudaEventRecord(finish_event);
- cudaEventSynchronize(start_event);
- cudaEventSynchronize(finish_event);
- cudaEventElapsedTime(&kernel_time_ms_, start_event, finish_event);
- return kernel_time_ms_;
- }
-
- private:
- // Create
- float kernel_time_ms_ = 0;
- cudaEvent_t start_event = {};
- cudaEvent_t finish_event = {};
-};
-
-namespace executorCache {
-using ExecutorPtr = std::unique_ptr<FusionExecutorCache>;
-using ExecutorMap = std::unordered_map<std::string, ExecutorPtr>;
-ExecutorMap& getGlobalMap();
-} // namespace executorCache
-
-//! Utility to manage FusionExecutorCache instances for
-//! all defined benchmarks
-class BenchmarkGraph : public benchmark::Fixture {
- public:
- using SetupFusionFunction = std::function<void(Fusion*)>;
- using SetupFusionMap = std::unordered_map<std::string, SetupFusionFunction>;
-
- virtual std::string graphName() = 0;
- virtual SetupFusionFunction setupFusion() = 0;
-
- FusionExecutorCache* getExecutorCache() {
- auto& executor_ = getExecutorCacheMap()[graphName()];
- TORCH_INTERNAL_ASSERT(executor_);
- return executor_.get();
- }
-
- void SetUp(const ::benchmark::State& state) {
- auto& executor_ = getExecutorCacheMap()[graphName()];
- // Makes sure same graph hasn't been compiled before
- if (!executor_) {
- auto fusion_ptr = std::make_unique<Fusion>();
- FusionGuard(fusion_ptr.get());
- setupFusion()(fusion_ptr.get());
- getExecutorCacheMap()[graphName()] =
- std::make_unique<FusionExecutorCache>(std::move(fusion_ptr));
- }
- }
-
- void TearDown(const ::benchmark::State& state) {}
-
- protected:
- static executorCache::ExecutorMap& getExecutorCacheMap() {
- return executorCache::getGlobalMap();
- }
-};
-
-#define NVFUSER_TO_STRING_HELPER(n) std::string(#n)
-#define NVFUSER_TO_STRING(n) NVFUSER_TO_STRING_HELPER(n)
-
-//! NVFUSER_BENCHMARK_RUN utility usage:
-//! This utility helps create and manage FusionExecutorCaches and tries to use
-//! the caching
-//! mechanism in NVFuser to avoid re-compilation.
-//!
-//! There are two macros in this utility: NVFUSER_BENCHMARK_DEFINE, and
-//! NVFUSER_BENCHMARK_RUN,
-//! and user needs to supply two functions SETUP_FUSION and RUN_FUSION, with
-//! following signatures:
-//!
-//! SETUP_FUSION(Fusion* , args...);
-//! RUN_FUSION(benchmark::State&, FusionExecutorCache* , args...);
-//!
-//! where args... are additional arguments, and they need to be the same for
-//! SETUP_FUSION and RUN_FUSION.
-//!
-//! SETUP_FUSION is called once in each definition of benchmark to build the
-//! fusionIR graph
-//!
-//! RUN_FUSION is just like the normal benchmark instance, except that a
-//! FusionExecutorCache
-//! will be provided for scheduling, running and timing the fusion runs. It is
-//! called once in each benchmark instance. For example:
-//! NVFUSER_BENCHMARK_RUN(my_benchmark)
-//! ->RangeMultiplier(2)
-//! ->Ranges({{1, 4})
-//! Calls RUN_FUSION 3 times.
-//!
-//! To register a benchmark, the API is:
-//!
-//! NVFUSER_BENCHMARK_DEFINE(my_benchmark,SETUP_FUSION,RUN_FUSION,args...);
-//!
-//! where my_benchmark is any unique name given for this benchmark,
-//! SETUP_FUSION, RUN_FUSION as described above,
-//! args... is the arg list supplied to both setup_fusion and run_fusion
-//!
-//! each NVFUSER_BENCHMARK_DEFINE registers a benchmark with a single
-//! FusionExecutorCache, i.e. a single fusion graph, and multiple benchmark
-//! data points can be registered like:
-//!
-//! NVFUSER_BENCHMARK_RUN(my_benchmark)
-//! ->Ranges({{1,2}});
-//!
-//! NVFUSER_BENCHMARK_RUN(my_benchmark)
-//! ->Ranges({{3,4}});
-//!
-//! All datapoints will use the same FusionExecutorCache so recompilation is
-//! avoided as much as possible.
-
-#define NVFUSER_BENCHMARK_DEFINE( \
- BENCHMARK_NAME, SETUP_FUSION, RUN_FUSION, ...) \
- class BENCHMARK_NAME##___GRAPH : public BenchmarkGraph { \
- public: \
- std::string graphName() { \
- return NVFUSER_TO_STRING(BENCHMARK_NAME##___GRAPH); \
- } \
- SetupFusionFunction setupFusion() { \
- return [](Fusion* fusion) { SETUP_FUSION(fusion, __VA_ARGS__); }; \
- } \
- }; \
- BENCHMARK_DEFINE_F(BENCHMARK_NAME##___GRAPH, BENCHMARK_NAME) \
- (benchmark::State & benchmark_state) { \
- RUN_FUSION( \
- benchmark_state, \
- BENCHMARK_NAME##___GRAPH::getExecutorCache(), \
- __VA_ARGS__); \
- }
-
-#define NVFUSER_BENCHMARK_RUN(BENCHMARK_NAME) \
- BENCHMARK_REGISTER_F(BENCHMARK_NAME##___GRAPH, BENCHMARK_NAME)
# The list of NVFUSER runtime files
list(APPEND NVFUSER_RUNTIME_FILES
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/block_reduction.cu
- ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu
- ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/block_sync_default.cu
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/broadcast.cu
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/fp16_support.cu
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/grid_reduction.cu
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/helpers.cu
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/random_numbers.cu
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/tensor.cu
- ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/welford.cu
${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh
${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/detail/UnpackRaw.cuh
)
add_subdirectory(${TORCH_ROOT}/benchmarks/cpp/tensorexpr ${CMAKE_BINARY_DIR}/tensorexpr_bench)
endif()
-if(BUILD_NVFUSER_BENCHMARK)
- add_subdirectory(${TORCH_ROOT}/benchmarks/cpp/nvfuser ${CMAKE_BINARY_DIR}/nvfuser_bench)
-endif()
-
if(BUILD_CPP_BENCHMARKS)
add_subdirectory(${TORCH_ROOT}/benchmarks/cpp ${PROJECT_BINARY_DIR}/bin)
endif()
message(STATUS " BUILD_CAFFE2_MOBILE : ${BUILD_CAFFE2_MOBILE}")
message(STATUS " BUILD_STATIC_RUNTIME_BENCHMARK: ${BUILD_STATIC_RUNTIME_BENCHMARK}")
message(STATUS " BUILD_TENSOREXPR_BENCHMARK: ${BUILD_TENSOREXPR_BENCHMARK}")
- message(STATUS " BUILD_NVFUSER_BENCHMARK: ${BUILD_NVFUSER_BENCHMARK}")
message(STATUS " BUILD_BINARY : ${BUILD_BINARY}")
message(STATUS " BUILD_CUSTOM_PROTOBUF : ${BUILD_CUSTOM_PROTOBUF}")
if(${CAFFE2_LINK_LOCAL_PROTOBUF})
if(USE_CUDA)
list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu.cpp)
- list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu_shift.cpp)
endif()
add_executable(test_jit
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/codegen.h>
-#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
#include <torch/csrc/jit/codegen/cuda/interface.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_graphviz.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/mutator.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
+#include <torch/csrc/jit/codegen/cuda/scheduler.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
#include <torch/csrc/jit/codegen/cuda/parser.h>
#include <torch/csrc/jit/ir/irparser.h>
-#include "test_gpu_validator.h"
-
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAStream.h>
-#include <algorithm>
#include <iostream>
// Tests go in torch::jit
namespace jit {
using namespace torch::jit::fuser::cuda;
-using namespace at::indexing;
namespace {
-// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
-// but unknown sizes
-TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) {
- return TensorViewBuilder()
- .ndims(ndims)
- .dtype(dtype)
- .contiguity(std::vector<bool>(ndims, true))
- .build();
+TensorView* makeContigTensor(int nDims, DataType dtype = DataType::Float) {
+ std::vector<IterDomain*> dom;
+ for (int i = 0; i < nDims; i++)
+ dom.push_back(new IterDomain(new Int(0), new Int()));
+ std::vector<bool> contig(dom.size(), true);
+ return new TensorView(new TensorDomain(dom, contig), dtype);
}
-// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
-// but unknown sizes
-TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
- return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
+TensorView* makeDummyTensor(int nDims, DataType dtype = DataType::Float) {
+ // We can uncomment the below statement to test all tests with contiguous
+ // tensors. return makeContigTensor(nDims, dtype);
+ std::vector<IterDomain*> dom;
+ for (int i = 0; i < nDims; i++)
+ dom.push_back(new IterDomain(new Int(0), new Int()));
+ return new TensorView(new TensorDomain(dom), dtype);
}
-// Make a non-contiguous tensor of compile-time known sizes
TensorView* makeConcreteTensor(
- std::vector<int64_t> shape,
+ std::vector<int> sizes,
+ DataType dtype = DataType::Float) {
+ // We can uncomment the below statement to test all tests with contiguous
+ // tensors. return makeContigTensor(nDims, dtype);
+ std::vector<IterDomain*> dom;
+ for (int size : sizes) {
+ if (size >= 0) {
+ dom.push_back(new IterDomain(new Int(0), new Int(size)));
+ } else {
+ dom.push_back(new IterDomain(new Int(0), new Int()));
+ }
+ }
+ return new TensorView(new TensorDomain(dom), dtype);
+}
+
+TensorView* makeTensorWithContig(
+ int nDims,
+ std::vector<bool> contig_info,
DataType dtype = DataType::Float) {
- return TensorViewBuilder().shape(shape).dtype(dtype).build();
+ std::vector<IterDomain*> dom;
+ for (int i = 0; i < nDims; i++)
+ dom.push_back(new IterDomain(new Int(0), new Int()));
+ return new TensorView(new TensorDomain(dom, contig_info), dtype);
}
void checkIntValue(
- ExpressionEvaluator& evaluator,
+ StatefulExpressionEvaluator& evaluator,
Val* val,
Int::ScalarType expected_value) {
TORCH_CHECK(val->isAnInt());
- const auto actual_value = evaluator.evaluate(val);
- TORCH_CHECK(actual_value.has_value());
- TORCH_CHECK(actual_value.value() == expected_value);
-}
-
-void checkIntValue(
- kir::ExpressionEvaluator& evaluator,
- const kir::Val* val,
- kir::Int::ScalarType expected_value) {
- const auto actual_value = evaluator.evaluate(val);
+ const auto actual_value = evaluator.inferValue(val);
TORCH_CHECK(actual_value.has_value());
TORCH_CHECK(actual_value.value() == expected_value);
}
-bool isPredicated(TensorView* tv, GpuLower& gpulw) {
- auto parent_scope = gpulw.lowerValue(tv)->definition()->parentScope();
- if (parent_scope->isA<kir::IfThenElse>()) {
- return !parent_scope->predicate()->value()->isConst();
- }
- return true;
-};
-
} // namespace
// 1. Test cases are void() functions.
.empty());
// Construct an interesting IR
- TensorView* tv0 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- TensorView* tv2 = add(tv0, new Double(3.141));
+ TensorView* tv2 = add(tv0, new Float(3.141));
TensorView* tv3 = broadcast(tv0, {false, true, false, true});
- TensorView* tv4 = reductionOp(BinaryOpType::Add, {2}, new Double(0), tv3);
- TensorView* tv5 = clamp(tv4, new Double(0.f), new Double(1.f));
+ TensorView* tv4 = reductionOp(BinaryOpType::Add, {2}, new Float(0), tv3);
+ TensorView* tv5 = clamp(tv4, new Float(0.f), new Float(1.f));
TensorView* tv6 = add(tv2, tv2);
// Another checkpoint before adding outputs
Fusion fusion;
FusionGuard fg(&fusion);
- Double* f = new Double{2.f};
+ Float* f = new Float{2.f};
std::stringstream ss1, ss2, ss3;
ss1 << f;
ss2 << static_cast<Val*>(f);
ss3 << static_cast<Statement*>(f);
TORCH_CHECK(
ss1.str().compare(ss2.str()) == 0 && ss1.str().compare(ss3.str()) == 0,
- "Error with dispatch system where results differ by passing Double* vs Val* vs Statement*.");
+ "Error with dispatch system where results differ by passing Float* vs Val* vs Statement*.");
}
// Evaluate basic scalar operations with constant values
Fusion fusion;
FusionGuard fg(&fusion);
- ExpressionEvaluator evaluator(&fusion);
+ StatefulExpressionEvaluator evaluator(&fusion);
auto* a = new Int(7);
auto* b = new Int(3);
Fusion fusion;
FusionGuard fg(&fusion);
- ExpressionEvaluator evaluator(&fusion);
+ StatefulExpressionEvaluator evaluator(&fusion);
auto* a = new Int();
auto* b = new Int();
auto* e = new Int(0);
// trying to evaluate before binding should give empty results
- TORCH_CHECK(!evaluator.evaluate(a).has_value());
- TORCH_CHECK(!evaluator.evaluate(d).has_value());
+ TORCH_CHECK(!evaluator.inferValue(a).has_value());
+ TORCH_CHECK(!evaluator.inferValue(d).has_value());
- evaluator.bind(a, 7);
- evaluator.bind(b, 3);
+ evaluator.safeBind(a, 7);
+ evaluator.safeBind(b, 3);
// can't bind to the results of expressions
- ASSERT_ANY_THROW(evaluator.bind(c, 100));
+ ASSERT_ANY_THROW(evaluator.safeBind(c, 100));
// can't bind to concrete values
- ASSERT_ANY_THROW(evaluator.bind(e, 100));
+ ASSERT_ANY_THROW(evaluator.safeBind(e, 100));
checkIntValue(evaluator, c, 10);
checkIntValue(evaluator, sub(a, b), 4);
checkIntValue(evaluator, d, -4);
// Reset evaluation context
- evaluator = ExpressionEvaluator(&fusion);
+ evaluator = StatefulExpressionEvaluator(&fusion);
- evaluator.bind(a, 2);
- evaluator.bind(b, 5);
+ evaluator.safeBind(a, 2);
+ evaluator.safeBind(b, 5);
checkIntValue(evaluator, c, 7);
checkIntValue(evaluator, sub(a, b), -3);
FusionGuard fg(&fusion);
// Create a non-trivial IR
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
fusion.addInput(tv0);
fusion.addInput(tv1);
- TensorView* tv2 = add(tv1, new Double(2.0));
+ TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);
fusion.addOutput(tv3);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
// 1. Create an evaluator
- ExpressionEvaluator evaluator(&fusion);
+ StatefulExpressionEvaluator evaluator(&fusion);
// 2. Bind values
//
// (ex. `tv0->getRootDomain()[0]->extent()`
// instead of `tv0->axis(0)->extent()`)
//
- evaluator.bind(tv0->getRootDomain()[0]->extent(), 6);
- evaluator.bind(tv0->getRootDomain()[1]->extent(), 128);
- evaluator.bind(tv1->getRootDomain()[0]->extent(), 6);
- evaluator.bind(tv1->getRootDomain()[1]->extent(), 128);
+ evaluator.safeBind(tv0->getRootDomain()[0]->extent(), 6);
+ evaluator.safeBind(tv0->getRootDomain()[1]->extent(), 128);
+ evaluator.safeBind(tv1->getRootDomain()[0]->extent(), 6);
+ evaluator.safeBind(tv1->getRootDomain()[1]->extent(), 128);
// 3. Evaluate and check result values
TORCH_CHECK(tv2->domain()->nDims() == 3);
- checkIntValue(evaluator, tv2->axis(0)->extent(), 2);
- checkIntValue(evaluator, tv2->axis(1)->extent(), 4);
- checkIntValue(evaluator, tv2->axis(2)->extent(), 128);
+ checkIntValue(evaluator, tv2->axis(0)->rawExtent(), 2);
+ checkIntValue(evaluator, tv2->axis(1)->rawExtent(), 4);
+ checkIntValue(evaluator, tv2->axis(2)->rawExtent(), 128);
TORCH_CHECK(tv3->domain()->nDims() == 3);
- checkIntValue(evaluator, tv3->axis(0)->extent(), 2);
- checkIntValue(evaluator, tv3->axis(1)->extent(), 4);
- checkIntValue(evaluator, tv3->axis(2)->extent(), 128);
+ checkIntValue(evaluator, tv3->axis(0)->rawExtent(), 2);
+ checkIntValue(evaluator, tv3->axis(1)->rawExtent(), 4);
+ checkIntValue(evaluator, tv3->axis(2)->rawExtent(), 128);
}
// Evaluate expressions in a more complex IR
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- TensorView* tv1 = mul(tv0, new Double(-1.0));
- TensorView* tv2 = add(tv0, new Double(3.0));
- TensorView* tv3 = mul(tv0, new Double(2.0));
+ TensorView* tv1 = mul(tv0, new Float(-1.0));
+ TensorView* tv2 = add(tv0, new Float(3.0));
+ TensorView* tv3 = mul(tv0, new Float(2.0));
TensorView* tv4 = add(tv2, tv1);
TensorView* tv5 = add(tv4, tv3);
TensorView* tv6 = add(tv0, tv3);
tv5->merge(0);
// 1. Create an evaluator
- ExpressionEvaluator evaluator(&fusion);
+ StatefulExpressionEvaluator evaluator(&fusion);
// 2. Bind values
- evaluator.bind(tv0->getRootDomain()[0]->extent(), 129);
- evaluator.bind(tv0->getRootDomain()[1]->extent(), 127);
+ evaluator.safeBind(tv0->getRootDomain()[0]->extent(), 129);
+ evaluator.safeBind(tv0->getRootDomain()[1]->extent(), 127);
// Evaluate and check extent values
TORCH_CHECK(tv0->domain()->nDims() == 2);
- checkIntValue(evaluator, tv0->axis(0)->extent(), 129);
- checkIntValue(evaluator, tv0->axis(1)->extent(), 127);
+ checkIntValue(evaluator, tv0->axis(0)->rawExtent(), 129);
+ checkIntValue(evaluator, tv0->axis(1)->rawExtent(), 127);
TORCH_CHECK(tv3->domain()->nDims() == 2);
- checkIntValue(evaluator, tv3->axis(0)->extent(), 129);
- checkIntValue(evaluator, tv3->axis(1)->extent(), 127);
+ checkIntValue(evaluator, tv3->axis(0)->rawExtent(), 129);
+ checkIntValue(evaluator, tv3->axis(1)->rawExtent(), 127);
TORCH_CHECK(tv4->domain()->nDims() == 2);
- checkIntValue(evaluator, tv4->axis(0)->extent(), 129);
- checkIntValue(evaluator, tv4->axis(1)->extent(), 127);
+ checkIntValue(evaluator, tv4->axis(0)->rawExtent(), 129);
+ checkIntValue(evaluator, tv4->axis(1)->rawExtent(), 127);
TORCH_CHECK(tv5->domain()->nDims() == 1);
- checkIntValue(evaluator, tv5->axis(0)->extent(), 16383);
+ checkIntValue(evaluator, tv5->axis(0)->rawExtent(), 16383);
TORCH_CHECK(tv6->domain()->nDims() == 3);
- checkIntValue(evaluator, tv6->axis(0)->extent(), 26);
- checkIntValue(evaluator, tv6->axis(1)->extent(), 5);
- checkIntValue(evaluator, tv6->axis(2)->extent(), 127);
+ checkIntValue(evaluator, tv6->axis(0)->rawExtent(), 26);
+ checkIntValue(evaluator, tv6->axis(1)->rawExtent(), 5);
+ checkIntValue(evaluator, tv6->axis(2)->rawExtent(), 127);
}
// Evaluate expressions post lowering
FusionGuard fg(&fusion);
// Create a non-trivial IR
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
fusion.addInput(tv0);
fusion.addInput(tv1);
- TensorView* tv2 = add(tv1, new Double(2.0));
+ TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);
fusion.addOutput(tv3);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
- auto* bid_x = add(tv3->axis(0)->extent(), new Int(0));
- auto* tid_x = add(tv3->axis(-1)->extent(), new Int(0));
+ auto* bid_x = add(tv3->axis(0)->rawExtent(), new Int(0));
+ auto* tid_x = add(tv3->axis(-1)->rawExtent(), new Int(0));
// Lower
GpuLower gpulw(&fusion);
// 1. Create an evaluation context
- ExpressionEvaluator evaluator(&fusion);
+ StatefulExpressionEvaluator evaluator(&fusion);
// 2. Bind values
- evaluator.bind(tv0->getRootDomain()[0]->extent(), 6);
- evaluator.bind(tv0->getRootDomain()[1]->extent(), 128);
- evaluator.bind(tv1->getRootDomain()[0]->extent(), 6);
- evaluator.bind(tv1->getRootDomain()[1]->extent(), 128);
+ evaluator.safeBind(tv0->getRootDomain()[0]->extent(), 6);
+ evaluator.safeBind(tv0->getRootDomain()[1]->extent(), 128);
+ evaluator.safeBind(tv1->getRootDomain()[0]->extent(), 6);
+ evaluator.safeBind(tv1->getRootDomain()[1]->extent(), 128);
// 3. Evaluate and check result values
TORCH_CHECK(tv2->domain()->nDims() == 3);
- checkIntValue(evaluator, tv2->axis(0)->extent(), 2);
- checkIntValue(evaluator, tv2->axis(1)->extent(), 4);
- checkIntValue(evaluator, tv2->axis(2)->extent(), 128);
+ checkIntValue(evaluator, tv2->axis(0)->rawExtent(), 2);
+ checkIntValue(evaluator, tv2->axis(1)->rawExtent(), 4);
+ checkIntValue(evaluator, tv2->axis(2)->rawExtent(), 128);
TORCH_CHECK(tv3->domain()->nDims() == 3);
- checkIntValue(evaluator, tv3->axis(0)->extent(), 2);
- checkIntValue(evaluator, tv3->axis(1)->extent(), 4);
- checkIntValue(evaluator, tv3->axis(2)->extent(), 128);
+ checkIntValue(evaluator, tv3->axis(0)->rawExtent(), 2);
+ checkIntValue(evaluator, tv3->axis(1)->rawExtent(), 4);
+ checkIntValue(evaluator, tv3->axis(2)->rawExtent(), 128);
checkIntValue(evaluator, bid_x, 2);
checkIntValue(evaluator, tid_x, 128);
}
-// Kernel IR: Evaluate basic scalar operations with constant values
-TEST(NVFuserTest, KernelExprEvalConstants_CUDA) {
- kir::Kernel kernel;
- kir::IrBuilder ir_builder(&kernel);
-
- auto a = ir_builder.create<kir::Int>(7);
- auto b = ir_builder.create<kir::Int>(3);
- auto c = ir_builder.subExpr(a, b);
- auto d = ir_builder.divExpr(a, b);
- auto e = ir_builder.mulExpr(c, d);
-
- kir::ExpressionEvaluator evaluator;
-
- checkIntValue(evaluator, ir_builder.negExpr(a), -7);
- checkIntValue(evaluator, ir_builder.addExpr(a, b), 10);
- checkIntValue(evaluator, ir_builder.negExpr(e), -8);
- checkIntValue(evaluator, ir_builder.modExpr(a, b), 1);
- checkIntValue(evaluator, ir_builder.ceilDivExpr(a, b), 3);
-}
-
-// Kernel IR: Evaluate basic scalar operations with bound values
-TEST(NVFuserTest, KernelExprEvalBindings_CUDA) {
- kir::Kernel kernel;
- kir::IrBuilder ir_builder(&kernel);
-
- kir::ExpressionEvaluator evaluator;
-
- auto a = ir_builder.create<kir::Int>(c10::nullopt);
- auto b = ir_builder.create<kir::Int>(c10::nullopt);
- auto c = ir_builder.addExpr(a, b);
- auto d = ir_builder.negExpr(ir_builder.ceilDivExpr(c, b));
- auto e = ir_builder.create<kir::Int>(0);
-
- // trying to evaluate before binding should give empty results
- TORCH_CHECK(!evaluator.evaluate(a).has_value());
- TORCH_CHECK(!evaluator.evaluate(d).has_value());
-
- evaluator.bind(a, 7);
- evaluator.bind(b, 3);
-
- // can't bind to the results of expressions
- ASSERT_ANY_THROW(evaluator.bind(c, 100));
-
- // can't bind to concrete values
- ASSERT_ANY_THROW(evaluator.bind(e, 100));
-
- checkIntValue(evaluator, c, 10);
- checkIntValue(evaluator, ir_builder.subExpr(a, b), 4);
- checkIntValue(evaluator, ir_builder.modExpr(a, b), 1);
- checkIntValue(evaluator, ir_builder.ceilDivExpr(a, b), 3);
- checkIntValue(evaluator, d, -4);
-
- // Reset the evaluation context
- evaluator = kir::ExpressionEvaluator();
-
- evaluator.bind(a, 2);
- evaluator.bind(b, 5);
-
- checkIntValue(evaluator, c, 7);
- checkIntValue(evaluator, ir_builder.subExpr(a, b), -3);
- checkIntValue(evaluator, ir_builder.modExpr(a, b), 2);
- checkIntValue(evaluator, ir_builder.ceilDivExpr(a, b), 1);
- checkIntValue(evaluator, d, -2);
-}
-
TEST(NVFuserTest, FusionClear_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
// 1. Create a dummy IR
{
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
fusion.addInput(tv0);
fusion.addInput(tv1);
- TensorView* tv2 = add(tv1, new Double(2.0));
+ TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);
fusion.addOutput(tv3);
fusion.clear();
- TORCH_CHECK(fusion.unordered_exprs().empty());
+ TORCH_CHECK(fusion.exprs().empty());
TORCH_CHECK(fusion.vals().empty());
TORCH_CHECK(fusion.inputs().empty());
TORCH_CHECK(fusion.outputs().empty());
TORCH_CHECK(!fusion.hasReduction());
+ TORCH_CHECK(!fusion.hasBlockReduction());
+ TORCH_CHECK(!fusion.hasGridReduction());
// 3. Rebuild the IR
{
- TensorView* tv0 = makeSymbolicTensor(3);
- TensorView* tv1 = makeSymbolicTensor(3);
- TensorView* tv2 = add(tv1, new Double(2.0));
+ TensorView* tv0 = makeDummyTensor(3);
+ TensorView* tv1 = makeDummyTensor(3);
+ TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);
fusion.addInput(tv0);
{
FusionGuard fg(&original_fusion);
- auto tv0 = makeSymbolicTensor(3);
- auto tv1 = makeSymbolicTensor(3);
- auto tv2 = add(tv1, new Double(2.0));
+ auto tv0 = makeDummyTensor(3);
+ auto tv1 = makeDummyTensor(3);
+ auto tv2 = add(tv1, new Float(2.0));
auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2);
original_fusion.addInput(tv0);
{
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(3);
- auto tv1 = makeSymbolicTensor(3);
- auto tv2 = add(tv1, new Double(2.0));
+ auto tv0 = makeDummyTensor(3);
+ auto tv1 = makeDummyTensor(3);
+ auto tv2 = add(tv1, new Float(2.0));
auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2);
fusion.addInput(tv0);
// standard library containers:
// https://en.cppreference.com/w/cpp/utility/move
//
- TORCH_CHECK(fusion.unordered_exprs().empty());
+ TORCH_CHECK(fusion.exprs().empty());
TORCH_CHECK(fusion.vals().empty());
TORCH_CHECK(fusion.inputs().empty());
TORCH_CHECK(fusion.outputs().empty());
Fusion fusion;
FusionGuard fg(&fusion);
- Double* d1 = new Double(1.f);
- Double* d2 = new Double{2.f};
- Double* d3 = new Double();
+ Float* f1 = new Float(1.f);
+ Float* f2 = new Float{2.f};
+ Float* f3 = new Float();
// Disrupt the fusion to make sure guard works well
{
Fusion fusion2;
FusionGuard fg(&fusion2);
- Double* d1 = new Double(1.f);
- Double* d2 = new Double(2.f);
- add(d1, d2);
+ Float* f1 = new Float(1.f);
+ Float* f2 = new Float(2.f);
+ add(f1, f2);
ss2 << fusion2;
}
- new BinaryOp(BinaryOpType::Add, d3, d1, d2);
+ new BinaryOp(BinaryOpType::Add, f3, f1, f2);
ss1 << fusion;
TORCH_CHECK(
Fusion fusion;
FusionGuard fg(&fusion);
- Double* d4 = new Double{4.f};
+ Float* f4 = new Float{4.f};
+ Int* i1 = new Int{3};
+ auto f5 = add(f4, i1);
+
+ TORCH_CHECK(f5->getDataType() == DataType::Float);
+}
+
+class ZeroMutator : public OptOutMutator {
+ public:
+ Statement* mutate(Float* f) {
+ if (f->isConst() && *(f->value()) == 1.0)
+ return new Float(0.0);
+ return f;
+ }
+ void mutate(Fusion* f) {
+ OptOutMutator::mutate(f);
+ }
+};
+
+TEST(NVFuserTest, FusionMutator_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
+
+ Float* f4 = new Float{1.f};
Int* i1 = new Int{3};
- auto d5 = add(d4, i1);
+ Val* f5 = add(f4, i1);
+ ZeroMutator mutator;
+ mutator.mutate(&fusion);
+ Val* lhs = static_cast<BinaryOp*>(fusion.origin(f5))->lhs();
+ TORCH_CHECK(
+ lhs->getValType().value() == ValType::Scalar &&
+ lhs->getDataType().value() == DataType::Float);
+ Float* flhs = static_cast<Float*>(lhs);
- TORCH_CHECK(d5->getDataType() == DataType::Double);
+ TORCH_CHECK(flhs->value().value() == 0.f);
}
TEST(NVFuserTest, FusionRegister_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- Double* v1 = new Double{1.f};
- Double* v2 = new Double{2.f};
+ Float* v1 = new Float{1.f};
+ Float* v2 = new Float{2.f};
Val* v3 = binaryOp(BinaryOpType::Add, v1, v2);
Val* v4 = binaryOp(BinaryOpType::Add, v1, v2);
TORCH_CHECK(v1->name() + 1 == v2->name());
TORCH_CHECK(v2->name() + 1 == v3->name());
TORCH_CHECK(v3->name() + 1 == v4->name());
- TORCH_CHECK(v3->definition()->name() + 1 == v4->definition()->name());
+ TORCH_CHECK(fusion.origin(v3)->name() + 1 == fusion.origin(v4)->name());
}
// dummy expr with 2 outputs only for toposort test.
// e1: v4 = add(v3, v2)
// e2: v5 = add(v2, v4)
// e3: v6 = add(v5, v5)
- Double* v0 = new Double{1.f};
- Double* v1 = new Double{2.f};
- Double* v2 = new Double();
- Double* v3 = new Double();
- Double* v4 = new Double();
- Double* v5 = new Double();
- Double* v6 = new Double();
-
- std::vector<Val*> inputs = {v0, v1};
- for (auto val : inputs) {
- fusion.addInput(val);
- }
+ Float* v0 = new Float{1.f};
+ Float* v1 = new Float{2.f};
+ Float* v2 = new Float();
+ Float* v3 = new Float();
+ Float* v4 = new Float();
+ Float* v5 = new Float();
+ Float* v6 = new Float();
Expr* e0 = new DummyExpr(v3, v2, v1, v0);
Expr* e1 = new BinaryOp(BinaryOpType::Add, v4, v3, v2);
Expr* e2 = new BinaryOp(BinaryOpType::Add, v5, v2, v4);
Expr* e3 = new BinaryOp(BinaryOpType::Add, v6, v5, v5);
+ std::vector<Expr*> exprs = fusion.exprs();
+
+ TORCH_CHECK(exprs.size() == 4);
+ TORCH_CHECK(exprs[0] == e0);
+ TORCH_CHECK(exprs[1] == e1);
+ TORCH_CHECK(exprs[2] == e2);
+ TORCH_CHECK(exprs[3] == e3);
+
fusion.addOutput(v2);
- fusion.addOutput(v3);
- auto exprs = fusion.exprs();
- TORCH_CHECK(exprs.size() == 1, "Found ", exprs.size(), " but expecting 1");
+ exprs = fusion.exprs(true);
+ TORCH_CHECK(exprs.size() == 1);
TORCH_CHECK(exprs[0] == e0);
fusion.addOutput(v5);
- exprs = fusion.exprs();
- TORCH_CHECK(exprs.size() == 3, "Found ", exprs.size(), " but expecting 3");
+ exprs = fusion.exprs(true);
TORCH_CHECK(exprs[0] == e0);
TORCH_CHECK(exprs[1] == e1);
TORCH_CHECK(exprs[2] == e2);
fusion.addOutput(v4);
- exprs = fusion.exprs();
- TORCH_CHECK(exprs.size() == 3, "Found ", exprs.size(), " but expecting 3");
+ exprs = fusion.exprs(true);
+ TORCH_CHECK(exprs[0] == e0);
+ TORCH_CHECK(exprs[1] == e1);
+ TORCH_CHECK(exprs[2] == e2);
+
+ fusion.addOutput(v3);
+ exprs = fusion.exprs(true);
TORCH_CHECK(exprs[0] == e0);
TORCH_CHECK(exprs[1] == e1);
TORCH_CHECK(exprs[2] == e2);
fusion.addOutput(v6);
- exprs = fusion.exprs();
- TORCH_CHECK(exprs.size() == 4, "Found ", exprs.size(), " but expecting 4");
+ exprs = fusion.exprs(true);
+ TORCH_CHECK(exprs.size() == 4);
TORCH_CHECK(exprs[0] == e0);
TORCH_CHECK(exprs[1] == e1);
TORCH_CHECK(exprs[2] == e2);
TORCH_CHECK(exprs[3] == e3);
- TORCH_CHECK(v2->definition()->name() == 0);
- TORCH_CHECK(v3->definition()->name() == 0);
- TORCH_CHECK(v4->definition()->name() == 1);
- TORCH_CHECK(v5->definition()->name() == 2);
- TORCH_CHECK(v6->definition()->name() == 3);
+ TORCH_CHECK(fusion.origin(v2)->name() == 0);
+ TORCH_CHECK(fusion.origin(v3)->name() == 0);
+ TORCH_CHECK(fusion.origin(v4)->name() == 1);
+ TORCH_CHECK(fusion.origin(v5)->name() == 2);
+ TORCH_CHECK(fusion.origin(v6)->name() == 3);
}
TEST(NVFuserTest, FusionTensor_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
- auto tv1 = makeSymbolicTensor(1);
- auto scalar0 = new Double(0);
+ auto tv0 = makeDummyTensor(1);
+ auto tv1 = makeDummyTensor(1);
+ auto scalar0 = new Float(0);
auto scalar1 = new Int(0);
auto scalar2 = new Int(1);
TORCH_CHECK(tvs[0] == tv0);
TORCH_CHECK(tvs[1] == tv1);
- std::vector<Double*> floats(
- ir_utils::filterByType<Double>(vals).begin(),
- ir_utils::filterByType<Double>(vals).end());
+ std::vector<Float*> floats(
+ ir_utils::filterByType<Float>(vals).begin(),
+ ir_utils::filterByType<Float>(vals).end());
TORCH_CHECK(floats.size() == 1);
TORCH_CHECK(floats[0] == scalar0);
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv = makeSymbolicTensor(3);
+ TensorView* tv = makeDummyTensor(3);
tv = tv->split(2, 2);
TORCH_CHECK(tv->nDims() == 4);
- Expr* outer = tv->axis(2)->extent()->definition();
+ Expr* outer = tv->axis(2)->extent()->getOrigin();
TORCH_CHECK(
outer->getExprType().value() == ExprType::BinaryOp &&
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv = makeSymbolicTensor(3);
+ TensorView* tv = makeDummyTensor(3);
tv = tv->merge(1);
- Expr* axisOp = tv->axis(1)->extent()->definition();
+ Expr* axisOp = tv->axis(1)->extent()->getOrigin();
TORCH_CHECK(
tv->nDims() == 2 && axisOp->getExprType() == ExprType::BinaryOp &&
std::unordered_map<int, int> swap{{0, 2}, {2, 0}};
- auto tv = makeSymbolicTensor(3);
+ auto tv = makeDummyTensor(3);
std::vector<IterDomain*> ref;
ref = std::vector<IterDomain*>(
tv->domain()->domain().begin(), tv->domain()->domain().end());
for (int i = 0; i < (int)tv->nDims(); i++)
TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1)));
- tv = makeSymbolicTensor(3);
+ tv = makeDummyTensor(3);
ref = std::vector<IterDomain*>(
tv->domain()->domain().begin(), tv->domain()->domain().end());
for (int i = 0; i < (int)tv->nDims(); i++)
TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1)));
- tv = makeSymbolicTensor(3);
+ tv = makeDummyTensor(3);
ref = std::vector<IterDomain*>(
tv->domain()->domain().begin(), tv->domain()->domain().end());
for (int i = 1; i < (int)tv->nDims(); i++)
TORCH_CHECK(ref[i - 1]->sameAs(tv->axis(i)));
- tv = makeSymbolicTensor(3);
+ tv = makeDummyTensor(3);
ref = std::vector<IterDomain*>(
tv->domain()->domain().begin(), tv->domain()->domain().end());
tv->reorder(swap);
Fusion fusion;
FusionGuard fg(&fusion);
- Double* fval1 = new Double();
- Double* fval1_copy = fval1;
- Double* fval2 = new Double();
- Double* fone = new Double(1.0);
+ Float* fval1 = new Float();
+ Float* fval1_copy = fval1;
+ Float* fval2 = new Float();
+ Float* fone = new Float(1.0);
TORCH_CHECK(fval1->sameAs(fval1_copy));
TORCH_CHECK(!fval1->sameAs(fval2));
TORCH_CHECK(!fone->sameAs(fval1));
- TORCH_CHECK(fone->sameAs(new Double(1.0)));
+ TORCH_CHECK(fone->sameAs(new Float(1.0)));
Int* ival1 = new Int();
Int* ival1_copy = ival1;
TORCH_CHECK(!ione->sameAs(ival1));
TORCH_CHECK(ione->sameAs(new Int(1)));
- BinaryOp* add1 = new BinaryOp(BinaryOpType::Add, new Double(), fval1, ival1);
+ BinaryOp* add1 = new BinaryOp(BinaryOpType::Add, new Float(), fval1, ival1);
BinaryOp* add1_copy =
- new BinaryOp(BinaryOpType::Add, new Double(), fval1, ival1);
- BinaryOp* sub1 = new BinaryOp(BinaryOpType::Sub, new Double(), fval1, ival1);
+ new BinaryOp(BinaryOpType::Add, new Float(), fval1, ival1);
+ BinaryOp* sub1 = new BinaryOp(BinaryOpType::Sub, new Float(), fval1, ival1);
- UnaryOp* neg1 = new UnaryOp(UnaryOpType::Neg, new Double(), fval1);
- UnaryOp* neg2 = new UnaryOp(UnaryOpType::Neg, new Double(), fval2);
- UnaryOp* neg1_copy = new UnaryOp(UnaryOpType::Neg, new Double(), fval1);
+ UnaryOp* neg1 = new UnaryOp(UnaryOpType::Neg, new Float(), fval1);
+ UnaryOp* neg2 = new UnaryOp(UnaryOpType::Neg, new Float(), fval2);
+ UnaryOp* neg1_copy = new UnaryOp(UnaryOpType::Neg, new Float(), fval1);
TORCH_CHECK(add1->sameAs(add1_copy));
TORCH_CHECK(!add1->sameAs(sub1));
Fusion fusion;
FusionGuard fg(&fusion);
- Double* d0 = new Double(0.f);
- Double* d1 = new Double(1.f);
- auto d2 = add(d0, d1);
-
- auto d3 = add(d2, d2);
-
- Double* d4 = new Double(4.f);
- Double* d5 = new Double(5.f);
- auto d6 = add(d4, d5);
-
- Double* d7 = new Double(7.f);
- Double* d8 = new Double(8.f);
- auto d9 = add(d7, d8);
-
- auto d10 = add(d6, d9);
-
- auto d11 = add(d3, d10);
-
- TORCH_CHECK(DependencyCheck::isDependencyOf(d0, d11));
- TORCH_CHECK(DependencyCheck::isDependencyOf(d1, d11));
- TORCH_CHECK(DependencyCheck::isDependencyOf(d2, d11));
- TORCH_CHECK(DependencyCheck::isDependencyOf(d3, d11));
- TORCH_CHECK(DependencyCheck::isDependencyOf(d6, d11));
- TORCH_CHECK(DependencyCheck::isDependencyOf(d9, d11));
- TORCH_CHECK(DependencyCheck::isDependencyOf(d0, d2));
- TORCH_CHECK(DependencyCheck::isDependencyOf(d2, d3));
- TORCH_CHECK(DependencyCheck::isDependencyOf(d4, d6));
- TORCH_CHECK(DependencyCheck::isDependencyOf(d8, d10));
-
- TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d0));
- TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d1));
- TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d2));
- TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d3));
- TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d4));
- TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d5));
- TORCH_CHECK(!DependencyCheck::isDependencyOf(d2, d0));
- TORCH_CHECK(!DependencyCheck::isDependencyOf(d3, d2));
- TORCH_CHECK(!DependencyCheck::isDependencyOf(d6, d4));
- TORCH_CHECK(!DependencyCheck::isDependencyOf(d10, d8));
-
- auto dep_chain = DependencyCheck::getSingleDependencyChain(d0, d11);
- TORCH_CHECK(dep_chain.back() == d11);
+ Float* f0 = new Float(0.f);
+ Float* f1 = new Float(1.f);
+ auto f2 = add(f0, f1);
+
+ auto f3 = add(f2, f2);
+
+ Float* f4 = new Float(4.f);
+ Float* f5 = new Float(5.f);
+ auto f6 = add(f4, f5);
+
+ Float* f7 = new Float(7.f);
+ Float* f8 = new Float(8.f);
+ auto f9 = add(f7, f8);
+
+ auto f10 = add(f6, f9);
+
+ auto f11 = add(f3, f10);
+
+ TORCH_CHECK(DependencyCheck::isDependencyOf(f0, f11));
+ TORCH_CHECK(DependencyCheck::isDependencyOf(f1, f11));
+ TORCH_CHECK(DependencyCheck::isDependencyOf(f2, f11));
+ TORCH_CHECK(DependencyCheck::isDependencyOf(f3, f11));
+ TORCH_CHECK(DependencyCheck::isDependencyOf(f6, f11));
+ TORCH_CHECK(DependencyCheck::isDependencyOf(f9, f11));
+ TORCH_CHECK(DependencyCheck::isDependencyOf(f0, f2));
+ TORCH_CHECK(DependencyCheck::isDependencyOf(f2, f3));
+ TORCH_CHECK(DependencyCheck::isDependencyOf(f4, f6));
+ TORCH_CHECK(DependencyCheck::isDependencyOf(f8, f10));
+
+ TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f0));
+ TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f1));
+ TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f2));
+ TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f3));
+ TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f4));
+ TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f5));
+ TORCH_CHECK(!DependencyCheck::isDependencyOf(f2, f0));
+ TORCH_CHECK(!DependencyCheck::isDependencyOf(f3, f2));
+ TORCH_CHECK(!DependencyCheck::isDependencyOf(f6, f4));
+ TORCH_CHECK(!DependencyCheck::isDependencyOf(f10, f8));
+
+ auto dep_chain = DependencyCheck::getSingleDependencyChain(f0, f11);
+ TORCH_CHECK(dep_chain.back() == f11);
dep_chain.pop_back();
- TORCH_CHECK(dep_chain.back() == d3);
+ TORCH_CHECK(dep_chain.back() == f3);
dep_chain.pop_back();
- TORCH_CHECK(dep_chain.back() == d2);
+ TORCH_CHECK(dep_chain.back() == f2);
dep_chain.pop_back();
- dep_chain = DependencyCheck::getSingleDependencyChain(d6, d11);
- TORCH_CHECK(dep_chain.back() == d11);
+ dep_chain = DependencyCheck::getSingleDependencyChain(f6, f11);
+ TORCH_CHECK(dep_chain.back() == f11);
dep_chain.pop_back();
- TORCH_CHECK(dep_chain.back() == d10);
+ TORCH_CHECK(dep_chain.back() == f10);
dep_chain.pop_back();
- dep_chain = DependencyCheck::getSingleDependencyChain(d4, d11);
- TORCH_CHECK(dep_chain.back() == d11);
+ dep_chain = DependencyCheck::getSingleDependencyChain(f4, f11);
+ TORCH_CHECK(dep_chain.back() == f11);
dep_chain.pop_back();
- TORCH_CHECK(dep_chain.back() == d10);
+ TORCH_CHECK(dep_chain.back() == f10);
dep_chain.pop_back();
- TORCH_CHECK(dep_chain.back() == d6);
+ TORCH_CHECK(dep_chain.back() == f6);
dep_chain.pop_back();
- dep_chain = DependencyCheck::getSingleDependencyChain(d11, d2);
+ dep_chain = DependencyCheck::getSingleDependencyChain(f11, f2);
TORCH_CHECK(dep_chain.empty());
}
TEST(NVFuserTest, FusionParser_CUDA) {
- // This test may not pass if using a custom block sync as there may
- // be additional calls. Skip the test as it's not specifically
- // relevant with block synchronizatin.
- if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
- return;
- }
auto g = std::make_shared<Graph>();
const auto graph0_string = R"IR(
graph(%0 : Float(2, strides=[1]),
auto fusion = parseJitIR(g);
FusionGuard fg(fusion.get());
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- // Avoid vectorization here as those kernels can't be lowered twice at the
- // moment
at::Tensor input1 = at::randn({16}, options);
at::Tensor input2 = at::randn({16}, options);
- auto lparams = schedulePointwise(fusion.get(), {input1, input2});
+ scheduleFusion(fusion.get(), {input1, input2});
// CONSIDER:
// 1. this can be moved to a dedicated "golden" file
// 2. use a fuzzy compare (ignore non-significant whitespaces for example)
const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3) {
- if ((((((((((nvfuser_index_t)blockIdx.x) * 1) + (1 - 1)) * 1) + (1 - 1)) * 128) + ((nvfuser_index_t)threadIdx.x)) < T0.size[0])) {
- constexpr nvfuser_index_t ki169 = 0;
- float T5[1];
- constexpr nvfuser_index_t ki203 = 0;
- T5[ki203] = 0;
- constexpr nvfuser_index_t ki194 = 0;
- T5[ki194]
- = T1[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki169) * 1) + ki194) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)];
- float T4[1];
- constexpr nvfuser_index_t ki209 = 0;
- T4[ki209] = 0;
- constexpr nvfuser_index_t ki189 = 0;
- T4[ki189]
- = T0[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki169) * 1) + ki189) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)];
- float T6[1];
- constexpr nvfuser_index_t ki178 = 0;
- float T2[1];
- T2[0]
- = T4[ki178]
- * T5[ki178];
- T6[ki178]
- = T2[0]
- * T4[ki178];
- constexpr nvfuser_index_t ki171 = 0;
- T3[(((((((((nvfuser_index_t)blockIdx.x) * 1) + ki169) * 1) + ki171) * 128) + ((nvfuser_index_t)threadIdx.x)) * 1)]
- = T6[ki171];
+ float T2[1];
+ if ((((((blockIdx.x * 1) + (1 - 1)) * 128) + threadIdx.x) < T0.size[0])) {
+ for(size_t i6 = 0; i6 < 1; ++i6) {
+ T2[i6]
+ = T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)]
+ * T1[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
+ T3[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)]
+ = T2[i6]
+ * T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
+ }
+ } else {
+ for(size_t i6 = 0; i6 < 1; ++i6) {
+ if ((((((blockIdx.x * 1) + i6) * 128) + threadIdx.x) < T0.size[0])) {
+ T2[i6]
+ = T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)]
+ * T1[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
+ }
+ if ((((((blockIdx.x * 1) + i6) * 128) + threadIdx.x) < T0.size[0])) {
+ T3[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)]
+ = T2[i6]
+ * T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
+ }
+ }
}
}
)";
<< " \n ========= EXPECTED ========= \n"
<< expected_kernel << "\n========= ACTUAL ========== \n"
<< actual_kernel << "\n=================" << std::endl;
- auto it = std::mismatch(
- expected_kernel.begin(),
- expected_kernel.end(),
- actual_kernel.begin(),
- actual_kernel.end());
- std::string actual_mismatched_snippet(it.second, actual_kernel.end());
- actual_mismatched_snippet = actual_mismatched_snippet.substr(0, 10);
- std::string expected_mismatched_snippet(it.first, expected_kernel.end());
- expected_mismatched_snippet = expected_mismatched_snippet.substr(0, 10);
- std::cerr << "First mismatch found at: " << actual_mismatched_snippet
- << ", expected: " << expected_mismatched_snippet << std::endl;
TORCH_CHECK(false);
}
FusionExecutor fe;
fe.compileFusion(fusion.get());
- auto outputs = fe.runFusion({input1, input2}, lparams);
+ auto outputs = fe.runFusion({input1, input2});
at::Tensor output_ref = input1 * input2 * input1;
TORCH_CHECK(output_ref.equal(outputs[0]));
}
auto ID0 = new kir::IterDomain(new IterDomain(new Int(0), new Int(8)));
TensorView* TV2 = add(TV0, TV1);
- BinaryOp* op = static_cast<BinaryOp*>(TV2->definition();
+ BinaryOp* op = static_cast<BinaryOp*>(TV2->getOrigin());
fusion.addOutput(TV2);
auto fl = new kir::ForLoop(new kir::Int(c10::nullopt), ID0, {op});
#endif
}
-TEST(NVFuserTest, FusionOuterSplit_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(3);
-
- new BinaryOp(BinaryOpType::Add, tv0, new Double(0.0), new Double(1.0));
- TensorView* tv1 = add(tv0, new Double(2.0));
- TensorView* tv2 = add(tv1, new Double(3.0));
- fusion.addOutput(tv2);
-
- //[I0, I1, I2]
- tv2->split(-1, 4, false);
- //[I0, I1, I2o{4}, I2i]
- tv2->merge(0);
- tv2->merge(0);
- //[I0*I1*I2o{4}, I2i]
- tv2->split(0, 2);
- //[I0*I1*I2o{4}o, I0*I1*I2o{4}i{2}, I2i]
- tv2->reorder({{0, 1}, {1, 0}});
- // I0*I1*I2o{4}i{2}, [I0*I1*I2o{4}o, I2i]
-
- tv0->computeAt(tv2, -1);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor output = at::empty({2, 6, 32}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({}, {output});
-
- at::Tensor output_ref = at::zeros_like(output, options);
- output_ref = output_ref + 0.0 + 1.0 + 2.0 + 3.0;
-
- TORCH_CHECK(output_ref.equal(output));
-}
-
TEST(NVFuserTest, FusionCodeGen_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(3);
+ TensorView* tv0 = makeDummyTensor(3);
- new BinaryOp(BinaryOpType::Add, tv0, new Double(0.0), new Double(1.0));
- TensorView* tv1 = add(tv0, new Double(2.0));
- TensorView* tv2 = add(tv1, new Double(3.0));
+ new BinaryOp(BinaryOpType::Add, tv0, new Float(0.0), new Float(1.0));
+ TensorView* tv1 = add(tv0, new Float(2.0));
+ TensorView* tv2 = add(tv1, new Float(3.0));
fusion.addOutput(tv2);
//[I0, I1, I2]
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(3);
- TensorView* tv1 = makeSymbolicTensor(3);
- TensorView* tv2 = add(tv1, new Double(2.0));
+ TensorView* tv0 = makeDummyTensor(3);
+ TensorView* tv1 = makeDummyTensor(3);
+ TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);
fusion.addInput(tv0);
// Do math with it, it returns a `Val*` but can be static_casted back to
// TensorView
- TensorView* tv2 = add(tv1, new Double(2.0));
+ TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);
// Register your outputs
FusionGuard fg(&fusion);
// Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
// Register your inputs
fusion.addInput(tv0);
// Do math with it, it returns a `Val*` but can be static_casted back to
// TensorView
- TensorView* tv2 = add(tv1, new Double(2.0));
+ TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);
// Register your outputs
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- TensorView* tv1 = mul(tv0, new Double(0.5));
- TensorView* tv2 = mul(tv1, new Double(-1.0));
- TensorView* tv3 = add(tv1, new Double(3.0));
- TensorView* tv4 = mul(tv1, new Double(2.0));
+ TensorView* tv1 = mul(tv0, new Float(0.5));
+ TensorView* tv2 = mul(tv1, new Float(-1.0));
+ TensorView* tv3 = add(tv1, new Float(3.0));
+ TensorView* tv4 = mul(tv1, new Float(2.0));
TensorView* tv5 = add(tv3, tv2);
TensorView* tv6 = add(tv5, tv4);
tv0->computeAt(tv7, 1);
- GpuLower gpulw(&fusion);
-
- // The this-position of the last tensor should be zero.
- TORCH_CHECK(
- tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 &&
- tv7->getMaxProducerPosition() == 1);
- TORCH_CHECK(
- tv7->nDims() == 3 && tv6->getComputeAtPosition() == 0 &&
- tv6->getMaxProducerPosition() == 1);
- // The position of every other tensor should be 1.
- for (auto tv : {tv1, tv2, tv3, tv4, tv5}) {
- TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1);
- TORCH_CHECK(gpulw.caLoopMap().areMapped(tv7->axis(0), tv->axis(0)));
- }
+ TORCH_CHECK(tv1->hasComputeAt() && tv1->nDims() == 3);
+ TORCH_CHECK(tv2->getComputeAtView() == tv5 && tv2->nDims() == 3);
+ TORCH_CHECK(tv3->getComputeAtView() == tv5 && tv3->nDims() == 3);
+ TORCH_CHECK(tv4->hasComputeAt() && tv4->nDims() == 3);
+ TORCH_CHECK(tv5->getComputeAtView() == tv6 && tv5->nDims() == 3);
+ TORCH_CHECK(tv6->getComputeAtView() == tv7 && tv6->nDims() == 3);
+ TORCH_CHECK(!tv7->hasComputeAt());
for (Val* val : fusion.vals()) {
if (!fusion.hasInput(val) &&
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({129, 127}, options);
+ at::Tensor t0 = at::randn({129, 127}, options);
- auto t1 = aten_input.mul({0.5});
+ auto t1 = t0.mul({0.5});
auto t2 = t1.mul({-1.0});
auto t3 = t1.add({3.0});
auto t4 = t1.mul({2.0});
auto t6 = t5.add(t4);
auto t7 = t1.add(t4);
- std::vector<at::Tensor> aten_outputs = {t6, t7};
- std::vector<at::Tensor> cg_outputs = {
- at::empty_like(aten_input, options), at::empty_like(aten_input, options)};
+ at::Tensor kernel_tv6 = at::empty_like(t0, options);
+ at::Tensor kernel_tv7 = at::empty_like(t0, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- fe.runFusion({aten_input}, cg_outputs);
+ fe.runFusion({t0}, {kernel_tv6, kernel_tv7});
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
+ TORCH_CHECK(at::allclose(kernel_tv6, t6));
+ TORCH_CHECK(at::allclose(kernel_tv7, t7));
}
TEST(NVFuserTest, FusionAdvancedComputeAt2_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- TensorView* tv1 = mul(tv0, new Double(-1.0));
- TensorView* tv2 = add(tv0, new Double(3.0));
- TensorView* tv3 = mul(tv0, new Double(2.0));
+ TensorView* tv1 = mul(tv0, new Float(-1.0));
+ TensorView* tv2 = add(tv0, new Float(3.0));
+ TensorView* tv3 = mul(tv0, new Float(2.0));
TensorView* tv4 = add(tv2, tv1);
TensorView* tv5 = add(tv4, tv3);
}
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({129, 127}, options);
+ at::Tensor t0 = at::randn({129, 127}, options);
- auto t1 = input.mul({-1.0});
- auto t2 = input.add({3.0});
- auto t3 = input.mul({2.0});
+ auto t1 = t0.mul({-1.0});
+ auto t2 = t0.add({3.0});
+ auto t3 = t0.mul({2.0});
auto t4 = t2.add(t1);
auto t5 = t4.add(t3);
auto t6 = t5.add(t3);
- std::vector<at::Tensor> aten_outputs = {t5, t6};
-
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input});
+ auto outputs = fe.runFusion({t0});
- testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__);
+ TORCH_CHECK(at::allclose(outputs[0], t5));
+ TORCH_CHECK(at::allclose(outputs[1], t6));
}
TEST(NVFuserTest, FusionAdvancedComputeAt3_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(4);
+ TensorView* tv0 = makeDummyTensor(4);
fusion.addInput(tv0);
- TensorView* tv1 = makeSymbolicTensor(4);
+ TensorView* tv1 = makeDummyTensor(4);
fusion.addInput(tv1);
- TensorView* tv2 = mul(tv1, new Double(.979361));
+ TensorView* tv2 = mul(tv1, new Float(.979361));
TensorView* tv3 = mul(tv2, tv0);
fusion.addOutput(tv3);
at::Tensor t1 = at::rand_like(t0, options);
auto t2 = t1.mul({0.979361});
- auto aten_output = t2.mul(t0);
-
- std::vector<IValue> aten_inputs = {t0, t1};
+ auto t3 = t2.mul(t0);
- at::Tensor cg_output = at::empty_like(t0, options);
+ at::Tensor kernel_tv3 = at::empty_like(t0, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- fe.runFusion(aten_inputs, {cg_output});
+ fe.runFusion({t0, t1}, {kernel_tv3});
- testValidate(
- &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ TORCH_CHECK(at::allclose(kernel_tv3, t3));
}
TEST(NVFuserTest, FusionAdvancedComputeAt4_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(4);
+ TensorView* tv0 = makeDummyTensor(4);
fusion.addInput(tv0);
- TensorView* tv1 = makeSymbolicTensor(4);
+ TensorView* tv1 = makeDummyTensor(4);
fusion.addInput(tv1);
- TensorView* tv2 = makeSymbolicTensor(4);
+ TensorView* tv2 = makeDummyTensor(4);
fusion.addInput(tv2);
- TensorView* tv3 = makeSymbolicTensor(4);
+ TensorView* tv3 = makeDummyTensor(4);
fusion.addInput(tv3);
TensorView* tv4 = sub(tv2, tv3);
auto t4 = t2.sub(t3);
auto t5 = t1.add(t4);
- auto aten_output = t5.sub(t0);
-
- std::vector<IValue> aten_inputs = {t0, t1, t2, t3};
+ auto t6 = t5.sub(t0);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ auto outputs = fe.runFusion({t0, t1, t2, t3});
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ TORCH_CHECK(at::allclose(outputs[0], t6));
}
TEST(NVFuserTest, FusionAdvancedComputeAt5_CUDA) {
FusionGuard fg(&fusion);
// Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- TensorView* tv1 = makeSymbolicTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
fusion.addInput(tv1);
- TensorView* tv2 = add(tv0, new Double(2.0));
+ TensorView* tv2 = add(tv0, new Float(2.0));
TensorView* tv3 = mul(tv1, tv2);
fusion.addOutput(tv3);
at::Tensor t1 = at::rand_like(t0, options);
auto t2 = t0.add(2.0);
- auto aten_output = t1.mul(t2);
-
- std::vector<IValue> aten_inputs = {t0, t1};
+ auto t3 = t1.mul(t2);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ auto outputs = fe.runFusion({t0, t1});
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ TORCH_CHECK(at::allclose(outputs[0], t3));
}
TEST(NVFuserTest, FusionAdvancedComputeAt6_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- TensorView* tv1 = makeSymbolicTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
fusion.addInput(tv1);
- TensorView* tv2 = add(tv0, new Double(2.0));
+ TensorView* tv2 = add(tv0, new Float(2.0));
TensorView* tv3 = mul(tv1, tv2);
fusion.addOutput(tv3);
at::Tensor t1 = at::rand_like(t0, options);
auto t2 = t0.add(2.0);
- auto aten_output = t1.mul(t2);
-
- std::vector<IValue> aten_inputs = {t0, t1};
+ auto t3 = t1.mul(t2);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ auto outputs = fe.runFusion({t0, t1});
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ TORCH_CHECK(at::allclose(outputs[0], t3));
}
-TEST(NVFuserTest, FusionAdvancedComputeAt7_CUDA) {
+TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) {
+ // tv1 = tv0 * 0.5
+ // tv2 = tv1 * -1
+ // tv3 = tv2 * -2
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
+ TensorView* tv0 = makeDummyTensor(1);
fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1.0));
-
- auto tv2 = makeSymbolicTensor(1);
- fusion.addInput(tv2);
-
- auto tv3 = add(tv2, new Double(3.0));
-
- auto tv4 = add(tv1, tv3);
- fusion.addOutput(tv4);
-
- auto tv5 = broadcast(tv1, {false, true});
+ TensorView* tv1 = mul(tv0, new Float(0.5));
+ TensorView* tv2 = mul(tv1, new Float(-1.0));
+ TensorView* tv3 = mul(tv1, new Float(-2.0));
+ fusion.addOutput(tv2);
+ fusion.addOutput(tv3);
- auto tv6 = makeSymbolicTensor(2);
- fusion.addInput(tv6);
+ // This computeAt will affect tv2 as well, even though tv2 is not in
+ // the data-flow path between tv1 and tv3. The reason is that tv1 is
+ // now computed at tv3, so tv2 must also be computed at the same
+ // location. Overall, what will happen is basically we merge
+ // expressions of all tensors and compute them in a single loop
+ // nest.
+ TensorView* computeAtTarget = tv3;
+ computeAtTarget->split(0, 128);
+ tv1->computeAt(computeAtTarget, 1);
- auto tv7 = mul(tv5, tv6);
+ TensorView* affected_tensors[] = {tv1, tv2, tv3};
+ for (auto tv : affected_tensors) {
+ TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
+ }
- fusion.addOutput(tv7);
+ // Note that tv2 is also computed at tv3.
+ TORCH_CHECK(tv1->getComputeAtView() == computeAtTarget);
+ TORCH_CHECK(tv2->getComputeAtView() == tv3);
+ TORCH_CHECK(!tv3->hasComputeAt());
- tv7->split(1, 2);
- tv7->merge(0);
- tv7->split(0, 4);
- tv7->split(0, 128);
+ computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
+ for (auto tv : affected_tensors) {
+ tv->axis(-1)->parallelize(ParallelType::TIDx);
+ }
- tv7->axis(0)->parallelize(ParallelType::BIDx);
- tv7->axis(1)->parallelize(ParallelType::TIDx);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- tv0->computeAt(tv7, 1);
- auto tv5_domain = tv5->domain()->domain();
+ at::Tensor t0 = at::randn({1000}, options);
- // These computeAt transformations should not affect the TV5 domain
- tv0->computeAt(tv4, -1);
- tv2->computeAt(tv4, -1);
+ auto t1 = t0 * 0.5;
+ auto t2 = t1 * -1.0;
+ auto t3 = t1 * -2.0;
- auto tv5_domain_current = tv5->domain()->domain();
- TORCH_CHECK(tv5_domain == tv5_domain_current, "Invalid TV5 domain");
+ at::Tensor kernel_tv2 = at::empty_like(t0, options);
+ at::Tensor kernel_tv3 = at::empty_like(t0, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
+ fe.runFusion({t0}, {kernel_tv2, kernel_tv3});
- const int numel_x = 100;
- const int numel_y = 200;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto t0 = at::randn({numel_x}, options);
- auto t2 = at::randn({numel_x}, options);
- auto t6 = at::randn({numel_x, numel_y}, options);
-
- auto t1 = t0.add(1.0);
- auto t3 = t2.add(3.0);
- auto t4 = t1.add(t3);
- auto t5 = t1.unsqueeze(1);
- auto t7 = t5.mul(t6);
-
- std::vector<IValue> aten_inputs = {t0, t2, t6};
- std::vector<at::Tensor> aten_outputs = {t4, t7};
-
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
+ TORCH_CHECK(at::allclose(kernel_tv2, t2));
+ TORCH_CHECK(at::allclose(kernel_tv3, t3));
}
-TEST(NVFuserTest, FusionAdvancedComputeAt8_CUDA) {
+// Similar to ComputeAtMultiConsumers, but with a common consumer.
+TEST(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) {
+ // tv1 = tv0 * 0.5
+ // tv2 = tv1 * -1
+ // tv3 = tv2 * -2
+ // tv4 = tv2 + tv3
+ // tv5 = tv4 * 5
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
+ TensorView* tv0 = makeDummyTensor(1);
fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1.0));
-
- auto tv2 = makeSymbolicTensor(1);
- fusion.addInput(tv2);
+ TensorView* tv1 = mul(tv0, new Float(0.5));
+ TensorView* tv2 = mul(tv1, new Float(-1.0));
+ TensorView* tv3 = mul(tv1, new Float(-2.0));
+ TensorView* tv4 = add(tv2, tv3);
+ TensorView* tv5 = mul(tv4, new Float(5.0));
+ fusion.addOutput(tv3);
+ fusion.addOutput(tv4);
+ fusion.addOutput(tv5);
- auto tv3 = add(tv2, new Double(3.0));
+ // Computing tv1 at tv3. This will affect tv2 as discussed in
+ // ComplexComputeAt1. Additionally, in this case, notice that tv4 is
+ // the common consumer of tv2 and tv3, so they are computed at
+ // tv4. The indirect propagation of the computeAt should stop at the
+ // common consumer, and no further change should occur. More
+ // specifically, tv4 and tv5 should not have a computeAt tensor.
+ TensorView* computeAtTarget = tv3;
+ computeAtTarget->split(0, 128);
+ tv1->computeAt(computeAtTarget, 1);
- auto tv4 = add(tv1, tv3);
- fusion.addOutput(tv4);
+ TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4};
+ for (auto tv : affected_tensors) {
+ TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
+ }
- auto tv5 = broadcast(tv1, {false, true});
+ TORCH_CHECK(tv1->getComputeAtView() == computeAtTarget);
+ TORCH_CHECK(tv2->getComputeAtView() == tv4);
+ TORCH_CHECK(tv3->getComputeAtView() == tv4);
+ TORCH_CHECK(!tv4->hasComputeAt());
+ TORCH_CHECK(!tv5->hasComputeAt());
- auto tv6 = makeSymbolicTensor(2);
- fusion.addInput(tv6);
+ computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
- auto tv7 = mul(tv5, tv6);
+ for (auto tv : affected_tensors) {
+ tv->axis(-1)->parallelize(ParallelType::TIDx);
+ }
- fusion.addOutput(tv7);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- tv7->split(1, 2);
- tv7->merge(0);
- tv7->split(0, 128, false);
- tv7->split(0, 4, false);
+ at::Tensor t0 = at::randn({1000}, options);
- tv7->axis(0)->parallelize(ParallelType::BIDx);
- tv7->axis(1)->parallelize(ParallelType::TIDx);
+ auto t1 = t0 * 0.5;
+ auto t2 = t1 * -1.0;
+ auto t3 = t1 * -2.0;
+ auto t4 = t2 + t3;
+ auto t5 = t4 * 5.0;
- // Reverse computeAt structure from previous test
- tv0->computeAt(tv4, -1);
- tv2->computeAt(tv4, -1);
- tv0->computeAt(tv7, -1);
+ at::Tensor kernel_tv3 = at::empty_like(t0, options);
+ at::Tensor kernel_tv4 = at::empty_like(t0, options);
+ at::Tensor kernel_tv5 = at::empty_like(t0, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
+ fe.runFusion({t0}, {kernel_tv3, kernel_tv4, kernel_tv5});
- const int numel_x = 100;
- const int numel_y = 200;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto t0 = at::randn({numel_x}, options);
- auto t2 = at::randn({numel_x}, options);
- auto t6 = at::randn({numel_x, numel_y}, options);
-
- auto t1 = t0.add(1.0);
- auto t3 = t2.add(3.0);
- auto t4 = t1.add(t3);
- auto t5 = t1.unsqueeze(1);
- auto t7 = t5.mul(t6);
-
- std::vector<IValue> aten_inputs = {t0, t2, t6};
- std::vector<at::Tensor> aten_outputs = {t4, t7};
-
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
+ TORCH_CHECK(at::allclose(kernel_tv3, t3));
+ TORCH_CHECK(at::allclose(kernel_tv4, t4));
+ TORCH_CHECK(at::allclose(kernel_tv5, t5));
}
-TEST(NVFuserTest, FusionAdvancedComputeWith1_CUDA) {
- // Case 1
+TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) {
// tv1 = tv0 * 0.5
// tv2 = tv1 * -1
- // tv3 = tv1 + 3
- // tv4 = tv1 * 2
- // tv5 = tv3 + tv2
- // tv6 = tv5 + tv4
- // tv7 = tv1 + tv4
+ // tv3 = tv2 * -1
+ // tv4 = tv1 + 4
+ // tv5 = tv3 + tv4
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- TensorView* tv1 = mul(tv0, new Double(0.5));
- TensorView* tv2 = mul(tv1, new Double(-1.0));
- TensorView* tv3 = add(tv1, new Double(3.0));
- TensorView* tv4 = mul(tv1, new Double(2.0));
- TensorView* tv5 = add(tv3, tv2);
-
- TensorView* tv6 = add(tv5, tv4);
- TensorView* tv7 = add(tv1, tv4);
-
- fusion.addOutput(tv6);
- fusion.addOutput(tv7);
-
- // Lets setup to actually run
- tv0->merge(0);
- tv0->split(0, 128);
- tv0->split(0, 4);
+ TensorView* tv1 = mul(tv0, new Float(0.5));
+ TensorView* tv2 = mul(tv1, new Float(-1.0));
+ TensorView* tv3 = mul(tv2, new Float(-1.0));
+ TensorView* tv4 = add(tv1, new Float(4.0));
+ TensorView* tv5 = add(tv3, tv4);
- tv0->axis(0)->parallelize(ParallelType::BIDx);
+ fusion.addOutput(tv5);
- tv0->computeWith(tv7, 1);
+ TensorView* computeAtTarget = tv3;
- GpuLower gpulw(&fusion);
+ computeAtTarget->merge(0);
+ computeAtTarget->split(0, 128);
+ computeAtTarget->split(0, 4);
- // The this-position of the last tensor should be zero.
- TORCH_CHECK(
- tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 &&
- tv7->getMaxProducerPosition() == 1);
- TORCH_CHECK(
- tv7->nDims() == 3 && tv6->getComputeAtPosition() == 0 &&
- tv6->getMaxProducerPosition() == 1);
+ computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
- // The position of every other tensor should be 1.
- for (auto tv : {tv1, tv2, tv3, tv4, tv5}) {
- TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1);
- TORCH_CHECK(gpulw.caLoopMap().areMapped(tv7->axis(0), tv->axis(0)));
- }
+ // This computeAt will affect all tensors including tv3, tv4 and
+ // tv5, even though it appears to impact only tv1 and tv2. The
+ // reason is that tv1 is now computed at tv3, so tv4 must also be
+ // computed at the same location. Similarly, the consumer of tv4,
+ // tv5, must also be computed at the same location. Overall, what
+ // will happen is basically we merge expressions of all tensors and
+ // compute them in a single loop nest. Internally, this will be
+ // realized by making all tensors, except for those in the path
+ // between tv1 and tv3, computed at tv5, which we call the common
+ // consumer.
+ tv1->computeAt(computeAtTarget, 1);
+ // All tensors should have the same dimenionality as the target
for (Val* val : fusion.vals()) {
- if (!fusion.hasInput(val) &&
- val->getValType().value() == ValType::TensorView) {
- TensorView* tv = static_cast<TensorView*>(val);
+ if (fusion.hasInput(val) ||
+ val->getValType().value() != ValType::TensorView) {
+ continue;
+ }
+ TensorView* tv = val->as<TensorView>();
+ TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
+ }
+
+ TORCH_CHECK(tv1->getComputeAtView() == tv2);
+ TORCH_CHECK(tv2->getComputeAtView() == tv3);
+ // tv3 and tv4 are computed at tv5
+ TORCH_CHECK(tv3->getComputeAtView() == tv5);
+ TORCH_CHECK(tv4->getComputeAtView() == tv5);
+ TORCH_CHECK(!tv5->hasComputeAt());
+
+ for (Val* val : fusion.vals()) {
+ if (!fusion.hasInput(val) &&
+ val->getValType().value() == ValType::TensorView) {
+ TensorView* tv = val->as<TensorView>();
tv->axis(1)->parallelize(ParallelType::Unroll);
tv->axis(-1)->parallelize(ParallelType::TIDx);
}
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({129, 127}, options);
+ at::Tensor t0 = at::randn({129, 127}, options);
- auto t1 = aten_input.mul({0.5});
+ auto t1 = t0.mul({0.5});
auto t2 = t1.mul({-1.0});
- auto t3 = t1.add({3.0});
- auto t4 = t1.mul({2.0});
- auto t5 = t3.add(t2);
- auto t6 = t5.add(t4);
- auto t7 = t1.add(t4);
+ auto t3 = t2.mul({-1.0});
+ auto t4 = t1.add({4.0});
+ auto t5 = t3 + t4;
- std::vector<at::Tensor> aten_outputs = {t6, t7};
- std::vector<at::Tensor> cg_outputs = {
- at::empty_like(aten_input, options), at::empty_like(aten_input, options)};
+ at::Tensor kernel_tv5 = at::empty_like(t0, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- fe.runFusion({aten_input}, cg_outputs);
+ fe.runFusion({t0}, {kernel_tv5});
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
+ TORCH_CHECK(at::allclose(kernel_tv5, t5));
}
-TEST(NVFuserTest, FusionAdvancedComputeWith2_CUDA) {
- // Case 2
- // tv1 = tv0 * -1
- // tv2 = tv0 + 3
- // tv3 = tv0 * 2
- // tv4 = tv2 + tv1
- // tv5 = tv4 + tv3
- // tv6 = tv5 + tv3
+// Similar to the above common consumer test but adds an additional
+// tensor that has no common consumer with the other tensors.
+TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) {
+ // tv1 = tv0 * 0.5
+ // tv2 = tv1 * -1
+ // tv3 = tv2 * -1
+ // tv4 = tv1 + 4
+ // tv5 = tv2 + tv3
+ // tv6 = tv1 + 6
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- TensorView* tv1 = mul(tv0, new Double(-1.0));
- TensorView* tv2 = add(tv0, new Double(3.0));
- TensorView* tv3 = mul(tv0, new Double(2.0));
- TensorView* tv4 = add(tv2, tv1);
-
- TensorView* tv5 = add(tv4, tv3);
- TensorView* tv6 = add(tv5, tv3);
+ TensorView* tv1 = mul(tv0, new Float(0.5));
+ TensorView* tv2 = mul(tv1, new Float(-1.0));
+ TensorView* tv3 = mul(tv2, new Float(-1.0));
+ TensorView* tv4 = add(tv1, new Float(4.0));
+ TensorView* tv5 = add(tv3, tv4);
+ TensorView* tv6 = add(tv1, new Float(6.0));
fusion.addOutput(tv5);
fusion.addOutput(tv6);
- // Lets setup to actually run
- tv0->merge(0);
- tv0->split(0, 128);
- tv0->split(0, 4);
+ TensorView* computeAtTarget = tv3;
+
+ computeAtTarget->merge(0);
+ computeAtTarget->split(0, 128);
+ computeAtTarget->split(0, 4);
+
+ computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
+
+ // This will have the same impact on the tensors except for tv5 and
+ // tv6. tv6 does not have any common consumer with the computeAt
+ // target, but since it uses tv1, it must be also computed at the
+ // same location as the other impacted tensors. We can either make
+ // tv5 computed at tv6 or tv6 computed at tv5. In this case, tv5
+ // should be computed at tv6 just because the current implementation
+ // orders the computeAt relationship based on the order in which
+ // tensors are specified as outputs.
+
+ tv1->computeAt(computeAtTarget, 1);
+
+ // All tensors should have the same dimenionality as the target
+ for (Val* val : fusion.vals()) {
+ if (fusion.hasInput(val) ||
+ val->getValType().value() != ValType::TensorView) {
+ continue;
+ }
+ TensorView* tv = val->as<TensorView>();
+ TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
+ }
- tv0->axis(0)->parallelize(ParallelType::BIDx);
+ TORCH_CHECK(tv1->getComputeAtView() == tv2);
+ TORCH_CHECK(tv2->getComputeAtView() == tv3);
- tv0->computeWith(tv6, 1);
+ // tv3 and tv4 are computed at tv5
+ TORCH_CHECK(tv3->getComputeAtView() == tv5);
+ TORCH_CHECK(tv4->getComputeAtView() == tv5);
+
+ // tv5 should be computed at tv6 since tv5 is added as an output
+ // before tv6. If we call fusion.addOutput(tv6) first, tv6 should be
+ // computed at tv5.
+ TORCH_CHECK(tv5->getComputeAtView() == tv6);
+ TORCH_CHECK(!tv6->hasComputeAt());
for (Val* val : fusion.vals()) {
if (!fusion.hasInput(val) &&
val->getValType().value() == ValType::TensorView) {
- TensorView* tv = static_cast<TensorView*>(val);
-
+ TensorView* tv = val->as<TensorView>();
tv->axis(1)->parallelize(ParallelType::Unroll);
tv->axis(-1)->parallelize(ParallelType::TIDx);
}
}
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({129, 127}, options);
- auto t1 = input.mul({-1.0});
- auto t2 = input.add({3.0});
- auto t3 = input.mul({2.0});
- auto t4 = t2.add(t1);
- auto t5 = t4.add(t3);
- auto t6 = t5.add(t3);
+ at::Tensor t0 = at::randn({129, 127}, options);
- std::vector<at::Tensor> aten_outputs = {t5, t6};
+ auto t1 = t0.mul({0.5});
+ auto t2 = t1.mul({-1.0});
+ auto t3 = t2.mul({-1.0});
+ auto t4 = t1.add({4.0});
+ auto t5 = t3 + t4;
+ auto t6 = t1.add({6.0});
+
+ at::Tensor kernel_tv5 = at::empty_like(t0, options);
+ at::Tensor kernel_tv6 = at::empty_like(t0, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input});
+ fe.runFusion({t0}, {kernel_tv5, kernel_tv6});
- testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__);
+ TORCH_CHECK(at::allclose(kernel_tv5, t5));
+ TORCH_CHECK(at::allclose(kernel_tv6, t6));
}
-TEST(NVFuserTest, FusionAdvancedComputeWith3_CUDA) {
- // Case 3
- // T2 = T1 * 0.979361
- // T3 = T2 * T0
+// Similar to ComputeAtCommonConsumer1 but with an addtiona ltensor
+// that does not have data dependency with the consumer.
+TEST(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) {
+ // tv1 = tv0 * 0.5
+ // tv2 = tv1 * -1
+ // tv3 = tv1 * -2
+ // tv4 = tv2 + tv3
+ // tv5 = tv4 * 5
+ // tv6 = tv1 * 6
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(4);
+ TensorView* tv0 = makeDummyTensor(1);
fusion.addInput(tv0);
- TensorView* tv1 = makeSymbolicTensor(4);
- fusion.addInput(tv1);
-
- TensorView* tv2 = mul(tv1, new Double(.979361));
- TensorView* tv3 = mul(tv2, tv0);
-
+ TensorView* tv1 = mul(tv0, new Float(0.5));
+ TensorView* tv2 = mul(tv1, new Float(-1.0));
+ TensorView* tv3 = mul(tv1, new Float(-2.0));
+ TensorView* tv4 = add(tv2, tv3);
+ TensorView* tv5 = mul(tv4, new Float(5.0));
+ // Notice that tv6 is not a consumer of tv4.
+ TensorView* tv6 = mul(tv1, new Float(6.0));
fusion.addOutput(tv3);
+ fusion.addOutput(tv4);
+ fusion.addOutput(tv5);
+ fusion.addOutput(tv6);
- // Lets setup to actually run
- while (tv0->nDims() > 1)
- tv0->merge(0);
- tv0->split(0, 128);
- tv0->split(0, 4);
-
- while (tv1->nDims() > 1)
- tv1->merge(0);
- tv1->split(0, 128);
- tv1->split(0, 4);
+ TensorView* computeAtTarget = tv3;
+ computeAtTarget->split(0, 128);
+ tv1->computeAt(computeAtTarget, 1);
- tv0->computeWith(tv3, 1);
- tv1->computeWith(tv3, 1);
+ TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4, tv6};
+ for (auto tv : affected_tensors) {
+ TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
+ }
- tv3->axis(0)->parallelize(ParallelType::BIDx);
+ TORCH_CHECK(tv1->getComputeAtView() == computeAtTarget);
+ TORCH_CHECK(tv2->getComputeAtView() == tv4);
+ TORCH_CHECK(tv3->getComputeAtView() == tv4);
+ TORCH_CHECK(tv4->getComputeAtView() == tv5);
+ TORCH_CHECK(tv5->getComputeAtView() == tv6);
+ TORCH_CHECK(!tv6->hasComputeAt());
- for (Val* val : fusion.vals()) {
- if (!fusion.hasInput(val) &&
- val->getValType().value() == ValType::TensorView) {
- TensorView* tv = static_cast<TensorView*>(val);
+ computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
- tv->axis(1)->parallelize(ParallelType::Unroll);
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
+ for (auto tv : affected_tensors) {
+ tv->axis(-1)->parallelize(ParallelType::TIDx);
}
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({129, 127, 63, 65}, options);
- at::Tensor t1 = at::rand_like(t0, options);
- auto t2 = t1.mul({0.979361});
- auto aten_output = t2.mul(t0);
+ at::Tensor t0 = at::randn({1000}, options);
- std::vector<IValue> aten_inputs = {t0, t1};
+ auto t1 = t0 * 0.5;
+ auto t2 = t1 * -1.0;
+ auto t3 = t1 * -2.0;
+ auto t4 = t2 + t3;
+ auto t5 = t4 * 5.0;
+ auto t6 = t1 * 6.0;
- at::Tensor cg_output = at::empty_like(t0, options);
+ at::Tensor kernel_tv3 = at::empty_like(t0, options);
+ at::Tensor kernel_tv4 = at::empty_like(t0, options);
+ at::Tensor kernel_tv5 = at::empty_like(t0, options);
+ at::Tensor kernel_tv6 = at::empty_like(t0, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- fe.runFusion(aten_inputs, {cg_output});
+ fe.runFusion({t0}, {kernel_tv3, kernel_tv4, kernel_tv5, kernel_tv6});
- testValidate(
- &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ TORCH_CHECK(at::allclose(kernel_tv3, t3));
+ TORCH_CHECK(at::allclose(kernel_tv4, t4));
+ TORCH_CHECK(at::allclose(kernel_tv5, t5));
+ TORCH_CHECK(at::allclose(kernel_tv6, t6));
}
-TEST(NVFuserTest, FusionAdvancedComputeWith4_CUDA) {
- // Case 4
- // T4 = T2 - T3
- // T5 = T1 + T4
- // T6 = T5 - T0
- Fusion fusion;
- FusionGuard fg(&fusion);
+namespace {
- TensorView* tv0 = makeSymbolicTensor(4);
- fusion.addInput(tv0);
+void checkConcretized(
+ TensorView* v0,
+ int a0,
+ TensorView* v1,
+ int a1,
+ bool should_concretize) {
+ if (should_concretize) {
+ TORCH_CHECK(
+ IterDomain::concretizeDomain(v0->axis(a0))->sameAs(v1->axis(a1)));
+ } else {
+ TORCH_CHECK(
+ !IterDomain::concretizeDomain(v0->axis(a0))->sameAs(v1->axis(a1)));
+ }
+}
- TensorView* tv1 = makeSymbolicTensor(4);
- fusion.addInput(tv1);
+} // namespace
- TensorView* tv2 = makeSymbolicTensor(4);
- fusion.addInput(tv2);
+TEST(NVFuserTest, FusionBCastConcretizeBasic_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- TensorView* tv3 = makeSymbolicTensor(4);
- fusion.addInput(tv3);
+ // tv0: [I I]
+ TensorView* tv0 = makeDummyTensor(2);
- TensorView* tv4 = sub(tv2, tv3);
- TensorView* tv5 = add(tv1, tv4);
- TensorView* tv6 = sub(tv5, tv0);
+ // tv1: [I I I]
+ TensorView* tv1 = makeDummyTensor(3);
- fusion.addOutput(tv6);
- std::vector<TensorView*> tvs = {tv0, tv1, tv2};
- for (auto tv : tvs) {
- // Lets setup to actually run
- while (tv->nDims() > 1) {
- tv->merge(0);
- }
- tv->split(0, 128);
- tv->split(0, 4);
- tv->computeWith(tv6, 1);
- }
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
- tv6->axis(0)->parallelize(ParallelType::BIDx);
+ // tv2*: [B I I]
+ auto tv2_0 = broadcast(tv0, {true, false, false});
+ auto tv2_1 = broadcast(tv0, {true, false, false});
+ auto tv2 = add(tv2_0, tv2_1);
- for (Val* val : fusion.vals()) {
- if (!fusion.hasInput(val) &&
- val->getValType().value() == ValType::TensorView) {
- TensorView* tv = static_cast<TensorView*>(val);
+ // tv3: [I I I]
+ auto tv3 = add(tv2, tv1);
- tv->axis(1)->parallelize(ParallelType::Unroll);
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
- }
+ fusion.addOutput(tv3);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({129, 127, 63, 65}, options);
- at::Tensor t1 = at::rand_like(t0, options);
- at::Tensor t2 = at::rand_like(t0, options);
- at::Tensor t3 = at::rand_like(t0, options);
+ checkConcretized(tv2, 0, tv1, 0, true);
+ checkConcretized(tv2_0, 0, tv1, 0, true);
+ checkConcretized(tv2_1, 0, tv1, 0, true);
+ checkConcretized(tv2_0, 1, tv1, 0, false);
+ checkConcretized(tv2_0, 0, tv1, 1, false);
+}
- auto t4 = t2.sub(t3);
- auto t5 = t1.add(t4);
- auto aten_output = t5.sub(t0);
+TEST(NVFuserTest, FusionBCastConcretizeRfactor_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- std::vector<IValue> aten_inputs = {t0, t1, t2, t3};
+ // both tv0 and tv1 = [I, I]
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ //[B,I,I]
+ auto tv2 = broadcast(tv1, {true, false, false});
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+ //[B,I,R]
+ auto tv3 = sum(tv2, {2});
-TEST(NVFuserTest, FusionAdvancedComputeWith5_CUDA) {
- // Case 5
- // tv2 = tv0 + 2.0
- // tv3 = tv1 * tv2
- Fusion fusion;
- FusionGuard fg(&fusion);
+ auto tv5 = add(tv3, tv1);
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);
- TensorView* tv1 = makeSymbolicTensor(2);
fusion.addInput(tv1);
- TensorView* tv2 = add(tv0, new Double(2.0));
- TensorView* tv3 = mul(tv1, tv2);
- fusion.addOutput(tv3);
-
- tv2->merge(0);
- tv2->split(-1, 8);
- tv2->split(-1, 4);
-
- tv2->computeWith(tv3, 1);
- tv3->axis(0)->parallelize(ParallelType::BIDx);
+ fusion.addOutput(tv5);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({63, 65}, options);
- at::Tensor t1 = at::rand_like(t0, options);
+ // scheduling:
+ //[B,I,R0,R1=128], root = [B,I,R]
+ tv3->split(2, 128);
- auto t2 = t0.add(2.0);
- auto aten_output = t1.mul(t2);
+ // root=[B,I,Irf], rfactor=[B,I,Irf,Rrf]
+ auto tv4 = tv3->rFactor({3});
- std::vector<IValue> aten_inputs = {t0, t1};
+ checkConcretized(tv2, 0, tv5, 0, true);
+ checkConcretized(tv4, 0, tv5, 0, true);
+ checkConcretized(tv3, 0, tv5, 0, true);
+}
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+namespace {
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+void checkIdProvedEquivalent(
+ TensorView* v0,
+ int a0,
+ TensorView* v1,
+ int a1,
+ bool should_prove) {
+ if (should_prove) {
+ TORCH_CHECK(IterDomain::proveEquivalent(v0->axis(a0), v1->axis(a1)));
+ } else {
+ TORCH_CHECK(!IterDomain::proveEquivalent(v0->axis(a0), v1->axis(a1)));
+ }
}
-TEST(NVFuserTest, FusionAdvancedComputeWith6_CUDA) {
+} // namespace
+
+TEST(NVFuserTest, FusionProveIdEqBasic_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
+ TensorView* tv2 = makeDummyTensor(3);
+
fusion.addInput(tv0);
- TensorView* tv1 = makeSymbolicTensor(2);
fusion.addInput(tv1);
- TensorView* tv2 = add(tv0, new Double(2.0));
- TensorView* tv3 = mul(tv1, tv2);
- fusion.addOutput(tv3);
+ auto tv3 = broadcast(tv0, {true, false, false});
+ auto tv4 = broadcast(tv1, {false, true, false});
+ auto tv5 = add(tv3, tv4);
+ fusion.addOutput(tv5);
- tv2->merge(0);
- tv2->split(-1, 8);
- tv2->split(-1, 4);
- tv3->merge(0);
- tv3->split(-1, 8);
+ checkIdProvedEquivalent(tv0, 0, tv4, 1, true);
+ checkIdProvedEquivalent(tv1, 0, tv4, 0, true);
+ checkIdProvedEquivalent(tv1, 1, tv0, 1, true);
+ checkIdProvedEquivalent(tv0, 0, tv5, 1, true);
+ checkIdProvedEquivalent(tv1, 1, tv5, 2, true);
+ checkIdProvedEquivalent(tv0, 0, tv1, 0, false);
+ checkIdProvedEquivalent(tv0, 1, tv1, 0, false);
+ checkIdProvedEquivalent(tv0, 0, tv1, 1, false);
+}
- tv2->computeWith(tv3, 1);
+TEST(NVFuserTest, FusionProveIdEqRfactor_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- tv3->axis(0)->parallelize(ParallelType::BIDx);
+ // [I,I]
+ TensorView* tv0 = makeDummyTensor(2);
+ // [I,I,I]
+ TensorView* tv1 = makeDummyTensor(3);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({63, 65}, options);
- at::Tensor t1 = at::rand_like(t0, options);
+ //[I,I,R]
+ auto tv2 = sum(tv1, {2});
- auto t2 = t0.add(2.0);
- auto aten_output = t1.mul(t2);
+ auto tv5 = add(tv2, tv0);
+
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
+ fusion.addOutput(tv5);
- std::vector<IValue> aten_inputs = {t0, t1};
+ // scheduling:
+ //[B,I,R0,R1=128], root = [B,I,R]
+ tv2->split(2, 128);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ // root=[B,I,Irf], rfactor=[B,I,Irf,Rrf]
+ auto tv3 = tv2->rFactor({3});
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ checkIdProvedEquivalent(tv1, 0, tv0, 0, true);
+ checkIdProvedEquivalent(tv2, 0, tv0, 0, true);
+ checkIdProvedEquivalent(tv3, 0, tv0, 0, true);
}
-TEST(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) {
- // tv1 = tv0 * 0.5
- // tv2 = tv1 * -1
- // tv3 = tv2 * -2
+TEST(NVFuserTest, FusionScalarInputs_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(1);
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
+ TensorView* tv1 = makeDummyTensor(2);
+ fusion.addInput(tv1);
- TensorView* tv1 = mul(tv0, new Double(0.5));
- TensorView* tv2 = mul(tv1, new Double(-1.0));
- TensorView* tv3 = mul(tv1, new Double(-2.0));
- fusion.addOutput(tv2);
- fusion.addOutput(tv3);
-
- // This computeAt will affect tv2 as well, even though tv2 is not in
- // the data-flow path between tv1 and tv3. The reason is that tv1 is
- // now computed at tv3, so tv2 must also be computed at the same
- // location. Overall, what will happen is basically we merge
- // expressions of all tensors and compute them in a single loop
- // nest.
- TensorView* computeAtTarget = tv3;
- computeAtTarget->split(0, 128);
- tv1->computeAt(computeAtTarget, 1);
+ Float* f0 = new Float();
+ fusion.addInput(f0);
+ Float* f1 = new Float();
+ fusion.addInput(f1);
+ Float* f2 = new Float();
+ fusion.addInput(f2);
+ Float* f3 = new Float();
+ fusion.addInput(f3);
+ Val* f4 = mul(f0, f1);
+ Val* f5 = sub(f2, f3);
+
+ TensorView* tv2 = sub(tv1, f4);
+ TensorView* tv3 = add(tv0, f5);
+ TensorView* tv4 = mul(tv3, tv2);
- TensorView* affected_tensors[] = {tv1, tv2, tv3};
- for (auto tv : affected_tensors) {
- TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
- }
+ fusion.addOutput(tv4);
- GpuLower gpulw(&fusion);
+ // Lets setup to actually run
+ while (tv4->nDims() > 1)
+ tv4->merge(0);
+ tv4->split(0, 128);
+ tv4->split(0, 4);
- TORCH_CHECK(tv1->getComputeAtPosition() == 1);
- TORCH_CHECK(
- tv2->getComputeAtPosition() == 0 && tv2->getMaxProducerPosition() == 1);
- TORCH_CHECK(
- tv3->getComputeAtPosition() == 0 && tv3->getMaxProducerPosition() == 1);
+ tv0->computeAt(tv4, 1);
+ tv1->computeAt(tv4, 1);
- // Note that tv2 is also computed at tv3.
- for (auto tv : {tv1, tv2}) {
- TORCH_CHECK(
- gpulw.caLoopMap().areMapped(tv->axis(0), computeAtTarget->axis(0)));
- }
+ tv4->axis(0)->parallelize(ParallelType::BIDx);
- TORCH_CHECK(tv3->getComputeAtPosition() == 0);
+ for (Val* val : fusion.vals()) {
+ if (!fusion.hasInput(val) &&
+ val->getValType().value() == ValType::TensorView) {
+ TensorView* tv = static_cast<TensorView*>(val);
- computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
- for (auto tv : affected_tensors) {
- tv->axis(-1)->parallelize(ParallelType::TIDx);
+ tv->axis(1)->parallelize(ParallelType::Unroll);
+ tv->axis(-1)->parallelize(ParallelType::TIDx);
+ }
}
+ // f4 = f0 * f1
+ // f5 = f2 - f3
+ // t2 = t1 - f4
+ // t3 = t0 + f5
+ // t4 = t3 * t2
+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({1000}, options);
+ float fl0 = 0.1;
+ float fl1 = -0.2;
+ float fl2 = 0.3;
+ float fl3 = -0.4;
+ float fl4 = fl0 * fl1;
+ float fl5 = fl2 - fl3;
+
+ at::Tensor t0 = at::randn({129, 127}, options);
+ at::Tensor t1 = at::rand_like(t0, options);
- auto t1 = aten_input * 0.5;
- auto t2 = t1 * -1.0;
- auto t3 = t1 * -2.0;
+ auto t2 = t1.sub(fl4);
+ auto t3 = t0.add(fl5);
+ auto t4 = t3.mul(t2);
- std::vector<at::Tensor> aten_outputs = {t2, t3};
+ at::Tensor kernel_tv4 = at::empty_like(t0, options);
- std::vector<at::Tensor> cg_outputs = {
- at::empty_like(aten_input, options), at::empty_like(aten_input, options)};
+ at::Scalar test(fl0);
FusionExecutor fe;
fe.compileFusion(&fusion);
- fe.runFusion({aten_input}, cg_outputs);
+ fe.runFusion(
+ {t0,
+ t1,
+ at::Scalar(fl0),
+ at::Scalar(fl1),
+ at::Scalar(fl2),
+ at::Scalar(fl3)},
+ {kernel_tv4});
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
+ TORCH_CHECK(at::allclose(kernel_tv4, t4));
}
-// Similar to ComputeAtMultiConsumers, but with a common consumer.
-TEST(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) {
- // tv1 = tv0 * 0.5
- // tv2 = tv1 * -1
- // tv3 = tv2 * -2
- // tv4 = tv2 + tv3
- // tv5 = tv4 * 5
+TEST(NVFuserTest, FusionLoopUnroll_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(1);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(3);
+ TensorView* tv1 = makeDummyTensor(3);
+
+ // Register your inputs
fusion.addInput(tv0);
+ fusion.addInput(tv1);
- TensorView* tv1 = mul(tv0, new Double(0.5));
- TensorView* tv2 = mul(tv1, new Double(-1.0));
- TensorView* tv3 = mul(tv1, new Double(-2.0));
- TensorView* tv4 = add(tv2, tv3);
- TensorView* tv5 = mul(tv4, new Double(5.0));
- fusion.addOutput(tv3);
- fusion.addOutput(tv4);
- fusion.addOutput(tv5);
+ // Do math with it, it returns a `Val*` but can be static_casted back to
+ // TensorView
+ TensorView* tv2 = add(tv1, new Float(2.0));
+ TensorView* tv3 = add(tv0, tv2);
- // Computing tv1 at tv3. This will affect tv2 as discussed in
- // ComplexComputeAt1. Additionally, in this case, notice that tv4 is
- // the common consumer of tv2 and tv3, so they are computed at
- // tv4. The indirect propagation of the computeAt should stop at the
- // common consumer, and no further change should occur. More
- // specifically, the computeAT position of tv4 and tv5 should be zero.
- TensorView* computeAtTarget = tv3;
- computeAtTarget->split(0, 128);
- tv1->computeAt(computeAtTarget, 1);
+ // Register your outputs
+ fusion.addOutput(tv3);
- TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4};
- for (auto tv : affected_tensors) {
- TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
- }
+ int block_size = 16;
- TORCH_CHECK(tv1->getComputeAtPosition() == 1);
- TORCH_CHECK(tv2->getComputeAtPosition() == 1);
- TORCH_CHECK(tv3->getComputeAtPosition() == 1);
- TORCH_CHECK(tv4->getComputeAtPosition() == 0);
- TORCH_CHECK(tv5->getComputeAtPosition() == 0);
+ tv3->merge(0, 1);
+ tv3->merge(0, 1);
- computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
+ tv3->split(0, block_size);
+ tv3->split(0, 4);
- for (auto tv : affected_tensors) {
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
+ // For all inputs, computeAt the output inline, temporaries should be squeezed
+ // between them
+ tv0->computeAt(tv3, 1);
+ tv1->computeAt(tv3, 1);
- // Transform tv5 to make it look like the rest
- tv5->split(0, 128);
- tv5->axis(1)->parallelize(ParallelType::TIDx);
- tv5->axis(0)->parallelize(ParallelType::BIDx);
+ // Parallelize
+ tv2->axis(1)->parallelize(ParallelType::Unroll);
+ tv3->axis(1)->parallelize(ParallelType::Unroll);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(0)->parallelize(ParallelType::BIDx);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({1000}, options);
-
- auto t1 = aten_input * 0.5;
- auto t2 = t1 * -1.0;
- auto t3 = t1 * -2.0;
- auto t4 = t2 + t3;
- auto t5 = t4 * 5.0;
-
- std::vector<at::Tensor> aten_outputs = {t3, t4, t5};
- std::vector<at::Tensor> cg_outputs = {
- at::empty_like(aten_input, options),
- at::empty_like(aten_input, options),
- at::empty_like(aten_input, options)};
+ at::Tensor input0 = at::rand({129, 13, 3}, options);
+ at::Tensor input1 = at::rand({129, 13, 3}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- fe.runFusion({aten_input}, cg_outputs);
+ auto outputs = fe.runFusion({input0, input1});
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
+ TORCH_CHECK(outputs[0].equal(input0.add(input1.add(2.0))));
}
-TEST(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) {
- // tv1 = tv0 * 0.5
- // tv2 = tv1 * -1
- // tv3 = tv2 * -1
- // tv4 = tv1 + 4
- // tv5 = tv3 + tv4
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- TensorView* tv1 = mul(tv0, new Double(0.5));
- TensorView* tv2 = mul(tv1, new Double(-1.0));
- TensorView* tv3 = mul(tv2, new Double(-1.0));
- TensorView* tv4 = add(tv1, new Double(4.0));
- TensorView* tv5 = add(tv3, tv4);
-
- fusion.addOutput(tv5);
-
- TensorView* computeAtTarget = tv3;
-
- computeAtTarget->merge(0);
- computeAtTarget->split(0, 128);
- computeAtTarget->split(0, 4);
+/*
+ * Helper function for single op testing that generates a codegen operand
+ */
- computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
+Val* gen_jit_operand(std::pair<ValType, DataType> desc) {
+ if (desc.first == ValType::TensorView) {
+ return makeDummyTensor(2, desc.second);
+ } else if (desc.first == ValType::Scalar) {
+ if (desc.second == DataType::Float)
+ return new Float();
+ else if (desc.second == DataType::Int)
+ return new Int();
+ else
+ TORCH_CHECK(false, "Not currently supported type", desc.first);
+ } else {
+ TORCH_CHECK(false, "Not currently supported type", desc.first);
+ }
+ return nullptr;
+}
- // This computeAt will affect all tensors including tv3, tv4 and
- // tv5, even though it appears to impact only tv1 and tv2. The
- // reason is that tv1 is now computed at tv3, so tv4 must also be
- // computed at the same location. Similarly, the consumer of tv4,
- // tv5, must also be computed at the same location. Overall, what
- // will happen is basically we merge expressions of all tensors and
- // compute them in a single loop nest. Internally, this will be
- // realized by making all tensors, except for those in the path
- // between tv1 and tv3, computed at tv5, which we call the common
- // consumer.
- tv1->computeAt(computeAtTarget, 1);
+/*
+ * Helper function for single op testing that generates an ATen operand
+ */
- // All tensors should have the same dimenionality as the target
- for (Val* val : fusion.vals()) {
- if (fusion.hasInput(val) ||
- val->getValType().value() != ValType::TensorView) {
- continue;
- }
- TensorView* tv = val->as<TensorView>();
- TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
- if (tv == tv5) {
- TORCH_CHECK(tv->getComputeAtPosition() == 0);
+IValue gen_aten_operand(
+ std::pair<ValType, DataType> desc,
+ int blocks,
+ int threads,
+ bool rand) {
+ if (desc.first == ValType::TensorView) {
+ if (desc.second == DataType::Float) {
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ if (rand)
+ return IValue(at::rand({blocks, threads}, options));
+ else
+ return IValue(at::empty({blocks, threads}, options));
+ } else if (desc.second == DataType::Half) {
+ auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
+ if (rand)
+ return IValue(at::rand({blocks, threads}, options));
+ else
+ return IValue(at::empty({blocks, threads}, options));
+ } else if (desc.second == DataType::Bool) {
+ if (rand) {
+ auto options =
+ at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ return IValue(at::rand({blocks, threads}, options).to(at::kBool));
+ } else {
+ auto options =
+ at::TensorOptions().dtype(at::kBool).device(at::kCUDA, 0);
+ return IValue(at::empty({blocks, threads}, options));
+ }
} else {
- TORCH_CHECK(tv->getComputeAtPosition() == 1);
- }
- }
-
- for (auto tv : ir_utils::filterByType<TensorView>(fusion.vals())) {
- if (!fusion.hasInput(tv)) {
- tv->axis(1)->parallelize(ParallelType::Unroll);
- tv->axis(-1)->parallelize(ParallelType::TIDx);
+ TORCH_CHECK("Not currently supported type", desc.second)
}
+ } else if (desc.first == ValType::Scalar) {
+ if (desc.second == DataType::Float)
+ return IValue(at::Scalar(1.f));
+ else if (desc.second == DataType::Int)
+ return IValue(at::Scalar(1));
+ else
+ TORCH_CHECK("Not currently supported type", desc.first);
+ } else {
+ TORCH_CHECK("Not currently supported type", desc.first);
}
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor aten_input = at::randn({129, 127}, options);
-
- auto t1 = aten_input.mul({0.5});
- auto t2 = t1.mul({-1.0});
- auto t3 = t2.mul({-1.0});
- auto t4 = t1.add({4.0});
- auto aten_output = t3 + t4;
-
- at::Tensor cg_output = at::empty_like(aten_input, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({aten_input}, {cg_output});
-
- testValidate(
- &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
+ return nullptr;
}
-// Similar to the above common consumer test but adds an additional
-// tensor that has no common consumer with the other tensors.
-TEST(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) {
- // tv1 = tv0 * 0.5
- // tv2 = tv1 * -1
- // tv3 = tv2 * -1
- // tv4 = tv1 + 4
- // tv5 = tv2 + tv3
- // tv6 = tv1 + 6
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- TensorView* tv1 = mul(tv0, new Double(0.5));
- TensorView* tv2 = mul(tv1, new Double(-1.0));
- TensorView* tv3 = mul(tv2, new Double(-1.0));
- TensorView* tv4 = add(tv1, new Double(4.0));
- TensorView* tv5 = add(tv3, tv4);
- TensorView* tv6 = add(tv1, new Double(6.0));
-
- fusion.addOutput(tv5);
- fusion.addOutput(tv6);
-
- TensorView* computeAtTarget = tv3;
-
- computeAtTarget->merge(0);
- computeAtTarget->split(0, 128);
- computeAtTarget->split(0, 4);
-
- computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
-
- // This will have the same impact on the tensors except for tv5 and
- // tv6. tv6 does not have any common consumer with the computeAt
- // target, but since it uses tv1, it must be also computed at the
- // same location as the other impacted tensors. We can either make
- // tv5 computed at tv6 or tv6 computed at tv5. In this case, tv5
- // should be computed at tv6 just because the current implementation
- // orders the computeAt relationship based on the order in which
- // tensors are specified as outputs.
-
- tv1->computeAt(computeAtTarget, 1);
-
- // All tensors should have the same dimenionality as the target
- for (auto tv : ir_utils::filterByType<TensorView>(fusion.vals())) {
- if (fusion.hasInput(tv)) {
- continue;
- }
- TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
- if (tv == tv5 || tv == tv6) {
- TORCH_CHECK(tv->getComputeAtPosition() == 0);
- TORCH_CHECK(tv->getMaxProducerPosition() == 1);
- } else {
- TORCH_CHECK(tv->getComputeAtPosition() == 1);
- }
- }
-
- for (Val* val : fusion.vals()) {
- if (!fusion.hasInput(val) &&
- val->getValType().value() == ValType::TensorView) {
- TensorView* tv = val->as<TensorView>();
- tv->axis(1)->parallelize(ParallelType::Unroll);
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
- }
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor aten_input = at::randn({129, 127}, options);
-
- auto t1 = aten_input.mul({0.5});
- auto t2 = t1.mul({-1.0});
- auto t3 = t2.mul({-1.0});
- auto t4 = t1.add({4.0});
- auto t5 = t3 + t4;
- auto t6 = t1.add({6.0});
-
- std::vector<at::Tensor> aten_outputs = {t5, t6};
- std::vector<at::Tensor> cg_outputs = {
- at::empty_like(aten_input, options), at::empty_like(aten_input, options)};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({aten_input}, cg_outputs);
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-// Similar to ComputeAtCommonConsumer1 but with an addtiona ltensor
-// that does not have data dependency with the consumer.
-TEST(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) {
- // tv1 = tv0 * 0.5
- // tv2 = tv1 * -1
- // tv3 = tv1 * -2
- // tv4 = tv2 + tv3
- // tv5 = tv4 * 5
- // tv6 = tv1 * 6
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
-
- TensorView* tv1 = mul(tv0, new Double(0.5));
- TensorView* tv2 = mul(tv1, new Double(-1.0));
- TensorView* tv3 = mul(tv1, new Double(-2.0));
- TensorView* tv4 = add(tv2, tv3);
- TensorView* tv5 = mul(tv4, new Double(5.0));
- // Notice that tv6 is not a consumer of tv4.
- TensorView* tv6 = mul(tv1, new Double(6.0));
- fusion.addOutput(tv3);
- fusion.addOutput(tv4);
- fusion.addOutput(tv5);
- fusion.addOutput(tv6);
-
- TensorView* computeAtTarget = tv3;
- computeAtTarget->split(0, 128);
- tv1->computeAt(computeAtTarget, 1);
-
- TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4, tv5, tv6};
- for (auto tv : affected_tensors) {
- TORCH_CHECK(tv->nDims() == computeAtTarget->nDims());
- if (tv == tv6 || tv == tv5) {
- TORCH_CHECK(tv->getComputeAtPosition() == 0);
- } else {
- TORCH_CHECK(tv->getComputeAtPosition() == 1);
- }
- }
-
- computeAtTarget->axis(0)->parallelize(ParallelType::BIDx);
-
- for (auto tv : affected_tensors) {
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor aten_input = at::randn({1000}, options);
-
- auto t1 = aten_input * 0.5;
- auto t2 = t1 * -1.0;
- auto t3 = t1 * -2.0;
- auto t4 = t2 + t3;
- auto t5 = t4 * 5.0;
- auto t6 = t1 * 6.0;
-
- std::vector<at::Tensor> aten_outputs = {t3, t4, t5, t6};
- std::vector<at::Tensor> cg_outputs = {
- at::empty_like(aten_input, options),
- at::empty_like(aten_input, options),
- at::empty_like(aten_input, options),
- at::empty_like(aten_input, options)};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({aten_input}, cg_outputs);
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-namespace {
-
-void checkConcretized(
- TensorView* v0,
- int a0,
- TensorView* v1,
- int a1,
- bool should_concretize) {
- if (should_concretize) {
- TORCH_CHECK(
- IterDomain::concretizeDomain(v0->axis(a0))->sameAs(v1->axis(a1)));
- } else {
- TORCH_CHECK(
- !IterDomain::concretizeDomain(v0->axis(a0))->sameAs(v1->axis(a1)));
- }
-}
-
-} // namespace
-
-TEST(NVFuserTest, FusionBCastConcretizeBasic_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // tv0: [I I]
- TensorView* tv0 = makeSymbolicTensor(2);
-
- // tv1: [I I I]
- TensorView* tv1 = makeSymbolicTensor(3);
-
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- // tv2*: [B I I]
- auto tv2_0 = broadcast(tv0, {true, false, false});
- auto tv2_1 = broadcast(tv0, {true, false, false});
- auto tv2 = add(tv2_0, tv2_1);
-
- // tv3: [I I I]
- auto tv3 = add(tv2, tv1);
-
- fusion.addOutput(tv3);
-
- checkConcretized(tv2, 0, tv1, 0, true);
- checkConcretized(tv2_0, 0, tv1, 0, true);
- checkConcretized(tv2_1, 0, tv1, 0, true);
- checkConcretized(tv2_0, 1, tv1, 0, false);
- checkConcretized(tv2_0, 0, tv1, 1, false);
-}
-
-TEST(NVFuserTest, FusionBCastConcretizeRfactor_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // both tv0 and tv1 = [I, I]
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
-
- //[B,I,I]
- auto tv2 = broadcast(tv1, {true, false, false});
-
- //[B,I,R]
- auto tv3 = sum(tv2, {2});
-
- auto tv5 = add(tv3, tv1);
-
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addOutput(tv5);
-
- // scheduling:
- //[B,I,R0,R1=128], root = [B,I,R]
- tv3->split(2, 128);
-
- // root=[B,I,Irf], rfactor=[B,I,Irf,Rrf]
- auto tv4 = tv3->rFactor({3});
-
- checkConcretized(tv2, 0, tv5, 0, true);
- checkConcretized(tv4, 0, tv5, 0, true);
- checkConcretized(tv3, 0, tv5, 0, true);
-}
-
-namespace {
-
-void checkIdMapped(
- ComputeAtRootDomainMap& root_map,
- TensorView* v0,
- IterDomain* id0,
- TensorView* v1,
- IterDomain* id1,
- bool should_map) {
- if (should_map) {
- TORCH_CHECK(
- root_map.canMap(v0->domain(), id0, v1->domain(), id1),
- "Should be mappable: ",
- id0,
- " of ",
- v0,
- " and ",
- id1,
- " of ",
- v1);
- } else {
- TORCH_CHECK(
- !root_map.canMap(v0->domain(), id0, v1->domain(), id1),
- "Should not be mappable: ",
- id0,
- " of ",
- v0,
- " and ",
- id1,
- " of ",
- v1);
- }
-}
-
-void checkIdMapped(
- TensorView* v0,
- const std::vector<IterDomain*>& root0,
- const std::vector<bool> should_map0,
- TensorView* v1,
- const std::vector<IterDomain*>& root1,
- const std::vector<bool> should_map1) {
- ComputeAtRootDomainMap map;
- map.build();
- TORCH_INTERNAL_ASSERT(root0.size() == should_map0.size());
- TORCH_INTERNAL_ASSERT(root1.size() == should_map1.size());
- size_t idx0 = 0;
- for (size_t i = 0; i < root0.size(); ++i) {
- size_t idx1 = 0;
- for (size_t j = 0; j < root1.size(); ++j) {
- if (should_map0[i] && should_map1[j] && idx0 == idx1) {
- checkIdMapped(map, v0, root0[i], v1, root1[j], true);
- } else {
- checkIdMapped(map, v0, root0[i], v1, root1[j], false);
- }
- if (should_map1[j])
- ++idx1;
- }
- if (should_map0[i])
- ++idx0;
- }
-}
-
-void checkIdMapped(
- TensorView* v0,
- const std::vector<IterDomain*>& root0,
- TensorView* v1,
- const std::vector<IterDomain*>& root1) {
- checkIdMapped(
- v0,
- root0,
- std::vector<bool>(root0.size(), true),
- v1,
- root1,
- std::vector<bool>(root1.size(), true));
-}
-
-} // namespace
-
-TEST(NVFuserTest, FusionRootMappingBasic_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
-
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- auto tv3 = broadcast(tv0, {true, false, false});
- auto tv4 = broadcast(tv1, {false, true, false});
- auto tv5 = add(tv3, tv4);
- fusion.addOutput(tv5);
-
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, true},
- tv4,
- tv4->getRootDomain(),
- {false, true, true});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, true},
- tv4,
- tv4->getRootDomain(),
- {true, false, true});
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {false, true},
- tv1,
- tv1->getRootDomain(),
- {false, true});
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, true},
- tv5,
- tv5->getRootDomain(),
- {false, true, true});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, true},
- tv5,
- tv5->getRootDomain(),
- {true, false, true});
- checkIdMapped(tv3, tv3->getRootDomain(), tv4, tv4->getRootDomain());
- checkIdMapped(tv3, tv3->getRootDomain(), tv5, tv5->getRootDomain());
- checkIdMapped(tv4, tv4->getRootDomain(), tv5, tv5->getRootDomain());
-}
-
-TEST(NVFuserTest, FusionRootMappingRfactor_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // [I,I]
- TensorView* tv0 = makeSymbolicTensor(2);
- // [I,I,I]
- TensorView* tv1 = makeSymbolicTensor(3);
-
- //[I,I,R]
- auto tv2 = sum(tv1, {2});
- auto tv3 = add(tv2, tv0);
-
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addOutput(tv3);
-
- // scheduling:
- //[B,I,R0,R1=128], root = [B,I,R]
- tv2->split(2, 128);
-
- // root=[B,I,Irf], rfactor=[B,I,Irf,Rrf]
- auto tv4 = tv2->rFactor({3});
-
- checkIdMapped(tv1, tv1->getRootDomain(), tv4, tv4->getRootDomain());
- checkIdMapped(
- tv4,
- tv4->getRFactorDomain(),
- {true, true, true, false},
- tv2,
- tv2->getRootDomain(),
- {true, true, true});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, true, false},
- tv2,
- tv2->getRootDomain(),
- {true, true, false});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, true, false},
- tv3,
- tv3->getRootDomain(),
- {true, true});
- checkIdMapped(
- tv2,
- tv2->getRootDomain(),
- {true, true, false},
- tv3,
- tv3->getRootDomain(),
- {true, true});
- checkIdMapped(tv0, tv0->getRootDomain(), tv3, tv3->getRootDomain());
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, true},
- tv1,
- tv1->getRootDomain(),
- {true, true, false});
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, true},
- tv2,
- tv2->getRootDomain(),
- {true, true, false});
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, true},
- tv4,
- tv4->getRFactorDomain(),
- {true, true, false, false});
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, true},
- tv4,
- tv4->getRootDomain(),
- {true, true, false});
-}
-
-TEST(NVFuserTest, FusionRootMappingReductionDependency1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- auto tv1 = sum(tv0, {1});
- auto tv2 = broadcast(tv1, {false, true});
- fusion.addOutput(tv2);
-
- // The second dimension cannot be mapped as it would require recomputation.
- checkIdMapped(tv0, tv0->getRootDomain(), tv1, tv1->getRootDomain());
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, false},
- tv2,
- tv2->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, false},
- tv2,
- tv2->getRootDomain(),
- {true, false});
-}
-
-TEST(NVFuserTest, FusionRootMappingReductionDependency2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- auto tv1 = sum(tv0, {1});
- auto tv2 = broadcast(tv1, {false, true});
- auto tv3 = add(tv0, tv2);
- fusion.addOutput(tv3);
-
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, false},
- tv1,
- tv1->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, false},
- tv2,
- tv2->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, false},
- tv3,
- tv3->getRootDomain(),
- {true, false});
- checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain());
-}
-
-TEST(NVFuserTest, FusionRootMappingReductionDependency3_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- auto tv1 = sum(tv0, {1});
- auto tv2 = broadcast(tv1, {false, true});
- fusion.addOutput(tv2);
-
- tv1->split(-1, 4);
- auto tv3 = tv1->rFactor({-2});
-
- checkIdMapped(tv0, tv0->getRootDomain(), tv3, tv3->getRootDomain());
- checkIdMapped(
- tv3,
- tv3->getMaybeRFactorDomain(),
- {true, false, true},
- tv1,
- tv1->getRootDomain(),
- {true, true});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, false},
- tv2,
- tv2->getRootDomain(),
- {true, false});
-}
-
-TEST(NVFuserTest, FusionRootMappingReductionDependency4_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- auto tv1 = sum(tv0, {1});
- auto tv2 = broadcast(tv1, {false, true});
- auto tv3 = add(tv0, tv2);
- fusion.addOutput(tv3);
-
- tv1->split(-1, 4);
- auto tv4 = tv1->rFactor({-2});
-
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, false},
- tv4,
- tv4->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv4,
- tv4->getMaybeRFactorDomain(),
- {true, false, true},
- tv1,
- tv1->getRootDomain(),
- {true, true});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, false},
- tv2,
- tv2->getRootDomain(),
- {true, false});
- checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain());
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, false},
- tv2,
- tv2->getRootDomain(),
- {true, false});
-}
-
-// Reproducer of issue #749
-TEST(NVFuserTest, FusionRootMappingReductionDependency5_CUDA_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = sum(tv1, {1});
- auto tv3 = broadcast(tv2, {false, true});
- auto tv4 = add(tv0, tv3);
- auto tv5 = add(tv4, tv1);
- fusion.addOutput(tv5);
-
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, false},
- tv1,
- tv1->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, false},
- tv2,
- tv2->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv2,
- tv2->getRootDomain(),
- {true, false},
- tv3,
- tv3->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv3,
- tv3->getRootDomain(),
- {true, true},
- tv4,
- tv4->getRootDomain(),
- {true, true});
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, false},
- tv4,
- tv4->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv4,
- tv4->getRootDomain(),
- {true, true},
- tv5,
- tv5->getRootDomain(),
- {true, true});
-}
-
-// Similar to RootMappingReductionDependency5 but with rFactor
-TEST(NVFuserTest, FusionRootMappingReductionDependency6_CUDA_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = sum(tv1, {1});
- auto tv3 = broadcast(tv2, {false, true});
- auto tv4 = add(tv0, tv3);
- auto tv5 = add(tv4, tv1);
- fusion.addOutput(tv5);
-
- tv2->split(1, 4);
- auto tv6 = tv2->rFactor({-1});
-
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, false},
- tv1,
- tv1->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, false},
- tv6,
- tv6->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv6,
- tv6->getMaybeRFactorDomain(),
- {true, true, false},
- tv2,
- tv2->getRootDomain(),
- {true, true});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, false},
- tv2,
- tv2->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv2,
- tv2->getRootDomain(),
- {true, false},
- tv3,
- tv3->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv3,
- tv3->getRootDomain(),
- {true, true},
- tv4,
- tv4->getRootDomain(),
- {true, true});
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true, false},
- tv4,
- tv4->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv4,
- tv4->getRootDomain(),
- {true, true},
- tv5,
- tv5->getRootDomain(),
- {true, true});
-}
-
-TEST(NVFuserTest, FusionRootMappingMultipleBroadcast_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(1);
- auto tv1 = broadcast(tv0, {false, true});
- auto tv2 = broadcast(tv0, {true, false});
- auto tv3 = add(tv1, tv2);
- fusion.addOutput(tv3);
-
- // tv0 cannot be mapped with the consumers as it would mean its only
- // domain would be mapped to both the first and second domains of
- // the two consumers, thus computing tv0 at both corresponding loops.
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {false},
- tv1,
- tv1->getRootDomain(),
- {false, false});
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {false},
- tv2,
- tv2->getRootDomain(),
- {false, false});
- checkIdMapped(tv1, tv1->getRootDomain(), tv3, tv3->getRootDomain());
- checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain());
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {false},
- tv3,
- tv3->getRootDomain(),
- {false, false});
-}
-
-TEST(NVFuserTest, FusionRootMappingMultipleBroadcastWithNoCommonConsumer_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(1);
- auto tv1 = broadcast(tv0, {false, true});
- auto tv2 = broadcast(tv0, {true, false});
- fusion.addOutput(tv1);
- fusion.addOutput(tv2);
-
- // If there is no common consumer, there is no recomputation constraint.
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true},
- tv1,
- tv1->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true},
- tv2,
- tv2->getRootDomain(),
- {false, true});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, false},
- tv2,
- tv2->getRootDomain(),
- {false, true});
-}
-
-TEST(NVFuserTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- auto tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
- auto tv2 = makeSymbolicTensor(2);
- fusion.addInput(tv2);
- auto tv3 = broadcast(tv0, {false, true});
- auto tv4 = add(tv1, tv3);
- fusion.addOutput(tv4);
- auto tv5 = add(tv2, tv3);
- fusion.addOutput(tv5);
-
- // Broadcast domains can be used with multiple domains with
- // different sizes. In this test, the broadcast domain of tv3 has
- // two consumers, tv4 and tv5, which may have different sizes. Each
- // of the consumers is used with the broadcast domain of tv3, but
- // the two consumers may not have the same size, it is not possible
- // to map those domains.
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true},
- tv3,
- tv3->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true},
- tv1,
- tv1->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true},
- tv2,
- tv2->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, false},
- tv2,
- tv2->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, false},
- tv3,
- tv3->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv2,
- tv2->getRootDomain(),
- {true, false},
- tv3,
- tv3->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv3,
- tv3->getRootDomain(),
- {true, false},
- tv4,
- tv4->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv3,
- tv3->getRootDomain(),
- {true, false},
- tv5,
- tv5->getRootDomain(),
- {true, false});
- checkIdMapped(
- tv4,
- tv4->getRootDomain(),
- {true, false},
- tv5,
- tv5->getRootDomain(),
- {true, false});
-}
-
-TEST(NVFuserTest, FusionRootMappingBroadcast_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- // tv0[I0]
- fusion.addInput(tv0);
- auto tv1 = broadcast(tv0, {true, false});
- // tv1[B1, I0]
- auto tv2 = broadcast(tv1, {true, false, false});
- // tv2[B2, B1, I0]
- fusion.addOutput(tv2);
-
- // In this case, tv1 and tv2 has one and two broadcast domains,
- // respectively. It is the second broadcast domain that is mapped to
- // the broadcast of tv1.
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true},
- tv1,
- tv1->getRootDomain(),
- {false, true});
- checkIdMapped(
- tv1,
- tv1->getRootDomain(),
- {true, true},
- tv2,
- tv2->getRootDomain(),
- {false, true, true}); // Not {true, false, true}
- checkIdMapped(
- tv0,
- tv0->getRootDomain(),
- {true},
- tv2,
- tv2->getRootDomain(),
- {false, false, true});
-}
-
-// Reproducer of issue #723
-TEST(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- auto tv1 = makeSymbolicTensor(2);
-
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- auto tv2 = broadcast(tv0, {true, false});
- auto tv3 = sum(tv2, {0});
- auto tv4 = add(tv2, tv1);
-
- fusion.addOutput(tv3);
- fusion.addOutput(tv4);
-
- ComputeAtRootDomainMap map;
- map.build();
-
- checkIdMapped(
- map, tv2, tv2->getRootDomain()[0], tv4, tv4->getRootDomain()[0], true);
- checkIdMapped(
- map, tv2, tv2->getRootDomain()[0], tv3, tv3->getRootDomain()[0], true);
-
- tv2->computeAt(tv4, -1);
-
- const int x = 11;
- const int y = 12;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({x}, options);
- at::Tensor t1 = at::randn({y, x}, options);
- std::vector<IValue> aten_inputs = {t0, t1};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(aten_inputs);
-
- auto t3 = t0;
- auto t4 = t0.unsqueeze(0).expand({y, x}) + t1;
-
- testValidate(&fusion, outputs, aten_inputs, {t3, t4}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionComputeAtFailDueToRootMapping_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = broadcast(tv1, {true, false});
- auto tv3 = broadcast(tv1, {false, true});
- auto tv4 = add(tv2, tv3);
- fusion.addOutput(tv4);
-
- // computeAt should fail as there is no valid root mapping.
- ASSERT_ANY_THROW(tv1->computeAt(tv4, 1));
-}
-
-TEST(NVFuserTest, FusionScalarInputs_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- TensorView* tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
-
- Double* d0 = new Double();
- fusion.addInput(d0);
- Double* d1 = new Double();
- fusion.addInput(d1);
- Double* d2 = new Double();
- fusion.addInput(d2);
- Double* d3 = new Double();
- fusion.addInput(d3);
- Val* d4 = mul(d0, d1);
- Val* d5 = sub(d2, d3);
-
- TensorView* tv2 = sub(tv1, d4);
- TensorView* tv3 = add(tv0, d5);
- TensorView* tv4 = mul(tv3, tv2);
-
- fusion.addOutput(tv4);
-
- // Lets setup to actually run
- while (tv4->nDims() > 1)
- tv4->merge(0);
- tv4->split(0, 128);
- tv4->split(0, 4);
-
- tv0->computeAt(tv4, 1);
- tv1->computeAt(tv4, 1);
-
- tv4->axis(0)->parallelize(ParallelType::BIDx);
-
- for (Val* val : fusion.vals()) {
- if (!fusion.hasInput(val) &&
- val->getValType().value() == ValType::TensorView) {
- TensorView* tv = static_cast<TensorView*>(val);
-
- tv->axis(1)->parallelize(ParallelType::Unroll);
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
- }
-
- // d4 = d0 * d1
- // d5 = d2 - d3
- // t2 = t1 - d4
- // t3 = t0 + d5
- // t4 = t3 * t2
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- float fl0 = 0.1;
- float fl1 = -0.2;
- float fl2 = 0.3;
- float fl3 = -0.4;
- float fl4 = fl0 * fl1;
- float fl5 = fl2 - fl3;
-
- at::Tensor t0 = at::randn({129, 127}, options);
- at::Tensor t1 = at::rand_like(t0, options);
-
- auto t2 = t1.sub(fl4);
- auto t3 = t0.add(fl5);
- auto aten_output = t3.mul(t2);
-
- at::Tensor cg_output = at::empty_like(t0, options);
-
- at::Scalar test(fl0);
-
- std::vector<IValue> aten_inputs = {
- t0,
- t1,
- at::Scalar(fl0),
- at::Scalar(fl1),
- at::Scalar(fl2),
- at::Scalar(fl3)};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion(aten_inputs, {cg_output});
-
- testValidate(
- &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionLoopUnroll_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(3);
- TensorView* tv1 = makeSymbolicTensor(3);
-
- // Register your inputs
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- // Do math with it, it returns a `Val*` but can be static_casted back to
- // TensorView
- TensorView* tv2 = add(tv1, new Double(2.0));
- TensorView* tv3 = add(tv0, tv2);
-
- // Register your outputs
- fusion.addOutput(tv3);
-
- int block_size = 16;
-
- tv3->merge(0, 1);
- tv3->merge(0, 1);
-
- tv3->split(0, block_size);
- tv3->split(0, 4);
-
- // For all inputs, computeAt the output inline, temporaries should be squeezed
- // between them
- tv0->computeAt(tv3, 1);
- tv1->computeAt(tv3, 1);
-
- // Parallelize
- tv2->axis(1)->parallelize(ParallelType::Unroll);
- tv3->axis(1)->parallelize(ParallelType::Unroll);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(0)->parallelize(ParallelType::BIDx);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor input0 = at::randn({129, 13, 3}, options);
- at::Tensor input1 = at::randn({129, 13, 3}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion({input0, input1});
-
- TORCH_CHECK(outputs[0].equal(input0.add(input1.add(2.0))));
-}
-
-/*
- * Helper function for single op testing that generates a codegen operand
- */
-
-Val* gen_jit_operand(std::pair<ValType, DataType> desc) {
- if (desc.first == ValType::TensorView) {
- return makeSymbolicTensor(2, desc.second);
- } else if (desc.first == ValType::Scalar) {
- if (desc.second == DataType::Float) {
- return new Double();
- } else if (desc.second == DataType::Double) {
- return new Double();
- } else if (desc.second == DataType::Int) {
- return new Int();
- } else {
- TORCH_CHECK(false, "Not currently supported type: ", desc.first);
- }
- } else {
- TORCH_CHECK(false, "Not currently supported type: ", desc.first);
- }
- return nullptr;
-}
-
-/*
- * Helper function for single op testing that generates an ATen operand
- */
-
-IValue gen_aten_operand(
- std::pair<ValType, DataType> desc,
- int blocks,
- int threads,
- bool rand) {
- if (desc.first == ValType::TensorView) {
- if (desc.second == DataType::Double || desc.second == DataType::Float ||
- desc.second == DataType::Half) {
- auto options = at::TensorOptions()
- .dtype(data_type_to_aten(desc.second))
- .device(at::kCUDA, 0);
- if (rand) {
- return IValue(at::rand({blocks, threads}, options));
- } else {
- return IValue(at::empty({blocks, threads}, options));
- }
- } else if (desc.second == DataType::Int || desc.second == DataType::Int32) {
- auto dtype = desc.second == DataType::Int32 ? at::kInt : at::kLong;
- if (rand) {
- auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- return IValue(at::randn({blocks, threads}, options).mul(5).to(dtype));
- } else {
- auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
- return IValue(at::empty({blocks, threads}, options));
- }
- } else if (desc.second == DataType::Bool) {
- if (rand) {
- auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- return IValue(
- at::rand({blocks, threads}, options).round().to(at::kBool));
- } else {
- auto options =
- at::TensorOptions().dtype(at::kBool).device(at::kCUDA, 0);
- return IValue(at::empty({blocks, threads}, options));
- }
- } else {
- TORCH_CHECK(false, "Not currently supported type: ", desc.second)
- }
- } else if (desc.first == ValType::Scalar) {
- // IValue scalars can only be double int64 or bool
- if (desc.second == DataType::Double || desc.second == DataType::Float ||
- desc.second == DataType::Half) {
- return IValue(at::Scalar(1.f));
- } else if (desc.second == DataType::Int) {
- return IValue(at::Scalar(1));
- } else {
- TORCH_CHECK(false, "Not currently supported type: ", desc.first);
- }
- } else {
- TORCH_CHECK(false, "Not currently supported type: ", desc.first);
- }
- return nullptr;
-}
-
-/*
- * Templatized Helper Function To generate single Op comparison between the
- * JIT codegen for Cuda and the ATen Library.
- */
-
-using OutputPair = std::pair<ValType, DataType>;
-template <
- typename AtenFunc,
- typename JitFunc,
- typename InputTuple,
- size_t... NumInputs>
-void test_op(
- int blocks,
- int threads,
- std::string op_str,
- AtenFunc af,
- JitFunc jf,
- OutputPair op,
- InputTuple it,
- std::index_sequence<NumInputs...>) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Generate Input JIT function Inputs and add them as Inputs to the Fusion
- // Graph
- std::array<Val*, sizeof...(NumInputs)> jit_inputs = {
- gen_jit_operand(std::get<NumInputs>(it))...};
- std::for_each(jit_inputs.begin(), jit_inputs.end(), [&fusion](Val* v) {
- fusion.addInput(v);
- });
- TensorView* out =
- static_cast<TensorView*>(jf(std::get<NumInputs>(jit_inputs)...));
- fusion.addOutput(out);
-
- std::for_each(jit_inputs.begin(), jit_inputs.end(), [out](Val* v) {
- if (v->getValType() == ValType::TensorView)
- static_cast<TensorView*>(v)->computeAt(out, -1);
- });
- out->axis(0)->parallelize(ParallelType::BIDx);
- out->axis(-1)->parallelize(ParallelType::TIDx);
-
- std::array<IValue, sizeof...(NumInputs)> aten_inputs = {gen_aten_operand(
- std::get<NumInputs>(it), blocks, threads, /*rand*/ true)...};
- const at::ArrayRef<IValue> aten_inputs_ivalues(aten_inputs);
-
- at::Tensor cg_output =
- gen_aten_operand(op, blocks, threads, /*rand*/ false).toTensor();
- std::vector<at::Tensor> output_vect = {cg_output};
- cudaDeviceSynchronize();
- if (fusion.isStochastic())
- at::manual_seed(0);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion(aten_inputs_ivalues, output_vect);
- cudaDeviceSynchronize();
-
- if (fusion.isStochastic())
- at::manual_seed(0);
- at::Tensor aten_output = af(aten_inputs);
- cudaDeviceSynchronize(); // This sync shouldn't be necessary;
-
- std::string op_msg = "Operation " + op_str;
-
- testValidate(
- &fusion,
- {cg_output},
- aten_inputs,
- {aten_output},
- __LINE__,
- __FILE__,
- op_msg);
-}
-
-/*
- * Templatized Helper Function that uses variadic templates to
- * process a variable length Input Tuple of different Operand Type.
- */
-template <typename AtenFunc, typename JitFunc, typename InputTuple>
-void test_op(
- int blocks,
- int threads,
- std::string op_str,
- AtenFunc af,
- JitFunc jf,
- OutputPair op,
- InputTuple it) {
- static constexpr auto size = std::tuple_size<InputTuple>::value;
- test_op(
- blocks,
- threads,
- op_str,
- af,
- jf,
- op,
- it,
- std::make_index_sequence<size>{});
-}
-
-TEST(NVFuserTest, FusionUnaryOps_CUDA) {
- using OpTuple =
- std::tuple<at::Tensor (*)(const at::Tensor&), UnaryOpType, std::string>;
-
- // [Note: explicit tuple type for uniform initialization list]
- // Tuple type must be explicitly specified for each uniform initialization
- // list within the vector to make this code compatible with some old env
- // which we still need to support. eg. gcc 5.4 + cuda 9.2.
- std::vector<OpTuple> ops{
- OpTuple{at::abs, UnaryOpType::Abs, "abs"},
- OpTuple{at::acos, UnaryOpType::Acos, "acos"},
- OpTuple{at::asin, UnaryOpType::Asin, "asin"},
- OpTuple{at::atan, UnaryOpType::Atan, "atan"},
- // There does not appear to be an appropriate ATen function for atanh
- // OpTuple{at::atanh, UnaryOpType::Atanh, "atanh" },
- OpTuple{at::ceil, UnaryOpType::Ceil, "ceil"},
- OpTuple{at::cos, UnaryOpType::Cos, "cos"},
- OpTuple{at::cosh, UnaryOpType::Cosh, "cosh"},
- OpTuple{at::erf, UnaryOpType::Erf, "erf"},
- OpTuple{at::erfc, UnaryOpType::Erfc, "erfc"},
- OpTuple{at::exp, UnaryOpType::Exp, "exp"},
- OpTuple{at::expm1, UnaryOpType::Expm1, "expm1"},
- OpTuple{at::floor, UnaryOpType::Floor, "floor"},
- OpTuple{at::frac, UnaryOpType::Frac, "frac"},
- // OpTuple{at::gelu, UnaryOpType::Gelu, "gelu"},
- OpTuple{at::lgamma, UnaryOpType::Lgamma, "lgamma"},
- OpTuple{at::log, UnaryOpType::Log, "log"},
- OpTuple{at::log10, UnaryOpType::Log10, "log10"},
- OpTuple{at::log1p, UnaryOpType::Log1p, "log1p"},
- OpTuple{at::log2, UnaryOpType::Log2, "log2"},
- OpTuple{at::neg, UnaryOpType::Neg, "neg"},
- OpTuple{at::reciprocal, UnaryOpType::Reciprocal, "reciprocal"},
- OpTuple{at::relu, UnaryOpType::Relu, "relu"},
- OpTuple{at::round, UnaryOpType::Round, "round"},
- OpTuple{at::rsqrt, UnaryOpType::Rsqrt, "rsqrt"},
- OpTuple{at::sigmoid, UnaryOpType::Sigmoid, "sigmoid"},
- OpTuple{at::sin, UnaryOpType::Sin, "sin"},
- OpTuple{at::sinh, UnaryOpType::Sinh, "sinh"},
- OpTuple{at::sqrt, UnaryOpType::Sqrt, "sqrt"},
- OpTuple{at::tan, UnaryOpType::Tan, "tan"},
- OpTuple{at::tanh, UnaryOpType::Tanh, "tanh"},
- OpTuple{at::trunc, UnaryOpType::Trunc, "trunc"}};
-
- std::vector<DataType> dtypes = {DataType::Float, DataType::Double};
-
- for (auto dtype : dtypes) {
- std::for_each(ops.begin(), ops.end(), [&](OpTuple& op) {
- test_op(
- /*blocks*/ 640,
- /*threads*/ 64,
- /*name*/ std::get<2>(op),
- /*Aten Func */
- [&op](std::array<IValue, 1>& vals) {
- return std::get<0>(op)(vals[0].toTensor());
- },
- /*JIT Func */
- [&op](Val* in1) -> Val* { return unaryOp(std::get<1>(op), in1); },
- /*Output */ std::make_pair(ValType::TensorView, dtype),
- /*Inputs Tuple*/
- std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
- });
-
- test_op(
- /*blocks*/ 128,
- /*threads*/ 64,
- /*name*/ "rand_like",
- /*Aten Func */
- [](std::array<IValue, 1>& vals) {
- return at::rand_like(vals[0].toTensor());
- },
- /*JIT Func */
- [](Val* in1) -> Val* { return unaryOp(UnaryOpType::RandLike, in1); },
- /*Output */ std::make_pair(ValType::TensorView, dtype),
- /*Inputs Tuple*/
- std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
- }
-
- dtypes = {DataType::Int, DataType::Int32, DataType::Bool};
- for (auto dtype : dtypes) {
- test_op(
- /*blocks*/ 128,
- /*threads*/ 64,
- /*name*/ "bitwise_not",
- /*Aten Func */
- [](std::array<IValue, 1>& vals) {
- return at::bitwise_not(vals[0].toTensor());
- },
- /*JIT Func */
- [](Val* in1) -> Val* { return unaryOp(UnaryOpType::Not, in1); },
- /*Output */ std::make_pair(ValType::TensorView, dtype),
- /*Inputs Tuple*/
- std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
- }
-}
-
-TEST(NVFuserTest, FusionBinaryOps_CUDA) {
- using AtenFuncSig = at::Tensor (*)(const at::Tensor&, const at::Tensor&);
- using OpTuple = std::tuple<AtenFuncSig, BinaryOpType, std::string>;
-
- // see [Note: explicit tuple type for uniform initialization list]
- std::vector<OpTuple> logic_ops{
- OpTuple{at::eq, BinaryOpType::Eq, "eq"},
- OpTuple{at::ge, BinaryOpType::GE, "ge"},
- OpTuple{at::gt, BinaryOpType::GT, "gt"},
- OpTuple{at::le, BinaryOpType::LE, "le"},
- OpTuple{at::lt, BinaryOpType::LT, "lt"},
- OpTuple{at::ne, BinaryOpType::NE, "ne"}};
- std::vector<DataType> dtypes = {DataType::Double, DataType::Float};
-
- for (auto dtype : dtypes) {
- std::for_each(logic_ops.begin(), logic_ops.end(), [&](OpTuple& op) {
- test_op(
- /*blocks*/ 640,
- /*threads*/ 64,
- /*name*/ std::get<2>(op),
- /*Aten Func */
- [&op](std::array<IValue, 2>& vals) {
- return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor());
- },
- /*JIT Func */
- [&op](Val* in1, Val* in2) -> Val* {
- return binaryOp(std::get<1>(op), in1, in2);
- },
- /*Output */ std::make_pair(ValType::TensorView, DataType::Bool),
- /*Inputs Tuple*/
- std::make_tuple(
- std::make_pair(ValType::TensorView, dtype),
- std::make_pair(ValType::TensorView, dtype)));
- });
-
- // see [Note: explicit tuple type for uniform initialization list]
- std::vector<OpTuple> math_ops{
- OpTuple{at::atan2, BinaryOpType::Atan2, "atan2"},
- OpTuple{at::div, BinaryOpType::Div, "div"},
- OpTuple{at::fmod, BinaryOpType::Fmod, "fmod"},
- OpTuple{at::max, BinaryOpType::Max, "max"},
- OpTuple{at::min, BinaryOpType::Min, "min"},
- OpTuple{at::mul, BinaryOpType::Mul, "mul"},
- OpTuple{at::pow, BinaryOpType::Pow, "pow"},
- // NOTE: Remainder does not match the Aten impl exactly
- // despite using an identical function.
- OpTuple{at::remainder, BinaryOpType::Remainder, "remainder"},
- };
-
- std::for_each(math_ops.begin(), math_ops.end(), [&](OpTuple& op) {
- test_op(
- /*blocks*/ 640,
- /*threads*/ 64,
- /*name*/ std::get<2>(op),
- /*Aten Func */
- [&op](std::array<IValue, 2>& vals) {
- return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor());
- },
- /*JIT Func */
- [&op](Val* in1, Val* in2) -> Val* {
- return binaryOp(std::get<1>(op), in1, in2);
- },
- /*Output */ std::make_pair(ValType::TensorView, dtype),
- /*Inputs Tuple*/
- std::make_tuple(
- std::make_pair(ValType::TensorView, dtype),
- std::make_pair(ValType::TensorView, dtype)));
- });
-
- test_op(
- /*blocks*/ 640,
- /*threads*/ 64,
- /*name*/ "add_alpha",
- /*Aten Func */
- [](std::array<IValue, 3>& vals) {
- return at::add(
- vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar());
- },
- /*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&add_alpha),
- /*Output */ std::make_pair(ValType::TensorView, dtype),
- /*Inputs Tuple*/
- std::make_tuple(
- std::make_pair(ValType::TensorView, dtype),
- std::make_pair(ValType::TensorView, dtype),
- std::make_pair(ValType::Scalar, dtype)));
-
- test_op(
- /*blocks*/ 640,
- /*threads*/ 64,
- /*name*/ "sub_alpha",
- /*Aten Func */
- [](std::array<IValue, 3>& vals) {
- return at::sub(
- vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar());
- },
- /*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&sub_alpha),
- /*Output */ std::make_pair(ValType::TensorView, dtype),
- /*Inputs Tuple*/
- std::make_tuple(
- std::make_pair(ValType::TensorView, dtype),
- std::make_pair(ValType::TensorView, dtype),
- std::make_pair(ValType::Scalar, dtype)));
- }
-}
-
-TEST(NVFuserTest, FusionTernaryOps_CUDA) {
- std::vector<DataType> dtypes = {DataType::Double, DataType::Float};
-
- for (auto dtype : dtypes) {
- test_op(
- /*blocks*/ 640,
- /*threads*/ 64,
- /*name*/ "clamp",
- /*Aten Func */
- [](std::array<IValue, 1>& vals) {
- return at::clamp(vals[0].toTensor(), 0.f, 1.f);
- },
- /*JIT Func */
- [&](Val* in1) -> Val* {
- if (dtype == DataType::Float) {
- return clamp(in1, new Double(0.f), new Double(1.f));
- } else {
- return clamp(in1, new Double(0.f), new Double(1.f));
- }
- },
- /*Output */ std::make_pair(ValType::TensorView, dtype),
- /*Inputs Tuple*/
- std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
- test_op(
- /*blocks*/ 640,
- /*threads*/ 64,
- /*name*/ "threshold",
- /*Aten Func */
- [](std::array<IValue, 1>& vals) {
- return at::threshold(vals[0].toTensor(), 0.f, 1.f);
- },
- /*JIT Func */
- [&](Val* in1) -> Val* {
- if (dtype == DataType::Float) {
- return threshold(in1, new Double(0.f), new Double(1.f));
- } else {
- return threshold(in1, new Double(0.f), new Double(1.f));
- }
- },
- /*Output */ std::make_pair(ValType::TensorView, dtype),
- /*Inputs Tuple*/
- std::make_tuple(std::make_pair(ValType::TensorView, dtype)));
- test_op(
- /*blocks*/ 640,
- /*threads*/ 64,
- /*name*/ "where",
- /*Aten Func */
- [](std::array<IValue, 3>& vals) {
- return at::where(
- vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor());
- },
- /*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&where),
- /*Output */ std::make_pair(ValType::TensorView, dtype),
- /*Inputs Tuple*/
- std::make_tuple(
- std::make_pair(ValType::TensorView, DataType::Bool),
- std::make_pair(ValType::TensorView, dtype),
- std::make_pair(ValType::TensorView, dtype)));
- }
-}
-
-TEST(NVFuserTest, FusionCompoundOps_CUDA) {
- std::vector<DataType> dtypes = {DataType::Double, DataType::Float};
-
- for (auto dtype : dtypes) {
- test_op(
- /*blocks*/ 640,
- /*threads*/ 64,
- /*name*/ "lerp",
- /*Aten Func */
- [](std::array<IValue, 3>& vals) {
- return at::lerp(
- vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor());
- },
- /*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&lerp),
- /*Output */ std::make_pair(ValType::TensorView, dtype),
- /*Inputs Tuple*/
- std::make_tuple(
- std::make_pair(ValType::TensorView, dtype),
- std::make_pair(ValType::TensorView, dtype),
- std::make_pair(ValType::TensorView, dtype)));
- test_op(
- /*blocks*/ 640,
- /*threads*/ 64,
- /*name*/ "addcmul",
- /*Aten Func */
- [](std::array<IValue, 4>& vals) {
- return at::addcmul(
- vals[0].toTensor(),
- vals[1].toTensor(),
- vals[2].toTensor(),
- vals[3].toScalar());
- },
- /*JIT Func */
- static_cast<Val* (*)(Val*, Val*, Val*, Val*)>(&addcmul),
- /*Output */ std::make_pair(ValType::TensorView, dtype),
- /*Inputs Tuple*/
- std::make_tuple(
- std::make_pair(ValType::TensorView, dtype),
- std::make_pair(ValType::TensorView, dtype),
- std::make_pair(ValType::TensorView, dtype),
- std::make_pair(ValType::Scalar, dtype)));
- }
-}
-
-TEST(NVFuserTest, FusionCastOps_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2, DataType::Half);
-
- TensorView* intrm1 = castOp(DataType::Float, tv0);
- TensorView* out = castOp(DataType::Half, intrm1);
-
- fusion.addInput(tv0);
- fusion.addOutput(out);
- tv0->computeAt(out, -1);
-
- out->axis(0)->parallelize(ParallelType::BIDx);
- out->axis(-1)->parallelize(ParallelType::TIDx);
-
- auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
-
- at::Tensor input1 = at::randn({1, 4}, options);
- at::Tensor ref_output = at::empty_like(input1);
-
- std::array<IValue, 1> inputs = {input1};
- const at::ArrayRef<IValue> input_ivalues(inputs);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(input_ivalues);
-
- ref_output = at::_cast_Half(at::_cast_Double(input1));
-
- TORCH_CHECK(
- outputs[0].equal(ref_output),
- "\nOp Type: -- ",
- "cast FP16->FP32->FP16",
- " -- had a mismatch.\n",
- "\nABS MAX DIFF: ",
- outputs[0].sub(ref_output).abs().max(),
- "\n");
-}
-
-// Start off simple, block on the outer dim
-// block stride + thread all reduce + unrolling on inner dim
-TEST(NVFuserTest, FusionReduction1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- // tv1[I0, R1] = tv0[I0, I1]
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
- tv1->split(1, 128);
- // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
- tv1->split(1, 4);
- // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1]
-
- TensorView* tv2 = tv1->rFactor({1});
- // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1]
- // tv1[I0, R1oi{4}, R1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}]
-
- TensorView* tv3 = tv1->rFactor({1});
- // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1]
- // tv3[I0, R1oi{4}, Ir1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}]
- // tv1[I0, R1i{128}] = tv3[I0, R1oi{4}, Ir1i{128}]
-
- // Incrementally, can print in between for debugging
- tv0->computeAt(tv2, 1);
- tv2->computeAt(tv3, 1);
- tv3->computeAt(tv1, 1);
-
- // Re do it all at once, because why not.
- tv0->computeAt(tv1, 1);
-
- tv2->axis(2)->parallelize(ParallelType::Unroll);
- tv1->axis(0)->parallelize(ParallelType::BIDx);
-
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
- int numel_x = 65000;
- int numel_y = 1025;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y}, options);
- at::Tensor cg_output = at::empty({numel_x}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({input}, {cg_output});
-
- auto aten_output = input.to(at::kDouble).sum({1});
-
- testValidate(
- &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReduction2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- // tv1[I0, R1] = tv0[I0, I1]
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-
- fusion.addOutput(tv1);
-
- // switches to try some different scenarios. maybe we should iterate on all
- // permutations.
- bool bind_bidx = true;
- bool bind_tidx = true;
- bool bind_tidy = true;
- bool bind_unroll = true;
-
- int numel_x = 1025; // Cannot exceed block dim max size / tidy
- int numel_y = 129;
- int tidx = 16;
- int tidy = 8;
- int unroll_factor = 4;
-
- tv1->split(1, tidx);
- // tv1[I0, R1o, R1i{tidx}] = tv0[I0, I1]
-
- tv1->split(1, unroll_factor);
- // tv1[I0, R1oo, R1oi{unroll}, R1i{tidx}] = tv0[I0, I1]
-
- tv1->split(0, tidy);
-
- TensorView* tv2 = tv1->rFactor({-3});
- // tv2[I0, >R1oo<, Ir1oi{unroll}, Ir1i{tidx}]
- // tv1[I0o, I0i{tidy}, R1oi{unroll}, R1i{tidx}]
-
- TensorView* tv3 = tv1->rFactor({-2});
- // tv2[I0, >R1oo<, Ir1oi{unroll}, Ir1i{tidx}]
- // tv3[I0, R1oi{unroll}, Ir1i{tidx}]
- // tv1[I0o, I0i{tidy}, R1i{tidx}]
-
- tv0->computeAt(tv1, -2);
-
- if (bind_unroll)
- tv2->axis(-2)->parallelize(ParallelType::Unroll);
- if (bind_bidx)
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- if (bind_tidy)
- tv1->axis(1)->parallelize(ParallelType::TIDy);
-
- if (bind_tidx) {
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- }
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input});
-
- auto aten_output = input.to(at::kDouble).sum({1});
- testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReduction3_CUDA) {
- // What if Z participates in the reduction with X?
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- // tv1[I0, R1] = tv0[I0, I1]
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-
- fusion.addOutput(tv1);
-
- int numel_x = 1025; // Cannot exceed block dim max size / tidy
- int numel_y = 129;
- int tidx = 16;
- int tidz = 8;
-
- tv1->split(1, tidz);
- // tv1[I0, R1o, R1i{tidz}] = tv0[I0, I1]
-
- tv1->split(1, tidx);
- // tv1[I0, R1oo, R1oi{tidx}, R1i{tidz}] = tv0[I0, I1]
-
- TensorView* tv2 = tv1->rFactor({-3});
- // tv2[I0, >R1oo<, Ir1oi{tidx}, Ir1i{tidz}]
- // tv1[I0o, R1oi{tidx}, R1i{tidz}]
-
- tv0->computeAt(tv1, -3);
-
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv1->axis(-2)->parallelize(ParallelType::TIDx);
- tv1->axis(-1)->parallelize(ParallelType::TIDz);
-
- tv2->axis(-2)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDz);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({numel_x, numel_y}, options);
- at::Tensor cg_output = at::empty({numel_x}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({aten_input}, {cg_output});
-
- auto aten_output = aten_input.to(at::kDouble).sum({1});
-
- testValidate(
- &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReduction4_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
-
- TensorView* tv2 = add(tv0, tv1);
- // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1]
-
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- TensorView* tv3 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv2);
- // tv3[I0, R1] = tv2[I0, I1]
-
- TensorView* tv4 = makeSymbolicTensor(1);
- fusion.addInput(tv4);
-
- // tv5[I0] = tv3[I0, R1] * tv4[I0]
- TensorView* tv5 = mul(tv3, tv4);
- fusion.addOutput(tv5);
-
- int tidx = 16;
-
- // RFactor the reduction
- tv3->split(1, tidx);
- // tv3[I0, R1o, R1i{tidx}] = tv2[I0, I1]
-
- TensorView* tv6 = tv3->rFactor({-2});
- // tv6[I0, R1o, iR1i{tidx}] = tv2[I0, I1]
- // tv3[I0, R1i{tidx}] = tv3[I0, I1]
- tv2->computeAt(tv6, 2);
-
- // Compute at inline with tv5 (only 1D)
- tv6->computeAt(tv3, 1);
- tv3->computeAt(tv5, 1);
-
- tv5->axis(0)->parallelize(ParallelType::BIDx);
-
- // Intermediate tensors only need this, but doesn't hurt to do on inputs
- // tv0, 1, 4
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv6->axis(-1)->parallelize(ParallelType::TIDx);
-
- int numel_x = 1025;
- int numel_y = 129;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- at::Tensor t1 = at::randn({numel_x, numel_y}, options);
- at::Tensor t4 = at::randn({numel_x}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({t0, t1, t4});
-
- auto t2 = t0.add(t1);
- auto t3 = t2.to(at::kDouble).sum({1});
- auto aten_output = t3.mul(t4);
-
- testValidate(
- &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReduction5_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(3);
-
- fusion.addInput(tv0);
-
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-
- fusion.addOutput(tv1);
-
- int bidy = 2;
- int tidy = 4;
- int tidx = 5;
-
- int dim1 = 11;
-
- tv1->split(-2, tidy);
-
- TensorView* tv2 = tv1->rFactor({-3});
-
- tv0->computeAt(tv1, 1);
- tv1->axis(0)->parallelize(ParallelType::BIDy);
-
- for (auto* val : fusion.vals()) {
- if (!fusion.hasInput(val) &&
- val->getValType().value() == ValType::TensorView) {
- val->as<TensorView>()->axis(-1)->parallelize(ParallelType::TIDx);
- }
- }
-
- tv2->axis(-2)->parallelize(ParallelType::TIDy);
- tv1->axis(-2)->parallelize(ParallelType::TIDy);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({bidy, dim1, tidx}, options);
-
- at::Tensor cg_output = at::empty({bidy, tidx}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({input}, {cg_output});
-
- auto aten_output = input.to(at::kDouble).sum({1});
- testValidate(
- &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReduction6_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- const int bdimx = 64;
- const int bdimy = 8;
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(3);
- fusion.addInput(tv0);
-
- // tv1[I0, R1, R2] = tv0[I0, I1, I2]
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
- tv1->split(2, bdimx);
- // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
- tv1->split(1, bdimy);
- // tv1[I0, R1o, R1i{8}, R2o, R2i{128}] = tv0[I0, I1, I2]
-
- TensorView* tv2 = tv1->rFactor({3});
- // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
- // tv1[I0, R1o, R1i{8}, R2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
-
- TensorView* tv3 = tv1->rFactor({1});
- // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
- // tv3[I0, R1o, I1i{8}, I2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
- // tv1[I0, R1i{8}, R2i{128}] = tv3[I0, R1o, I1i{8}, I2i{128}]
-
- tv3->computeAt(tv1, 1);
- tv2->computeAt(tv3, 2);
-
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv3->axis(0)->parallelize(ParallelType::BIDx);
-
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
- tv1->axis(-2)->parallelize(ParallelType::TIDy);
- tv3->axis(-2)->parallelize(ParallelType::TIDy);
- tv2->axis(-3)->parallelize(ParallelType::TIDy);
-
- int numel_x = 650;
- int numel_y = 1000;
- int numel_z = 4;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y, numel_z}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input});
-
- auto aten_output = input.to(at::kDouble).sum({1, 2});
- testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionMultiGridReduction_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- TensorView* tv1 = max(tv0, {0});
- TensorView* tv2 = sum(tv0, {0});
-
- fusion.addOutput(tv1);
- fusion.addOutput(tv2);
-
- int numel_x = 4;
- int numel_y = 2;
-
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv1->axis(1)->parallelize(ParallelType::TIDx);
-
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(1)->parallelize(ParallelType::TIDx);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input});
-
- std::vector<at::Tensor> aten_outputs = {
- std::get<0>(input.to(at::kDouble).max(0)), input.to(at::kDouble).sum(0)};
- testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionMultiGridReduction2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = sum(tv0, {0});
- auto tv2 = sum(tv1, {0});
- fusion.addOutput(tv2);
-
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv1->axis(1)->parallelize(ParallelType::BIDy);
- tv2->axis(0)->parallelize(ParallelType::BIDy);
-
- FusionExecutor fe;
- ASSERT_ANY_THROW(fe.compileFusion(&fusion));
-}
-
-TEST(NVFuserTest, FusionReductionTFT_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- // tv1[I0, R1] = tv0[I0, I1]
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
-
- fusion.addOutput(tv1);
-
- int numel_x = 1025;
- int numel_y = 129;
- int tidx = 16;
- int tidy = 8;
- int tidz = 8;
-
- tv1->split(1, tidx);
- // tv1[I0, R1o, R1i{tidx}]
-
- tv1->split(1, tidz);
- // tv1[I0, R1oo, R1Oi{tidz}, R1R1i{tidx}]
-
- tv1->split(0, tidy);
- // tv1[I0o, I0i, R1oo, R1Oi{tidz}, R1R1i{tidx}]
-
- TensorView* tv2 = tv1->rFactor({2});
- // tv2[I0o, I0i, R1oo, I1Oi{tidz}, I11i{tidx}]
- // tv1[I0o, I0i, R1Oi{tidz}, R1R1i{tidx}]
-
- tv2->computeAt(tv1, 2);
-
- tv1->axis(1)->parallelize(ParallelType::TIDy);
-
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
-
- tv1->axis(-2)->parallelize(ParallelType::TIDz);
- tv2->axis(-2)->parallelize(ParallelType::TIDz);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y}, options);
- at::Tensor cg_output = at::empty({numel_x}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({input}, {cg_output});
-
- auto aten_output = input.to(at::kDouble).sum({1});
- testValidate(
- &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReductionOuterSplit_CUDA) {
- // based off FusionReduction4
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
-
- TensorView* tv2 = add(tv0, tv1);
- // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1]
-
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- TensorView* tv3 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv2);
- // tv3[I0, R1] = tv2[I0, I1]
-
- TensorView* tv4 = makeSymbolicTensor(1);
- fusion.addInput(tv4);
-
- // tv5[I0] = tv3[I0, R1] * tv4[I0]
- TensorView* tv5 = mul(tv3, tv4);
- fusion.addOutput(tv5);
-
- // RFactor the reduction
- tv3->split(1, 16, false);
- // tv3[I0, R1o{16}, R1i{tidx}] = tv2[I0, I1]
-
- TensorView* tv6 = tv3->rFactor({-2});
- // tv6[I0, R1o{16}, iR1i{tidx}] = tv2[I0, I1]
- // tv3[I0, R1i{tidx}] = tv3[I0, I1]
- tv2->computeAt(tv6, 2);
-
- // Compute at inline with tv5 (only 1D)
- tv6->computeAt(tv3, 1);
- tv3->computeAt(tv5, 1);
-
- tv5->axis(0)->parallelize(ParallelType::BIDx);
-
- // Intermediate tensors only need this, but doesn't hurt to do on inputs
- // tv0, 1, 4
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv6->axis(-1)->parallelize(ParallelType::TIDx);
-
- int numel_x = 1025;
- int numel_y = 129;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- at::Tensor t1 = at::randn({numel_x, numel_y}, options);
- at::Tensor t4 = at::randn({numel_x}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({t0, t1, t4});
-
- auto t2 = t0.add(t1);
- auto t3 = t2.to(at::kDouble).sum({1});
- auto aten_output = t3.mul(t4);
-
- testValidate(
- &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionBranches_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
- TensorView* tv2 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addInput(tv2);
-
- auto tv3 = add(tv0, new Double(1.0));
- auto tv4 = add(tv3, tv1);
- auto tv5 = add(tv3, tv2);
- auto tv6 = add(tv4, tv5);
-
- fusion.addOutput(tv6);
-
- constexpr int x = 63, y = 33;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({x, y}, options);
- at::Tensor t1 = at::randn({x, y}, options);
- at::Tensor t2 = at::randn({x, y}, options);
-
- FusionExecutor fe;
- tv6->merge(0);
- tv6->split(0, 128);
- tv6->split(0, 4);
-
- tv6->axis(0)->parallelize(ParallelType::BIDx);
-
- tv0->computeAt(tv6, 1);
- tv1->computeAt(tv6, 1);
- tv2->computeAt(tv6, 1);
-
- tv3->axis(-2)->parallelize(ParallelType::Unroll);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv4->axis(-2)->parallelize(ParallelType::Unroll);
- tv4->axis(-1)->parallelize(ParallelType::TIDx);
- tv5->axis(-2)->parallelize(ParallelType::Unroll);
- tv5->axis(-1)->parallelize(ParallelType::TIDx);
- tv6->axis(-1)->parallelize(ParallelType::TIDx);
-
- std::vector<IValue> aten_inputs = {t0, t1, t2};
-
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- auto t3 = t0.add(1.0);
- auto t4 = t3.add(t1);
- auto t5 = t3.add(t2);
- auto aten_output = t4.add(t5);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSimpleBCast1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- TensorView* tv1 = add(tv0, new Double(1.5));
-
- TensorView* tv2 = makeSymbolicTensor(2);
- fusion.addInput(tv2);
- TensorView* tv3 = makeSymbolicTensor(2);
- fusion.addInput(tv3);
- TensorView* tv4 = sub(tv2, tv3);
-
- TensorView* tv5 = broadcast(tv1, {false, false, true});
- TensorView* tv6 = broadcast(tv4, {true, false, false});
-
- TensorView* tv7 = add(tv5, tv6);
- fusion.addOutput(tv7);
-
- tv7->split(-1, 4);
- tv7->split(0, 8);
-
- tv0->computeAt(tv7, -1);
- tv2->computeAt(tv7, -1);
-
- tv7->axis(0)->parallelize(ParallelType::BIDx);
- tv7->axis(-1)->parallelize(ParallelType::TIDx);
-
- constexpr int x = 63, y = 33, z = 15;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({x, y}, options);
- at::Tensor t1 = t0.add(1.5);
-
- at::Tensor t2 = at::randn({y, z}, options);
- at::Tensor t3 = at::randn({y, z}, options);
-
- at::Tensor t4 = t2.sub(t3);
- at::Tensor t5 = t1.unsqueeze(-1).expand({x, y, z});
-
- at::Tensor t6 = t4.expand({x, y, z});
-
- at::Tensor aten_output = t5.add(t6);
-
- std::vector<IValue> aten_inputs = {t0, t2, t3};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSimpleBCast2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- TensorView* tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
-
- TensorView* tv2 = add(tv0, tv1);
-
- TensorView* tv3 = broadcast(tv2, {false, false, true});
-
- TensorView* tv4 = makeSymbolicTensor(2);
- fusion.addInput(tv4);
-
- TensorView* tv5 = sub(tv4, new Double(0.1));
-
- TensorView* tv6 = broadcast(tv5, {true, false, false});
-
- TensorView* tv7 = add(tv3, tv6);
-
- fusion.addOutput(tv7);
-
- tv7->merge(0, 1);
-
- tv0->computeAt(tv7, -1);
- tv4->computeAt(tv7, -1);
-
- tv7->axis(0)->parallelize(ParallelType::BIDx);
- tv7->axis(-1)->parallelize(ParallelType::TIDx);
-
- constexpr int x = 63, y = 33, z = 15;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({x, y}, options);
- at::Tensor t1 = at::randn({x, y}, options);
- at::Tensor t2 = t0.add(t1);
- at::Tensor t3 = t2.unsqueeze(-1).expand({x, y, z});
-
- at::Tensor t4 = at::randn({y, z}, options);
- at::Tensor t5 = t4.sub(0.1);
- at::Tensor t6 = t5.expand({x, y, z});
- at::Tensor aten_output = t3.add(t6);
-
- at::Tensor cg_output = at::empty({x, y, z}, options);
-
- std::vector<IValue> aten_inputs = {t0, t1, t4};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion(aten_inputs, {cg_output});
-
- testValidate(
- &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSimpleBCast3_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- std::vector<IterDomain*> dom;
- dom.push_back(new IterDomain(new Int(0), new Int()));
- dom.push_back(new IterDomain(
- new Int(0),
- new Int(1),
- ParallelType::Serial,
- IterType::BroadcastWithStride));
-
- // tv0[I1, B{1}]
- TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float);
- fusion.addInput(tv0);
-
- // tv1[I0, I1, I2]
- TensorView* tv2 = makeSymbolicTensor(3);
- fusion.addInput(tv2);
-
- TensorView* tv3 = add(tv0, tv2);
-
- fusion.addOutput(tv3);
-
- tv3->merge(0);
- tv3->merge(0);
-
- tv0->computeAt(tv3, -1);
- tv2->computeAt(tv3, -1);
-
- tv3->axis(0)->parallelize(ParallelType::BIDx);
-
- constexpr int x = 2, y = 3, z = 4;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({y, 1}, options);
- at::Tensor t2 = at::randn({x, y, z}, options);
- auto aten_output = t0.add(t2);
-
- std::vector<IValue> aten_inputs = {t0, t2};
- at::Tensor cg_output = at::empty({x, y, z}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion(aten_inputs, {cg_output});
-
- testValidate(
- &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSimpleBCast4_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- std::vector<IterDomain*> dom;
- dom.push_back(new IterDomain(
- new Int(0),
- new Int(1),
- ParallelType::Serial,
- IterType::BroadcastWithStride));
- dom.push_back(new IterDomain(new Int(0), new Int()));
- TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float);
-
- TensorView* tv1 = makeSymbolicTensor(3);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- TensorView* tv3 = add(tv0, tv1);
-
- tv3->merge(0);
- tv3->merge(0);
- tv3->split(0, 128);
- tv3->split(0, 4);
-
- fusion.addOutput(tv3);
-
- tv0->computeAt(tv3, -1);
- tv1->computeAt(tv3, -1);
-
- tv3->axis(0)->parallelize(ParallelType::BIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-2)->parallelize(ParallelType::Unroll);
-
- constexpr int x = 63, y = 33, z = 15;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({1, z}, options);
- at::Tensor t1 = at::randn({x, y, z}, options);
-
- auto aten_output = t0.add(t1);
-
- at::Tensor cg_output = at::empty({x, y, z}, options);
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion(aten_inputs, {cg_output});
-
- testValidate(
- &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSimpleBCast5_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- constexpr int m = 2, k = 3, n = 4;
-
- auto zero = new Int(0);
- auto M = new IterDomain(zero, new Int(m));
- auto K = new IterDomain(zero, new Int(k));
- auto N = new IterDomain(zero, new Int(n));
-
- // Set up your input tensor views
- TensorView* tv0 =
- new TensorView(new TensorDomain({M, K}, {true, true}), DataType::Float);
- // Note: IterDomain must not be reused, so K needs to be cloned.
- TensorView* tv1 = new TensorView(
- new TensorDomain({K->clone(), N}, {true, true}), DataType::Float);
-
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- TensorView* tv2 = broadcast(tv0, {false, false, true});
- TensorView* tv3 = broadcast(tv1, {true, false, false});
-
- TensorView* tv4 = add(tv2, tv3);
-
- fusion.addOutput(tv4);
-
- tv4->merge(0);
- tv4->merge(0);
-
- tv0->computeAt(tv4, -1);
- tv1->computeAt(tv4, -1);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({m, k}, options);
- at::Tensor t1 = at::randn({k, n}, options);
-
- auto t2 = t0.unsqueeze(-1).expand({m, k, n});
- auto t3 = t1.expand({m, k, n});
- auto aten_output = t2.add(t3);
-
- at::Tensor cg_output = at::empty({m, k, n}, options);
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion(aten_inputs, {cg_output});
-
- testValidate(
- &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionComplexBCast1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- int x = 2, y = 3, z = 4;
-
- auto tv0 = makeConcreteTensor({y});
- auto tv1 = div(tv0, new Double(2.0));
- auto tv2 = broadcast(tv1, {false, true});
- auto tv3 = makeConcreteTensor({y, z});
- auto tv4 = mul(tv2, tv3);
- auto tv5 = broadcast(tv4, {true, false, false});
- auto tv6 = makeConcreteTensor({x, y, z});
- auto tv7 = add(tv5, tv6);
-
- // tv0[ i1 ] = input
- // tv1[ i1 ] = tv0/2.0
- // tv2[ i1, b2] = bcast(tv1)
- // tv3[ i1, i2] = input
- // tv4[ i1, i2] = tv2 * tv3
- // tv5[b0, i1, i2] = bcast(tv4)
- // tv6[i0, i1, i2] = input
- // tv7[i0, i1, i2] = tv5 + tv6
-
- // tv4 = bcast(tv1) * tv3
- // tv7 = bcast(tv4) + tv6
-
- fusion.addInput(tv0);
- fusion.addInput(tv3);
- fusion.addInput(tv6);
-
- fusion.addOutput(tv7);
-
- tv7->merge(0);
- tv7->merge(0);
- tv0->computeAt(tv7, -1);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({y}, options);
- at::Tensor t3 = at::randn({y, z}, options);
- at::Tensor t6 = at::randn({x, y, z}, options);
-
- auto t4 = t0.div(2.0).unsqueeze(-1).expand({y, z}) * t3;
- auto aten_output = t4.unsqueeze(0).expand({x, y, z}) + t6;
-
- std::vector<IValue> aten_inputs = {t0, t3, t6};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionComplexBCast2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- int x = 2, y = 3, z = 4;
-
- auto tv0 = makeConcreteTensor({y, z});
- auto tv1 = div(tv0, new Double(2.0));
- auto tv2 = sum(tv1, {1});
- auto tv3 = broadcast(tv2, {true, false});
- auto tv4 = makeConcreteTensor({x, y});
- auto tv5 = add(tv3, tv4);
-
- // tv0[ i1, i2] = input
- // tv1[ i1, i2] = tv0/2.0
- // tv2[ i1 ] = sum(tv1, 1)
- // tv3[b0, i1 ] = bcast(tv2)
- // tv4[i0, i1 ] = input
- // tv5[i0, i1 ] = tv3 + tv4
-
- // tv2 = sum(tv0/2.0, 1)
- // tv5 = bcast(tv2) + tv4
-
- fusion.addInput(tv0);
- fusion.addInput(tv4);
-
- fusion.addOutput(tv5);
-
- tv5->merge(0);
- tv0->computeAt(tv5, -1);
- tv1->computeAt(tv2, -1);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({y, z}, options);
- at::Tensor t4 = at::randn({x, y}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({t0, t4});
-
- auto t1 = t0.div(2.0);
- auto t2 = t1.to(at::kDouble).sum(1);
- auto t3 = t2.unsqueeze(0).expand({x, y});
- auto aten_output = t3.add(t4);
-
- testValidate(
- &fusion, {cg_outputs}, {t0, t4}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- int w = 3, x = 4, y = 7, z = 8;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- auto tv0 = makeSymbolicTensor(3);
- auto tv1 = makeSymbolicTensor(4);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- auto tv2 = add(tv0, new Double(1.0));
- auto tv3 = broadcast(tv2, {true, false, false, false});
- auto tv4 = add(tv3, tv1);
-
- fusion.addOutput(tv4);
-
- tv4->merge(0);
- tv4->merge(0);
- tv4->merge(0);
-
- tv4->split(0, 128);
- tv4->split(0, 4);
-
- tv2->computeAt(tv4, 1);
-
- tv4->axis(0)->parallelize(ParallelType::BIDx);
- tv4->axis(1)->parallelize(ParallelType::Unroll);
- tv4->axis(2)->parallelize(ParallelType::TIDx);
-
- tv3->axis(1)->parallelize(ParallelType::Unroll);
- tv3->axis(2)->parallelize(ParallelType::TIDx);
-
- tv2->axis(1)->parallelize(ParallelType::Unroll);
- tv2->axis(2)->parallelize(ParallelType::TIDx);
-
- FusionExecutor fe;
-
- at::Tensor t0 = at::randn({x, y, z}, options);
- at::Tensor t1 = at::randn({w, x, y, z}, options);
-
- auto t3 = t0.add(1.0);
- auto aten_output = t3.add(t1);
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- int w = 3, x = 4, y = 7, z = 8;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- auto tv0 = makeSymbolicTensor(3);
- auto tv1 = makeSymbolicTensor(4);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- auto tv2 = add(tv0, new Double(1.0));
- auto tv3 = broadcast(tv2, {true, false, false, false});
- auto tv4 = add(tv3, tv1);
-
- fusion.addOutput(tv4);
-
- tv4->merge(-2);
- tv4->merge(-2);
- tv4->merge(-2);
-
- tv4->split(0, 128);
- tv4->split(0, 4);
-
- tv2->computeAt(tv4, 1);
-
- tv4->axis(0)->parallelize(ParallelType::BIDx);
- tv4->axis(1)->parallelize(ParallelType::Unroll);
- tv4->axis(2)->parallelize(ParallelType::TIDx);
-
- tv3->axis(1)->parallelize(ParallelType::Unroll);
- tv3->axis(2)->parallelize(ParallelType::TIDx);
-
- tv2->axis(1)->parallelize(ParallelType::Unroll);
- tv2->axis(2)->parallelize(ParallelType::TIDx);
-
- FusionExecutor fe;
-
- at::Tensor t0 = at::randn({x, y, z}, options);
- at::Tensor t1 = at::randn({w, x, y, z}, options);
-
- auto t3 = t0.add(1.0);
- auto aten_output = t3.add(t1);
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing3_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- int w = 3, x = 4, y = 7, z = 8;
-
- auto tv0 = makeSymbolicTensor(3);
- auto tv1 = makeSymbolicTensor(4);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- auto tv2 = add(tv0, new Double(1.0));
- auto tv3 = add(tv2, tv1);
- fusion.addOutput(tv3);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({x, y, z}, options);
- at::Tensor t1 = at::randn({w, x, y, z}, options);
-
- auto t2 = t0.add(1.0);
- auto aten_output = t2.add(t1);
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- auto lparams = schedulePointwise(&fusion, aten_inputs);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing4_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeConcreteTensor({10, 20});
- fusion.addInput(tv0);
- TensorView* tv1 = makeConcreteTensor({10, 10, 20});
- fusion.addInput(tv1);
-
- TensorView* tv2 = add(tv0, new Double(1));
- TensorView* tv3 = broadcast(tv2, {true, false, false});
- TensorView* tv4 = add(tv3, tv1);
- fusion.addOutput(tv4);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({10, 20}, options);
- at::Tensor t1 = at::randn({10, 10, 20}, options);
-
- auto t2 = t0.add(1.0);
- auto aten_output = t2.add(t1);
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing5_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- TensorView* tv1 = makeSymbolicTensor(3);
- fusion.addInput(tv1);
-
- TensorView* tv2 = add(tv0, new Double(1));
- TensorView* tv3 = broadcast(tv2, {true, false, true});
- TensorView* tv4 = add(tv3, tv1);
- fusion.addOutput(tv4);
-
- tv3->merge(0)->merge(0)->split(0, 2)->split(0, 3);
- tv4->merge(0)->merge(0)->split(0, 2)->split(0, 3);
-
- tv0->computeAt(tv4, 1);
- tv1->computeAt(tv4, 1);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({7}, options);
- at::Tensor t1 = at::randn({5, 7, 11}, options);
-
- auto t2 = t0.add(1.0);
- auto aten_output = t2.unsqueeze(-1).add(t1);
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing6_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- std::vector<int64_t> tensor0_shape{7, 4, 7};
- std::vector<int64_t> tensor1_shape{4, 7};
-
- TensorView* tv0 = makeSymbolicTensor(tensor0_shape.size());
- fusion.addInput(tv0);
- TensorView* tv1 = makeSymbolicTensor(tensor1_shape.size());
- fusion.addInput(tv1);
-
- TensorView* tv2 = add(tv0, tv1);
- TensorView* tv3 = sum(tv2, {0, 1});
- fusion.addOutput(tv3);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor input0 = at::randn(tensor0_shape, options);
- at::Tensor input1 = at::randn(tensor1_shape, options);
-
- std::vector<int64_t> reduction_axes{0, 1};
- auto reduction_params = getReductionHeuristics(&fusion, {input0, input1});
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
- scheduleReduction(&fusion, reduction_params.value());
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs =
- fe.runFusion({input0, input1}, reduction_params.value().lparams);
-
- auto aten_output = input0.add(input1).to(at::kDouble).sum(reduction_axes);
-
- testValidate(
- &fusion,
- cg_outputs,
- {input0, input1},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- reduction_params.value().lparams);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing7_CUDA) {
- // Might be able to use this one without 6 as the heuristics in 6 may change
- // and this test is to cover the same issue.
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
-
- auto tv1 = broadcast(tv0, {false, true});
-
- auto tv2 = makeSymbolicTensor(2);
- fusion.addInput(tv2);
-
- auto tv3 = add(tv1, tv2);
- auto tv4 = sum(tv3, {0, 1});
- fusion.addOutput(tv4);
-
- tv4->merge(0, 1);
- tv4->split(0, 128);
- tv4->split(0, 4);
-
- auto tv5 = tv4->rFactor({0, 1});
-
- tv5->computeAt(tv4, -1);
- tv0->computeAt(tv5, -1);
-
- tv4->axis(0)->parallelize(ParallelType::TIDx);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- const int numel_x = 100;
- const int numel_y = 200;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto at_t0 = at::randn({numel_x}, options);
- auto at_t1 = at::randn({numel_x, numel_y}, options);
-
- auto cg_outputs = fe.runFusion({at_t0, at_t1});
-
- auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1)
- .to(at::kDouble)
- .sum();
-
- testValidate(
- &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing8_CUDA) {
- // Same as 7 but with outer splits instead of inner
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
-
- auto tv1 = broadcast(tv0, {false, true});
-
- auto tv2 = makeSymbolicTensor(2);
- fusion.addInput(tv2);
-
- auto tv3 = add(tv1, tv2);
- auto tv4 = sum(tv3, {0, 1});
- fusion.addOutput(tv4);
-
- tv4->merge(0, 1);
- tv4->split(0, 128, false);
- tv4->split(0, 4, false);
-
- auto tv5 = tv4->rFactor({0, 1});
-
- tv5->computeAt(tv4, -1);
- tv0->computeAt(tv5, -1);
-
- tv4->axis(0)->parallelize(ParallelType::TIDx);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- const int numel_x = 100;
- const int numel_y = 200;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto at_t0 = at::randn({numel_x}, options);
- auto at_t1 = at::randn({numel_x, numel_y}, options);
-
- auto cg_outputs = fe.runFusion({at_t0, at_t1});
-
- auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1)
- .to(at::kDouble)
- .sum();
-
- testValidate(
- &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing9_CUDA) {
- // Same as 7 but with outer splits instead of inner
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
-
- auto tv1 = broadcast(tv0, {false, true});
-
- auto tv2 = mul(tv1, new Double(2));
- fusion.addOutput(tv2);
-
- auto tv3 = makeSymbolicTensor(3);
- fusion.addInput(tv3);
-
- auto tv4 = add(tv3, tv2);
- fusion.addOutput(tv4);
-
- const int numel_x = 200;
- const int numel_y = 300;
- const int numel_z = 400;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto at_t0 = at::randn({numel_y}, options);
- auto at_t3 = at::randn({numel_x, numel_y, numel_z}, options);
- std::vector<IValue> aten_inputs = {at_t0, at_t3};
-
- auto lparams = schedulePointwise(&fusion, aten_inputs);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
- auto at_t1 = at_t0.unsqueeze(-1);
- auto at_t2 = at_t1.mul(2.0);
-
- auto at_t4 = at_t3.add(at_t2);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {at_t2, at_t4}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing10_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeContigTensor(2);
- TensorView* tv1 = makeContigTensor(2);
-
- // Register your inputs
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- // Do math with it, it returns a `Val*` but can be static_casted back to
- // TensorView
- TensorView* tv2 = add(tv1, new Double(2.0));
- TensorView* tv3 = add(tv0, tv2);
-
- // Register your outputs
- fusion.addOutput(tv3);
-
- auto tv0_cache = tv0->cache_after();
- auto tv1_cache = tv1->cache_after();
-
- std::vector<TensorView*> tvs = {tv0_cache, tv1_cache, tv2, tv3};
-
- for (auto tv : tvs) {
- tv->split(1, 2, false);
- tv->split(1, 1);
- tv->split(-1, 4);
- // [I0, 2, 1, I1/2/4, 4]
- tv->reorder({{1, 2}, {2, 3}, {3, 1}});
- tv->axis(0)->parallelize(ParallelType::BIDx);
- tv->axis(1)->parallelize(ParallelType::TIDx);
- }
-
- // For all inputs, computeAt the output inline, temporaries should be squeezed
- // between them
- tv0->computeAt(tv3, 1);
- tv1->computeAt(tv3, 1);
-
- tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize);
- tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor input1 = at::randn({64, 128}, options);
- at::Tensor input2 = at::rand_like(input1);
- at::Tensor output = at::empty_like(input1);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({input1, input2}, {output});
-
- at::Tensor tv2_ref = input2 + 2.0;
- at::Tensor output_ref = input1 + tv2_ref;
-
- TORCH_CHECK(output_ref.equal(output));
-}
-
-TEST(NVFuserTest, FusionAdvancedIndexing11_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- int w = 3, x = 4, y = 7, z = 8;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- auto tv0 = makeSymbolicTensor(4);
- auto tv1 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- auto tv2 = add(tv1, new Double(1.0));
- auto tv3 = broadcast(tv2, {true, false, true, true});
- auto tv4 = add(tv3, tv0);
-
- fusion.addOutput(tv4);
-
- tv4->merge(0);
- tv4->merge(1);
-
- tv4->split(1, 32);
- tv4->split(0, 1);
-
- tv4->reorder({{2, 1}});
-
- tv2->computeAt(tv4, 3);
-
- tv2->setMemoryType(MemoryType::Global);
-
- tv4->axis(0)->parallelize(ParallelType::BIDx);
- tv4->axis(1)->parallelize(ParallelType::BIDy);
- tv4->axis(2)->parallelize(ParallelType::Unswitch);
- tv4->axis(-1)->parallelize(ParallelType::TIDx);
-
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
- FusionExecutor fe;
-
- at::Tensor t0 = at::randn({w, x, y, z}, options);
- at::Tensor t1 = at::randn({x}, options);
-
- auto t3 = t1.add(1.0).unsqueeze(-1).unsqueeze(-1);
- auto aten_output = t3.add(t0);
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-// Intended to stress the lowering of our code generator
-TEST(NVFuserTest, FusionAdvancedLowering1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeConcreteTensor({9, 5});
- fusion.addInput(tv0);
-
- TensorView* tv1 = add(tv0, new Double(1));
- TensorView* tv2 = add(tv1, new Double(2));
- TensorView* tv3 = add(tv1, new Double(3));
- TensorView* tv4 = sum(tv3, {1});
-
- fusion.addOutput(tv2);
- fusion.addOutput(tv4);
-
- tv4->split(1, 4);
- auto tv5 = tv4->rFactor({2});
-
- tv1->computeAt(tv5, 2);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(1);
- at::Tensor aten_input = at::randn({9, 5}, options);
-
- auto t1 = aten_input.add(1.0);
- auto t2 = t1.add(2.0);
- auto t3 = t1.add(3.0);
- auto t4 = t3.sum(1);
-
- std::vector<at::Tensor> aten_outputs = {t2, t4};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto cg_outputs = fe.runFusion({aten_input});
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedLowering2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Progressively broadcast tensors
- TensorView* tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- TensorView* tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
- TensorView* tv2 = makeSymbolicTensor(3);
- fusion.addInput(tv2);
-
- TensorView* tv3 = add(tv0, new Double(1));
- TensorView* tv4 = broadcast(tv3, {false, true});
- TensorView* tv5 = add(tv4, tv1);
- TensorView* tv6 = add(tv5, tv2);
-
- fusion.addOutput(tv6);
-
- // Split inner dimension
- tv6->split(1, 4);
- // Merge middle dims with outer dimensions
- tv6->merge(2);
- tv6->merge(0);
-
- // tv6[I0*I1o, I1i*I2]
-
- // Compute everything inline
- tv0->computeAt(tv6, -1);
-
- tv6->axis(0)->parallelize(ParallelType::BIDx);
- tv6->axis(1)->parallelize(ParallelType::TIDx);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- int x = 13, y = 9, z = 5;
- at::Tensor t0 = at::randn({y}, options);
- at::Tensor t1 = at::randn({y, z}, options);
- at::Tensor t2 = at::randn({x, y, z}, options);
-
- auto t3 = t0.add(1.0);
- auto t4 = t3.unsqueeze(-1);
- auto t5 = t4.add(t1);
- auto t6 = t5.add(t2);
-
- std::vector<IValue> aten_inputs = {t0, t1, t2};
- std::vector<at::Tensor> aten_outputs = {t6};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
-}
-
-// TODO: Complete test
-TEST(NVFuserTest, FusionAdvancedLowering3_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeConcreteTensor({1, -1});
- auto tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- // [b0, i1]
- auto tv2 = add(tv0, new Double(2.0));
-
- // [i0, i1]
- auto tv3 = add(tv1, new Double(3.0));
-
- // [b0, i1]
- auto tv4 = add(tv2, new Double(4.0));
-
- // [io, i1]
- auto tv5 = add(tv2, tv3);
-
- fusion.addOutput(tv4);
- fusion.addOutput(tv5);
-
- tv0->computeAt(tv4, -1);
-
- tv3->setMemoryType(MemoryType::Global);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- int x = 13, y = 9;
- at::Tensor t0 = at::randn({1, y}, options);
- at::Tensor t1 = at::randn({x, y}, options);
-
- auto t4 = t0 + 2 + 4;
- auto t5 = t0 + 2 + t1 + 3;
-
- std::vector<IValue> aten_inputs = {t0, t1};
- std::vector<at::Tensor> aten_outputs = {t4, t5};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
-}
-
-// This excercises indexing with broadcast root axes. Non-broadcast
-// axes need to be preferred when propagating index exprs to root
-// axes. See, e.g., Index::getConsumerIndex_impl.
-TEST(NVFuserTest, FusionAdvancedLowering4_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- auto tv1 = broadcast(tv0, {false, true});
- auto tv2 = broadcast(tv1, {false, false, true});
- auto tv3 = makeSymbolicTensor(3);
- fusion.addInput(tv3);
- auto tv4 = add(tv2, tv3);
- fusion.addOutput(tv4);
-
- tv4->merge(1)->merge(0);
- tv4->split(0, 8);
- tv0->computeAt(tv4, 1);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const int bx = 10;
- const int by = 20;
- const int bz = 30;
- at::Tensor t0 = at::randn({bx}, options);
- at::Tensor t3 = at::randn({bx, by, bz}, options);
- std::vector<IValue> aten_inputs = {t0, t3};
-
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- auto aten_output =
- t0.unsqueeze(-1).expand({bx, by}).unsqueeze(-1).expand({bx, by, bz}) + t3;
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionAdvancedLowering5_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeConcreteTensor({5, 4, 3});
- fusion.addInput(tv0);
-
- TensorView* tv1 = makeConcreteTensor({5, 3});
- fusion.addInput(tv1);
-
- auto tv2 = broadcast(tv1, {false, true, false});
-
- auto tv3 = add(tv0, tv2);
-
- fusion.addOutput(tv3);
-
- tv2->merge(0);
- tv1->computeAt(tv2, 1);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(1);
- at::Tensor t0 = at::randn({5, 4, 3}, options);
- at::Tensor t1 = at::randn({5, 3}, options);
- auto t2 = t1.unsqueeze(1);
- auto t3 = t0 + t2;
-
- std::vector<IValue> aten_inputs = {t0, t1};
- std::vector<at::Tensor> aten_outputs = {t3};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
-}
-
-// Test a simple Gemm but also play around with fusion executor features
-TEST(NVFuserTest, FusionSimpleGemm_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2); // M, K
- TensorView* tv1 = makeSymbolicTensor(2); // K, N
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- TensorView* tv2 = broadcast(tv0, {false, false, true});
- // tv2[I0, I1, B] = tv0[I0, I1]
-
- TensorView* tv3 = broadcast(tv1, {true, false, false});
- // tv3[B, I1, I2] = tv1[I1, I2]
-
- // tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2]
- TensorView* tv4 = mul(tv2, tv3);
- // tv5[I0, R1, I2] = tv4[I0, I1, I2]
- TensorView* tv5 = sum(tv4, {1});
- fusion.addOutput(tv5);
-
- tv5->split(1, 32);
- // tv5[I0, R1o, R1i{32}, I2]
-
- auto tv6 = tv5->rFactor({1});
- // tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2]
- // tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2]
-
- tv5->split(0, 4);
- tv5->split(-1, 4);
- // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
- // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
-
- tv0->computeAt(tv5, -1);
- tv1->computeAt(tv5, -1);
-
- // tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
- // tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}]
- //--> (line symbolizes compute at location)
- // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o]
- // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
- // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
-
- tv0->computeAt(tv6, -1);
- tv1->computeAt(tv6, -1);
- // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
- // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
- // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
-
- tv5->axis(0)->parallelize(ParallelType::BIDz);
- tv5->axis(1)->parallelize(ParallelType::TIDz);
-
- tv5->axis(-2)->parallelize(ParallelType::BIDy);
- tv5->axis(-1)->parallelize(ParallelType::TIDy);
-
- tv5->axis(2)->parallelize(ParallelType::TIDx);
- tv6->axis(2)->parallelize(ParallelType::TIDx);
-
- constexpr int M = 65, K = 33, N = 17;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({M, K}, options);
- at::Tensor t1 = at::randn({K, N}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- // Lets specify a few bounds in launch params to make sure it works
- fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4));
-
- // Make sure bad launch params throws
- // TODO: Re-enable once we have parallelization validation in.
- // ASSERT_ANY_THROW(fe.runFusion({t0, t1}, LaunchParams(1, 2, 3, 4, 5, 6)));
-
- // Don't specify any launch params
- auto cg_outputs = fe.runFusion({t0, t1});
-
- auto aten_output = t0.to(at::kDouble).matmul(t1.to(at::kDouble));
-
- testValidate(
- &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Softmax with a 1D tensor. Parallelized only with a single thread block.
-TEST(NVFuserTest, FusionSoftmax1D_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- const int tidx = 128;
- const int dimx = 1000;
-
- // Set up your input tensor views
- TensorView* input_tv0 = makeSymbolicTensor(1);
- fusion.addInput(input_tv0);
-
- TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0);
- TensorView* sum_exp_tv2 = sum(exp_tv1, {-1});
- TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {true});
-
- // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
- // computed at sum_exp_rf_tv8.
- TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0);
-
- TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3);
-
- fusion.addOutput(output_tv4);
-
- bcast_sum_tv3->split(0, tidx);
-
- sum_exp_tv2->split(-1, tidx);
- TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2});
-
- output_tv4->split(-1, tidx);
-
- exp_tv1->computeAt(sum_exp_rf_tv5, -1);
- exp_tv1_copy->computeAt(output_tv4, -1);
-
- TensorView* tensors_to_parallelize[] = {
- sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5};
-
- for (auto tv : tensors_to_parallelize) {
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({dimx}, options);
- at::Tensor cg_output = at::empty({dimx}, options);
- at::Tensor t3_output = at::empty_like(cg_output, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({t0}, {cg_output});
-
- auto aten_output = at::_softmax(t0.to(at::kDouble), -1, false);
-
- testValidate(&fusion, {cg_output}, {t0}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Softmax with a 1D tensor with input normalization.
-TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- const int tidx = 128;
- const int dimx = 1000;
-
- // Set up your input tensor views
- TensorView* input_tv0 = makeSymbolicTensor(1);
- fusion.addInput(input_tv0);
-
- // Normalize with the max value before computing exp.
- TensorView* max_val_tv1 =
- reductionOp(BinaryOpType::Max, {-1}, new Double(0), input_tv0);
- TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {true});
- TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2);
- TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3);
- TensorView* sum_exp_tv5 = sum(exp_tv4, {-1});
- TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {true});
-
- // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
- // computed at sum_exp_rf_tv8.
- TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2);
- TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy);
-
- TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6);
-
- fusion.addOutput(output_tv7);
- bcast_max_tv2->split(0, tidx);
- bcast_sum_tv6->split(0, tidx);
-
- max_val_tv1->split(-1, tidx);
- TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2});
-
- sum_exp_tv5->split(-1, tidx);
- TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2});
-
- output_tv7->split(-1, tidx);
-
- sub_tv3->computeAt(sum_exp_rf_tv9, -1);
- sub_tv3_copy->computeAt(output_tv7, -1);
-
- TensorView* tensors_to_parallelize[] = {
- max_val_tv1,
- bcast_max_tv2,
- sum_exp_tv5,
- bcast_sum_tv6,
- output_tv7,
- max_val_rf_tv8,
- sum_exp_rf_tv9};
-
- for (auto tv : tensors_to_parallelize) {
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({dimx}, options);
- at::Tensor t3_output = at::empty({dimx}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input});
-
- auto aten_output = at::_softmax(input.to(at::kDouble), -1, false);
-
- testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Softmax with a 3D tensor, where the inner-most 3rd dimension is
-// normalized. Pallelized with multiple thread blocks.
-TEST(NVFuserTest, FusionSoftmax3D_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- const int tidx = 32;
- const int dimx = 32;
- const int dimy = 16;
- const int dimz = 130;
-
- // Set up your input tensor views
- TensorView* input_tv0 = makeSymbolicTensor(3);
- fusion.addInput(input_tv0);
-
- TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0);
- TensorView* sum_exp_tv2 = sum(exp_tv1, {-1});
- TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {false, false, true});
-
- // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
- // computed at sum_exp_rf_tv8.
- TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0);
-
- TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3);
-
- fusion.addOutput(output_tv4);
-
- bcast_sum_tv3->split(-1, tidx);
-
- sum_exp_tv2->split(-1, tidx);
- TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2});
-
- output_tv4->split(-1, tidx);
-
- exp_tv1->computeAt(sum_exp_rf_tv5, -1);
- exp_tv1_copy->computeAt(output_tv4, -1);
-
- TensorView* tensors_to_parallelize[] = {
- sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5};
-
- for (auto tv : tensors_to_parallelize) {
- tv->axis(0)->parallelize(ParallelType::BIDx);
- tv->axis(1)->parallelize(ParallelType::BIDy);
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({dimx, dimy, dimz}, options);
-
- at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({input}, {cg_output});
-
- auto aten_output = at::_softmax(input.to(at::kDouble), -1, false);
-
- testValidate(
- &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Softmax with a 3D tensor with input normalization.
-TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- const int tidx = 32;
- const int dimx = 32;
- const int dimy = 16;
- const int dimz = 130;
-
- // Set up your input tensor views
- TensorView* input_tv0 = makeSymbolicTensor(3);
- fusion.addInput(input_tv0);
-
- // Normalize with the max value before computing exp.
- TensorView* max_val_tv1 =
- reductionOp(BinaryOpType::Max, {-1}, new Double(0), input_tv0);
- TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true});
- TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2);
- TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3);
- TensorView* sum_exp_tv5 = sum(exp_tv4, {-1});
- TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {false, false, true});
-
- // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
- // computed at sum_exp_rf_tv8.
- TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2);
- TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy);
-
- TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6);
-
- fusion.addOutput(output_tv7);
-
- bcast_max_tv2->split(-1, tidx);
- bcast_sum_tv6->split(-1, tidx);
-
- max_val_tv1->split(-1, tidx);
- TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2});
-
- sum_exp_tv5->split(-1, tidx);
- TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2});
-
- output_tv7->split(-1, tidx);
-
- sub_tv3->computeAt(sum_exp_rf_tv9, -1);
- sub_tv3_copy->computeAt(output_tv7, -1);
-
- TensorView* tensors_to_parallelize[] = {
- max_val_tv1,
- bcast_max_tv2,
- sum_exp_tv5,
- bcast_sum_tv6,
- output_tv7,
- max_val_rf_tv8,
- sum_exp_rf_tv9};
-
- for (auto tv : tensors_to_parallelize) {
- tv->axis(0)->parallelize(ParallelType::BIDx);
- tv->axis(1)->parallelize(ParallelType::BIDy);
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({dimx, dimy, dimz}, options);
- at::Tensor t3_output = at::empty({dimx, dimy, dimz}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input});
-
- auto aten_output = at::_softmax(input.to(at::kDouble), -1, false);
-
- testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSoftmaxComputeAt_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- auto tv1 = sum(tv0, {1});
- auto tv2 = broadcast(tv1, {false, true});
-
- auto tv3 = add(tv0, new Double(1.0));
-
- auto tv4 = mul(tv2, tv3);
-
- auto tv5 = sum(tv4, {1});
- auto tv6 = broadcast(tv5, {false, true});
-
- auto tv7 = sub(tv6, tv4);
- fusion.addOutput(tv7);
-
- tv1->computeAt(tv7, 1);
- ASSERT_ANY_THROW(tv1->computeAt(tv7, -1));
-}
-
-// Similar to FusionReduction but uses grid reduction
-TEST(NVFuserTest, FusionGridReduction1_CUDA) {
- const int gdimx = 32;
- const int bdimx = 128;
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- // tv1[I0, R1] = tv0[I0, I1]
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
- tv1->split(1, bdimx);
- // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
- tv1->split(1, gdimx);
- // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
-
- TensorView* tv2 = tv1->rFactor({1});
- // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
- // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
-
- // Incrementally, can print in between for debugging
- tv0->computeAt(tv2, 1);
- tv2->computeAt(tv1, 1);
-
- // Re do it all at once, because why not.
- tv0->computeAt(tv1, 1);
-
- tv1->axis(0)->parallelize(ParallelType::BIDy);
- tv1->axis(1)->parallelize(ParallelType::BIDx);
- tv2->axis(2)->parallelize(ParallelType::BIDx);
-
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
- // reduced shape for OOM on upstream CI
- int numel_x = 1000;
- int numel_y = 65000;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y}, options);
- at::Tensor cg_output = at::empty({numel_x}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({input}, {cg_output});
-
- auto aten_output = input.to(at::kDouble).sum({1});
-
- testValidate(
- &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Same test as the above but uses BIDy and TIDx for reduction
-TEST(NVFuserTest, FusionGridReduction2_CUDA) {
- const int gdimy = 32;
- const int bdimx = 128;
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- // tv1[I0, R1] = tv0[I0, I1]
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
- tv1->split(1, bdimx);
- // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
- tv1->split(1, gdimy);
- // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
-
- TensorView* tv2 = tv1->rFactor({1});
- // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
- // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
-
- // Incrementally, can print in between for debugging
- tv0->computeAt(tv2, 1);
- tv2->computeAt(tv1, 1);
-
- // Re do it all at once, because why not.
- tv0->computeAt(tv1, 1);
-
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv1->axis(1)->parallelize(ParallelType::BIDy);
- tv2->axis(2)->parallelize(ParallelType::BIDy);
-
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
- // reduced shape for OOM on upstream CI
- int numel_x = 1000;
- int numel_y = 65000;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input});
-
- auto aten_output = input.to(at::kDouble).sum({1});
-
- testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Same test but uses BIDy and BIDz for reduction. No TID used.
-TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) {
- // Grid reductions when there aren't any threads are serial reductions
- // keep these numbers low so our error isn't too high compared to normal cuda
- // reductions
- const int gdimz = 15;
- const int gdimy = 9;
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- // tv1[I0, R1] = tv0[I0, I1]
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
- tv1->split(1, gdimy);
- // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
- tv1->split(1, gdimz);
- // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
-
- TensorView* tv2 = tv1->rFactor({1});
- // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
- // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
-
- // Incrementally, can print in between for debugging
- tv0->computeAt(tv2, 1);
- tv2->computeAt(tv1, 1);
-
- // Re do it all at once, because why not.
- tv0->computeAt(tv1, 1);
-
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv1->axis(1)->parallelize(ParallelType::BIDz);
- tv2->axis(2)->parallelize(ParallelType::BIDz);
- tv1->axis(-1)->parallelize(ParallelType::BIDy);
- tv2->axis(-1)->parallelize(ParallelType::BIDy);
-
- int numel_x = 100;
- int numel_y = 6500;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y}, options);
- at::Tensor cg_output = at::empty({numel_x}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({input}, {cg_output});
-
- auto aten_output = input.to(at::kDouble).sum({1});
- testValidate(
- &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0
-TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) {
- // Grid reductions when there aren't any threads are serial reductions
- // keep these numbers low so our error isn't too high compared to normal cuda
- // reductions
- const int gdimz = 15;
- const int gdimy = 9;
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- // tv1[R0, I1] = tv0[I0, I1]
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {0}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
- tv1->split(0, gdimy);
- // tv1[R0o, R0i{128}, I1] = tv0[I0, I1]
- tv1->split(0, gdimz);
- // tv1[R0oo, R0oi{32}, R0i{128}, I1] = tv0[I0, I1]
-
- TensorView* tv2 = tv1->rFactor({0});
- // tv2[R0oo, I0oi{32}, I0i{128}, I1] = tv0[I0, I1]
- // tv1[ R0oi{32}, R0i{128}, I1] = tv2[R0oo, I0oi{32}, I0i{128}, I1]
-
- // Note that computeAt isn't going to make anything better as there
- // is no dynamically sized dimension.
-
- // Map parallelism as [Serial, BIDz, BIDy, BIDx]
- tv1->axis(-1)->parallelize(ParallelType::BIDx);
- tv2->axis(-1)->parallelize(ParallelType::BIDx);
- tv1->axis(-2)->parallelize(ParallelType::BIDy);
- tv2->axis(-2)->parallelize(ParallelType::BIDy);
- tv1->axis(-3)->parallelize(ParallelType::BIDz);
- tv2->axis(-3)->parallelize(ParallelType::BIDz);
-
- int numel_x = 6500;
- int numel_y = 100;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input});
-
- auto aten_output = input.to(at::kDouble).sum({0});
-
- testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// This is similar to the FusionReduction, but swaps BIDx and TIDx
-TEST(NVFuserTest, FusionGridReduction4_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- const int bdimx = 128;
- const int gdimx = 1024;
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- // tv1[I0, R1] = tv0[I0, I1]
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
- tv1->split(1, gdimx);
- // tv1[I0, R1o, R1i{1024}] = tv0[I0, I1]
- tv1->split(1, 4);
- // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1]
-
- TensorView* tv2 = tv1->rFactor({1});
- // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
- // tv1[I0, R1oi{4}, R1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
-
- TensorView* tv3 = tv1->rFactor({1});
- // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
- // tv3[I0, R1oi{4}, Ir1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
- // tv1[I0, R1i{1024}] = tv3[I0, R1oi{4}, Ir1i{1024}]
-
- // Incrementally, can print in between for debugging
- tv0->computeAt(tv2, 1);
- tv2->computeAt(tv3, 1);
- tv3->computeAt(tv1, 1);
-
- // Re do it all at once, because why not.
- tv0->computeAt(tv1, 1);
-
- tv2->axis(2)->parallelize(ParallelType::Unroll);
- tv1->axis(0)->parallelize(ParallelType::TIDx);
-
- tv1->axis(-1)->parallelize(ParallelType::BIDx);
- tv2->axis(-1)->parallelize(ParallelType::BIDx);
- tv3->axis(-1)->parallelize(ParallelType::BIDx);
-
- int numel_x = bdimx;
- int numel_y = 65000;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y}, options);
- at::Tensor cg_output = at::empty({numel_x}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({input}, {cg_output});
-
- auto aten_output = input.to(at::kDouble).sum({1});
- testValidate(
- &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Grid reduction with 2D thread blocks but only TIDx and BIDx are
-// mapped to a reduction dim
-TEST(NVFuserTest, FusionGridReduction5_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- const int bdimx = 64;
- const int bdimy = 16;
- const int gdimx = 4;
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- // tv1[I0, R1] = tv0[I0, I1]
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
- tv1->split(1, bdimx);
- // tv1[I0, R1o, R1i{64}] = tv0[I0, I1]
- tv1->split(1, gdimx);
- // tv1[I0, R1oo, R1oi{4}, R1i{64}] = tv0[I0, I1]
-
- TensorView* tv2 = tv1->rFactor({1});
- // tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}] = tv0[I0, I1]
- // tv1[I0, R1oi{4}, R1i{64}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}]
-
- tv0->computeAt(tv1, 1);
-
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
- tv1->axis(-2)->parallelize(ParallelType::BIDx);
- tv2->axis(-2)->parallelize(ParallelType::BIDx);
-
- tv1->axis(0)->parallelize(ParallelType::TIDy);
-
- int numel_x = bdimy;
- int numel_y = 6500;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input});
-
- auto aten_output = input.to(at::kDouble).sum({1});
- testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Similar to FusionGridReduction1 but with 3D tensors
-TEST(NVFuserTest, FusionGridReduction6_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(3);
- fusion.addInput(tv0);
-
- // tv1[I0, R1, R2] = tv0[I0, I1, I2]
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
-
- // Splitting for TID
- tv1->split(2, 128);
- // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
-
- // Splitting for BID
- tv1->split(1, 128);
-
- // tv1[I0, R1o, R1i{128}, R2o, R2i{128}] = tv0[I0, I1, I2]
-
- TensorView* tv2 = tv1->rFactor({3});
- // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
- // tv1[I0, R1o, R1i{128}, R2i{128}]
-
- TensorView* tv3 = tv1->rFactor({1});
- // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
- // tv3[I0, R1o, I1i{128}, I2i{128}]
- // tv1[I0, R1i{128}, R2i{128}]
-
- tv3->computeAt(tv1, 1);
- tv2->computeAt(tv3, 3);
-
- tv1->axis(0)->parallelize(ParallelType::BIDy);
-
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
- tv1->axis(-2)->parallelize(ParallelType::BIDx);
- tv2->axis(-3)->parallelize(ParallelType::BIDx);
- tv3->axis(-2)->parallelize(ParallelType::BIDx);
-
- int numel_x = 6500;
- int numel_y = 200;
- int numel_z = numel_y;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y, numel_z}, options);
- at::Tensor cg_output = at::empty({numel_x}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({input}, {cg_output});
-
- auto aten_output = input.to(at::kDouble).sum({1, 2});
-
- testValidate(
- &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// See issue #1049
-TEST(NVFuserTest, FusionGridReduction7_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
-
- auto tv1 = sum(tv0, {0});
- fusion.addOutput(tv1);
-
- tv1->split(0, 1000);
-
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv1->axis(1)->parallelize(ParallelType::BIDy);
-
- const int numel_x = 1;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x}, options);
- at::Tensor cg_output = at::empty({numel_x}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto out = fe.runFusion({input});
-
- auto aten_output = input.sum({0});
-
- testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) {
- int bid_x = 3;
- int tid_x = 2;
- int red_dim = 0;
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- TensorView* tv1 =
- reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- tv1->split(-1, tid_x);
- tv1->axis(-2)->parallelize(ParallelType::BIDx);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({16, bid_x * tid_x}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input});
-
- auto aten_output = input.to(at::kDouble).sum({red_dim});
-
- testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSplitBCast_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* input_tv0 = makeSymbolicTensor(3);
- TensorView* input_tv1 = makeSymbolicTensor(3);
- fusion.addInput(input_tv0);
- fusion.addInput(input_tv1);
-
- TensorView* sum_tv2 =
- reductionOp(BinaryOpType::Add, {2}, new Double(0), input_tv0);
- TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true});
- TensorView* output_tv4 = div(input_tv1, bcast_tv3);
-
- sum_tv2->split(-1, 32);
- TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2});
-
- bcast_tv3->split(-1, 32);
- output_tv4->split(-1, 32);
-
- sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx);
- sum_tv2->axis(0)->parallelize(ParallelType::BIDx);
- bcast_tv3->axis(0)->parallelize(ParallelType::BIDx);
- output_tv4->axis(0)->parallelize(ParallelType::BIDx);
-
- sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy);
- sum_tv2->axis(1)->parallelize(ParallelType::BIDy);
- bcast_tv3->axis(1)->parallelize(ParallelType::BIDy);
- output_tv4->axis(1)->parallelize(ParallelType::BIDy);
-
- sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx);
- sum_tv2->axis(-1)->parallelize(ParallelType::TIDx);
- bcast_tv3->axis(-1)->parallelize(ParallelType::TIDx);
- output_tv4->axis(-1)->parallelize(ParallelType::TIDx);
-
- fusion.addOutput(output_tv4);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({32, 32, 128}, options);
- at::Tensor t1 = at::randn({32, 32, 128}, options);
- at::Tensor cg_output = at::empty({32, 32, 128}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({t0, t1}, {cg_output});
-}
-
-TEST(NVFuserTest, FusionBCastInnerDim_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- // reduce then broadcast
- auto tv1 = sum(tv0, {0});
- auto tv2 = broadcast(tv1, {false, true});
-
- TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast());
-}
-
-TEST(NVFuserTest, FusionBCastReduce_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
-
- auto tv1 = broadcast(tv0, {true, false, false});
- auto tv2 = sum(tv1, {1});
- TORCH_CHECK(
- tv2->axis(0)->isBroadcast() && tv2->axis(1)->isReduction() &&
- !tv2->axis(2)->isBroadcast() && !tv2->axis(2)->isReduction());
-}
-
-// Multiple consumer reduction with computeAt
-// https://github.com/csarofeen/pytorch/issues/110
-TEST(NVFuserTest, FusionReductionMultiConsumer_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = unaryOp(UnaryOpType::Exp, tv0);
- auto tv2 = reductionOp(BinaryOpType::Max, {-1}, new Double(0), tv1);
- auto tv3 = reductionOp(BinaryOpType::Min, {-1}, new Double(0), tv1);
- auto tv4 = add(tv2, tv3);
- fusion.addOutput(tv4);
- tv1->computeAt(tv2, -1, ComputeAtMode::BestEffort);
-
- TORCH_CHECK(tv1->getComputeAtPosition() == 2);
-}
-
-TEST(NVFuserTest, FusionComputeAtExprOrder1_CUDA) {
- for (int i = 0; i < 2; ++i) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
-
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv0, new Double(1));
- TensorView* tv3 = add(tv1, tv2);
- // Set outputs tv2 or tv1 and then tv3
- if (i == 0) {
- fusion.addOutput(tv2);
- } else {
- fusion.addOutput(tv1);
- }
- fusion.addOutput(tv3);
-
- if (i == 0) {
- tv1->computeAt(tv3, -1);
- } else {
- tv2->computeAt(tv3, -1);
- }
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({100}, options);
- std::vector<at::Tensor> aten_outputs = {
- aten_input + 1, (aten_input + 1) * 2};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
- }
-}
-
-TEST(NVFuserTest, FusionComputeAtExprOrder2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv0, new Double(1));
- TensorView* tv3 = add(tv1, tv2);
- fusion.addOutput(tv3);
-
- tv3->split(-1, 32);
-
- tv1->computeAt(tv3, -1);
- tv2->computeAt(tv3, -2);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({100, 100}, options);
- auto aten_output = (aten_input + 1) * 2;
-
- at::Tensor cg_output = at::empty_like(aten_input, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({aten_input}, {cg_output});
-
- testValidate(
- &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionComputeAtExprOrder3_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- const size_t dimx = 13;
- const size_t dimy = 15;
-
- TensorView* tv0 = makeConcreteTensor({dimx, dimy});
- fusion.addInput(tv0);
- TensorView* tv1 = add(tv0, new Double(1));
- TensorView* tv2 = add(tv1, new Double(2));
- TensorView* tv3 = add(tv2, new Double(3));
- TensorView* tv4 = add(tv3, new Double(4));
- TensorView* tv5 = mul(tv2, tv4);
- fusion.addOutput(tv5);
-
- tv1->computeAt(tv2, 2);
- tv3->computeAt(tv4, 1);
- tv4->computeAt(tv5, 2);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({dimx, dimy}, options);
- auto t1 = aten_input.add(1.);
- auto t2 = t1.add(2.);
- auto t3 = t2.add(3.);
- auto t4 = t3.add(4.);
- auto aten_output = t2.mul(t4);
-
- torch::jit::fuser::cuda::FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionZeroDimComputeAt_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
-
- auto tv1 = sum(tv0, {0});
- auto tv2 = add(tv1, new Double(1));
- fusion.addOutput(tv2);
- TORCH_CHECK(tv2->nDims() == 0);
- tv1->computeAt(tv2, 0);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({100}, options);
- auto aten_output = aten_input.to(at::kDouble).sum() + 1;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionZeroDimBroadcast_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(0);
- fusion.addInput(tv0);
-
- auto tv1 = broadcast(tv0, {true, true});
- TORCH_CHECK(tv1->nDims() == 2);
-
- TensorView* tv2 = makeSymbolicTensor(2);
- fusion.addInput(tv2);
-
- auto tv3 = add(tv1, tv2);
- auto tv4 = sum(tv3, {0, 1});
- fusion.addOutput(tv4);
-
- tv3->computeAt(tv4, -1);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({}, options);
- at::Tensor t1 = at::randn({10, 10}, options);
-
- auto aten_output = (t0.unsqueeze(-1).unsqueeze(-1).expand({10, 10}) + t1)
- .to(at::kDouble)
- .sum();
-
- std::vector<IValue> aten_inputs = {t0, t1};
- at::Tensor cg_output = at::empty({}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion(aten_inputs, {cg_output});
-
- testValidate(
- &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionZeroDimReduction_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- const int bdimx = 32;
- const int gdimx = 32;
-
- TensorView* tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
-
- auto tv1 = sum(tv0, {0});
- fusion.addOutput(tv1);
-
- tv1->split(0, bdimx);
- tv1->split(0, gdimx);
- auto tv2 = tv1->rFactor({0});
-
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv1->axis(-2)->parallelize(ParallelType::BIDx);
- tv2->axis(-2)->parallelize(ParallelType::BIDx);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({1000}, options);
- auto aten_output = aten_input.to(at::kDouble).sum();
-
- at::Tensor cg_output = at::empty({}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({aten_input}, {cg_output});
-
- testValidate(
- &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionBCastAfterReduce_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
- const int tidx = 128;
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- auto tv1 = sum(tv0, {1});
- auto tv2 = broadcast(tv1, {false, true});
-
- tv1->split(1, tidx);
- auto tv3 = tv1->rFactor({-2});
-
- TensorView* tv4 = makeSymbolicTensor(2);
- fusion.addInput(tv4);
-
- auto tv5 = add(tv2, tv4);
- fusion.addOutput(tv5);
- tv5->split(1, tidx);
-
- tv3->computeAt(tv5, 1);
-
- tv2->split(1, tidx);
-
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv5->axis(-1)->parallelize(ParallelType::TIDx);
-
- tv5->axis(0)->parallelize(ParallelType::BIDx);
-
- int x = 63, y = 200;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor t0 = at::randn({x, y}, options);
- at::Tensor t4 = at::randn({x, y}, options);
-
- auto t3 = t0.to(at::kDouble).sum({1}).unsqueeze(-1).expand({x, y});
- auto aten_output = t3.add(t4);
-
- std::vector<IValue> aten_inputs = {t0, t4};
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({t0, t4});
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionOutputBroadcast_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeConcreteTensor({2, 3});
- fusion.addInput(tv0);
-
- TensorView* tv1 = broadcast(tv0, {true, false, true, false, true});
-
- fusion.addOutput(tv1);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor aten_input = at::randn({2, 3}, options);
- auto aten_output = aten_input.unsqueeze(2).unsqueeze(1).unsqueeze(0);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto cg_outputs = fe.runFusion({aten_input});
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReductionKeepDimBasic_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeConcreteTensor({2, 3, 4, 5, 6});
- fusion.addInput(tv0);
-
- TensorView* tv1 = sum(tv0, {0, 2, 4}, /*keep_dim=*/true);
-
- fusion.addOutput(tv1);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor aten_input = at::randn({2, 3, 4, 5, 6}, options);
- auto aten_output =
- aten_input.to(at::kDouble).sum({0, 2, 4}, /*keepdim=*/true);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto cg_outputs = fe.runFusion({aten_input});
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) {
- constexpr int bid_x = 80;
- constexpr int tid_x = 4096;
- constexpr int red_dim = 1;
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeConcreteTensor({bid_x, tid_x});
- fusion.addInput(tv0);
-
- TensorView* tv1 = reductionOp(
- BinaryOpType::Add, {red_dim}, new Double(0), tv0, /*keep_dim=*/true);
-
- fusion.addOutput(tv1);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor aten_input = at::randn({bid_x, tid_x}, options);
- auto aten_output =
- aten_input.to(at::kDouble).sum({red_dim}, /*keepdim=*/true);
-
- // Apply reduction heuristic
- auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
- scheduleReduction(&fusion, reduction_params.value());
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto lparams = reduction_params.value().lparams;
-
- auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-TEST(NVFuserTest, FusionSumTo_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- std::vector<int64_t> tensor_shape{2, 3, 4, 5, 6};
- std::vector<int64_t> sum_to_shape{1, 5, 6};
-
- std::vector<int64_t> tensor_shape_ref{2, 3, 4, 5, 6};
- std::vector<int64_t> sum_to_shape_ref{1, 5, 6};
-
- std::vector<Int*> sum_to_symb;
- std::transform(
- sum_to_shape.begin(),
- sum_to_shape.end(),
- std::back_inserter(sum_to_symb),
- [](int s) -> Int* { return new Int(s); });
-
- TensorView* tv0 = makeConcreteTensor(tensor_shape);
- fusion.addInput(tv0);
-
- TensorView* tv1 = sum_to(tv0, sum_to_symb);
- fusion.addOutput(tv1);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor aten_input = at::randn(tensor_shape_ref, options);
- auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto cg_outputs = fe.runFusion({aten_input});
-
- TORCH_CHECK(
- cg_outputs[0].dim() == sum_to_shape.size(),
- "sum_to not keeping the final dimension");
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSumToNoop_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- std::vector<int64_t> tensor_shape{4, 5, 6};
- std::vector<int64_t> sum_to_shape{4, 5, 6};
-
- std::vector<int64_t> tensor_shape_ref{4, 5, 6};
- std::vector<int64_t> sum_to_shape_ref{4, 5, 6};
-
- std::vector<Int*> sum_to_symb;
- std::transform(
- sum_to_shape.begin(),
- sum_to_shape.end(),
- std::back_inserter(sum_to_symb),
- [](int s) -> Int* { return new Int(s); });
-
- TensorView* tv0 = makeConcreteTensor(tensor_shape);
- fusion.addInput(tv0);
-
- TensorView* tv1 = sum_to(tv0, sum_to_symb);
-
- // Dummy operator to avoid tv0 both input and output
- TensorView* tv2 = add(tv1, new Double(0));
- fusion.addOutput(tv2);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor aten_input = at::randn(tensor_shape_ref, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto cg_outputs = fe.runFusion({aten_input});
- auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref);
-
- TORCH_CHECK(
- cg_outputs[0].dim() == sum_to_shape.size(),
- "sum_to not keeping the final dimension");
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReductionScheduler_CUDA) {
- constexpr int bid_x = 80;
- constexpr int tid_x = 4096;
- constexpr int red_dim = 1;
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- TensorView* tv1 =
- reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor aten_input = at::randn({bid_x, tid_x}, options);
- auto aten_output = aten_input.to(at::kDouble).sum({red_dim});
-
- // Apply reduction heuristic
- auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
- scheduleReduction(&fusion, reduction_params.value());
-
- auto lparams = reduction_params.value().lparams;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- // no broadcasting needed, omitting the last optional argument;
- auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-// Simple reduction parallelized on a symbolic size.
-TEST(NVFuserTest, FusionSymbolicReduction_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- // tv1[I0, R1] = tv0[I0, I1]
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- // Interface should just be a direct split with a Parallel type. We can
- // include the parallelize call if we do this.
- tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
- // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
-
- TensorView* tv2 = tv1->rFactor({1});
- // tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}] = tv0[I0, I1]
- // tv1[I0, R1oi{4}, R1i{BIDx}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}]
-
- // Incrementally, can print in between for debugging
- tv0->computeAt(tv2, 1);
- tv2->computeAt(tv1, 1);
-
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
-
- int numel_x = 65000;
- int numel_y = 1025;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({numel_x, numel_y}, options);
- auto aten_output = aten_input.to(at::kDouble).sum({1});
-
- // How many threads to use for the block reduction
- int runtime_threadIdx_dim = 128;
-
- LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) {
- const std::vector<int> red_dims = {0, 2};
- // Copy is because CodeGen requires int and Pytorch requires int64_t
- // for a vector of reduction dimensions
- const std::vector<int64_t> red_dims64 = {0, 2};
- const std::vector<int64_t> tensor_dims_in = {5, 10, 15, 20};
- const std::vector<int64_t> tensor_dims_out = {10, 20};
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size());
- fusion.addInput(tv0);
-
- TensorView* tv1 =
- reductionOp(BinaryOpType::Add, red_dims, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn(tensor_dims_in, options);
- auto aten_output = aten_input.to(at::kDouble).sum(red_dims64);
- at::Tensor cg_output = at::empty(tensor_dims_out, options);
-
- // Apply reduction heuristic
- auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
- scheduleReduction(&fusion, reduction_params.value());
- auto lparams = reduction_params.value().lparams;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({aten_input}, {cg_output}, lparams);
-
- testValidate(
- &fusion,
- {cg_output},
- {aten_input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) {
- const std::vector<int> red_dims = {1, 3};
- // Copy is because CodeGen requires int and Pytorch requires int64_t
- // for a vector of reduction dimensions
- const std::vector<int64_t> red_dims64 = {1, 3};
- const std::vector<int64_t> tensor_dims_in = {5, 10, 15, 20};
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size());
- fusion.addInput(tv0);
-
- TensorView* tv1 =
- reductionOp(BinaryOpType::Add, red_dims, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn(tensor_dims_in, options);
- auto aten_output = aten_input.to(at::kDouble).sum(red_dims64);
-
- auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
- scheduleReduction(&fusion, reduction_params.value());
- auto lparams = reduction_params.value().lparams;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-TEST(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) {
- std::vector<DataType> dtypes = {
- DataType::Double, DataType::Float, DataType::Half};
- std::vector<int> red_dims;
-
- // Tried to cut down the number iterations with just
- // doing every other power of 2.
- for (int i = 1; i <= 1024 * 1024; i <<= 2) {
- red_dims.push_back(i);
- }
-
- for (auto dtype : dtypes) {
- at::ScalarType aten_dtype = data_type_to_aten(dtype);
- for (auto& rdim : red_dims) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- bool is_fp16 = dtype == DataType::Half;
-
- TensorView* tv0 = makeSymbolicTensor(1, dtype);
- fusion.addInput(tv0);
-
- TensorView* tv0_cast = tv0;
- if (is_fp16) {
- tv0_cast = castOp(DataType::Float, tv0);
- }
-
- TensorView* tv1 = sum(tv0_cast, {0});
-
- TensorView* tv1_cast = tv1;
- if (is_fp16) {
- tv1_cast = castOp(DataType::Half, tv1);
- }
-
- fusion.addOutput(tv1_cast);
-
- auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
-
- at::Tensor aten_input = at::randn({rdim}, options);
- auto aten_output = aten_input.to(at::kDouble).sum({0});
-
- auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
- TORCH_CHECK(reduction_params.has_value(), "Reduction is not found!");
- scheduleReduction(&fusion, reduction_params.value());
- auto lparams = reduction_params.value().lparams;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
- }
- }
-}
-
-TEST(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) {
- std::vector<DataType> dtypes = {
- DataType::Double, DataType::Float, DataType::Half};
- std::vector<int> red_axis = {1, 0};
- std::vector<int> output_dims = {160, 320};
- std::vector<int> red_dims;
-
- // Tried to cut down the number iterations with just
- // doing every other power of 2.
- for (int i = 1; i <= 1024 * 1024; i <<= 2) {
- red_dims.push_back(i);
- }
-
- for (auto dtype : dtypes) {
- at::ScalarType aten_dtype = data_type_to_aten(dtype);
- for (auto& axis : red_axis) {
- for (auto& odim : output_dims) {
- for (auto& rdim : red_dims) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- bool is_fp16 = dtype == DataType::Half;
-
- TensorView* tv0 = makeSymbolicTensor(2, dtype);
- fusion.addInput(tv0);
-
- TensorView* tv0_cast = tv0;
- if (is_fp16) {
- tv0_cast = castOp(DataType::Float, tv0);
- }
-
- TensorView* tv1 = sum(tv0_cast, {axis});
-
- TensorView* tv1_cast = tv1;
- if (is_fp16) {
- tv1_cast = castOp(DataType::Half, tv1);
- }
-
- fusion.addOutput(tv1_cast);
-
- auto options =
- at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
-
- at::Tensor aten_input =
- (axis ? at::randn({odim, rdim}, options)
- : at::randn({rdim, odim}, options));
-
- auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
- TORCH_CHECK(reduction_params.has_value(), "Reduction is not found!");
- scheduleReduction(&fusion, reduction_params.value());
- auto lparams = reduction_params.value().lparams;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto cg_outputs = fe.runFusion({aten_input}, lparams);
- auto aten_output = aten_input.to(at::kDouble).sum({axis});
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
- }
- }
- }
- }
-}
-
-TEST(NVFuserTest, FusionCacheBefore_CUDA) {
- // TVM Cache Write
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = add(tv0, new Double(1.0));
- TensorView* tv2 = mul(tv1, new Double(3.0));
- fusion.addInput(tv0);
- fusion.addOutput(tv2);
-
- // Before: TV2 = TV1 * 3
- // After: TV3 = TV1 * 3;
- // TV2 = TV3;
- TensorView* tv3 = tv2->cache_before();
-
- constexpr int BSX = 32;
- tv2->split(-1, BSX);
- tv0->computeAt(tv2, -1);
-
- // Thread and Block binding
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
- constexpr int M = 32, N = 750;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({M, N}, options);
- at::Tensor aten_output = (aten_input + 1.0) * 3.0;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionCacheAfter_CUDA) {
- // TVM Cache Read
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = add(tv0, new Double(1.0));
- TensorView* tv2 = mul(tv1, new Double(3.0));
- fusion.addInput(tv0);
- fusion.addOutput(tv2);
-
- // Before: TV1 = TV0 + 1
- // After: TV3 = TV0;
- // TV1 = TV3 + 1
- TensorView* tv3 = tv0->cache_after();
-
- constexpr int BSX = 32;
- tv2->split(-1, BSX);
- tv0->computeAt(tv2, -1);
-
- // Thread and Block binding
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
- constexpr int M = 32, N = 457;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({M, N}, options);
- at::Tensor aten_output = (aten_input + 1.0) * 3.0;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionCacheFork_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = add(tv0, new Double(1.0));
- TensorView* tv2 = mul(tv1, new Double(3.0));
- fusion.addInput(tv0);
- fusion.addOutput(tv1);
- fusion.addOutput(tv2);
- // Before: TV1 = TV0 + 1
- // TV2 = TV1 * 1
- // Output: TV1, TV2
-
- // After: TV1 = TV0 + 1
- // TV3 = TV1
- // TV2 = TV1 * 1
- // Output: TV3, TV2
-
- // cache_fork !!does not!! automatically apply ComputeAt to the cache
- auto tv3 = tv1->cache_fork();
-
- constexpr int BSX = 32;
- tv2->split(-1, BSX);
- tv0->computeAt(tv2, -1);
-
- // Thread and Block binding
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
- constexpr int M = 32, N = 457;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({M, N}, options);
- at::Tensor aten_output1 = aten_input + 1.0;
- at::Tensor aten_output2 = aten_output1 * 3.0;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output1, aten_output2},
- __LINE__,
- __FILE__);
-}
-
-TEST(NVFuserTest, FusionCacheIndirect_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
- TensorView* tv2 = makeSymbolicTensor(2);
- TensorView* tv3 = makeSymbolicTensor(2);
- TensorView* tv4 = sub(tv2, tv3);
- TensorView* tv5 = add(tv1, tv4);
- TensorView* tv6 = sub(tv5, tv0);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addInput(tv2);
- fusion.addInput(tv3);
- fusion.addOutput(tv6);
- // t6 = ((t1 + (t2 - t3)) - t0)
-
- tv5->cache_after();
- tv5->cache_before();
-
- // cache_after on inputs placed before schedule
- constexpr int BSX = 32;
- tv6->split(-1, BSX);
- tv2->computeAt(tv6, -1);
-
- // Thread and Block binding
- tv6->axis(0)->parallelize(ParallelType::BIDx);
- tv6->axis(-1)->parallelize(ParallelType::TIDx);
-
- constexpr int M = 32, N = 810;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({M, N}, options);
- at::Tensor t1 = at::randn({M, N}, options);
- at::Tensor t2 = at::randn({M, N}, options);
- at::Tensor t3 = at::randn({M, N}, options);
-
- std::vector<IValue> aten_inputs = {t0, t1, t2, t3};
- at::Tensor aten_output = (t1 + (t2 - t3)) - t0;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionCacheBcast_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Algorithm
- TensorView* tv0 = makeSymbolicTensor(1); // (M, 1)
- TensorView* tv1 = broadcast(tv0, {false, true});
- TensorView* tv2 = makeSymbolicTensor(1); // (1, N)
- TensorView* tv3 = broadcast(tv2, {true, false});
- TensorView* tv4 = mul(tv1, tv3);
- fusion.addInput(tv0);
- fusion.addInput(tv2);
- fusion.addOutput(tv4);
-
- // Case 1
- tv0->cache_after();
-
- // Case 2
- tv1->cache_before();
-
- // Case 3
- tv1->cache_after();
-
- // Case 4
- TensorView* tv8 = tv4->cache_before();
-
- constexpr int BSX = 128;
- tv4->split(0, BSX);
- tv4->split(-1, BSX);
- tv4->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
- // M/BSX, N/BSY, BSX, BSY
- tv0->computeAt(tv4, 2);
- tv2->computeAt(tv4, 2);
- // 0, 1 | 2, 3, 4
-
- tv4->axis(0)->parallelize(ParallelType::BIDx);
- tv4->axis(1)->parallelize(ParallelType::BIDy);
- tv4->axis(-1)->parallelize(ParallelType::TIDx);
- // Manual Replay on TV3
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv8->axis(-1)->parallelize(ParallelType::TIDx);
-
- constexpr int M = 92, N = 500;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({M}, options);
- at::Tensor t1 = at::randn({N}, options);
- std::vector<IValue> aten_inputs = {t0, t1};
- at::Tensor aten_output =
- t0.to(at::kDouble).unsqueeze(1).matmul(t1.to(at::kDouble).unsqueeze(0));
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionCacheMultiConsumer_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(1);
- TensorView* tv1 = add(tv0, new Double(1));
- TensorView* tv2 = add(tv1, new Double(2));
- TensorView* tv3 = add(tv0, new Double(1));
- TensorView* tv4 = add(tv3, new Double(2));
-
- fusion.addInput(tv0);
- fusion.addOutput(tv2);
- fusion.addOutput(tv4);
-
- auto tv5 = tv1->cache_before();
- auto tv6 = tv3->cache_before();
- tv5->setMemoryType(MemoryType::Shared);
- tv6->setMemoryType(MemoryType::Shared);
-
- tv1->computeAt(tv2, -1);
- tv3->computeAt(tv4, -1);
-
- // Fails because tensor must be recomputed twice
- // auto tv7 = tv0->cache_after();
-
- constexpr int N = 800;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({N}, options);
- auto aten_output = (aten_input + 1) + 2;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output, aten_output},
- __LINE__,
- __FILE__);
-}
-
-TEST(NVFuserTest, FusionSmem_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Algorithm
- TensorView* tv0 = makeSymbolicTensor(2); // (M, N)
- TensorView* tv1 = makeSymbolicTensor(2); // (M, N)
- TensorView* tv2 = mul(tv0, tv1);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addOutput(tv2);
-
- // Schedule
- TensorView* tv3 = tv0->cache_after();
- TensorView* tv4 = tv1->cache_after();
- tv3->setMemoryType(MemoryType::Shared);
- tv4->setMemoryType(MemoryType::Shared);
-
- constexpr int BSY = 32;
- constexpr int BSX = 128;
- tv2->split(0, BSY);
- tv2->split(2, BSX);
- // M/BSX, BSX, N/BSX, BSX
- tv2->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
- // M/BSX, N/BSX, BSX, BSX
-
- tv0->computeAt(tv2, 2);
- tv1->computeAt(tv2, 2);
-
- // Thread and Block binding
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(1)->parallelize(ParallelType::BIDy);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- // Manual Binding
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv4->axis(-1)->parallelize(ParallelType::TIDx);
-
- constexpr int M = 128, N = 10240;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({M, N}, options);
- at::Tensor t1 = at::randn({M, N}, options);
- at::Tensor aten_output = mul(t0, t1);
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({t0, t1});
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-
- TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
-}
-
-TEST(NVFuserTest, FusionSmemReduce_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Algorithm
- TensorView* tv0 = makeSymbolicTensor(3); // M, K, N
- TensorView* tv1 = sum(tv0, {1}); // M, R, N
- fusion.addInput(tv0);
- fusion.addOutput(tv1);
-
- TensorView* tv2 = tv0->cache_after();
- tv2->setMemoryType(MemoryType::Shared);
-
- // Schedule
- constexpr int BSX = 32;
- tv1->split(2, BSX);
- tv1->split(1, 128);
- tv1->split(0, BSX);
- // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
- tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}});
- TensorView* tv3 = tv1->rFactor({-2});
-
- tv0->computeAt(tv1, -2);
- tv0->computeAt(tv3, -2);
-
- // Thread and Block binding
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv1->axis(1)->parallelize(ParallelType::BIDy);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- // Manual Binding
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
- constexpr int M = 154, K = 45, N = 1524;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({M, K, N}, options);
- at::Tensor aten_output = sum(aten_input.to(at::kDouble), {1});
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
- TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
-}
-
-TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Algorithm
- TensorView* tv0 = makeSymbolicTensor(2); // (M, K)
- TensorView* tv1 = makeSymbolicTensor(2); // (K, N)
- TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
- TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
- TensorView* tv4 = mul(tv2, tv3); // M, K, N
- TensorView* tv5 = sum(tv4, {1}); // M, R, N
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addOutput(tv5);
-
- // Schedule
- constexpr int BSX = 16;
- tv5->split(2, BSX);
- tv5->split(1, BSX);
- tv5->split(0, BSX);
- // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
- tv5->reorder({{0, 0}, {1, 3}, {2, 2}, {3, 5}, {4, 1}, {5, 4}});
- // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX
- TensorView* tv6 = tv5->rFactor({-1});
-
- tv2->setMemoryType(MemoryType::Shared);
- tv3->setMemoryType(MemoryType::Shared);
- tv4->setMemoryType(MemoryType::Shared);
- tv6->setMemoryType(MemoryType::Shared);
-
- tv0->computeAt(tv5, 3);
- tv1->computeAt(tv5, 3);
-
- // Thread and Block binding
- tv5->axis(0)->parallelize(ParallelType::BIDx);
- tv5->axis(1)->parallelize(ParallelType::BIDy);
- tv5->axis(-2)->parallelize(ParallelType::TIDy);
- tv5->axis(-1)->parallelize(ParallelType::TIDx);
- // Manual Binding
- tv2->axis(-3)->parallelize(ParallelType::TIDy);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv4->axis(-3)->parallelize(ParallelType::TIDy);
- tv4->axis(-1)->parallelize(ParallelType::TIDx);
- tv6->axis(-3)->parallelize(ParallelType::TIDy);
- tv6->axis(-2)->parallelize(ParallelType::TIDx);
-
- constexpr int M = 154, K = 45, N = 1524;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({M, K}, options);
- at::Tensor t1 = at::randn({K, N}, options);
-
- std::vector<IValue> aten_inputs = {t0, t1};
- at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble));
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({t0, t1});
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-
- TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
-}
-
-TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Algorithm
- TensorView* tv0 = makeSymbolicTensor(2); // (M, K)
- TensorView* tv1 = makeSymbolicTensor(2); // (K, N)
- TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
- TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
- TensorView* tv4 = mul(tv2, tv3); // M, K, N
- TensorView* tv5 = sum(tv4, {1}); // M, R, N
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addOutput(tv5);
-
- // Schedule
- // Remove reduction axis from tv5
- // tv6 = (M, R, N)
- // tv5 = (M, N)
- TensorView* tv6 = tv5->cache_before();
-
- constexpr int BSX = 16;
- tv5->split(1, BSX);
- tv5->split(0, BSX);
- // M/BSX, BSX, N/BSX, BSX
- tv5->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
- // tv5 = M/BSX, N/BSX, MSX, NSX
-
- tv6->computeAt(tv5, 2);
- tv6->computeAt(tv5, 2);
-
- tv6->split(-1, BSX);
- // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
- tv6->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 4}, {4, 2}, {5, 5}});
- // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX
- TensorView* tv7 = tv6->rFactor({-1});
- // tv7 = M/BSX, N/BSX, K/BSXrf, MSX, NSX, KSXr
- // tv6 = M/BSX, N/BSX, K/BSXr, MSX, NSX
-
- tv0->computeAt(tv6, 3);
- tv1->computeAt(tv6, 3);
-
- tv0->computeAt(tv7, 3);
- tv1->computeAt(tv7, 3);
-
- tv2->setMemoryType(MemoryType::Shared);
- tv3->setMemoryType(MemoryType::Shared);
- tv4->setMemoryType(MemoryType::Shared);
- tv6->setMemoryType(MemoryType::Shared);
- tv7->setMemoryType(MemoryType::Shared);
- // Memory Type
-
- // Thread and Block binding
- tv5->axis(0)->parallelize(ParallelType::BIDx);
- tv5->axis(1)->parallelize(ParallelType::BIDy);
- tv5->axis(-2)->parallelize(ParallelType::TIDy);
- tv5->axis(-1)->parallelize(ParallelType::TIDx);
- // Manual Binding
- tv2->axis(-3)->parallelize(ParallelType::TIDy);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv4->axis(-3)->parallelize(ParallelType::TIDy);
- tv4->axis(-1)->parallelize(ParallelType::TIDx);
-
- tv7->axis(-3)->parallelize(ParallelType::TIDy);
- tv7->axis(-2)->parallelize(ParallelType::TIDx);
-
- tv6->axis(-2)->parallelize(ParallelType::TIDy);
- tv6->axis(-1)->parallelize(ParallelType::TIDx);
-
- constexpr int M = 154, K = 45, N = 1524;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({M, K}, options);
- at::Tensor t1 = at::randn({K, N}, options);
- at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble));
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-
- TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
-}
-
-TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* x = makeSymbolicTensor(2);
- fusion.addInput(x);
- TensorView* max_val =
- reductionOp(BinaryOpType::Max, {-1}, new Double(FLT_MIN), x); // (M)
- TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B)
- TensorView* x_max_sub = sub(x, bcast_max); // (M, N)
- TensorView* exp = unaryOp(UnaryOpType::Exp, x_max_sub); // (M, N)
- TensorView* sum_exp = sum(exp, {-1}); // (M, R)
- TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B)
- TensorView* softmax = div(exp, bcast_sum); // (M, N)
- fusion.addOutput(softmax);
-
- // Read Input into Shared Memory
- // Load Input + Pwise into shared memory
- auto cache_x = x->cache_after();
- cache_x->setMemoryType(MemoryType::Shared);
- exp->setMemoryType(MemoryType::Shared);
-
- std::vector<TensorView*> all_tensors(
- {x,
- cache_x,
- max_val,
- bcast_max,
- x_max_sub,
- exp,
- sum_exp,
- bcast_sum,
- softmax});
-
- auto tidx = new Int();
- fusion.addInput(tidx);
-
- for (auto tensor : all_tensors) {
- tensor->split(-1, tidx);
- }
-
- auto sum_exp_rf = sum_exp->rFactor({1});
- all_tensors.push_back(sum_exp_rf);
-
- // computeAt
- x->computeAt(x_max_sub, 1);
- exp->computeAt(softmax, 1);
- x_max_sub->computeAt(exp, 2);
-
- softmax->axis(0)->parallelize(ParallelType::BIDx);
- for (auto tensor : all_tensors) {
- tensor->axis(-1)->parallelize(ParallelType::TIDx);
- }
-
- const size_t dimx = 1024;
- const size_t dimy = 4096;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({dimx, dimy}, options);
- auto aten_output = at::_softmax(aten_input.to(at::kDouble), -1, false);
-
- torch::jit::fuser::cuda::FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input, 128});
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input, 128},
- {aten_output},
- __LINE__,
- __FILE__);
-}
-
-TEST(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- const int kReductionAxis = 3;
- std::vector<int64_t> input_shape{10, 10, 10, 67};
- TensorView* input = makeSymbolicTensor(input_shape.size());
- fusion.addInput(input);
-
- auto output = softmax(input, kReductionAxis);
-
- fusion.addOutput(output);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn(input_shape, options);
- auto aten_output =
- at::_softmax(aten_input.to(at::kDouble), kReductionAxis, false);
-
- auto reduction_params = getNormalizationHeuristics(&fusion, {aten_input});
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-
- scheduleNormalization(&fusion, reduction_params.value());
-
- auto lparams = reduction_params.value().lparams;
-
- torch::jit::fuser::cuda::FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-TEST(NVFuserTest, FusionMagicSchedulerLayerNormBackward_CUDA) {
- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
- Fusion& fusion = *fusion_ptr.get();
- FusionGuard fg(&fusion);
-
- std::vector<int64_t> shape{20, 100, 35, 67};
- std::vector<int64_t> norm_shape{67};
-
- const size_t kM = shape.size();
- const size_t kN = norm_shape.size();
- const size_t kOuterNumDims = kM - kN;
-
- std::vector<int64_t> outer_shape;
- for (size_t idx = 0; idx < kOuterNumDims; ++idx) {
- outer_shape.push_back(shape[idx]);
- }
- for (size_t idx = kOuterNumDims; idx < kM; ++idx) {
- outer_shape.push_back(1);
- }
-
- auto grad_out = makeSymbolicTensor(shape.size());
- auto input = makeSymbolicTensor(shape.size());
- auto mean = makeConcreteTensor(outer_shape);
- auto rstd = makeConcreteTensor(outer_shape);
- auto weight = makeSymbolicTensor(norm_shape.size());
- auto bias = makeSymbolicTensor(norm_shape.size());
- fusion.addInput(grad_out);
- fusion.addInput(input);
- fusion.addInput(mean);
- fusion.addInput(rstd);
- fusion.addInput(weight);
- fusion.addInput(bias);
-
- auto grads = layer_norm_backward(
- grad_out,
- input,
- norm_shape,
- mean,
- rstd,
- weight,
- bias,
- {true, true, true});
-
- fusion.addOutput(grads.grad_input);
- fusion.addOutput(grads.grad_weight);
- fusion.addOutput(grads.grad_bias);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_grad_out = at::randn(shape, options);
- at::Tensor aten_input = at::randn(shape, options);
- at::Tensor aten_weight = at::randn(norm_shape, options);
- at::Tensor aten_bias = at::randn(norm_shape, options);
- auto at_weight = c10::optional<at::Tensor>(aten_weight);
- auto at_bias = c10::optional<at::Tensor>(aten_bias);
-
- const float kEps = 1e-5;
- auto aten_results =
- at::native_layer_norm(aten_input, norm_shape, at_weight, at_bias, kEps);
- auto aten_output = std::get<0>(aten_results);
- auto aten_mean = std::get<1>(aten_results);
- auto aten_rstd = std::get<2>(aten_results);
-
- FusionExecutorCache fec(std::move(fusion_ptr));
- std::vector<IValue> aten_inputs = {
- aten_grad_out, aten_input, aten_mean, aten_rstd, aten_weight, aten_bias};
- auto cg_outputs = fec.runFusionWithInputs(aten_inputs);
-
- auto aten_gradients = at::native_layer_norm_backward(
- aten_grad_out.to(at::kDouble),
- aten_input.to(at::kDouble),
- norm_shape,
- aten_mean.to(at::kDouble),
- aten_rstd.to(at::kDouble),
- c10::optional<at::Tensor>(aten_weight.to(at::kDouble)),
- c10::optional<at::Tensor>(aten_bias.to(at::kDouble)),
- {true, true, true});
-
- testValidate(
- &fusion,
- cg_outputs,
- aten_inputs,
- {std::get<0>(aten_gradients),
- std::get<1>(aten_gradients),
- std::get<2>(aten_gradients)},
- __LINE__,
- __FILE__);
-}
-
-TEST(NVFuserTest, FusionMagicSchedulerLayerNormalization_CUDA) {
- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
- Fusion& fusion = *fusion_ptr.get();
- FusionGuard fg(&fusion);
-
- const float kEps = 1e-5;
- Double* eps_ptr = new Double(kEps);
-
- std::vector<int64_t> input_shape{20, 100, 35, 67};
- std::vector<int64_t> norm_shape{67};
-
- auto input = makeSymbolicTensor(input_shape.size());
- fusion.addInput(input);
-
- auto result = layer_norm(input, norm_shape, nullptr, nullptr, eps_ptr);
-
- fusion.addOutput(result.output);
- fusion.addOutput(result.mean);
- fusion.addOutput(result.invstd);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn(input_shape, options);
- c10::optional<at::Tensor> aten_weight = c10::nullopt;
- c10::optional<at::Tensor> aten_bias = c10::nullopt;
- auto aten_outputs = at::native_layer_norm(
- aten_input, norm_shape, aten_weight, aten_bias, kEps);
-
- // Check reduction axis is same for all reductions
- // Generate Launch Parameters
- auto reduction_params = getNormalizationHeuristics(&fusion, {aten_input});
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-
- scheduleNormalization(&fusion, reduction_params.value());
- auto lparams = reduction_params.value().lparams;
-
- torch::jit::fuser::cuda::FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {std::get<0>(aten_outputs),
- std::get<1>(aten_outputs),
- std::get<2>(aten_outputs)},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-TEST(NVFuserTest, FusionMagicSchedulerBatchNormalization_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
-
- const float kMomentum = 0.1;
- const float kEps = 1e-5;
- const bool kTraining = true;
- std::vector<int64_t> input_shape{20, 100, 35, 45};
-
- auto input = makeSymbolicTensor(input_shape.size());
- auto weight = makeSymbolicTensor(1);
- auto bias = makeSymbolicTensor(1);
- auto running_mean = makeSymbolicTensor(1);
- auto running_var = makeSymbolicTensor(1);
- fusion->addInput(input);
- fusion->addInput(weight);
- fusion->addInput(bias);
- fusion->addInput(running_mean);
- fusion->addInput(running_var);
-
- Double* momentum = new Double(kMomentum);
- Double* eps = new Double(kEps);
-
- auto result = batch_norm(
- input, weight, bias, running_mean, running_var, kTraining, momentum, eps);
-
- fusion->addOutput(result.output);
- fusion->addOutput(result.mean);
- fusion->addOutput(result.invstd);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto at_input = at::randn(input_shape, options);
- auto at_weight = at::ones({input_shape[1]}, options);
- auto at_bias = at::zeros({input_shape[1]}, options);
- auto at_run_mean = at::zeros({input_shape[1]}, options);
- auto at_run_var = at::ones({input_shape[1]}, options);
-
- std::vector<IValue> aten_inputs = {
- at_input, at_weight, at_bias, at_run_mean, at_run_var};
-
- FusionExecutorCache executor_cache(std::move(fusion));
-
- auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);
-
- auto aten_outputs = at::native_batch_norm(
- at_input,
- c10::optional<at::Tensor>(at_weight),
- c10::optional<at::Tensor>(at_bias),
- c10::optional<at::Tensor>(at_run_mean),
- c10::optional<at::Tensor>(at_run_var),
- kTraining,
- kMomentum,
- kEps);
-
- testValidate(
- executor_cache.fusion(),
- cg_outputs,
- aten_inputs,
- {at_run_mean,
- at_run_var,
- std::get<0>(aten_outputs),
- std::get<1>(aten_outputs),
- std::get<2>(aten_outputs)},
- __LINE__,
- __FILE__,
- "");
-}
-
-// Disabling for now because memory reuse pass needs to be fixed.
-#if 0
-TEST(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- const int pixels_per_thread = 64;
- const int TIDX = 128;
- const int static_size = pixels_per_thread * TIDX;
-
- TensorView* sx = makeConcreteTensor({-1, static_size});
- TensorView* dx = makeSymbolicTensor(2);
- fusion.addInput(sx);
- fusion.addInput(dx);
-
- TensorView* max_sx =
- reductionOp(BinaryOpType::Max, {-1}, new Double(FLT_MIN), sx); // (M)
- TensorView* max_dx =
- reductionOp(BinaryOpType::Max, {-1}, new Double(FLT_MIN), dx); // (M)
-
- // Reduction => merge local and shared memory TensorViews
- TensorView* max_val = binaryOp(BinaryOpType::Max, max_sx, max_dx);
- TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B)
-
- TensorView* sx_max_sub = sub(sx, bcast_max); // (M, N)
- TensorView* dx_max_sub = sub(dx, bcast_max); // (M, N)
-
- TensorView* sx_exp = unaryOp(UnaryOpType::Exp, sx_max_sub); // (M, N)
- TensorView* dx_exp = unaryOp(UnaryOpType::Exp, dx_max_sub); // (M, N)
-
- TensorView* sx_sum_exp = sum(sx_exp, {-1}); // (M, R)
- TensorView* dx_sum_exp = sum(dx_exp, {-1}); // (M, R)
-
- // Reduction => merge local and shared memory TensorViews
- TensorView* sum_exp = binaryOp(BinaryOpType::Add, sx_sum_exp, dx_sum_exp);
- TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B)
-
- TensorView* sx_softmax = div(sx_exp, bcast_sum); // (M, N)
- TensorView* dx_softmax = div(dx_exp, bcast_sum); // (M, N)
- fusion.addOutput(sx_softmax);
- fusion.addOutput(dx_softmax);
-
- auto sx_cache = sx->cache_after();
- auto dx_cache = dx->cache_after();
- dx_cache->setMemoryType(MemoryType::Shared);
- dx_exp->setMemoryType(MemoryType::Shared);
-
- // Reduction and Broadcast Tensors common to both memory TVs
- std::vector<TensorView*> common_tensors(
- {max_val, sum_exp, bcast_max, bcast_sum});
-
- // Static Local Memory TVs
- std::vector<TensorView*> static_tensors(
- {sx, sx_cache, max_sx, sx_max_sub, sx_exp, sx_sum_exp, sx_softmax});
-
- // Dynamic Local Memory TVs
- std::vector<TensorView*> dynamic_tensors(
- {dx, dx_cache, max_dx, dx_max_sub, dx_exp, dx_sum_exp, dx_softmax});
-
- std::vector<TensorView*> all_tensors;
- all_tensors.insert(
- all_tensors.end(), common_tensors.begin(), common_tensors.end());
- all_tensors.insert(
- all_tensors.end(), static_tensors.begin(), static_tensors.end());
- all_tensors.insert(
- all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end());
-
- // M => M
- // M, N => M, N/128, 128
- for (auto tensor : all_tensors) {
- if (tensor->nDims() > 1) {
- tensor->split(-1, TIDX);
- }
- }
-
- auto sx_sum_exp_rf = sx_sum_exp->rFactor({1});
- auto dx_sum_exp_rf = dx_sum_exp->rFactor({1});
- all_tensors.push_back(sx_sum_exp_rf);
- all_tensors.push_back(dx_sum_exp_rf);
-
- // computeAt
- sx->computeAt(sx_max_sub, 1);
- dx->computeAt(dx_max_sub, 1);
-
- sx_exp->computeAt(sx_softmax, 1);
- dx_exp->computeAt(dx_softmax, 1);
-
- sx_max_sub->computeAt(sx_exp, 2);
- dx_max_sub->computeAt(dx_exp, 2);
-
- sx_softmax->axis(0)->parallelize(ParallelType::BIDx);
- dx_softmax->axis(0)->parallelize(ParallelType::BIDx);
- for (auto tensor : all_tensors) {
- if (tensor->nDims() > 1) {
- tensor->axis(-1)->parallelize(ParallelType::TIDx);
- }
- }
-
- const size_t dimx = 1024;
- const size_t dimy = 16384;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({dimx, dimy}, options);
- at::Tensor aten_static_in = aten_input.narrow(1, 0, static_size);
- at::Tensor aten_dynamic_in =
- aten_input.narrow(1, static_size, dimy - static_size);
-
- at::Tensor out = at::zeros({dimx, dimy}, options);
- at::Tensor cg_static_out = out.narrow(1, 0, static_size);
- at::Tensor cg_dynamic_out = out.narrow(1, static_size, dimy - static_size);
-
- std::vector<at::Tensor> aten_outputs;
-
- auto aten_output = at::_softmax(aten_input.to(at::kDouble), -1, false);
- at::Tensor aten_static_out = aten_output.narrow(1, 0, static_size);
- at::Tensor aten_dynamic_out =
- aten_output.narrow(1, static_size, dimy - static_size);
-
- torch::jit::fuser::cuda::FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion(
- {aten_static_in, aten_dynamic_in}, {cg_static_out, cg_dynamic_out});
-
- testValidate(
- &fusion,
- {cg_static_out, cg_dynamic_out},
- {aten_static_in, aten_dynamic_in},
- {cg_static_out, cg_dynamic_out},
- __LINE__,
- __FILE__);
-}
-#endif
-
-// DISABLED. TODO: https://github.com/csarofeen/pytorch/issues/743
-TEST(NVFuserTest, FusionPersistentNormLocalShared_CUDA) {
- return;
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- const int pixels_per_thread = 64;
- const int TIDX = 128;
- const int static_size = pixels_per_thread * TIDX;
-
- TensorView* sx = makeConcreteTensor({-1, static_size});
- TensorView* dx = makeSymbolicTensor(2);
- fusion.addInput(sx);
- fusion.addInput(dx);
-
- Double* gamma = new Double();
- Double* beta = new Double();
- Double* eps = new Double();
- Int* N = new Int();
- fusion.addInput(gamma);
- fusion.addInput(beta);
- fusion.addInput(eps);
- fusion.addInput(N);
-
- // Reduction
- auto sx_sum = sum(sx, {-1}); // (M, R)
- auto dx_sum = sum(dx, {-1}); // (M, R)
- // Reduction => merge local and shared memory TensorViews
- auto x_sum = binaryOp(BinaryOpType::Add, sx_sum, dx_sum);
-
- // Broadcast
- auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B)
- // Pwise
- auto x_mean = div(x_sum_bcast, N); // (M, B)
-
- auto sx_mean_sub = sub(sx, x_mean); // (M, N)
- auto dx_mean_sub = sub(dx, x_mean); // (M, N)
-
- auto sx_mean_sub_pow = mul(sx_mean_sub, sx_mean_sub); // (M, N)
- auto dx_mean_sub_pow = mul(dx_mean_sub, dx_mean_sub); // (M, N)
-
- // Reduction
- auto sx_var_sum = sum(sx_mean_sub_pow, {-1}); // (M, R)
- auto dx_var_sum = sum(dx_mean_sub_pow, {-1}); // (M, R)
- // Reduction => merge local and shared memory TensorViews
- auto var_sum = binaryOp(BinaryOpType::Add, sx_var_sum, dx_var_sum);
-
- // Broadcast
- auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B)
- // Pwise
- auto var = div(var_sum_bcast, N); // (M, B)
- auto var_eps = add(var, eps); // (M, B)
- auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B)
-
- auto sx_norm = mul(sx_mean_sub, rvar);
- auto dx_norm = mul(dx_mean_sub, rvar);
-
- auto sx_norm_gamma = mul(sx_norm, gamma);
- auto dx_norm_gamma = mul(dx_norm, gamma);
-
- auto sx_norm_gamma_beta = add(sx_norm_gamma, beta);
- auto dx_norm_gamma_beta = add(dx_norm_gamma, beta);
-
- fusion.addOutput(sx_norm_gamma_beta);
- fusion.addOutput(dx_norm_gamma_beta);
-
- // Read Input into Shared Memory
- // Read Input minus Input_Mean into Shared Memory
- auto sx_cache = sx->cache_after();
- auto dx_cache = dx->cache_after();
- dx_cache->setMemoryType(MemoryType::Shared);
- dx_mean_sub->setMemoryType(MemoryType::Shared);
-
- std::vector<TensorView*> common_tensors(
- {x_sum, x_sum_bcast, x_mean, var_sum, var_sum_bcast, var, var_eps, rvar});
-
- std::vector<TensorView*> static_tensors(
- {sx,
- sx_cache,
- sx_sum,
- sx_mean_sub,
- sx_mean_sub_pow,
- sx_var_sum,
- sx_norm,
- sx_norm_gamma,
- sx_norm_gamma_beta});
-
- std::vector<TensorView*> dynamic_tensors(
- {dx,
- dx_cache,
- dx_sum,
- dx_mean_sub,
- dx_mean_sub_pow,
- dx_var_sum,
- dx_norm,
- dx_norm_gamma,
- dx_norm_gamma_beta});
-
- std::vector<TensorView*> all_tensors;
- all_tensors.insert(
- all_tensors.end(), common_tensors.begin(), common_tensors.end());
- all_tensors.insert(
- all_tensors.end(), static_tensors.begin(), static_tensors.end());
- all_tensors.insert(
- all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end());
-
- // M => M
- // M, N => M, N/128, 128
- for (auto tensor : all_tensors) {
- if (tensor->nDims() > 1) {
- tensor->split(-1, TIDX);
- }
- }
-
- // Local Sum => Block Broadcast
- TensorView* sx_sum_rf = sx_sum->rFactor({1});
- TensorView* sx_var_sum_rf = sx_var_sum->rFactor({1});
- TensorView* dx_sum_rf = dx_sum->rFactor({1});
- TensorView* dx_var_sum_rf = dx_var_sum->rFactor({1});
- all_tensors.push_back(sx_sum_rf);
- all_tensors.push_back(sx_var_sum_rf);
- all_tensors.push_back(dx_sum_rf);
- all_tensors.push_back(dx_var_sum_rf);
-
- // ComputeAt
- sx->computeAt(sx_mean_sub_pow, 1);
- dx->computeAt(dx_mean_sub_pow, 1);
-
- var_sum->computeAt(rvar, 1);
-
- sx_mean_sub_pow->computeAt(sx_var_sum_rf, 2);
- dx_mean_sub_pow->computeAt(dx_var_sum_rf, 2);
-
- sx_norm->computeAt(sx_norm_gamma_beta, 2);
- dx_norm->computeAt(dx_norm_gamma_beta, 2);
-
- sx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx);
- dx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx);
- for (auto tensor : all_tensors) {
- if (tensor->nDims() > 1) {
- tensor->axis(-1)->parallelize(ParallelType::TIDx);
- }
- }
-
- const int dimx = 1024;
- const int dimy = 16384;
- const float kGamma = 1.0f;
- const float kBeta = 0.0f;
- const float kEps = 1e-5;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor aten_input = at::randn({dimx, dimy}, options);
- at::Tensor aten_static_in = aten_input.narrow(1, 0, static_size);
- at::Tensor aten_dynamic_in =
- aten_input.narrow(1, static_size, dimy - static_size);
-
- at::Tensor out = at::zeros({dimx, dimy}, options);
- at::Tensor cg_static_out = out.narrow(1, 0, static_size);
- at::Tensor cg_dynamic_out = out.narrow(1, static_size, dimy - static_size);
-
- std::vector<IValue> aten_inputs = {
- aten_static_in, aten_dynamic_in, kGamma, kBeta, kEps, dimy};
-
- torch::jit::fuser::cuda::FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion(aten_inputs, {cg_static_out, cg_dynamic_out});
-
- auto at_mu = at::mean(aten_input.to(at::kDouble), -1).unsqueeze(1);
- auto at_var = at::var(aten_input.to(at::kDouble), -1, false).unsqueeze(1);
- auto at_rvar = at::rsqrt(at::add(at_var, kEps));
- auto at_norm = at::mul(at::sub(aten_input, at_mu), at_rvar);
- auto aten_output = at::add(at::mul(at_norm, kGamma), kBeta);
- at::Tensor aten_static_out = aten_output.narrow(1, 0, static_size);
- at::Tensor aten_dynamic_out =
- aten_output.narrow(1, static_size, dimy - static_size);
-
- testValidate(
- &fusion,
- {cg_static_out, cg_dynamic_out},
- aten_inputs,
- {aten_static_out, aten_dynamic_out},
- __LINE__,
- __FILE__);
-}
-
-TEST(NVFuserTest, FusionSmemDynamicPersistentNorm_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- auto x = makeSymbolicTensor(2);
- Double* gamma = new Double();
- Double* beta = new Double();
- Double* eps = new Double();
- Int* N = new Int();
- fusion.addInput(x);
- fusion.addInput(gamma);
- fusion.addInput(beta);
- fusion.addInput(eps);
- fusion.addInput(N);
-
- // Reduction
- auto x_sum = sum(x, {-1}); // (M, R)
- // Broadcast
- auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B)
- // Pwise
- auto x_mean = div(x_sum_bcast, N); // (M, B)
- auto x_mean_sub = sub(x, x_mean); // (M, N)
- auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub); // (M, N)
- // Reduction
- auto var_sum = sum(x_mean_sub_pow, {-1}); // (M, R)
- // Broadcast
- auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B)
- // Pwise
- auto var = div(var_sum_bcast, N); // (M, B)
- auto var_eps = add(var, eps); // (M, B)
- auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B)
- auto norm = mul(x_mean_sub, rvar);
- auto norm_gamma = mul(norm, gamma);
- auto norm_gamma_beta = add(norm_gamma, beta);
- fusion.addOutput(norm_gamma_beta);
-
- // Read Input into Shared Memory
- // Read Input minus Input_Mean into Shared Memory
- auto cache_x = x->cache_after();
- cache_x->setMemoryType(MemoryType::Shared);
- x_mean_sub->setMemoryType(MemoryType::Shared);
-
- std::vector<TensorView*> all_tensors(
- {x_sum,
- x_mean,
- cache_x,
- x_sum_bcast,
- x_mean_sub,
- x_mean_sub_pow,
- var_sum,
- var_sum_bcast,
- var,
- var_eps,
- rvar,
- norm,
- norm_gamma,
- norm_gamma_beta});
-
- auto tidx = new Int();
- fusion.addInput(tidx);
-
- for (auto tensor : all_tensors) {
- tensor->split(-1, tidx);
- }
-
- // Local Sum => Block Broadcast
- TensorView* x_sum_rf = x_sum->rFactor({1});
- TensorView* var_sum_rf = var_sum->rFactor({1});
- all_tensors.push_back(x_sum_rf);
- all_tensors.push_back(var_sum_rf);
-
- // ComputeAt
- x->computeAt(x_mean_sub_pow, 1);
- var_sum->computeAt(rvar, 1);
- x_mean_sub_pow->computeAt(var_sum_rf, 2);
- norm->computeAt(norm_gamma_beta, 2);
-
- for (auto tv : all_tensors) {
- tv->axis(0)->parallelize(ParallelType::BIDx);
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
-
- const int dimx = 128;
- const int dimy = 2048;
- const float kGamma = 1.0f;
- const float kBeta = 0.0f;
- const float kEps = 1e-5;
- const int TIDX = 128;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({dimx, dimy}, options);
- auto at_mu = at::mean(aten_input.to(at::kDouble), -1).unsqueeze(1);
- auto at_var = at::var(aten_input.to(at::kDouble), -1).unsqueeze(1);
- auto at_rvar = at::rsqrt(at::add(at_var, kEps));
- auto at_norm = at::mul(at::sub(aten_input, at_mu), at_rvar);
- auto aten_output = at::add(at::mul(at_norm, kGamma), kBeta);
-
- std::vector<IValue> aten_inputs = {
- aten_input, kGamma, kBeta, kEps, dimy, TIDX};
-
- torch::jit::fuser::cuda::FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
- fusion.addInput(tv0);
- fusion.addOutput(tv1);
- // tv1[I0, R1] = tv0[I0, I1]
-
- // Interface should just be a direct split with a Parallel type. We can
- // include the parallelize call if we do this.
- tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
- // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
-
- TensorView* tv2 = tv1->rFactor({2});
- tv2->setMemoryType(MemoryType::Shared);
- // tv2[I0, R1oo, Ir1i{BIDx}] = tv0[I0, I1]
- // tv1[I0, R1i{BIDx}] = tv2[I0, R1oo, Ir1i{BIDx}]
-
- tv0->computeAt(tv1, 1);
-
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv1->axis(0)->parallelize(ParallelType::BIDx);
-
- constexpr int numel_x = 65000, numel_y = 1024;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({numel_x, numel_y}, options);
- auto aten_output = aten_input.to(at::kDouble).sum({1});
-
- // How many threads to use for the block reduction
- constexpr int runtime_threadIdx_dim = 128;
-
- LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
- TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0);
-}
-
-TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Algorithm
- Int* sym_bsx = new Int();
- TensorView* tv0 = makeSymbolicTensor(3); // M, K, N
- fusion.addInput(tv0);
- fusion.addInput(sym_bsx);
-
- TensorView* tv1 = sum(tv0, {1}); // M, R, N
- fusion.addOutput(tv1);
-
- TensorView* tv2 = tv0->cache_after();
- tv2->setMemoryType(MemoryType::Shared);
-
- // Schedule
- constexpr int BSX = 32;
- tv1->split(2, BSX);
- tv1->split(1, sym_bsx);
- tv1->split(0, BSX);
- // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
- tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}});
- TensorView* tv3 = tv1->rFactor({-2});
-
- tv0->computeAt(tv1, -2);
- tv0->computeAt(tv3, -2);
-
- // Thread and Block binding
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv1->axis(1)->parallelize(ParallelType::BIDy);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- // Manual Binding
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
- constexpr int M = 154, K = 45, N = 1524;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({M, K, N}, options);
- at::Tensor aten_output = aten_input.to(at::kDouble).sum({1});
-
- // How many threads to use for the block reduction
- constexpr int runtime_threadIdx_dim = 128;
-
- auto lparams = LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input, runtime_threadIdx_dim}, lparams);
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input, runtime_threadIdx_dim},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
-
- TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
-}
-
-TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- Int* sym_bsx = new Int();
- TensorView* tv0 = makeSymbolicTensor(2); // (M, K)
- TensorView* tv1 = makeSymbolicTensor(2); // (K, N)
- TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
- TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
- TensorView* tv4 = mul(tv2, tv3); // M, K, N
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addInput(sym_bsx);
- fusion.addOutput(tv4);
- // Algorithm
-
- tv2->setMemoryType(MemoryType::Shared);
- tv3->setMemoryType(MemoryType::Shared);
-
- constexpr int BSX = 32;
- tv4->split(2, BSX);
- tv4->split(1, sym_bsx);
- tv4->split(0, BSX);
- // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
- tv4->reorder({{0, 0}, {1, 3}, {2, 1}, {3, 4}, {4, 2}, {5, 5}});
- // M/BSX, K/BSX, N/BSX, MSX, KSX, NSX
-
- tv0->computeAt(tv4, 3);
- tv1->computeAt(tv4, 3);
- // Schedule
-
- tv4->axis(0)->parallelize(ParallelType::BIDx);
- tv4->axis(2)->parallelize(ParallelType::BIDy);
- // Manual Binding
- tv2->axis(-2)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- // Thread and Block binding
-
- constexpr int M = 128, K = 457, N = 1024;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({M, K}, options);
- at::Tensor t1 = at::randn({K, N}, options);
- at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0));
- std::vector<IValue> aten_inputs = {t0, t1, BSX};
-
- LaunchParams lparams(-1, -1, -1, BSX, -1, -1);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
- testValidate(
- &fusion,
- cg_outputs,
- aten_inputs,
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
-
- TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
-}
-
-TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Symbolic integers we will use for runtime tiling
- Int* symbolic_m_tile_dim = new Int(); // bound to threadIdx.z
- Int* symbolic_split_k_tile_dim = new Int(); // bound to blockIdx.x
- Int* symbolic_block_k_tile_dim = new Int(); // bound to threadIdx.x
- // Compile-time integer for tiling
- int n_smem_tile = 8; // bound to threadIdx.y
-
- // Symbolic 2D tensors TV0[M, K], TV1[K, N]
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
-
- // Broadcast tv0 to [M, K, *]
- TensorView* tv2 = broadcast(tv0, {false, false, true});
- // Broadcast tv1 to [*, K, N]
- TensorView* tv3 = broadcast(tv1, {true, false, false});
-
- // Pointwise multiplication resulting in tv3[M, K, N]
- TensorView* tv4 = mul(tv2, tv3);
-
- // Turn the K-dimension of tv4 into a reduction dimension
- TensorView* tv5 = sum(tv4, {1});
-
- // Register inputs and outputs
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addOutput(tv5);
-
- // Register runtime tile dims as inputs
- fusion.addInput(symbolic_m_tile_dim);
- fusion.addInput(symbolic_split_k_tile_dim);
- fusion.addInput(symbolic_block_k_tile_dim);
-
- // Make a 3D tile, mix of symbolic and constant, do in reverse order because
- // dims are inserted
- tv5->split(2, n_smem_tile);
- tv5->split(1, symbolic_block_k_tile_dim);
- tv5->split(1, symbolic_split_k_tile_dim);
- tv5->split(0, symbolic_m_tile_dim);
-
- // Reorder so all outer tiles are in the leftmost 3 positions
- tv5->reorder({{1, 5}, {5, 1}});
-
- // Factor out the outer reduction IterDomain, then run the inter-cta
- // reduction, and intra-cta reduction
- auto tv6 = tv5->rFactor({2});
-
- // Scope computations
- tv6->computeAt(tv5, 2);
-
- // RFactor moves reduction axes around, reorder to match ordering of tv5
- tv6->reorder({
- {2, -2},
- {3, -1},
- {4, 2},
- {5, 3},
- {6, 4},
- });
-
- // Setup compute at schedule
- tv0->computeAt(tv6, 3);
- tv1->computeAt(tv6, 3);
- tv4->computeAt(tv6, -1);
- //
- // T2[Mo, bNo, Koo, Koi, Kii, Mi, bNi] CA(4, 3)
- // T3[bMo, No, Koo, Koi, Kii, bMi, Ni] CA(4, 3)
- // T4[ Mo, No, Koo, Koi, Kii, Mi, Ni]
- // T6[ Mo, No, rKoo, Koi, Kii, Mi, Ni]
- // T5[ Mo, No, rKoi, rKii, Mi, Ni]
-
- // Cache smem tiles
- tv2->setMemoryType(MemoryType::Shared);
- tv3->setMemoryType(MemoryType::Shared);
- tv4->setMemoryType(MemoryType::Local);
- tv6->setMemoryType(MemoryType::Local);
-
- tv5->axis(0)->parallelize(ParallelType::BIDz);
- tv5->axis(1)->parallelize(ParallelType::BIDy);
-
- std::vector<TensorView*> tv_list = {tv2, tv3, tv4, tv5, tv6};
- for (auto tv : tv_list) {
- tv->axis(-2)->parallelize(ParallelType::TIDz);
- tv->axis(-1)->parallelize(ParallelType::TIDy);
- }
- tv2->axis(3)->parallelize(ParallelType::TIDx);
- tv3->axis(3)->parallelize(ParallelType::TIDx);
- tv4->axis(3)->parallelize(ParallelType::TIDx);
- tv6->axis(3)->parallelize(ParallelType::TIDx);
- tv5->axis(2)->parallelize(ParallelType::TIDx);
-
- tv2->axis(4)->parallelize(ParallelType::BIDx);
- tv3->axis(4)->parallelize(ParallelType::BIDx);
- tv4->axis(4)->parallelize(ParallelType::BIDx);
- tv6->axis(4)->parallelize(ParallelType::BIDx);
- tv5->axis(3)->parallelize(ParallelType::BIDx);
-
- constexpr int M = 31, K = 65, N = 33;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({M, K}, options);
- at::Tensor t1 = at::randn({K, N}, options);
-
- FusionExecutor fe;
- // Generate CUDA and compile with nvRTC
- fe.compileFusion(&fusion);
-
- // Runtime tiling
- int m_tile = 4; // bound to threadIdx.z
- int split_k = 7; // bound to blockIdx.x
- int intra_cta = 8; // bound to threadIdx.x
-
- std::vector<IValue> aten_inputs = {t0, t1, m_tile, split_k, intra_cta};
- at::Tensor aten_output =
- mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1);
-
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-
- TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1);
-}
-
-TEST(NVFuserTest, FusionGlobalIntermediate_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
- fusion.addInput(tv0);
- fusion.addOutput(tv1);
- // tv1[I0, R1] = tv0[I0, I1]
-
- // Interface should just be a direct split with a Parallel type. We can
- // include the parallelize call if we do this.
- tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
- // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
-
- TensorView* tv2 = tv1->rFactor({2});
- tv2->setMemoryType(MemoryType::Global);
- // tv2[I0, R1oo, Ir1i{BIDx}] = tv0[I0, I1]
- // tv1[I0, R1i{BIDx}] = tv2[I0, R1oo, Ir1i{BIDx}]
-
- tv0->computeAt(tv1, 1);
-
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv1->axis(0)->parallelize(ParallelType::BIDx);
-
- constexpr int numel_x = 65000, numel_y = 1024;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y}, options);
-
- // How many threads to use for the block reduction
- constexpr int runtime_threadIdx_dim = 128;
-
- auto lparams = LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input}, lparams);
-
- auto aten_output = input.to(at::kDouble).sum({1});
- testValidate(
- &fusion,
- cg_outputs,
- {input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-TEST(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
- TensorView* tv2 = makeSymbolicTensor(2);
- TensorView* tv3 = makeSymbolicTensor(2);
- TensorView* tv4 = sub(tv2, tv3);
- TensorView* tv5 = add(tv1, tv4);
- TensorView* tv6 = sub(tv5, tv0);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addInput(tv2);
- fusion.addInput(tv3);
- fusion.addOutput(tv6);
- // t6 = ((t1 + (t2 - t3)) - t0)
-
- tv4->setMemoryType(MemoryType::Global);
- tv5->setMemoryType(MemoryType::Global);
- tv6->setMemoryType(MemoryType::Global);
-
- constexpr int M = 32, N = 810;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({M, N}, options);
- at::Tensor t1 = at::randn({M, N}, options);
- at::Tensor t2 = at::randn({M, N}, options);
- at::Tensor t3 = at::randn({M, N}, options);
-
- at::Tensor aten_output = (t1 + (t2 - t3)) - t0;
-
- std::vector<IValue> aten_inputs = {t0, t1, t2, t3};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({t0, t1, t2, t3});
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionConstCheck_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto one = new Int(1);
- TORCH_CHECK(one->isConstScalar());
-
- auto one_x2 = mul(one, one);
- TORCH_CHECK(one_x2->isConstScalar());
-
- auto one_x3 = mul(one_x2, one);
- TORCH_CHECK(one_x3->isConstScalar());
-
- auto one_x4 = mul(one_x3, one);
- TORCH_CHECK(one_x4->isConstScalar());
-}
-
-TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) {
- const std::vector<int64_t> tensor_dims_in = {128, 128};
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size());
- fusion.addInput(tv0);
-
- TensorView* tv1 = add(tv0, new Double(0));
- TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv1);
- fusion.addOutput(tv2);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn(tensor_dims_in, options);
- at::Tensor cg_output = at::empty({tensor_dims_in[0]}, options);
-
- // Schedule
- tv2->split(1, 32);
- tv2->split(1, 4); // unroll
-
- auto tv2_rf = tv2->rFactor({-3, -2});
-
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
- tv2_rf->axis(0)->parallelize(ParallelType::BIDx);
- tv2_rf->axis(-1)->parallelize(ParallelType::TIDx);
- tv2_rf->axis(-2)->parallelize(ParallelType::Unroll);
-
- tv1->computeAt(tv2_rf, -1);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input});
-
- auto aten_output = (input + 0).to(at::kDouble).sum(1);
-
- testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Test isZeroInt
-TEST(NVFuserTest, FusionIsZeroInt_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- Int* x = new Int(0);
- Int* y = new Int(1);
- Val* z = mul(x, y);
- TORCH_CHECK(x->isZeroInt());
- TORCH_CHECK(!y->isZeroInt());
- TORCH_CHECK(!z->isZeroInt());
-}
-
-// Test isOneInt
-TEST(NVFuserTest, FusionIsOneInt_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- Int* x = new Int(1);
- Int* y = new Int(1);
- Val* z = mul(x, y);
- TORCH_CHECK(x->isOneInt());
- TORCH_CHECK(y->isOneInt());
- TORCH_CHECK(!z->isOneInt());
-}
-
-// This is to verify no cycle of computeAt is created. A more complex
-// variation of this pattern appears in one of the Python tests
-// (test_random_topo).
-TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
-
- // Common intermediate tensor
- auto tv1 = add(tv0, new Double(1));
- // tv1 -> tv2
- auto tv2 = add(tv1, new Double(2));
- // tv1 -> tv3 -> tv4
- auto tv3 = add(tv1, new Double(3));
- auto tv4 = add(tv3, new Double(4));
-
- // NOTE: This should no longer occur as of PR #201.
- // The order of adding outputs matters. If tv3 is added before tv4,
- // it should be fine. However, if tv4 is added before tv3, there
- // will be a cycle of tv3->tv4 and tv4->tv3. tv3->tv4 is created
- // first, and then tv4->tv3 is created at the final phase of
- // computeAt (ComputeAt::setupOutputs).
- fusion.addOutput(tv2);
- fusion.addOutput(tv4);
- fusion.addOutput(tv3);
-
- tv0->computeAt(tv2, -1);
-
- TORCH_CHECK(tv3->hasComputeAt());
- TORCH_CHECK(!tv4->hasComputeAt());
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn(100, options);
-
- auto t1 = aten_input + 1;
- auto t2 = t1 + 2;
- auto t3 = t1 + 3;
- auto t4 = t3 + 4;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
-
- std::vector<at::Tensor> aten_outputs = {t2, t4, t3};
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTraversalOrder1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- TensorView* tv1 = add(tv0, new Double(1));
- TensorView* tv2 = add(tv0, new Double(2));
- TensorView* tv3 = add(tv1, new Double(3));
- TensorView* tv4 = add(tv1, new Double(4));
-
- fusion.addOutput(tv2);
- fusion.addOutput(tv3);
- fusion.addOutput(tv4);
-
- tv1->computeAt(tv3, -1);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({10, 10}, options);
-
- auto t1 = aten_input + 1;
- auto t2 = aten_input + 2;
- auto t3 = t1 + 3;
- auto t4 = t1 + 4;
-
- std::vector<at::Tensor> aten_outputs = {t2, t3, t4};
-
- std::vector<at::Tensor> cg_outputs = {
- at::empty_like(aten_input, options),
- at::empty_like(aten_input, options),
- at::empty_like(aten_input, options)};
-
- fe.runFusion({aten_input}, cg_outputs);
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTraversalOrder2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- TensorView* tv1 = add(tv0, new Double(1));
- TensorView* tv2 = add(tv1, new Double(2));
-
- TensorView* tv3 = add(tv0, new Double(3));
- TensorView* tv4 = add(tv3, new Double(4));
-
- TensorView* tv5 = add(tv1, tv3);
-
- fusion.addOutput(tv2);
- fusion.addOutput(tv4);
- fusion.addOutput(tv5);
-
- tv1->computeAt(tv5, -1);
- tv3->computeAt(tv5, -1);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({10, 10}, options);
-
- auto t1 = aten_input + 1;
- auto t2 = t1 + 2;
- auto t3 = aten_input + 3;
- auto t4 = t3 + 4;
- auto t5 = t1 + t3;
-
- std::vector<at::Tensor> aten_outputs = {t2, t4, t5};
-
- std::vector<at::Tensor> cg_outputs = {
- at::empty_like(aten_input, options),
- at::empty_like(aten_input, options),
- at::empty_like(aten_input, options)};
-
- fe.runFusion({aten_input}, cg_outputs);
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTraversalOrder3_CUDA) {
- for (int i = 0; i < 2; ++i) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
-
- TensorView* tv1 = add(tv0, new Double(1));
- TensorView* tv2 = add(tv1, new Double(2));
-
- TensorView* tv3 = add(tv0, new Double(3));
- TensorView* tv4 = add(tv3, new Double(4));
-
- TensorView* tv5 = add(tv1, tv3);
-
- fusion.addOutput(tv2);
- fusion.addOutput(tv4);
- fusion.addOutput(tv5);
-
- const int tile = 32;
-
- tv1->split(-1, tile);
- tv2->split(-1, tile);
- tv3->split(-1, tile);
- tv4->split(-1, tile);
- tv5->split(-1, tile);
-
- auto compute_at_outer = tv1;
- auto compute_at_inner = tv3;
- if (i == 1) {
- std::swap(compute_at_inner, compute_at_outer);
- }
-
- compute_at_outer->computeAt(tv5, -2);
- compute_at_inner->computeAt(tv5, -1);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({100}, options);
- auto t1 = aten_input + 1;
- auto t2 = t1 + 2;
- auto t3 = aten_input + 3;
- auto t4 = t3 + 4;
- auto t5 = t1 + t3;
-
- std::vector<at::Tensor> aten_outputs = {t2, t4, t5};
-
- std::vector<at::Tensor> cg_outputs = {
- at::empty_like(aten_input, options),
- at::empty_like(aten_input, options),
- at::empty_like(aten_input, options)};
-
- fe.runFusion({aten_input}, cg_outputs);
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
- }
-}
-
-TEST(NVFuserTest, FusionTraversalOrder4_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // First tree
- TensorView* tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- TensorView* tv1 = add(tv0, new Double(1));
- TensorView* tv2 = add(tv1, new Double(2));
- TensorView* tv3 = add(tv1, new Double(3));
- fusion.addOutput(tv2);
- fusion.addOutput(tv3);
-
- // Second tree
- TensorView* tv4 = makeSymbolicTensor(1);
- fusion.addInput(tv4);
- TensorView* tv5 = add(tv4, new Double(5));
- TensorView* tv6 = add(tv5, new Double(6));
- TensorView* tv7 = add(tv5, new Double(7));
- fusion.addOutput(tv6);
- fusion.addOutput(tv7);
-
- tv1->computeAt(tv2, -1);
- tv5->computeAt(tv6, -1);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({100}, options);
- at::Tensor t4 = at::rand_like(t0, options);
-
- auto t1 = t0 + 1;
- auto t2 = t1 + 2;
- auto t3 = t1 + 3;
- auto t5 = t4 + 5;
- auto t6 = t5 + 6;
- auto t7 = t5 + 7;
-
- std::vector<at::Tensor> aten_outputs = {t2, t3, t6, t7};
- std::vector<IValue> aten_inputs = {t0, t4};
- std::vector<at::Tensor> cg_outputs = {
- at::empty_like(t0, options),
- at::empty_like(t0, options),
- at::empty_like(t0, options),
- at::empty_like(t0, options)};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion(aten_inputs, cg_outputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTraversalOrder5_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- TensorView* tv1 = add(tv0, new Double(1));
- TensorView* tv2 = add(tv1, new Double(2));
- TensorView* tv3 = add(tv0, new Double(3));
- TensorView* tv4 = add(tv3, new Double(4));
- TensorView* tv5 = add(tv2, tv4);
-
- fusion.addOutput(tv1);
- fusion.addOutput(tv3);
- fusion.addOutput(tv5);
-
- tv2->computeAt(tv5, -1);
- tv4->computeAt(tv5, -1);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({100}, options);
- std::vector<at::Tensor> cg_outputs = {
- at::empty_like(aten_input, options),
- at::empty_like(aten_input, options),
- at::empty_like(aten_input, options)};
-
- fe.runFusion({aten_input}, cg_outputs);
-
- auto t1 = aten_input + 1;
- auto t2 = t1 + 2;
- auto t3 = aten_input + 3;
- auto t4 = t3 + 4;
- auto t5 = t2 + t4;
-
- std::vector<at::Tensor> aten_outputs = {t1, t3, t5};
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTraversalOrder6_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- TensorView* tv1 = add(tv0, new Double(1));
- TensorView* tv2 = add(tv0, new Double(2));
- TensorView* tv3 = add(tv1, tv2);
- TensorView* tv4 = add(tv3, new Double(4));
-
- fusion.addOutput(tv4);
-
- tv1->split(0, 32);
- tv2->split(0, 32);
- tv3->split(0, 32);
- tv4->split(0, 32);
-
- tv3->computeAt(tv4, -2);
- tv1->computeAt(tv3, -1);
- tv2->computeAt(tv3, -2);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({100}, options);
-
- auto t1 = aten_input + 1;
- auto t2 = aten_input + 2;
- auto t3 = t1 + t2;
- auto aten_output = t3 + 4;
-
- at::Tensor cg_output = at::empty_like(aten_input, options);
-
- fe.runFusion({aten_input}, {cg_output});
-
- testValidate(
- &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTraversalOrder7_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- TensorView* tv1 = add(tv0, new Double(1));
- TensorView* tv2 = add(tv1, new Double(2));
- TensorView* tv3 = add(tv0, new Double(3));
- TensorView* tv4 = add(tv3, new Double(4));
- TensorView* tv5 = add(tv2, tv4);
-
- fusion.addOutput(tv5);
-
- TensorView* tvs[] = {tv1, tv2, tv3, tv4, tv5};
- for (auto tv : tvs) {
- tv->split(0, 2);
- tv->split(0, 4);
- tv->split(0, 8);
- }
-
- // computeAt into inner loop nests
- tv1->computeAt(tv2, -1);
- tv3->computeAt(tv4, -2);
-
- tv2->computeAt(tv5, -4);
- tv4->computeAt(tv5, -3);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({100}, options);
-
- auto t1 = aten_input + 1;
- auto t2 = t1 + 2;
- auto t3 = aten_input + 3;
- auto t4 = t3 + 4;
- auto aten_output = t2 + t4;
-
- at::Tensor cg_output = at::empty_like(aten_input, options);
- fe.runFusion({aten_input}, {cg_output});
-
- testValidate(
- &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-// Test predication of grid reduction
-TEST(NVFuserTest, FusionThreadPredicate_CUDA) {
- const int gdimx = 4;
- const int bdimx = 128;
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv0);
- TensorView* tv2 = unaryOp(UnaryOpType::Neg, tv1);
- TensorView* tv3 = add(tv0, new Double(2));
-
- fusion.addOutput(tv3);
- fusion.addOutput(tv2);
-
- tv1->split(1, bdimx);
- tv1->split(1, gdimx);
- tv3->split(1, bdimx);
- tv3->split(1, gdimx);
-
- TensorView* tv1_rf = tv1->rFactor({1});
-
- tv1->computeAt(tv2, -1);
-
- tv1->axis(0)->parallelize(ParallelType::BIDy);
- tv1_rf->axis(0)->parallelize(ParallelType::BIDy);
- tv2->axis(0)->parallelize(ParallelType::BIDy);
- tv1->axis(-2)->parallelize(ParallelType::BIDx);
- tv1_rf->axis(-2)->parallelize(ParallelType::BIDx);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv1_rf->axis(-1)->parallelize(ParallelType::TIDx);
-
- tv3->axis(3)->parallelize(ParallelType::TIDx);
- tv3->axis(2)->parallelize(ParallelType::BIDx);
- tv3->axis(0)->parallelize(ParallelType::BIDy);
-
- int numel_x = 100;
- int numel_y = 1000;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({numel_x, numel_y}, options);
-
- auto t2 = -aten_input.to(at::kDouble).sum({1});
- auto t3 = aten_input + 2.0;
-
- std::vector<at::Tensor> aten_outputs = {t3, t2};
-
- std::vector<at::Tensor> cg_outputs = {
- at::empty_like(aten_input, options), at::empty({numel_x}, options)};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({aten_input}, cg_outputs);
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionLSTMCell_CUDA) {
- const int hidden_features = 512;
- const int batch_size = 64;
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- TensorView* tvs[16];
- for (size_t i = 0; i < 16; i++) {
- tvs[i] = makeSymbolicTensor(2);
- fusion.addInput(tvs[i]);
- }
-
- auto ingate = unaryOp(
- UnaryOpType::Sigmoid, add(add(add(tvs[0], tvs[1]), tvs[2]), tvs[3]));
-
- auto forgetgate = unaryOp(
- UnaryOpType::Sigmoid, add(add(add(tvs[4], tvs[5]), tvs[6]), tvs[7]));
-
- auto cellgate = unaryOp(
- UnaryOpType::Tanh, add(add(add(tvs[8], tvs[9]), tvs[10]), tvs[11]));
-
- auto outgate = unaryOp(
- UnaryOpType::Sigmoid, add(add(add(tvs[12], tvs[13]), tvs[14]), tvs[15]));
-
- auto cx = makeContigTensor(2);
- fusion.addInput(cx);
-
- auto cy = add(mul(forgetgate, cx), mul(ingate, cellgate));
-
- auto hy = mul(outgate, unaryOp(UnaryOpType::Tanh, cy));
-
- fusion.addOutput(cy);
- fusion.addOutput(hy);
-
- std::vector<c10::IValue> aten_inputs;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor large_tensor0 =
- at::randn({batch_size, hidden_features * 4}, options);
- at::Tensor large_tensor1 =
- at::randn({batch_size, hidden_features * 4}, options);
- at::Tensor large_tensor2 =
- at::randn({batch_size, hidden_features * 4}, options);
- at::Tensor large_tensor3 =
- at::randn({batch_size, hidden_features * 4}, options);
-
- auto chunked0 = large_tensor0.chunk(4, 1);
- auto chunked1 = large_tensor1.chunk(4, 1);
- auto chunked2 = large_tensor2.chunk(4, 1);
- auto chunked3 = large_tensor3.chunk(4, 1);
-
- aten_inputs.insert(aten_inputs.end(), chunked0.begin(), chunked0.end());
- aten_inputs.insert(aten_inputs.end(), chunked1.begin(), chunked1.end());
- aten_inputs.insert(aten_inputs.end(), chunked2.begin(), chunked2.end());
- aten_inputs.insert(aten_inputs.end(), chunked3.begin(), chunked3.end());
-
- auto at_ingate =
- chunked0[0].add(chunked0[1]).add(chunked0[2]).add(chunked0[3]).sigmoid();
- auto at_forgetgate =
- chunked1[0].add(chunked1[1]).add(chunked1[2]).add(chunked1[3]).sigmoid();
- auto at_cellgate =
- chunked2[0].add(chunked2[1]).add(chunked2[2]).add(chunked2[3]).tanh();
- auto at_outgate =
- chunked3[0].add(chunked3[1]).add(chunked3[2]).add(chunked3[3]).sigmoid();
-
- auto at_cx = at::randn({batch_size, hidden_features}, options);
- aten_inputs.push_back(at_cx);
- auto at_cy = at_forgetgate.mul(at_cx).add(at_ingate.mul(at_cellgate));
- auto at_hy = at_outgate.mul(at_cy.tanh());
-
- auto lparams = schedulePointwise(&fusion, aten_inputs);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {at_cy, at_hy}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionComputeAtMultiBCast_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
-
- TensorView* tv1 = mul(tv0, new Double(0.5));
- TensorView* tv2 = broadcast(tv1, {true, false});
- TensorView* tv3 = broadcast(tv1, {false, true});
- TensorView* tv4 = add(tv2, tv3);
- fusion.addOutput(tv4);
-
- // Not possible to do computeAt at position -1 as recomputation
- // would be required. An exception should be thrown.
- ASSERT_ANY_THROW(tv1->computeAt(tv3, -1));
-}
-
-TEST(NVFuserTest, FusionReductionHalf_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(3, DataType::Half);
- fusion.addInput(tv0);
-
- auto tv1 = castOp(DataType::Float, tv0);
- auto tv2 = add(tv1, new Double(1.0));
- auto tv3 = sum(tv2, {2});
- auto tv4 = castOp(DataType::Half, tv3);
-
- fusion.addOutput(tv4);
-
- const auto options =
- at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({8, 8, 16}, options);
-
- auto reduction_tv = tv3;
-
- auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
- scheduleReduction(&fusion, reduction_params.value());
-
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-
- auto lparams = reduction_params.value().lparams;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- // no broadcasting needed, omitting the last optional argument;
- auto cg_outputs = fe.runFusion({aten_input}, lparams);
-
- auto aten_output = aten_input.add(1.0).to(at::kDouble).sum({2});
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-TEST(NVFuserTest, FusionReduceSingle_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeConcreteTensor({100, 1});
- fusion.addInput(tv0);
- auto tv1 = sum(tv0, {1});
- fusion.addOutput(tv1);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({100, 1}, options);
-
- // Grab only tensor views, though there shouldn't be any other type
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- // no broadcasting needed, omitting the last optional argument;
- auto cg_outputs = fe.runFusion({aten_input});
-
- auto aten_output = aten_input.to(at::kDouble).sum({1});
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) {
- constexpr int bid_x = 80;
- constexpr int tid_x = 4096;
- constexpr int red_dim = 1;
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1});
- fusion.addInput(tv0);
-
- TensorView* tv1 =
- reductionOp(BinaryOpType::Add, {red_dim, 2}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({bid_x, tid_x, 1}, options);
-
- // Apply reduction heuristic
- auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
- scheduleReduction(&fusion, reduction_params.value());
- auto lparams = reduction_params.value().lparams;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- // no broadcasting needed, omitting the last optional argument;
- auto cg_outputs = fe.runFusion({aten_input}, lparams);
- auto aten_output = aten_input.to(at::kDouble).sum({red_dim, 2});
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-TEST(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) {
- constexpr int bid_x = 80;
- constexpr int tid_x = 4096;
- constexpr int red_dim = 1;
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1});
- fusion.addInput(tv0);
-
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {2}, new Double(0), tv0);
-
- TensorView* tv2 =
- reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv1);
- fusion.addOutput(tv2);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({bid_x, tid_x, 1}, options);
-
- // Apply reduction heuristic
- auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-
- scheduleReduction(&fusion, reduction_params.value());
- auto lparams = reduction_params.value().lparams;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- // no broadcasting needed, omitting the last optional argument;
- auto cg_outputs = fe.runFusion({aten_input}, lparams);
- auto aten_output = aten_input.to(at::kDouble).sum({1, 2});
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-TEST(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) {
- constexpr int bid_x = 80;
- constexpr int tid_x = 4096;
- constexpr int red_dim = 1;
-
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1});
- fusion.addInput(tv0);
-
- TensorView* tv1 =
- reductionOp(BinaryOpType::Add, {red_dim}, new Double(0), tv0);
-
- TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Double(0), tv1);
- fusion.addOutput(tv2);
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({bid_x, tid_x, 1}, options);
-
- // Apply reduction heuristic
- auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
- scheduleReduction(&fusion, reduction_params.value());
- auto lparams = reduction_params.value().lparams;
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- // no broadcasting needed, omitting the last optional argument;
- auto cg_outputs = fe.runFusion({aten_input}, lparams);
- auto aten_output = aten_input.to(at::kDouble).sum({2, 1});
-
- testValidate(
- &fusion,
- cg_outputs,
- {aten_input},
- {aten_output},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-TEST(NVFuserTest, FusionTrivialReduction_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Set up your input tensor views
- TensorView* tv0 = makeConcreteTensor({10, 20, 1});
- fusion.addInput(tv0);
- TensorView* tv1 = reductionOp(BinaryOpType::Add, {2}, new Double(0), tv0);
- fusion.addOutput(tv1);
-
- TORCH_CHECK(!fusion.hasReduction(), "Trivial reduction picked up by fusion");
-
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({10, 20, 1}, options);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
- auto aten_output = aten_input.to(at::kDouble).sum({2});
-
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTrivialReduction2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- int w = 1, x = 1, y = 7, z = 8;
-
- auto tv0 = makeSymbolicTensor(2);
- auto tv1 = makeConcreteTensor({w, x, y, z});
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- auto tv2 = sum(tv1, {0});
- auto tv3 = sum(tv2, {0});
- auto tv4 = add(tv3, tv0);
-
- fusion.addOutput(tv4);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({y, z}, options);
- at::Tensor t1 = at::randn({w, x, y, z}, options);
- auto aten_output = t1.to(at::kDouble).sum({0}).sum({0}).add(t0);
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- auto lparams = schedulePointwise(&fusion, aten_inputs);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionTrivialReduction3_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- int v = 1, w = 1, x = 1, y = 7, z = 8;
-
- auto tv0 = makeSymbolicTensor(2);
- auto tv1 = makeConcreteTensor({v, w, x, y, z});
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- auto tv2 = sum(tv1, {0, 1, 2});
- auto tv3 = add(tv2, tv0);
-
- fusion.addOutput(tv3);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({y, z}, options);
- at::Tensor t1 = at::randn({v, w, x, y, z}, options);
- auto aten_output = t1.sum({0, 1, 2}).add(t0);
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- auto lparams = schedulePointwise(&fusion, aten_inputs);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+/*
+ * Templatized Helper Function To generate single Op comparison between the
+ * JIT codegen for Cuda and the ATen Library.
+ */
-// Make sure trivial reductions are correctly detected even with
-// scheduling applied.
-TEST(NVFuserTest, FusionDetectTrivialReduction1_CUDA) {
+using OutputPair = std::pair<ValType, DataType>;
+template <
+ typename AtenFunc,
+ typename JitFunc,
+ typename InputTuple,
+ size_t... NumInputs>
+void test_op(
+ int blocks,
+ int threads,
+ std::string op_str,
+ AtenFunc af,
+ JitFunc jf,
+ OutputPair op,
+ InputTuple it,
+ std::index_sequence<NumInputs...>) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
-
- auto tv1 = broadcast(tv0, {false, true});
- auto tv2 = sum(tv1, {1});
- fusion.addOutput(tv2);
-
- tv2->split(1, 4);
- tv2->split(1, 8);
- auto tv3 = tv2->rFactor({-1});
- auto tv4 = tv2->rFactor({-1});
-
- auto tv5 = broadcast(tv0, {true, false});
- auto tv6 = add(tv5, new Double(1));
- auto tv7 = sub(tv6, new Double(1));
- auto tv8 = sum(tv7, {0});
- fusion.addOutput(tv8);
-
- auto tv9 = broadcast(tv0, {false, true, true});
- auto tv10 = sum(tv9, {1});
- auto tv11 = sum(tv10, {1});
- fusion.addOutput(tv11);
-
- tv8->split(0, 3);
- tv10->split(1, 4);
- tv11->split(1, 5);
-
- tv0->computeAt(tv2, -1);
- tv0->computeAt(tv8, -1);
- tv0->computeAt(tv11, 1);
-
- // Test indexing to gmem-backed tensors
- tv3->setMemoryType(MemoryType::Global);
- tv8->setMemoryType(MemoryType::Global);
+ // Generate Input JIT function Inputs and add them as Inputs to the Fusion
+ // Graph
+ std::array<Val*, sizeof...(NumInputs)> jit_inputs = {
+ gen_jit_operand(std::get<NumInputs>(it))...};
+ std::for_each(jit_inputs.begin(), jit_inputs.end(), [&fusion](Val* v) {
+ fusion.addInput(v);
+ });
+ TensorView* out =
+ static_cast<TensorView*>(jf(std::get<NumInputs>(jit_inputs)...));
+ fusion.addOutput(out);
- GpuLower gpulw(&fusion);
+ std::for_each(jit_inputs.begin(), jit_inputs.end(), [out](Val* v) {
+ if (v->getValType() == ValType::TensorView)
+ static_cast<TensorView*>(v)->computeAt(out, -1);
+ });
+ out->axis(0)->parallelize(ParallelType::BIDx);
+ out->axis(-1)->parallelize(ParallelType::TIDx);
- // No kir::ReductionOp should be generated as all the reduction
- // exprs should be replaced with a unary set op.
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- TORCH_CHECK(!kir_node->isA<kir::ReductionOp>());
- }
+ std::array<IValue, sizeof...(NumInputs)> aten_inputs = {gen_aten_operand(
+ std::get<NumInputs>(it), blocks, threads, /*rand*/ true)...};
+ const at::ArrayRef<IValue> aten_inputs_ivalues(aten_inputs);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({100}, options);
- std::vector<IValue> aten_inputs = {t0};
+ at::Tensor output =
+ gen_aten_operand(op, blocks, threads, /*rand*/ false).toTensor();
+ std::vector<at::Tensor> output_vect = {output};
+ cudaDeviceSynchronize();
+ if (fusion.isStochastic())
+ at::manual_seed(0);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {t0, t0, t0}, __LINE__, __FILE__);
-}
-
-// Test detection of partially trivial reduction
-TEST(NVFuserTest, FusionDetectTrivialReduction2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = sum(tv0, {1});
- auto tv2 = add(tv1, new Double(1));
- fusion.addOutput(tv2);
-
- tv1->split(1, 1);
- // tv1->axis(1): non-trivial
- // tv1->axis(2): trivial
+ fe.runFusion(aten_inputs_ivalues, output_vect);
+ cudaDeviceSynchronize();
- auto tv3 = tv1->rFactor({-1});
+ if (fusion.isStochastic())
+ at::manual_seed(0);
+ at::Tensor ref_output = af(aten_inputs);
+ cudaDeviceSynchronize(); // This sync shouldn't be necessary;
- GpuLower gpulw(&fusion);
+ std::function<std::string()> aten_inputs_to_str =
+ [&aten_inputs]() -> std::string {
+ int input_cnt = 1;
+ std::stringstream ss;
+ std::for_each(
+ aten_inputs.begin(), aten_inputs.end(), [&input_cnt, &ss](IValue& iv) {
+ ss << "\nINPUT" << input_cnt++ << ": " << iv.toTensor();
+ });
+ return ss.str();
+ };
- // tv3's reduction axis is a trivial reduction. The only
- // kir::ReductionOp should be for tv1.
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (kir_node->isA<kir::ReductionOp>()) {
- auto reduction_out =
- kir_node->as<kir::ReductionOp>()->outputs()[0]->as<kir::TensorView>();
- TORCH_CHECK(reduction_out->fuserTv() == tv1);
- }
+ at::Tensor diff;
+ if (output.scalar_type() == at::kBool) {
+ diff = at::eq(output, ref_output);
+ } else {
+ diff = at::sub(output, ref_output);
}
-}
-
-TEST(NVFuserTest, FusionInputsIdLookup_CUDA) {
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({16, 8, 8}, options);
- at::Tensor t1 = at::randn({8, 8}, options);
- at::Tensor t2 = at::randn({6, 4}, options);
-
- // create a cache with max size 2;
- torch::jit::fuser::cuda::InputsIdLookup inputs_id_lookup(2);
-
- // testing basic function, same encoding for identical inputs
- auto id_0 = inputs_id_lookup.lookupId({t0, t1, 5.0});
- auto id_0_lookup = inputs_id_lookup.lookupId({t0, t1, 2.5});
- TORCH_CHECK(id_0.id == id_0_lookup.id);
- TORCH_CHECK(inputs_id_lookup.size() == 1);
- TORCH_CHECK(id_0.eviction == false);
-
- // new input (even tho same shape, but we have different signature because of
- // missing scalar input
- auto id_1 = inputs_id_lookup.lookupId({t0, t1});
- auto id_1_lookup = inputs_id_lookup.lookupId({t0, t1});
- TORCH_CHECK(id_1.id == id_1_lookup.id);
- TORCH_CHECK(inputs_id_lookup.size() == 2);
- TORCH_CHECK(id_1.eviction == false);
-
- // eviction should happen at this point
- auto id_2 = inputs_id_lookup.lookupId({t2, t1});
- TORCH_CHECK(id_2.id != id_0.id);
- TORCH_CHECK(id_2.id != id_1.id);
- TORCH_CHECK(inputs_id_lookup.size() == 2);
- TORCH_CHECK(id_2.eviction == true);
- TORCH_CHECK(id_2.evict_id == id_0.id);
-
- // look at input 1 again
- auto id_1_relook = inputs_id_lookup.lookupId({t0, t1});
- TORCH_CHECK(id_1_relook.id == id_1.id);
- TORCH_CHECK(id_1_relook.eviction == false);
-}
-
-TEST(NVFuserTest, FusionGroupGuardSimpleTensor_CUDA) {
- std::vector<int64_t> sizes_vec({16, 8, 8});
- std::vector<int64_t> strides_vec({64, 8, 1});
- auto tensor_type = TensorType::create(
- at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- // pass with identical shape
- auto t0 = at::randn({16, 8, 8}, options);
- TORCH_CHECK(complyWith(t0, tensor_type));
-
- // pass with dynamic shape
- auto t1 = at::randn({16, 16, 8}, options);
- TORCH_CHECK(complyWith(t1, tensor_type));
-
- // broadcasting semantic change failure
- auto t2 = at::randn({16, 1, 8}, options);
- TORCH_CHECK(!complyWith(t2, tensor_type));
-
- // contiguity failure via slicing
- auto t3 = t0.slice(1, 0, 8, 2);
- TORCH_CHECK(!complyWith(t3, tensor_type));
-
- // contiguity failure via slicing
- auto t4 = t0.slice(2, 0, 8, 2);
- TORCH_CHECK(!complyWith(t4, tensor_type));
-
- // rank failure
- auto t5 = at::randn({16, 8, 8, 8}, options);
- TORCH_CHECK(!complyWith(t5, tensor_type));
-
- // contiguity on stride 1 dimension with implicit broadcasting
- auto t = at::randn({4}, options);
- auto t6 = t.unsqueeze(1).expand({4, 8});
- TORCH_CHECK(complyWith(t6, TensorType::create(t6)));
-}
-
-TEST(NVFuserTest, FusionGroupGuardBroadcastTensor_CUDA) {
- std::vector<int64_t> sizes_vec({16, 1, 8});
- std::vector<int64_t> strides_vec({8, 8, 1});
- auto tensor_type = TensorType::create(
- at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- // broadcasting semantic change
- auto t0 = at::randn({16, 8, 8}, options);
- TORCH_CHECK(!complyWith(t0, tensor_type));
-
- // dtype failure
- auto t1 = at::randn({16, 1, 8}, options.dtype(at::kHalf));
- TORCH_CHECK(!complyWith(t1, tensor_type));
-
- // dtype failure
- auto t2 = at::randn({16, 1, 8}, options);
- TORCH_CHECK(complyWith(t2, tensor_type));
- // device inconsistency shouldn't fail
- auto t3 = at::randn({16, 1, 8}, options.device(at::kCPU, 0));
- TORCH_CHECK(complyWith(t3, tensor_type));
+ TORCH_CHECK(
+ (output.scalar_type() == at::kBool
+ ? output.equal(ref_output)
+ :
+ // The absolute Tolerance was raised to 1e-07 from 1e-08 to allow
+ // allow for the remainder function to pass.
+ output.allclose(ref_output, /*rtol*/ 1e-05, /*atol*/ 1e-07)),
+ "\nOp Type: -- ",
+ op_str,
+ " -- had a mismatch.",
+ aten_inputs_to_str(),
+ "\nABS MAX DIFF: ",
+ output.sub(ref_output).abs().max(),
+ "\n");
}
-TEST(NVFuserTest, FusionGroupGuardPermutedTensor_CUDA) {
- std::vector<int64_t> sizes_vec({16, 8, 8});
- std::vector<int64_t> strides_vec({64, 1, 8});
- auto tensor_type = TensorType::create(
- at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- // failing permutation
- auto t0 = at::randn({16, 8, 8}, options);
- TORCH_CHECK(!complyWith(t0, tensor_type));
-
- // passing with dynamic shape
- auto t1 = t0.permute({0, 2, 1});
- TORCH_CHECK(complyWith(t1, tensor_type));
+/*
+ * Templatized Helper Function that uses variadic templates to
+ * process a variable length Input Tuple of different Operand Type.
+ */
+template <typename AtenFunc, typename JitFunc, typename InputTuple>
+void test_op(
+ int blocks,
+ int threads,
+ std::string op_str,
+ AtenFunc af,
+ JitFunc jf,
+ OutputPair op,
+ InputTuple it) {
+ static constexpr auto size = std::tuple_size<InputTuple>::value;
+ test_op(
+ blocks,
+ threads,
+ op_str,
+ af,
+ jf,
+ op,
+ it,
+ std::make_index_sequence<size>{});
}
-TEST(NVFuserTest, FusionGroupGuardRelaxedCheck_CUDA) {
- std::vector<int64_t> sizes_vec({16, 8, 8});
- std::vector<int64_t> strides_vec({128, 16, 1});
- auto tensor_type = TensorType::create(
- at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- // contiguity check passes although it differs
- auto t0 = at::randn({16, 16, 8}, options);
- TORCH_CHECK(complyWith(t0, tensor_type));
-
- // passing with dynamic shape
- auto t1 = t0.slice(1, 0, 16, 2);
- TORCH_CHECK(complyWith(t1, tensor_type));
-}
+TEST(NVFuserTest, FusionUnaryOps_CUDA) {
+ using OpTuple =
+ std::tuple<at::Tensor (*)(const at::Tensor&), UnaryOpType, std::string>;
-TEST(NVFuserTest, FusionDisjointSet_CUDA) {
- DisjointSet<int> set;
+ // [Note: explicit tuple type for uniform initialization list]
+ // Tuple type must be explicitly specified for each uniform initialization
+ // list within the vector to make this code compatible with some old env
+ // which we still need to support. eg. gcc 5.4 + cuda 9.2.
+ std::vector<OpTuple> ops{
+ OpTuple{at::abs, UnaryOpType::Abs, "abs"},
+ OpTuple{at::acos, UnaryOpType::Acos, "acos"},
+ OpTuple{at::asin, UnaryOpType::Asin, "asin"},
+ OpTuple{at::atan, UnaryOpType::Atan, "atan"},
+ // There does not appear to be an appropriate ATen function for atanh
+ // OpTuple{at::atanh, UnaryOpType::Atanh, "atanh" },
+ OpTuple{at::ceil, UnaryOpType::Ceil, "ceil"},
+ OpTuple{at::cos, UnaryOpType::Cos, "cos"},
+ OpTuple{at::cosh, UnaryOpType::Cosh, "cosh"},
+ OpTuple{at::erf, UnaryOpType::Erf, "erf"},
+ OpTuple{at::erfc, UnaryOpType::Erfc, "erfc"},
+ OpTuple{at::exp, UnaryOpType::Exp, "exp"},
+ OpTuple{at::expm1, UnaryOpType::Expm1, "expm1"},
+ OpTuple{at::floor, UnaryOpType::Floor, "floor"},
+ OpTuple{at::frac, UnaryOpType::Frac, "frac"},
+ OpTuple{at::gelu, UnaryOpType::Gelu, "gelu"},
+ OpTuple{at::lgamma, UnaryOpType::Lgamma, "lgamma"},
+ OpTuple{at::log, UnaryOpType::Log, "log"},
+ OpTuple{at::log10, UnaryOpType::Log10, "log10"},
+ OpTuple{at::log1p, UnaryOpType::Log1p, "log1p"},
+ OpTuple{at::log2, UnaryOpType::Log2, "log2"},
+ OpTuple{at::neg, UnaryOpType::Neg, "neg"},
+ OpTuple{at::reciprocal, UnaryOpType::Reciprocal, "reciprocal"},
+ OpTuple{at::relu, UnaryOpType::Relu, "relu"},
+ OpTuple{at::round, UnaryOpType::Round, "round"},
+ OpTuple{at::rsqrt, UnaryOpType::Rsqrt, "rsqrt"},
+ OpTuple{at::sigmoid, UnaryOpType::Sigmoid, "sigmoid"},
+ OpTuple{at::sin, UnaryOpType::Sin, "sin"},
+ OpTuple{at::sinh, UnaryOpType::Sinh, "sinh"},
+ OpTuple{at::sqrt, UnaryOpType::Sqrt, "sqrt"},
+ OpTuple{at::tan, UnaryOpType::Tan, "tan"},
+ OpTuple{at::tanh, UnaryOpType::Tanh, "tanh"},
+ OpTuple{at::trunc, UnaryOpType::Trunc, "trunc"}};
- const std::set<int> group_x({0, 1, 2});
- const std::set<int> group_y({3, 4, 5});
- const std::set<int> group_z({6, 7, 8});
- const std::vector<std::set<int>> groups({group_x, group_y, group_z});
- std::set<int> group_all;
- std::for_each(groups.begin(), groups.end(), [&](const auto& g) {
- group_all.insert(g.begin(), g.end());
+ std::for_each(ops.begin(), ops.end(), [](OpTuple& op) {
+ test_op(
+ /*blocks*/ 640,
+ /*threads*/ 64,
+ /*name*/ std::get<2>(op),
+ /*Aten Func */
+ [&op](std::array<IValue, 1>& vals) {
+ return std::get<0>(op)(vals[0].toTensor());
+ },
+ /*JIT Func */
+ [&op](Val* in1) -> Val* { return unaryOp(std::get<1>(op), in1); },
+ /*Output */ std::make_pair(ValType::TensorView, DataType::Float),
+ /*Inputs Tuple*/
+ std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float)));
});
- // Initially, nothing should be considered equivalent
- for (auto i : group_all) {
- for (auto j : group_all) {
- TORCH_CHECK(!set.areEquivalent(i, j));
- }
- }
-
- // Sets values in group_x are equivalent
- for (auto i : group_x) {
- for (auto j : group_x) {
- set.join(i, j);
- TORCH_CHECK(set.contains(i));
- TORCH_CHECK(set.contains(j));
- }
- }
+ test_op(
+ /*blocks*/ 128,
+ /*threads*/ 64,
+ /*name*/ "rand_like",
+ /*Aten Func */
+ [](std::array<IValue, 1>& vals) {
+ return at::rand_like(vals[0].toTensor());
+ },
+ /*JIT Func */
+ [](Val* in1) -> Val* { return unaryOp(UnaryOpType::RandLike, in1); },
+ /*Output */ std::make_pair(ValType::TensorView, DataType::Float),
+ /*Inputs Tuple*/
+ std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float)));
+}
- // All values in group_x shoudl be equivalent with each other
- for (auto i : group_x) {
- for (auto j : group_x) {
- TORCH_CHECK(set.areEquivalent(i, j));
- }
- }
- // But nothing else should be equivalent
- for (auto i : group_all) {
- for (auto j : group_y) {
- TORCH_CHECK(!set.areEquivalent(i, j));
- }
- for (auto j : group_z) {
- TORCH_CHECK(!set.areEquivalent(i, j));
- }
- }
+TEST(NVFuserTest, FusionBinaryOps_CUDA) {
+ using AtenFuncSig = at::Tensor (*)(const at::Tensor&, const at::Tensor&);
+ using OpTuple = std::tuple<AtenFuncSig, BinaryOpType, std::string>;
- // Sets values in group_y are equivalent
- for (auto i : group_y) {
- for (auto j : group_y) {
- set.join(i, j);
- TORCH_CHECK(set.contains(i));
- TORCH_CHECK(set.contains(j));
- }
- }
+ // see [Note: explicit tuple type for uniform initialization list]
+ std::vector<OpTuple> logic_ops{
+ OpTuple{at::eq, BinaryOpType::Eq, "eq"},
+ OpTuple{at::ge, BinaryOpType::GE, "ge"},
+ OpTuple{at::gt, BinaryOpType::GT, "gt"},
+ OpTuple{at::le, BinaryOpType::LE, "le"},
+ OpTuple{at::lt, BinaryOpType::LT, "lt"},
+ OpTuple{at::ne, BinaryOpType::NE, "ne"}};
- // group_x should be still equivalent
- for (auto i : group_x) {
- for (auto j : group_x) {
- TORCH_CHECK(set.areEquivalent(i, j));
- }
- }
- // group_y should be now equivalent
- for (auto i : group_y) {
- for (auto j : group_y) {
- TORCH_CHECK(set.areEquivalent(i, j));
- }
- }
- // But group_z should not be equivalent with anything yet
- for (auto i : group_all) {
- for (auto j : group_z) {
- TORCH_CHECK(!set.areEquivalent(i, j));
- }
- }
+ std::for_each(logic_ops.begin(), logic_ops.end(), [](OpTuple& op) {
+ test_op(
+ /*blocks*/ 640,
+ /*threads*/ 64,
+ /*name*/ std::get<2>(op),
+ /*Aten Func */
+ [&op](std::array<IValue, 2>& vals) {
+ return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor());
+ },
+ /*JIT Func */
+ [&op](Val* in1, Val* in2) -> Val* {
+ return binaryOp(std::get<1>(op), in1, in2);
+ },
+ /*Output */ std::make_pair(ValType::TensorView, DataType::Bool),
+ /*Inputs Tuple*/
+ std::make_tuple(
+ std::make_pair(ValType::TensorView, DataType::Float),
+ std::make_pair(ValType::TensorView, DataType::Float)));
+ });
- // Sets values in group_z are equivalent
- for (auto i : group_z) {
- for (auto j : group_z) {
- set.join(i, j);
- TORCH_CHECK(set.contains(i));
- TORCH_CHECK(set.contains(j));
- }
- }
+ // see [Note: explicit tuple type for uniform initialization list]
+ std::vector<OpTuple> math_ops{
+ OpTuple{at::atan2, BinaryOpType::Atan2, "atan2"},
+ OpTuple{at::div, BinaryOpType::Div, "div"},
+ OpTuple{at::fmod, BinaryOpType::Fmod, "fmod"},
+ OpTuple{at::max, BinaryOpType::Max, "max"},
+ OpTuple{at::min, BinaryOpType::Min, "min"},
+ OpTuple{at::mul, BinaryOpType::Mul, "mul"},
+ OpTuple{at::pow, BinaryOpType::Pow, "pow"},
+ // NOTE: Remainder does not match the Aten impl exactly
+ // despite using an identical function.
+ OpTuple{at::remainder, BinaryOpType::Remainder, "remainder"},
+ };
- // Now each of the three groups should be equivalent within each
- // group
- for (size_t gi = 0; gi < groups.size(); ++gi) {
- for (size_t gj = 0; gj < groups.size(); ++gj) {
- for (auto i : groups[gi]) {
- for (auto j : groups[gj]) {
- TORCH_CHECK(
- (gi == gj && set.areEquivalent(i, j)) ||
- (gi != gj && !set.areEquivalent(i, j)));
- }
- }
- }
- }
+ std::for_each(math_ops.begin(), math_ops.end(), [](OpTuple& op) {
+ test_op(
+ /*blocks*/ 640,
+ /*threads*/ 64,
+ /*name*/ std::get<2>(op),
+ /*Aten Func */
+ [&op](std::array<IValue, 2>& vals) {
+ return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor());
+ },
+ /*JIT Func */
+ [&op](Val* in1, Val* in2) -> Val* {
+ return binaryOp(std::get<1>(op), in1, in2);
+ },
+ /*Output */ std::make_pair(ValType::TensorView, DataType::Float),
+ /*Inputs Tuple*/
+ std::make_tuple(
+ std::make_pair(ValType::TensorView, DataType::Float),
+ std::make_pair(ValType::TensorView, DataType::Float)));
+ });
- auto all_elements = set.getAllElements();
- std::sort(all_elements.begin(), all_elements.end());
- std::vector<int> group_all_vec(group_all.begin(), group_all.end());
- std::sort(group_all_vec.begin(), group_all_vec.end());
- TORCH_CHECK(all_elements == group_all_vec);
+ test_op(
+ /*blocks*/ 640,
+ /*threads*/ 64,
+ /*name*/ "add_alpha",
+ /*Aten Func */
+ [](std::array<IValue, 3>& vals) {
+ return at::add(
+ vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar());
+ },
+ /*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&add_alpha),
+ /*Output */ std::make_pair(ValType::TensorView, DataType::Float),
+ /*Inputs Tuple*/
+ std::make_tuple(
+ std::make_pair(ValType::TensorView, DataType::Float),
+ std::make_pair(ValType::TensorView, DataType::Float),
+ std::make_pair(ValType::Scalar, DataType::Float)));
+ test_op(
+ /*blocks*/ 640,
+ /*threads*/ 64,
+ /*name*/ "sub_alpha",
+ /*Aten Func */
+ [](std::array<IValue, 3>& vals) {
+ return at::sub(
+ vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar());
+ },
+ /*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&sub_alpha),
+ /*Output */ std::make_pair(ValType::TensorView, DataType::Float),
+ /*Inputs Tuple*/
+ std::make_tuple(
+ std::make_pair(ValType::TensorView, DataType::Float),
+ std::make_pair(ValType::TensorView, DataType::Float),
+ std::make_pair(ValType::Scalar, DataType::Float)));
+}
- set.clear();
- all_elements = set.getAllElements();
- TORCH_CHECK(all_elements.size() == 0);
+TEST(NVFuserTest, FusionTernaryOps_CUDA) {
+ test_op(
+ /*blocks*/ 640,
+ /*threads*/ 64,
+ /*name*/ "clamp",
+ /*Aten Func */
+ [](std::array<IValue, 1>& vals) {
+ return at::clamp(vals[0].toTensor(), 0.f, 1.f);
+ },
+ /*JIT Func */
+ [](Val* in1) -> Val* {
+ return clamp(in1, new Float(0.f), new Float(1.f));
+ },
+ /*Output */ std::make_pair(ValType::TensorView, DataType::Float),
+ /*Inputs Tuple*/
+ std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float)));
+ test_op(
+ /*blocks*/ 640,
+ /*threads*/ 64,
+ /*name*/ "threshold",
+ /*Aten Func */
+ [](std::array<IValue, 1>& vals) {
+ return at::threshold(vals[0].toTensor(), 0.f, 1.f);
+ },
+ /*JIT Func */
+ [](Val* in1) -> Val* {
+ return threshold(in1, new Float(0.f), new Float(1.f));
+ },
+ /*Output */ std::make_pair(ValType::TensorView, DataType::Float),
+ /*Inputs Tuple*/
+ std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float)));
+ test_op(
+ /*blocks*/ 640,
+ /*threads*/ 64,
+ /*name*/ "where",
+ /*Aten Func */
+ [](std::array<IValue, 3>& vals) {
+ return at::where(
+ vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor());
+ },
+ /*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&where),
+ /*Output */ std::make_pair(ValType::TensorView, DataType::Float),
+ /*Inputs Tuple*/
+ std::make_tuple(
+ std::make_pair(ValType::TensorView, DataType::Bool),
+ std::make_pair(ValType::TensorView, DataType::Float),
+ std::make_pair(ValType::TensorView, DataType::Float)));
+}
- // All cleared. Nothing should be considered equivalent.
- for (auto i : group_all) {
- for (auto j : group_all) {
- TORCH_CHECK(!set.areEquivalent(i, j));
- }
- }
+TEST(NVFuserTest, FusionCompoundOps_CUDA) {
+ test_op(
+ /*blocks*/ 640,
+ /*threads*/ 64,
+ /*name*/ "lerp",
+ /*Aten Func */
+ [](std::array<IValue, 3>& vals) {
+ return at::lerp(
+ vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor());
+ },
+ /*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*)>(&lerp),
+ /*Output */ std::make_pair(ValType::TensorView, DataType::Float),
+ /*Inputs Tuple*/
+ std::make_tuple(
+ std::make_pair(ValType::TensorView, DataType::Float),
+ std::make_pair(ValType::TensorView, DataType::Float),
+ std::make_pair(ValType::TensorView, DataType::Float)));
+ test_op(
+ /*blocks*/ 640,
+ /*threads*/ 64,
+ /*name*/ "addcmul",
+ /*Aten Func */
+ [](std::array<IValue, 4>& vals) {
+ return at::addcmul(
+ vals[0].toTensor(),
+ vals[1].toTensor(),
+ vals[2].toTensor(),
+ vals[3].toScalar());
+ },
+ /*JIT Func */ static_cast<Val* (*)(Val*, Val*, Val*, Val*)>(&addcmul),
+ /*Output */ std::make_pair(ValType::TensorView, DataType::Float),
+ /*Inputs Tuple*/
+ std::make_tuple(
+ std::make_pair(ValType::TensorView, DataType::Float),
+ std::make_pair(ValType::TensorView, DataType::Float),
+ std::make_pair(ValType::TensorView, DataType::Float),
+ std::make_pair(ValType::Scalar, DataType::Float)));
}
-TEST(NVFuserTest, FusionNonUniqueBroadcastSize_CUDA) {
+TEST(NVFuserTest, FusionCastOps_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
- auto tv1 = makeSymbolicTensor(2);
- auto tv2 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addInput(tv2);
-
- auto tv3 = broadcast(tv0, {false, true});
- auto tv4 = add(tv3, tv1);
- auto tv5 = add(tv3, tv2);
-
- fusion.addOutput(tv4);
- fusion.addOutput(tv5);
+ TensorView* tv0 = makeDummyTensor(2, DataType::Half);
- // In order to do this, tv1->axis(1) and tv2->axis(1) must have the
- // same size, but we can't prove it, so this should throw an error.
- ASSERT_ANY_THROW(tv3->computeAt(tv4, -1));
-}
+ TensorView* intrm1 = castOp(DataType::Float, tv0);
+ TensorView* out = castOp(DataType::Half, intrm1);
-TEST(NVFuserTest, FusionBiasGeluFwd_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ fusion.addInput(tv0);
+ fusion.addOutput(out);
+ tv0->computeAt(out, -1);
- const float k_079 = 0.79788456;
- const float k_004 = 0.044715;
-
- // bias vector
- auto t0 = makeSymbolicTensor(1, DataType::Half);
- fusion.addInput(t0);
- auto t1 = castOp(DataType::Float, t0);
- // input tensor
- auto t2 = makeSymbolicTensor(3, DataType::Half);
- fusion.addInput(t2);
- auto t3 = castOp(DataType::Float, t2);
- auto t4 = broadcast(t1, {true, true, false});
- auto t5 = add(t4, t3);
- auto t6 = mul(t5, new Double(0.5));
- auto t7 = mul(t5, new Double(k_079));
- auto t8 = mul(t5, new Double(k_004));
- auto t9 = mul(t8, t5);
- auto t10 = add(t9, new Int(1));
- auto t11 = mul(t7, t10);
- auto t12 = unaryOp(UnaryOpType::Tanh, t11);
- auto t13 = add(t12, new Double(1));
- auto t14 = mul(t6, t13);
- auto t15 = castOp(DataType::Half, t14);
- fusion.addOutput(t15);
+ out->axis(0)->parallelize(ParallelType::BIDx);
+ out->axis(-1)->parallelize(ParallelType::TIDx);
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
- at::manual_seed(0);
- std::vector<int64_t> input_shape{6, 512, 4096};
- std::vector<int64_t> bias_shape{4096};
- auto at_input = at::randn(input_shape, options);
- auto at_bias = at::randn(bias_shape, options);
-
- auto at_x =
- at_bias.to(c10::ScalarType::Float) + at_input.to(c10::ScalarType::Float);
- auto aten_output_float =
- at_x * 0.5 * (1.0 + (k_079 * at_x * (1 + k_004 * at_x * at_x)).tanh());
- auto aten_output = aten_output_float.to(c10::ScalarType::Half);
+ at::Tensor input1 = at::rand({1, 4}, options);
+ at::Tensor ref_output = at::empty_like(input1);
- std::vector<IValue> aten_inputs = {at_bias, at_input};
- auto lparams = schedulePointwise(&fusion, aten_inputs);
+ std::array<IValue, 1> inputs = {input1};
+ const at::ArrayRef<IValue> input_ivalues(inputs);
FusionExecutor fe;
fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion(input_ivalues);
- auto cg_outputs = fe.runFusion(aten_inputs, lparams);
+ ref_output = at::_cast_Half(at::_cast_Float(input1));
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ TORCH_CHECK(
+ outputs[0].equal(ref_output),
+ "\nOp Type: -- ",
+ "cast FP16->FP32->FP16",
+ " -- had a mismatch.\n",
+ "\nABS MAX DIFF: ",
+ outputs[0].sub(ref_output).abs().max(),
+ "\n");
}
-TEST(NVFuserTest, FusionBiasGeluBwd_CUDA) {
- // skipping on pre-volta device
- if (at::cuda::getDeviceProperties(c10::cuda::current_device())->major < 7) {
- return;
- }
+// We want split/merge/reorder all tested both on and off rfactor domains, also
+// want compute at into the rfactor domain, and into its consumer
+TEST(NVFuserTest, FusionRFactorReplay_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- const float k_079 = 0.79788456;
- const float k_004 = 0.044715;
- const float k_010 = 0.1070322243;
-
- // gradient tensor
- auto t0 = makeSymbolicTensor(3, DataType::Half);
- fusion.addInput(t0);
- auto t1 = castOp(DataType::Float, t0);
- // bias tensor
- auto t2 = makeSymbolicTensor(1, DataType::Half);
- fusion.addInput(t2);
- auto t3 = castOp(DataType::Float, t2);
- // input tensor
- auto t4 = makeSymbolicTensor(3, DataType::Half);
- fusion.addInput(t4);
- auto t5 = castOp(DataType::Float, t4);
- auto t6 = broadcast(t3, {true, true, false});
- auto t7 = add(t6, t5);
- auto t8 = mul(t7, new Double(k_079));
- auto t9 = mul(t7, new Double(k_004));
- auto t10 = mul(t9, t7);
- auto t11 = add(t10, new Int(1));
- auto t12 = mul(t8, t11);
- auto t13 = unaryOp(UnaryOpType::Tanh, t12);
- auto t14 = mul(t7, new Double(0.5));
- auto t15 = mul(t13, t13);
- auto t16 = unaryOp(UnaryOpType::Neg, t15);
- auto t17 = add(t16, new Int(1));
- auto t18 = mul(t7, new Double(k_010));
- auto t19 = mul(t18, t7);
- auto t20 = add(t19, new Double(k_079));
- auto t21 = mul(t17, t20);
- auto t22 = mul(t14, t21);
- auto t23 = add(t13, new Int(1));
- auto t24 = mul(t23, new Double(0.5));
- auto t25 = add(t22, t24);
- auto t26 = mul(t25, t1);
- // Save float output for validation
- fusion.addOutput(t26);
- auto t27 = castOp(DataType::Half, t26);
- fusion.addOutput(t27);
-
- auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
- at::manual_seed(1);
- std::vector<int64_t> input_shape{6, 512, 4096};
- std::vector<int64_t> bias_shape{4096};
- auto at_input = at::randn(input_shape, options);
- auto at_bias = at::randn(bias_shape, options);
- auto at_grad = at::randn(input_shape, options);
-
- auto at_x =
- at_bias.to(c10::ScalarType::Float) + at_input.to(c10::ScalarType::Float);
- auto at_tanh_out = (k_079 * at_x * (1 + k_004 * at_x * at_x)).tanh();
- auto at_ff = 0.5 * at_x *
- ((1 - at_tanh_out * at_tanh_out) * (k_079 + k_010 * at_x * at_x)) +
- 0.5 * (1 + at_tanh_out);
- auto at_out = at_ff * at_grad;
- auto at_out_half = at_out.to(c10::ScalarType::Half);
-
- std::vector<IValue> aten_inputs = {at_grad, at_bias, at_input};
- std::vector<at::Tensor> aten_outputs = {at_out, at_out_half};
-
- auto lparams = schedulePointwise(&fusion, aten_inputs);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto cg_outputs = fe.runFusion(aten_inputs, lparams);
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
-}
-
-// Reproducer of issue #459
-TEST(NVFuserTest, FusionIssue459_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
- auto tv0 = makeSymbolicTensor(1);
+ // Register your inputs
fusion.addInput(tv0);
- auto tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
- auto tv2 = add(tv0, new Double(1));
- auto tv3 = broadcast(tv2, {true, false});
- auto tv4 = add(tv1, tv3);
+ // Do math with it, it returns a `Val*` but can be static_casted back to
+ // TensorView
+ TensorView* tv1 = sum(tv0, {1});
+ // tv1[I0, R1]
+ tv1->split(0, 32);
+ // tv1[I0o, I0i{32}, R1]
+ tv1->split(0, 16);
+ // tv1[I0oo, I0oi{16}, I0i{32}, R1]
+ tv1->split(-1, 8);
+ // tv1[I0oo, I0oi{16}, I0i{32}, R1o, R1i{8}]
+ tv1->split(-2, 4);
+ // tv1[I0oo, I0oi{16}, I0i{32}, R1oo, R1oi{4}, R1i{8}]
+ tv1->reorder({{0, -2}, {2, -1}, {-3, 0}, {-1, 1}});
+ // tv1[R1oo, R1i{8}, I0oi{16}, R1oi{4}, I0oo, I0i{32}]
- // Create two outputs from the final arithmetic result
- auto tv5 = add(tv4, new Double(1));
- fusion.addOutput(tv5);
- auto tv6 = add(tv4, new Double(1));
- fusion.addOutput(tv6);
+ tv1->merge(0);
+ tv1->merge(-2);
- // Scheduling
- for (auto output : ir_utils::filterByType<TensorView>(fusion.outputs())) {
- output->merge(-2, -1);
- }
- for (auto output : ir_utils::filterByType<TensorView>(fusion.outputs())) {
- output->split(0, 128);
- }
+ // tv1[R1oo*R1i{8}, I0oi{16}, R1oi{4}, I0oo*I0i{32}]
+ TensorDomain* new_domain = TransformRFactor::runReplay(tv1->domain(), {0});
+ // new_domain[r(R1oo*R1i{8})rf, I0oi{16}, ir1oi{4}rf, I0oo*I0i{32}]
- tv0->computeAt(tv5, -1);
+ TensorDomain* new_domain2 = TransformRFactor::runReplay2(tv1->domain(), {0});
+ // new_domain2[ I0oi{16}, , I0oo*I0i{32}, R1oi{4}]
- tv6->axis(0)->parallelize(ParallelType::BIDx);
- tv6->axis(1)->parallelize(ParallelType::TIDx);
+ // Move rfactor axis to end, keep iter rfactor axis
+ new_domain->reorder({{0, -1}, {2, 2}});
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(0);
- const int numel_x = 10;
- const int numel_y = 20;
- auto t0 = at::randn({numel_x}, options);
- auto t1 = at::randn({numel_y, numel_x}, options);
- auto aten_output = (t0 + 1).unsqueeze(0) + t1 + 1;
+ // Replay casp, replay new_domain2 as new_domain
+ // reordered_new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf]
+ auto replay_casp = TransformReplay::replayCasP(new_domain2, new_domain, 2);
+ TensorDomain* casp = replay_casp.first;
+ // new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf]
+ // casp[I0oi{16}, I0oo*I0i{32}, R1oi{4}]
- std::vector<IValue> aten_inputs = {t0, t1};
+ casp->split(1, new Int(2));
+ // casp [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4} ]
+ // new_domain[I0oi{16}, I0oo*I0i{32} , ir1oi{4}rf,
+ // R(R1oo*R1i{8})rf]
- torch::jit::fuser::cuda::FusionExecutor fe;
- fe.compileFusion(&fusion);
+ auto replay_pasc = TransformReplay::replayPasC(new_domain, casp, 2);
+ TensorDomain* pasc = replay_pasc.first;
+ // pasc [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4}rf,
+ // R(R1oo*R1i{8})rf]
+
+ TORCH_CHECK(
+ new_domain->nDims() - 1 == new_domain2->nDims(),
+ casp->nDims() == new_domain2->nDims() + 1,
+ pasc->nDims() == new_domain->nDims() + 1,
+ "Error in rfactor, number of dimensions is not correct.");
- auto cg_outputs = fe.runFusion(aten_inputs);
+ TORCH_CHECK(
+ !casp->sameAs(new_domain2) && !pasc->sameAs(new_domain) &&
+ !new_domain->sameAs(new_domain2) &&
+ !tv1->domain()->sameAs(new_domain) &&
+ !tv1->domain()->sameAs(new_domain2),
+ "Error in rfactor, number of dimensions is not correct.");
- testValidate(
- &fusion,
- cg_outputs,
- aten_inputs,
- {aten_output, aten_output},
- __LINE__,
- __FILE__);
+ auto dom = new_domain->getRootDomain();
+ TORCH_CHECK(
+ !dom[0]->isReduction() &&
+ std::any_of(
+ dom.begin(),
+ dom.end(),
+ [](IterDomain* id) { return id->isReduction(); }) &&
+ std::any_of(
+ dom.begin(),
+ dom.end(),
+ [](IterDomain* id) { return id->isRFactorProduct(); }),
+ "Error in rFactor, there seems to be something wrong in root domain.");
+
+ auto dom2 = new_domain2->getRootDomain();
+ TORCH_CHECK(
+ !dom2[0]->isReduction() &&
+ std::any_of(
+ dom2.begin(),
+ dom2.end(),
+ [](IterDomain* id) { return id->isReduction(); }),
+ "Error in rFactor, there seems to be something wrong in root domain.");
}
-TEST(NVFuserTest, FusionSmemIndexingSimple_CUDA) {
+// Start off simple, block on the outer dim
+// block stride + thread all reduce + unrolling on inner dim
+TEST(NVFuserTest, FusionReduction_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(1));
- auto tv3 = add(tv2, new Double(1));
- fusion.addOutput(tv3);
- tv3->axis(0)->parallelize(ParallelType::BIDx);
- tv3->axis(1)->parallelize(ParallelType::TIDx);
-
- tv0->computeAt(tv3, -1);
-
- tv1->setMemoryType(MemoryType::Shared);
- tv2->setMemoryType(MemoryType::Global);
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ fusion.addOutput(tv1);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ tv1->split(1, 128);
+ // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+ tv1->split(1, 4);
+ // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1]
- auto aten_input = at::randn({12, 34}, options);
- at::Tensor aten_output = aten_input + 1.0 + 1.0 + 1.0;
+ TensorView* tv2 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1]
+ // tv1[I0, R1oi{4}, R1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}]
- auto cg_outputs = fe.runFusion({aten_input});
+ TensorView* tv3 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1]
+ // tv3[I0, R1oi{4}, Ir1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}]
+ // tv1[I0, R1i{128}] = tv3[I0, R1oi{4}, Ir1i{128}]
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
+ // Incrementally, can print in between for debugging
+ tv0->computeAt(tv2, 1);
+ tv2->computeAt(tv3, 1);
+ tv3->computeAt(tv1, 1);
-TEST(NVFuserTest, FusionSmemIndexing_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ // Re do it all at once, because why not.
+ tv0->computeAt(tv1, 1);
- // Symbolic integers we will use for runtime tiling
- Int* symbolic_m_tile_dim = new Int();
- Int* symbolic_split_k_tile_dim = new Int();
- Int* symbolic_block_k_tile_dim = new Int();
- // Compile-time integer for tiling
- int n_smem_tile = 32;
+ tv2->axis(2)->parallelize(ParallelType::Unroll);
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
- // Symbolic 2D tensors TV0[M, K], TV1[K, N]
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
- // Broadcast tv0 to [M, K, *]
- TensorView* tv2 = broadcast(tv0, {false, false, true});
- // Broadcast tv1 to [*, K, N]
- TensorView* tv3 = broadcast(tv1, {true, false, false});
+ int numel_x = 65000;
+ int numel_y = 1025;
- // Pointwise multiplication resulting in tv3[M, K, N]
- TensorView* tv4 = mul(tv2, tv3);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
- // Sum the K-dim
- TensorView* tv5 = sum(tv4, {1});
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ fe.runFusion({input}, {cg_output});
- // Register inputs and outputs
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addOutput(tv5);
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(cg_output));
+}
- // Register runtime tile dims as inputs
- fusion.addInput(symbolic_m_tile_dim);
- fusion.addInput(symbolic_split_k_tile_dim);
- fusion.addInput(symbolic_block_k_tile_dim);
+TEST(NVFuserTest, FusionReduction2_CUDA) {
+ {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- // Make a 3D tile, mix of symbolic and constant, do in reverse order because
- // dims are inserted
- // [M, rK, N]
- tv5->split(2, n_smem_tile);
- // [M, rK, No, Ni{32}]
- tv5->split(1, symbolic_block_k_tile_dim);
- // [M, rKo, rKi{i2}, No, Ni{32}]
- tv5->split(1, symbolic_split_k_tile_dim);
- // [M, rKoo, rKoi{i1}, rKi{i2}, No, Ni{32}]
- tv5->split(0, symbolic_m_tile_dim);
- // [Mo, Mi{i0}, rKoo, rKoi{i1}, rKi{i2}, No, Ni{32}]
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
- // Reorder so all outer tiles are in the leftmost 3 positions
- // [Mo, Mi{i0}, rKoo, rKoi{i1}, rKi{i2}, No, Ni{32}]
- // [Mo, No, rKoo, rKoi{i1}, rKi{i2}, Mi{i0}, Ni{32}]
- tv5->reorder({{1, 5}, {5, 1}});
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
- // Factor out the outer reduction IterDomain, then run the inter-cta
- // reduction, and intra-cta reduction
- // [Mo, No, rKoo, Koi{i1}, Ki{i2}, Mi{i0}, Ni{32}]
- // [Mo, No, rKoi{i1}, rKi{i2}, Mi{i0}, Ni{32}]
- auto tv6 = tv5->rFactor({2});
+ fusion.addOutput(tv1);
- // Scope computations
- tv6->computeAt(tv5, 2);
+ // switches to try some different scenarios. maybe we should iterate on all
+ // permutations.
+ bool bind_bidx = true;
+ bool bind_tidx = true;
+ bool bind_tidy = true;
+ bool bind_unroll = true;
- // [Mo, No, rKoo, Koi{i1}, Ki{i2}, Mi{i0}, Ni{32}]
- // [Mo, No, Ki{i2}, Mi{i0}, Ni{32}, rKoo, Koi{i1}]
- tv6->reorder({
- {2, -2},
- {3, -1},
- {4, 2},
- {5, 3},
- {6, 4},
- });
+ int numel_x = 1025; // Cannot exceed block dim max size / tidy
+ int numel_y = 129;
+ int tidx = 16;
+ int tidy = 8;
+ int unroll_factor = 4;
- // Setup compute at schedule
- tv0->computeAt(tv6, 3);
- tv1->computeAt(tv6, 3);
- tv4->computeAt(tv6, -1);
+ tv1->split(1, tidx);
+ // tv1[I0, R1o, R1i{tidx}] = tv0[I0, I1]
- // Cache smem tiles
- tv2->setMemoryType(MemoryType::Shared);
- tv3->setMemoryType(MemoryType::Shared);
- tv4->setMemoryType(MemoryType::Shared);
- tv6->setMemoryType(MemoryType::Shared);
+ tv1->split(1, unroll_factor);
+ // tv1[I0, R1oo, R1oi{unroll}, R1i{tidx}] = tv0[I0, I1]
- tv5->axis(0)->parallelize(ParallelType::BIDz);
- tv5->axis(1)->parallelize(ParallelType::BIDy);
+ tv1->split(0, tidy);
- std::vector<TensorView*> tv_list = {tv2, tv3, tv4, tv5, tv6};
- for (auto tv : tv_list) {
- tv->axis(-2)->parallelize(ParallelType::TIDz);
- tv->axis(-1)->parallelize(ParallelType::TIDy);
- }
+ TensorView* tv2 = tv1->rFactor({-3});
+ // tv2[I0, >R1oo<, Ir1oi{unroll}, Ir1i{tidx}]
+ // tv1[I0o, I0i{tidy}, R1oi{unroll}, R1i{tidx}]
- constexpr int M = 31, K = 65, N = 32;
+ TensorView* tv3 = tv1->rFactor({-2});
+ // tv2[I0, >R1oo<, Ir1oi{unroll}, Ir1i{tidx}]
+ // tv3[I0, R1oi{unroll}, Ir1i{tidx}]
+ // tv1[I0o, I0i{tidy}, R1i{tidx}]
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({M, K}, options);
- at::Tensor t1 = at::randn({K, N}, options);
+ tv0->computeAt(tv1, -2);
- at::Tensor aten_output =
- mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1);
+ if (bind_unroll)
+ tv2->axis(-2)->parallelize(ParallelType::Unroll);
+ if (bind_bidx)
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
+ if (bind_tidy)
+ tv1->axis(1)->parallelize(ParallelType::TIDy);
- // A, B, m_tile_dim, split_k, intra_cta_tile
- std::vector<IValue> aten_inputs = {t0, t1, 3, 4, 5};
+ if (bind_tidx) {
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ }
- torch::jit::fuser::cuda::FusionExecutor fe;
- fe.compileFusion(&fusion);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({input});
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(outputs[0]));
+ }
-// Reproducer of issue 408
-TEST(NVFuserTest, FusionCacheBeforeReduction_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ {
+ // What if Z participates in the reduction with X?
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = sum(tv1, {1});
- fusion.addOutput(tv2);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
- tv2->split(0, 4);
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
- auto tv3 = tv2->cache_before();
+ fusion.addOutput(tv1);
- tv0->computeAt(tv3, -1);
- tv3->computeAt(tv2, -1);
+ int numel_x = 1025; // Cannot exceed block dim max size / tidy
+ int numel_y = 129;
+ int tidx = 16;
+ int tidz = 8;
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ tv1->split(1, tidz);
+ // tv1[I0, R1o, R1i{tidz}] = tv0[I0, I1]
- FusionExecutor fe;
- fe.compileFusion(&fusion);
+ tv1->split(1, tidx);
+ // tv1[I0, R1oo, R1oi{tidx}, R1i{tidz}] = tv0[I0, I1]
- const int numel_x = 100;
- const int numel_y = 200;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ TensorView* tv2 = tv1->rFactor({-3});
+ // tv2[I0, >R1oo<, Ir1oi{tidx}, Ir1i{tidz}]
+ // tv1[I0o, R1oi{tidx}, R1i{tidz}]
- at::Tensor aten_input = at::randn({numel_x, numel_y}, options);
- at::Tensor cg_output = at::empty({numel_x}, options);
+ tv0->computeAt(tv1, -3);
- auto aten_output = (aten_input + 1).to(at::kDouble).sum({1});
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
+ tv1->axis(-2)->parallelize(ParallelType::TIDx);
+ tv1->axis(-1)->parallelize(ParallelType::TIDz);
+
+ tv2->axis(-2)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDz);
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
- fe.runFusion({aten_input}, {cg_output});
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ fe.runFusion({input}, {cg_output});
- testValidate(
- &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__);
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(cg_output));
+ }
}
-TEST(NVFuserTest, FusionCacheBeforeReduction2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+TEST(NVFuserTest, FusionReduction3_CUDA) {
+ {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(3);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = sum(tv1, {1});
- auto tv3 = add(tv2, new Double(1));
- fusion.addOutput(tv2);
- fusion.addOutput(tv3);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
- auto tv4 = tv2->cache_before();
+ TensorView* tv2 = add(tv0, tv1);
+ // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1]
- tv4->computeAt(tv3, 1);
- tv0->computeAt(tv4, -1);
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
- tv3->axis(0)->parallelize(ParallelType::BIDx);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv4->axis(-1)->parallelize(ParallelType::TIDx);
+ TensorView* tv3 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv2);
+ // tv3[I0, R1] = tv2[I0, I1]
- FusionExecutor fe;
- fe.compileFusion(&fusion);
+ TensorView* tv4 = makeDummyTensor(1);
+ fusion.addInput(tv4);
- const int numel_x = 10;
- const int numel_y = 20;
- const int numel_z = 30;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ // tv5[I0] = tv3[I0, R1] * tv4[I0]
+ TensorView* tv5 = mul(tv3, tv4);
+ fusion.addOutput(tv5);
- at::Tensor aten_input = at::randn({numel_x, numel_y, numel_z}, options);
- auto t2 = (aten_input + 1).to(at::kDouble).sum({1});
- auto t3 = t2 + 1;
- std::vector<at::Tensor> aten_outputs = {t2, t3};
+ int tidx = 16;
- auto cg_outputs = fe.runFusion({aten_input});
+ // RFactor the reduction
+ tv3->split(1, tidx);
+ // tv3[I0, R1o, R1i{tidx}] = tv2[I0, I1]
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
+ TensorView* tv6 = tv3->rFactor({-2});
+ // tv6[I0, R1o, iR1i{tidx}] = tv2[I0, I1]
+ // tv3[I0, R1i{tidx}] = tv3[I0, I1]
+ tv2->computeAt(tv6, 2);
-TEST(NVFuserTest, FusionIssue367_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ // Compute at inline with tv5 (only 1D)
+ tv6->computeAt(tv3, 1);
+ tv3->computeAt(tv5, 1);
- // Symbolic integers we will use for runtime tiling
- Int* symbolic_m_tile_dim = new Int();
- Int* symbolic_split_k_tile_dim = new Int();
- Int* symbolic_block_k_tile_dim = new Int();
- // Compile-time integer for tiling
- int n_smem_tile = 32;
+ tv5->axis(0)->parallelize(ParallelType::BIDx);
- // Symbolic 2D tensors TV0[M, K], TV1[K, N]
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
+ // Intermediate tensors only need this, but doesn't hurt to do on inputs
+ // tv0, 1, 4
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ tv6->axis(-1)->parallelize(ParallelType::TIDx);
- // Broadcast tv0 to [M, K, *]
- TensorView* tv2 = broadcast(tv0, {false, false, true});
- // Broadcast tv1 to [*, K, N]
- TensorView* tv3 = broadcast(tv1, {true, false, false});
+ int numel_x = 1025;
+ int numel_y = 129;
- // Pointwise multiplication resulting in tv3[M, K, N]
- TensorView* tv4 = mul(tv2, tv3);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor t0 = at::rand({numel_x, numel_y}, options);
+ at::Tensor t1 = at::rand({numel_x, numel_y}, options);
+ auto t2 = t0.add(t1);
+ auto t3 = t2.sum({1});
+ at::Tensor t4 = at::rand({numel_x}, options);
+ auto t5 = t3.mul(t4);
- // Sum the K-dim
- TensorView* tv5 = sum(tv4, {1});
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({t0, t1, t4});
- // Register inputs and outputs
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addOutput(tv5);
+ TORCH_CHECK(
+ t5.allclose(outputs[0]), "Error of: ", t5.sub(outputs[0]).abs().max());
+ }
+}
- // Register runtime tile dims as inputs
- fusion.addInput(symbolic_m_tile_dim);
- fusion.addInput(symbolic_split_k_tile_dim);
- fusion.addInput(symbolic_block_k_tile_dim);
+TEST(NVFuserTest, FusionReduction4_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- // Make a 3D tile, mix of symbolic and constant, do in reverse order because
- // dims are inserted
- tv5->split(2, n_smem_tile);
- tv5->split(1, symbolic_block_k_tile_dim);
- tv5->split(1, symbolic_split_k_tile_dim);
- tv5->split(0, symbolic_m_tile_dim);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(3);
- // tv5[M/m_tile, m_tile, r{K/split_k/block_k}, r{split_k}, r{block_k}, N/32,
- // 32]
- tv5->reorder({{1, 5}, {5, 1}});
- // tv5[M/m_tile, N/32, r{K/split_k/block_k}, r{split_k}, r{block_k}, m_tile,
- // 32]
+ fusion.addInput(tv0);
- auto tv6 = tv5->rFactor({2});
- auto tv7 = tv5->rFactor({2});
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
- // Scope computations
- tv6->computeAt(tv5, 2);
+ fusion.addOutput(tv1);
- tv6->reorder({
- {2, -2},
- {3, -1},
- {4, 2},
- {5, 3},
- {6, 4},
- });
+ int bidy = 2;
+ int tidy = 4;
+ int tidx = 5;
- tv7->reorder({
- {2, -2},
- {3, -1},
- {-2, 2},
- {-1, 3},
- });
+ int dim1 = 11;
- tv0->computeAt(tv6, 3);
- tv1->computeAt(tv6, 3);
- tv4->computeAt(tv6, -1);
+ tv1->split(-2, tidy);
- // Cache smem tiles
- tv2->setMemoryType(MemoryType::Shared);
- tv3->setMemoryType(MemoryType::Shared);
- tv4->setMemoryType(MemoryType::Local);
- tv6->setMemoryType(MemoryType::Local);
- tv7->setMemoryType(MemoryType::Local);
+ TensorView* tv2 = tv1->rFactor({-3});
- tv5->axis(0)->parallelize(ParallelType::BIDz);
- tv5->axis(1)->parallelize(ParallelType::BIDy);
+ tv0->computeAt(tv1, 1);
+ tv1->axis(0)->parallelize(ParallelType::BIDy);
- std::vector<TensorView*> tv_list = {tv2, tv3, tv4, tv5, tv6, tv7};
- for (auto tv : tv_list) {
- tv->axis(-2)->parallelize(ParallelType::TIDz);
- tv->axis(-1)->parallelize(ParallelType::TIDy);
+ for (auto* val : fusion.vals()) {
+ if (!fusion.hasInput(val) &&
+ val->getValType().value() == ValType::TensorView) {
+ val->as<TensorView>()->axis(-1)->parallelize(ParallelType::TIDx);
+ }
}
- tv2->axis(3)->parallelize(ParallelType::TIDx);
- tv3->axis(3)->parallelize(ParallelType::TIDx);
- tv4->axis(3)->parallelize(ParallelType::TIDx);
- tv6->axis(3)->parallelize(ParallelType::TIDx);
- tv7->axis(2)->parallelize(ParallelType::TIDx);
-
- tv2->axis(4)->parallelize(ParallelType::BIDx);
- tv3->axis(4)->parallelize(ParallelType::BIDx);
- tv4->axis(4)->parallelize(ParallelType::BIDx);
- tv6->axis(4)->parallelize(ParallelType::BIDx);
- tv7->axis(3)->parallelize(ParallelType::BIDx);
- tv5->axis(2)->parallelize(ParallelType::BIDx);
- constexpr int M = 3, K = 6, N = 16;
+ tv2->axis(-2)->parallelize(ParallelType::TIDy);
+ tv1->axis(-2)->parallelize(ParallelType::TIDy);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::randn({bidy, dim1, tidx}, options);
- at::Tensor t0 = at::randn({M, K}, options);
- at::Tensor t1 = at::randn({K, N}, options);
-
- // A, B, m, split_k, block_k
- std::vector<IValue> aten_inputs = {t0, t1, 2, 2, 3};
- at::Tensor aten_output =
- mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1);
+ at::Tensor cg_output = at::empty({bidy, tidx}, options);
- torch::jit::fuser::cuda::FusionExecutor fe;
+ FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ fe.runFusion({input}, {cg_output});
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(
+ aten_output.allclose(cg_output, 1e-5, 1e-7),
+ "Error of: ",
+ aten_output.sub(cg_output).abs().max());
}
-TEST(NVFuserTest, FusionIssue468_CUDA) {
+TEST(NVFuserTest, FusionReduction5_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
+ const int bdimx = 64;
+ const int bdimy = 8;
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(3);
fusion.addInput(tv0);
- auto tv1 = sum(tv0, {1});
- auto tv2 = sum(tv1, {0});
- fusion.addOutput(tv2);
- tv1->axis(0)->parallelize(ParallelType::TIDy);
- tv1->axis(1)->parallelize(ParallelType::TIDx);
+ // tv1[I0, R1, R2] = tv0[I0, I1, I2]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Float(0), tv0);
+ fusion.addOutput(tv1);
+
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+ tv1->split(2, bdimx);
+ // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
+ tv1->split(1, bdimy);
+ // tv1[I0, R1o, R1i{8}, R2o, R2i{128}] = tv0[I0, I1, I2]
- tv2->axis(0)->parallelize(ParallelType::TIDy);
+ TensorView* tv2 = tv1->rFactor({3});
+ // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
+ // tv1[I0, R1o, R1i{8}, R2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
+
+ TensorView* tv3 = tv1->rFactor({1});
+ // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
+ // tv3[I0, R1o, I1i{8}, I2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
+ // tv1[I0, R1i{8}, R2i{128}] = tv3[I0, R1o, I1i{8}, I2i{128}]
+
+ tv3->computeAt(tv1, 1);
+ tv2->computeAt(tv3, 2);
+
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
+ tv2->axis(0)->parallelize(ParallelType::BIDx);
+ tv3->axis(0)->parallelize(ParallelType::BIDx);
+
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+
+ tv1->axis(-2)->parallelize(ParallelType::TIDy);
+ tv3->axis(-2)->parallelize(ParallelType::TIDy);
+ tv2->axis(-3)->parallelize(ParallelType::TIDy);
+
+ int numel_x = 650;
+ int numel_y = 1000;
+ int numel_z = 4;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({10, 100}, options);
- at::Tensor aten_output = aten_input.to(at::kDouble).sum({1}).sum({0});
+ at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options);
- torch::jit::fuser::cuda::FusionExecutor fe;
+ FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
+ auto outputs = fe.runFusion({input});
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
+ auto aten_output = input.sum({1, 2});
+ TORCH_CHECK(aten_output.allclose(outputs[0]));
}
-TEST(NVFuserTest, FusionIssue363_CUDA) {
+TEST(NVFuserTest, FusionReductionTFT_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- // Symbolic 2D tensors TV0[M, K], TV1[K, N]
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(2);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
- // Broadcast tv0 to [M, K, *]
- TensorView* tv2 = broadcast(tv0, {false, false, true});
- // Broadcast tv1 to [*, K, N]
- TensorView* tv3 = broadcast(tv1, {true, false, false});
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
- // Pointwise multiplication resulting in tv3[M, K, N]
- TensorView* tv4 = mul(tv2, tv3);
+ fusion.addOutput(tv1);
- // Sum the K-dim
- TensorView* tv5 = sum(tv4, {1});
+ int numel_x = 1025;
+ int numel_y = 129;
+ int tidx = 16;
+ int tidy = 8;
+ int tidz = 8;
- // Register inputs and outputs
- fusion.addInput(tv0);
- fusion.addInput(tv1);
- fusion.addOutput(tv5);
+ tv1->split(1, tidx);
+ // tv1[I0, R1o, R1i{tidx}]
- tv2->setMemoryType(MemoryType::Global);
- tv3->setMemoryType(MemoryType::Global);
- tv4->setMemoryType(MemoryType::Global);
+ tv1->split(1, tidz);
+ // tv1[I0, R1oo, R1Oi{tidz}, R1R1i{tidx}]
+
+ tv1->split(0, tidy);
+ // tv1[I0o, I0i, R1oo, R1Oi{tidz}, R1R1i{tidx}]
+
+ TensorView* tv2 = tv1->rFactor({2});
+ // tv2[I0o, I0i, R1oo, I1Oi{tidz}, I11i{tidx}]
+ // tv1[I0o, I0i, R1Oi{tidz}, R1R1i{tidx}]
- tv0->computeAt(tv5, -1);
- tv1->computeAt(tv5, -1);
+ tv2->computeAt(tv1, 2);
- tv5->axis(0)->parallelize(ParallelType::BIDz);
- tv5->axis(1)->parallelize(ParallelType::BIDy);
+ tv1->axis(1)->parallelize(ParallelType::TIDy);
- tv5->axis(2)->parallelize(ParallelType::BIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
- constexpr int M = 3, K = 6, N = 16;
+ tv1->axis(-2)->parallelize(ParallelType::TIDz);
+ tv2->axis(-2)->parallelize(ParallelType::TIDz);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
- at::Tensor t0 = at::randn({M, K}, options);
- at::Tensor t1 = at::randn({K, N}, options);
- at::Tensor aten_output =
- mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1);
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- torch::jit::fuser::cuda::FusionExecutor fe;
+ FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ fe.runFusion({input}, {cg_output});
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(cg_output));
}
-TEST(NVFuserTest, FusionIssue484_CUDA) {
+TEST(NVFuserTest, FusionBranches_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
+ TensorView* tv2 = makeDummyTensor(2);
fusion.addInput(tv0);
- auto tv1 = sum(tv0, {1});
- auto tv2 = add(tv1, new Double(0));
- fusion.addOutput(tv2);
+ fusion.addInput(tv1);
+ fusion.addInput(tv2);
+
+ auto tv3 = add(tv0, new Float(1.0));
+ auto tv4 = add(tv3, tv1);
+ auto tv5 = add(tv3, tv2);
+ auto tv6 = add(tv4, tv5);
- tv1->setMemoryType(MemoryType::Global);
- tv1->axis(1)->parallelize(ParallelType::TIDx);
+ fusion.addOutput(tv6);
- constexpr int M = 100;
+ constexpr int x = 63, y = 33;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_input = at::randn({M, M}, options);
- at::Tensor aten_output = aten_input.to(at::kDouble).sum({1});
+ at::Tensor t0 = at::randn({x, y}, options);
+ at::Tensor t1 = at::randn({x, y}, options);
+ at::Tensor t2 = at::randn({x, y}, options);
+
+ FusionExecutor fe;
+ tv6->merge(0);
+ tv6->split(0, 128);
+ tv6->split(0, 4);
+
+ tv6->axis(0)->parallelize(ParallelType::BIDx);
+
+ tv0->computeAt(tv6, 1);
+ tv1->computeAt(tv6, 1);
+ tv2->computeAt(tv6, 1);
+
+ tv3->axis(-2)->parallelize(ParallelType::Unroll);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ tv4->axis(-2)->parallelize(ParallelType::Unroll);
+ tv4->axis(-1)->parallelize(ParallelType::TIDx);
+ tv5->axis(-2)->parallelize(ParallelType::Unroll);
+ tv5->axis(-1)->parallelize(ParallelType::TIDx);
+ tv6->axis(-1)->parallelize(ParallelType::TIDx);
- torch::jit::fuser::cuda::FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
+ auto outputs = fe.runFusion({t0, t1, t2});
+
+ auto t3 = t0.add(1.0);
+ auto t4 = t3.add(t1);
+ auto t5 = t3.add(t2);
+ auto t6 = t4.add(t5);
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
+ TORCH_CHECK(t6.allclose(outputs[0]));
}
-TEST(NVFuserTest, Issue329_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+TEST(NVFuserTest, FusionSimpleBCast_CUDA) {
+ {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = sum(tv1, {1});
- fusion.addOutput(tv2);
- auto tv3 = sum(tv1, {1});
- fusion.addOutput(tv3);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+ TensorView* tv1 = add(tv0, new Float(1.5));
- tv1->computeAt(tv2, -1);
+ TensorView* tv2 = makeDummyTensor(2);
+ fusion.addInput(tv2);
+ TensorView* tv3 = makeDummyTensor(2);
+ fusion.addInput(tv3);
+ TensorView* tv4 = sub(tv2, tv3);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ TensorView* tv5 = broadcast(tv1, {false, false, true});
+ TensorView* tv6 = broadcast(tv4, {true, false, false});
- std::vector<int64_t> t0_shape{17, 19};
- auto aten_input = at::randn(t0_shape, options);
- auto t2 = (aten_input + 1).to(at::kDouble).sum({1});
- auto t3 = (aten_input + 1).to(at::kDouble).sum({1});
- std::vector<at::Tensor> aten_outputs = {t2, t3};
+ TensorView* tv7 = add(tv5, tv6);
+ fusion.addOutput(tv7);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
+ tv7->split(-1, 4);
+ tv7->split(0, 8);
- auto cg_outputs = fe.runFusion({aten_input});
+ tv0->computeAt(tv7, -1);
+ tv2->computeAt(tv7, -1);
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
-}
+ tv7->axis(0)->parallelize(ParallelType::BIDx);
+ tv7->axis(-1)->parallelize(ParallelType::TIDx);
-TEST(NVFuserTest, FusionIssue382_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ constexpr int x = 63, y = 33, z = 15;
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = broadcast(tv1, {false, false, true});
- auto tv3 = makeSymbolicTensor(3);
- fusion.addInput(tv3);
- auto tv4 = add(tv2, tv3);
- fusion.addOutput(tv4);
+ at::Tensor t0 = at::randn({x, y}, options);
+ at::Tensor t1 = t0.add(1.5);
- tv2->merge(1);
- tv4->merge(1);
+ at::Tensor t2 = at::randn({y, z}, options);
+ at::Tensor t3 = at::randn({y, z}, options);
- tv1->computeAt(tv4, 1);
+ at::Tensor t4 = t2.sub(t3);
+ at::Tensor t5 = t1.unsqueeze(-1).expand({x, y, z});
- tv4->axis(0)->parallelize(ParallelType::BIDx);
+ at::Tensor t6 = t4.expand({x, y, z});
+ at::Tensor t7 = t5.add(t6);
- tv1->setMemoryType(MemoryType::Global);
- tv2->setMemoryType(MemoryType::Global);
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({t0, t2, t3});
- torch::jit::fuser::cuda::FusionExecutor fe;
- fe.compileFusion(&fusion);
+ TORCH_CHECK(t7.allclose(outputs[0]));
+ }
- const int numel_x = 12;
- const int numel_y = 34;
- const int numel_z = 56;
+ {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(0);
- auto t0 = at::randn({numel_x, numel_y}, options);
- auto t3 = at::randn({numel_x, numel_y, numel_z}, options);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+ TensorView* tv1 = makeDummyTensor(2);
+ fusion.addInput(tv1);
- std::vector<IValue> aten_inputs = {t0, t3};
- auto aten_output = (t0 + 1).unsqueeze(-1) + t3;
+ TensorView* tv2 = add(tv0, tv1);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ TensorView* tv3 = broadcast(tv2, {false, false, true});
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+ TensorView* tv4 = makeDummyTensor(2);
+ fusion.addInput(tv4);
-TEST(NVFuserTest, Issue507_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ TensorView* tv5 = sub(tv4, new Float(0.1));
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(1));
- fusion.addOutput(tv2);
+ TensorView* tv6 = broadcast(tv5, {true, false, false});
- tv1->setMemoryType(MemoryType::Shared);
+ TensorView* tv7 = add(tv3, tv6);
- tv1->axis(1)->parallelize(ParallelType::TIDx);
- tv2->axis(1)->parallelize(ParallelType::TIDx);
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(0)->parallelize(ParallelType::BIDx);
+ fusion.addOutput(tv7);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ tv7->merge(0, 1);
- std::vector<int64_t> t0_shape{17, 19};
- auto aten_input = at::randn(t0_shape, options);
- auto t1 = (aten_input + 1);
- auto aten_output = (t1 + 1);
+ tv0->computeAt(tv7, -1);
+ tv4->computeAt(tv7, -1);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
+ tv7->axis(0)->parallelize(ParallelType::BIDx);
+ tv7->axis(-1)->parallelize(ParallelType::TIDx);
- auto cg_outputs = fe.runFusion({aten_input});
+ constexpr int x = 63, y = 33, z = 15;
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
-}
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-TEST(NVFuserTest, FusionIssue532_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ at::Tensor t0 = at::randn({x, y}, options);
+ at::Tensor t1 = at::randn({x, y}, options);
+ at::Tensor t2 = t0.add(t1);
+ at::Tensor t3 = t2.unsqueeze(-1).expand({x, y, z});
- // Algorithm
- TensorView* tv0 = makeSymbolicTensor(1);
- TensorView* tv1 = add(tv0, new Double(1));
- TensorView* tv2 = add(tv1, new Double(1));
- fusion.addInput(tv0);
- fusion.addOutput(tv2);
+ at::Tensor t4 = at::randn({y, z}, options);
+ at::Tensor t5 = t4.sub(0.1);
+ at::Tensor t6 = t5.expand({x, y, z});
+ at::Tensor t7 = t3.add(t6);
- const int M_BLOCK = 64;
- const int M_THREAD = 4;
+ at::Tensor cg_output = at::empty({x, y, z}, options);
- tv2->split(0, M_BLOCK);
- // tv2: [M/M_BLOCK, M_BLOCK]
- tv1->computeAt(tv2, 1);
- // tv1: [M/M_BLOCK, M_BLOCK]
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ fe.runFusion({t0, t1, t4}, {cg_output});
- tv1->split(-1, M_BLOCK / M_THREAD);
- // tv1: [M/M_BLOCK, M_THREAD, M_BLOCK / M_THREAD]
+ TORCH_CHECK(t7.allclose(cg_output));
+ }
- tv2->split(-1, M_THREAD);
- // tv2: [M/M_BLOCK, M_BLOCK / M_THREAD, M_THREAD]
+ {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- constexpr int M = 1000;
+ // Set up your input tensor views
+ std::vector<IterDomain*> dom;
+ dom.push_back(new IterDomain(new Int(0), new Int()));
+ dom.push_back(new IterDomain(
+ new Int(0),
+ new Int(1),
+ ParallelType::Serial,
+ IterType::BroadcastWithStride));
+
+ // tv0[I1, B{1}]
+ TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float);
+ fusion.addInput(tv0);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor t0 = at::randn({M}, options);
- std::vector<IValue> aten_inputs = {t0};
+ // tv1[I0, I1, I2]
+ TensorView* tv2 = makeDummyTensor(3);
+ fusion.addInput(tv2);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(aten_inputs);
+ TensorView* tv3 = add(tv0, tv2);
- at::Tensor aten_output = t0 + 1 + 1;
+ fusion.addOutput(tv3);
- testValidate(
- &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+ tv3->merge(0);
+ tv3->merge(0);
-TEST(NVFuserTest, FusionLoopUnswitch_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ tv0->computeAt(tv3, -1);
+ tv2->computeAt(tv3, -1);
- // Algorithm
- TensorView* tv0 = makeSymbolicTensor(1);
- TensorView* tv1 = add(tv0, new Double(1));
- TensorView* tv2 = add(tv1, new Double(1));
- fusion.addInput(tv0);
- fusion.addOutput(tv2);
+ tv3->axis(0)->parallelize(ParallelType::BIDx);
- tv2->split(0, 32);
- tv1->computeAt(tv2, -1);
+ constexpr int x = 2, y = 3, z = 4;
- tv2->axis(1)->parallelize(ParallelType::Unswitch);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- constexpr int M = 1000;
+ at::Tensor t0 = at::randn({y, 1}, options);
+ at::Tensor t2 = at::randn({x, y, z}, options);
+ auto t3 = t0.add(t2);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor t0 = at::randn({M}, options);
- std::vector<IValue> aten_inputs = {t0};
+ at::Tensor cg_output = at::empty({x, y, z}, options);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(aten_inputs);
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ fe.runFusion({t0, t2}, {cg_output});
- at::Tensor aten_output = t0 + 1 + 1;
+ TORCH_CHECK(t3.allclose(cg_output));
+ }
- testValidate(
- &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+ {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
-TEST(NVFuserTest, FusionIssue549_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ // Set up your input tensor views
+ std::vector<IterDomain*> dom;
+ dom.push_back(new IterDomain(
+ new Int(0),
+ new Int(1),
+ ParallelType::Serial,
+ IterType::BroadcastWithStride));
+ dom.push_back(new IterDomain(new Int(0), new Int()));
+ TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float);
+
+ TensorView* tv1 = makeDummyTensor(3);
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2); // M, K
- TensorView* tv1 = makeSymbolicTensor(2); // K, N
- fusion.addInput(tv0);
- fusion.addInput(tv1);
+ TensorView* tv3 = add(tv0, tv1);
- auto tv2 = add(tv0, new Double(1));
+ tv3->merge(0);
+ tv3->merge(0);
+ tv3->split(0, 128);
+ tv3->split(0, 4);
- TensorView* tv3 = broadcast(tv2, {false, false, true});
- // tv3[I0, I1, B] = tv0[I0, I1]
+ fusion.addOutput(tv3);
- TensorView* tv4 = broadcast(tv1, {true, false, false});
- // tv4[B, I1, I2] = tv1[I1, I2]
+ tv0->computeAt(tv3, -1);
+ tv1->computeAt(tv3, -1);
- // tv5[I0, I1, I2] = tv3[I0, I1, B] * tv4[B, I1, I2]
- TensorView* tv5 = mul(tv3, tv4);
- // tv6[I0, R1, I2] = tv5[I0, I1, I2]
- TensorView* tv6 = sum(tv5, {1});
- fusion.addOutput(tv6);
+ tv3->axis(0)->parallelize(ParallelType::BIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-2)->parallelize(ParallelType::Unroll);
- tv6->split(1, 32);
- // tv6[I0, R1o, R1i{32}, I2]
+ constexpr int x = 63, y = 33, z = 15;
- auto tv7 = tv6->rFactor({1});
- // tv7[I0, R1o, I1i{32}, I2] = tv5[I0, I1, I2]
- // tv6[I0, , R1i{32}, I2] = tv7[I0, R1o, I1i{32}, I2]
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- tv6->split(0, 4);
- tv6->split(-1, 4);
- // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
- // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
+ at::Tensor t0 = at::randn({1, z}, options);
+ at::Tensor t1 = at::randn({x, y, z}, options);
- tv0->computeAt(tv6, -1);
- tv1->computeAt(tv6, -1);
+ at::Tensor cg_output = at::empty({x, y, z}, options);
- // tv7[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
- // tv6[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}]
- //--> (line symbolizes compute at location)
- // tv5[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o]
- // tv7[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
- // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ fe.runFusion({t0, t1}, {cg_output});
- tv0->computeAt(tv7, -1);
- tv1->computeAt(tv7, -1);
- // tv5[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
- // tv7[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
- // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
+ auto t3 = t0.add(t1);
- tv6->axis(0)->parallelize(ParallelType::BIDz);
- tv6->axis(1)->parallelize(ParallelType::TIDz);
+ TORCH_CHECK(t3.allclose(cg_output));
+ }
- tv6->axis(-2)->parallelize(ParallelType::BIDy);
- tv6->axis(-1)->parallelize(ParallelType::TIDy);
+ {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- tv6->axis(2)->parallelize(ParallelType::TIDx);
- tv7->axis(2)->parallelize(ParallelType::TIDx);
+ constexpr int m = 2, k = 3, n = 4;
- constexpr int M = 65, K = 33, N = 17;
+ auto zero = new Int(0);
+ auto M = new IterDomain(zero, new Int(m));
+ auto K = new IterDomain(zero, new Int(k));
+ auto N = new IterDomain(zero, new Int(n));
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ // Set up your input tensor views
+ TensorView* tv0 =
+ new TensorView(new TensorDomain({M, K}, {true, true}), DataType::Float);
+ TensorView* tv1 =
+ new TensorView(new TensorDomain({K, N}, {true, true}), DataType::Float);
- at::Tensor t0 = at::randn({M, K}, options);
- at::Tensor t1 = at::randn({K, N}, options);
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- // Lets specify a few bounds in launch params to make sure it works
- fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4));
+ TensorView* tv2 = broadcast(tv0, {false, false, true});
+ TensorView* tv3 = broadcast(tv1, {true, false, false});
- // Make sure bad launch params throws
- // TODO: Re-enable once we have parallelization validation in.
- // ASSERT_ANY_THROW(fe.runFusion({t0, t1}, LaunchParams(1, 2, 3, 4, 5, 6)));
+ TensorView* tv4 = add(tv2, tv3);
- // Don't specify any launch params
- auto cg_outputs = fe.runFusion({t0, t1});
+ fusion.addOutput(tv4);
- auto aten_output = (t0 + 1).to(at::kDouble).matmul(t1.to(at::kDouble));
+ tv4->merge(0);
+ tv4->merge(0);
- testValidate(
- &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__);
-}
+ tv0->computeAt(tv4, -1);
+ tv1->computeAt(tv4, -1);
-TEST(NVFuserTest, simplecompileRtc_CUDA) {
- FusionExecutor fe;
- std::string kernel = R"(
-__global__ void kernel1(Tensor<float, 1> T0, Tensor<float, 1> T1) {
- if(threadIdx.x==0){
- for(size_t ki28 = 0; ki28 < T0.size[0]; ++ki28) {
- T1[ki28*T1.stride[0]] = T0[ki28*T0.stride[0]]*2;
- }
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+
+ at::Tensor t0 = at::randn({m, k}, options);
+ at::Tensor t1 = at::randn({k, n}, options);
+
+ at::Tensor cg_output = at::empty({m, k, n}, options);
+
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ fe.runFusion({t0, t1}, {cg_output});
+
+ auto t2 = t0.unsqueeze(-1).expand({m, k, n});
+ auto t3 = t1.expand({m, k, n});
+ auto t4 = t2.add(t3);
+
+ TORCH_CHECK(t4.allclose(cg_output));
}
}
- )";
- fe.compileRtc(kernel, "CudaCodeGen::kernel1");
- LaunchParams lp(
- 256, // gdimx
- 1, // gdimy
- 1, // gdimz
- 1, // bdimx
- 1, // bdimy
- 1 // bdimz
- );
- lp.setSmem(0);
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const std::vector<int64_t> tensor_dims = {8};
- auto in0 = at::randn(tensor_dims, options);
- auto out0 = at::empty_like(in0);
- fe.runRtc(lp, {in0, out0});
- auto out_ref = in0 * 2;
- TORCH_CHECK(out_ref.allclose(out0));
-}
+TEST(NVFuserTest, FusionComplexBCast_CUDA) {
+ {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
-TEST(NVFuserTest, serialWelford_CUDA) {
- FusionExecutor fe;
- int x = 128, y = 64, z = 64;
-
- std::string kernel = R"(
-__global__ void kernel1(
- Tensor<float,3> inp,
- Tensor<float,1> out_var,
- Tensor<float,1> out_avg
-){
- for(int i0=0;i0<inp.size[0];i0++){
- float tmp_M2=0;
- float tmp_avg=0;
- long tmp_N=0;
- for(int i1=0;i1<inp.size[1];i1++){
- for(int i2=0;i2<inp.size[2];i2++){
- welfordCombine(
- tmp_avg,
- tmp_M2,
- tmp_N,
- inp[i0*inp.stride[0]+
- i1*inp.stride[1]+
- i2*inp.stride[2]],
- 0.f,
- (long)1
- );
- }
- }
- out_var[i0*out_var.stride[0]]=
- tmp_M2/(tmp_N);
- out_avg[i0*out_avg.stride[0]]=
- tmp_avg;
- }
-}
- )";
- fe.compileRtc(kernel, "CudaCodeGen::kernel1");
- LaunchParams lp(
- 1, // gdimx
- 1, // gdimy
- 1, // gdimz
- 1, // bdimx
- 1, // bdimy
- 1 // bdimz
- );
- lp.setSmem(0);
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const std::vector<int64_t> tensor_dims = {x, y, z};
- auto in0 = at::randn(tensor_dims, options);
- auto out_var = at::empty({x}, options);
- auto out_avg = at::empty({x}, options);
- fe.runRtc(lp, {in0, out_var, out_avg});
+ int x = 2, y = 3, z = 4;
+
+ auto tv0 = makeConcreteTensor({y});
+ auto tv1 = div(tv0, new Float(2.0));
+ auto tv2 = broadcast(tv1, {false, true});
+ auto tv3 = makeConcreteTensor({y, z});
+ auto tv4 = mul(tv2, tv3);
+ auto tv5 = broadcast(tv4, {true, false, false});
+ auto tv6 = makeConcreteTensor({x, y, z});
+ auto tv7 = add(tv5, tv6);
+
+ // tv0[ i1 ] = input
+ // tv1[ i1 ] = tv0/2.0
+ // tv2[ i1, b2] = bcast(tv1)
+ // tv3[ i1, i2] = input
+ // tv4[ i1, i2] = tv2 * tv3
+ // tv5[b0, i1, i2] = bcast(tv4)
+ // tv6[i0, i1, i2] = input
+ // tv7[i0, i1, i2] = tv5 + tv6
+
+ // tv4 = bcast(tv1) * tv3
+ // tv7 = bcast(tv4) + tv6
- TORCH_CHECK(in0.var({1, 2}, false).allclose(out_var));
- TORCH_CHECK(in0.mean({1, 2}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6));
-}
+ fusion.addInput(tv0);
+ fusion.addInput(tv3);
+ fusion.addInput(tv6);
-TEST(NVFuserTest, blockWelford_CUDA) {
- FusionExecutor fe;
- int x = 7, y = 8, z = 9;
-
- std::string kernel = R"(
-__global__ void kernel1(
- Tensor<float,2> inp,
- Tensor<float,1> out_avg,
- Tensor<float,1> out_var,
- Tensor<float,1> init_avg,
- Tensor<float,1> init_var,
- Tensor<long,0> init_N
-){
- //actual generated kernel will use dynamic shared mem,
- // here is just for prototype
- __shared__ float mem_avg[512];
- __shared__ float mem_M2[512];
- __shared__ long mem_N[512];
- float in=inp[threadIdx.x*inp.stride[0]+
- threadIdx.y*inp.stride[1]];
- float tmp_avg=0;
- float tmp_M2=0;
- long tmp_N=0;
- blockWelford<false,true,false>(
- tmp_avg,
- tmp_M2,
- tmp_N,
- in,
- 0.f,
- (long)1,
- threadIdx,
- blockDim,
- (float*)mem_avg,
- (float*)mem_M2,
- (long*)mem_N,
- (bool)(threadIdx.x<inp.size[0]),
- 0.f);
- __syncthreads();
- if(threadIdx.x<out_var.size[0] && threadIdx.y==0){
- welfordCombine(
- tmp_avg,
- tmp_M2,
- tmp_N,
- init_avg[threadIdx.x*init_avg.stride[0]],
- init_var[threadIdx.x*init_var.stride[0]]*init_N[0],
- init_N[0]
- );
- out_avg[threadIdx.x*out_avg.stride[0]]=tmp_avg;
- out_var[threadIdx.x*out_var.stride[0]]=tmp_M2/(tmp_N);
- }
-}
- )";
- fe.compileRtc(kernel, "CudaCodeGen::kernel1");
- LaunchParams lp(
- 1, // gdimx
- 1, // gdimy
- 1, // gdimz
- x, // bdimx
- y, // bdimy
- 1 // bdimz
- );
- lp.setSmem(0);
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const std::vector<int64_t> tensor_dims = {x, y};
- const std::vector<int64_t> init_dims = {x, z};
-
- // generate initial values
- auto init_in = at::randn(init_dims, options);
- auto init_var = init_in.var({1}, false);
- auto init_avg = init_in.mean({1});
- auto init_N =
- at::tensor(z, at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0));
-
- auto in0 = at::randn(tensor_dims, options);
-
- // run kernel
- auto out_var = at::zeros({x}, options);
- auto out_avg = at::zeros({x}, options);
- fe.runRtc(lp, {in0, out_avg, out_var, init_avg, init_var, init_N});
-
- // compare with reference output
- auto cat_tensor = at::cat({init_in, in0}, 1);
- TORCH_CHECK(cat_tensor.var({1}, false).allclose(out_var));
- TORCH_CHECK(
- cat_tensor.mean({1}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6));
-}
+ fusion.addOutput(tv7);
-TEST(NVFuserTest, blockWelfordNoInit_CUDA) {
- FusionExecutor fe;
- int x = 7, y = 8, z = 9;
-
- // need support IValue for integer input as initial count
- std::string kernel = R"(
-__global__ void kernel1(
- Tensor<float,3> inp,
- Tensor<float,1> out_avg,
- Tensor<float,1> out_var
-){
- //actual generated kernel will use dynamic shared mem,
- // here is just for prototype
- __shared__ float mem_avg[512];
- __shared__ float mem_M2[512];
- __shared__ long mem_N[512];
- float in=inp[threadIdx.x*inp.stride[0]+
- threadIdx.y*inp.stride[1]+
- threadIdx.z*inp.stride[2]];
- float tmp_avg=0;
- float tmp_M2=0;
- long tmp_N=0;
- block_sync::init();
- blockWelford<false,true,true>(
- tmp_avg,
- tmp_M2,
- tmp_N,
- in,
- 0.f,
- (long) 1,
- threadIdx,
- blockDim,
- (float*)mem_avg,
- (float*)mem_M2,
- (long*)mem_N,
- (bool)(threadIdx.x<inp.size[0]),
- 0.f);
- __syncthreads();
- if(threadIdx.x<out_var.size[0] && threadIdx.y==0 && threadIdx.z==0){
- out_avg[threadIdx.x*out_var.stride[0]]=tmp_avg;
- out_var[threadIdx.x*out_var.stride[0]]=tmp_M2/(tmp_N);
- }
-}
- )";
- fe.compileRtc(kernel, "CudaCodeGen::kernel1");
- LaunchParams lp(
- 1, // gdimx
- 1, // gdimy
- 1, // gdimz
- x, // bdimx
- y, // bdimy
- z // bdimz
- );
- lp.setSmem(0);
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const std::vector<int64_t> tensor_dims = {x, y, z};
- auto in0 = at::randn(tensor_dims, options);
- auto out_var = at::empty({x}, options);
- auto out_avg = at::empty({x}, options);
- fe.runRtc(lp, {in0, out_avg, out_var});
+ tv7->merge(0);
+ tv7->merge(0);
+ tv0->computeAt(tv7, -1);
- TORCH_CHECK(in0.var({1, 2}, false).allclose(out_var));
- TORCH_CHECK(in0.mean({1, 2}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6));
-}
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-TEST(NVFuserTest, gridWelfordNoInit_CUDA) {
- FusionExecutor fe;
- int x = 128, y = 64, z = 128;
-
- std::string kernel = R"(
-__global__ void kernel1(
- Tensor<float,3> inp,
- Tensor<float,1> out_avg,
- Tensor<float,1> out_var,
- Tensor<float,1> work_buf_avg,
- Tensor<float,1> work_buf_M2,
- Tensor<long,1> work_buf_N,
- Tensor<int64_t,1> sync_flag
-){
- __shared__ float shared_buf_avg[512];
- __shared__ float shared_buf_M2[512];
- __shared__ long shared_buf_N[512];
- float tmp_avg=0;
- float tmp_M2=0;
- long tmp_N=0;
- float in = inp[ blockIdx.x * inp.stride[0]+
- blockIdx.y * inp.stride[1]+
- threadIdx.x * inp.stride[2]];
- bool T_pred;
- block_sync::init();
- T_pred=welford::gridWelford<
- true,true,false,
- true,false,false
- >(
- tmp_avg,
- tmp_M2,
- tmp_N,
- in,
- 0.f,
- (long) 1,
- &work_buf_avg[0],
- &work_buf_M2[0],
- &work_buf_N[0],
- sync_flag,
- (float*)shared_buf_avg,
- (float*)shared_buf_M2,
- (long*)shared_buf_N,
- threadIdx.x<out_var.size[0],
- threadIdx.x<out_var.size[0],
- 0.f);
- if(T_pred){
- out_avg[threadIdx.x*out_avg.stride[0]]=tmp_avg;
- out_var[threadIdx.x*out_var.stride[0]]=tmp_M2/tmp_N;
- }
-}
- )";
- fe.compileRtc(kernel, "CudaCodeGen::kernel1");
- LaunchParams lp(
- x, // gdimx
- y, // gdimy
- 1, // gdimz
- z, // bdimx
- 1, // bdimy
- 1 // bdimz
- );
- lp.setSmem(0);
- const auto options =
- at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const auto options_int =
- at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
-
- const std::vector<int64_t> tensor_dims = {x, y, z};
- auto in0 = at::randn(tensor_dims, options);
-
- auto out_avg = at::empty({z}, options);
- auto out_var = at::empty({z}, options);
- auto work_buf_avg = at::empty({x * y * z}, options);
- auto work_buf_var = at::empty({x * y * z}, options);
- auto work_buf_N = at::empty({x * y * z}, options_int);
- auto sync_flag = at::zeros({1}, options_int);
- fe.runRtc(
- lp,
- {in0,
- out_avg,
- out_var,
- work_buf_avg,
- work_buf_var,
- work_buf_N,
- sync_flag});
- std::vector<int64_t> dims{0, 1};
-
- TORCH_CHECK(in0.mean(dims).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6));
- TORCH_CHECK(in0.var(dims, false).allclose(out_var));
-}
-
-TEST(NVFuserTest, FusionWelfordOp_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ at::Tensor t0 = at::randn({y}, options);
+ at::Tensor t3 = at::randn({y, z}, options);
+ at::Tensor t6 = at::randn({x, y, z}, options);
- int M = 64, N = 128;
+ auto t4 = t0.div(2.0).unsqueeze(-1).expand({y, z}) * t3;
+ auto t7 = t4.unsqueeze(0).expand({x, y, z}) + t6;
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = mul(tv0, new Double(1));
- auto tvs = Welford(tv1, {1});
- auto tv_avg = tvs.avg;
- auto tv_M2 = tvs.var_sum;
- auto tv_N = tvs.n;
- fusion.addOutput(tv_avg);
- fusion.addOutput(tv_M2);
- fusion.addOutput(tv_N);
-
- tv_avg->split(1, 32);
- tv_avg->split(0, 32);
- tv_avg->split(0, 4);
- tv_avg->reorder({{-1, -3}, {-3, -1}});
- tv1->computeAt(tv_avg, -1);
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({t0, t3, t6});
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor t0 = at::randn({M, N}, options);
+ TORCH_CHECK(t7.allclose(outputs[0]));
+ }
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion({t0});
+ {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- // by default Welford outputs sum of square diff so need to divide to get var
- outputs[1] /= N;
+ int x = 2, y = 3, z = 4;
- testValidate(
- &fusion,
- outputs,
- {t0},
- {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N},
- __LINE__,
- __FILE__);
-}
+ auto tv0 = makeConcreteTensor({y, z});
+ auto tv1 = div(tv0, new Float(2.0));
+ auto tv2 = sum(tv1, {1});
+ auto tv3 = broadcast(tv2, {true, false});
+ auto tv4 = makeConcreteTensor({x, y});
+ auto tv5 = add(tv3, tv4);
-TEST(NVFuserTest, FusionBlockWelfordOp_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ // tv0[ i1, i2] = input
+ // tv1[ i1, i2] = tv0/2.0
+ // tv2[ i1 ] = sum(tv1, 1)
+ // tv3[b0, i1 ] = bcast(tv2)
+ // tv4[i0, i1 ] = input
+ // tv5[i0, i1 ] = tv3 + tv4
- int M = 64, N = 128;
+ // tv2 = sum(tv0/2.0, 1)
+ // tv5 = bcast(tv2) + tv4
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = mul(tv0, new Double(1));
- auto tvs = Welford(tv1, {1});
- auto tv_avg = tvs.avg;
- auto tv_M2 = tvs.var_sum;
- auto tv_N = tvs.n;
- fusion.addOutput(tv_avg);
- fusion.addOutput(tv_M2);
- fusion.addOutput(tv_N);
+ fusion.addInput(tv0);
+ fusion.addInput(tv4);
- tv_avg->axis(-1)->parallelize(ParallelType::TIDx);
+ fusion.addOutput(tv5);
- tv1->computeAt(tv_avg, -1);
+ tv5->merge(0);
+ tv0->computeAt(tv5, -1);
+ tv1->computeAt(tv2, -1);
- //
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor t0 = at::randn({M, N}, options);
- at::Tensor t_var = at::empty({M}, options);
- at::Tensor t_avg = at::empty({M}, options);
- at::Tensor t_N = at::empty({M}, options_int);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion({t0});
+ at::Tensor t0 = at::randn({y, z}, options);
+ auto t1 = t0.div(2.0);
+ auto t2 = t1.sum(1);
+ auto t3 = t2.unsqueeze(0).expand({x, y});
+ at::Tensor t4 = at::randn({x, y}, options);
+ auto t5 = t3.add(t4);
- // by default Welford outputs sum of square diff so need to divide to get var
- outputs[1] /= N;
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({t0, t4});
- testValidate(
- &fusion,
- outputs,
- {t0},
- {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N},
- __LINE__,
- __FILE__);
+ TORCH_CHECK(t5.allclose(outputs[0]));
+ }
}
-TEST(NVFuserTest, FusionGridWelfordOp_CUDA) {
+TEST(NVFuserTest, FusionAdvancedIndexing1_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- int M = 64, N = 128;
+ int w = 3, x = 4, y = 7, z = 8;
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto tv0 = makeSymbolicTensor(2);
+ auto tv0 = makeDummyTensor(3);
+ auto tv1 = makeDummyTensor(4);
fusion.addInput(tv0);
- auto tv1 = mul(tv0, new Double(1));
- auto tvs = Welford(tv1, {1});
- auto tv_avg = tvs.avg;
- auto tv_M2 = tvs.var_sum;
- auto tv_N = tvs.n;
- fusion.addOutput(tv_avg);
- fusion.addOutput(tv_M2);
- fusion.addOutput(tv_N);
+ fusion.addInput(tv1);
- tv_avg->axis(0)->parallelize(ParallelType::TIDx);
- tv_avg->axis(-1)->parallelize(ParallelType::BIDx);
+ auto tv2 = add(tv0, new Float(1.0));
+ auto tv3 = broadcast(tv2, {true, false, false, false});
+ auto tv4 = add(tv3, tv1);
- tv1->computeAt(tv_avg, -1);
+ fusion.addOutput(tv4);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor t0 = at::randn({M, N}, options);
- at::Tensor t_avg = at::empty({M}, options);
- at::Tensor t_var = at::empty({M}, options);
- at::Tensor t_N = at::empty({M}, options_int);
+ tv4->merge(0);
+ tv4->merge(0);
+ tv4->merge(0);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion({t0});
+ tv4->split(0, 128);
+ tv4->split(0, 4);
- // by default Welford outputs sum of square diff so need to divide to get var
- outputs[1] /= N;
+ tv2->computeAt(tv4, 1);
- testValidate(
- &fusion,
- outputs,
- {t0},
- {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N},
- __LINE__,
- __FILE__);
-}
+ tv4->axis(0)->parallelize(ParallelType::BIDx);
+ tv4->axis(1)->parallelize(ParallelType::Unroll);
+ tv4->axis(2)->parallelize(ParallelType::TIDx);
-TEST(NVFuserTest, FusionRfactorWelfordOp_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ tv3->axis(1)->parallelize(ParallelType::Unroll);
+ tv3->axis(2)->parallelize(ParallelType::TIDx);
- int M = 64, N = 128;
+ tv2->axis(1)->parallelize(ParallelType::Unroll);
+ tv2->axis(2)->parallelize(ParallelType::TIDx);
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = mul(tv0, new Double(1));
- auto tvs = Welford(tv1, {1});
- auto tv_avg = tvs.avg;
- auto tv_M2 = tvs.var_sum;
- auto tv_N = tvs.n;
- fusion.addOutput(tv_avg);
- fusion.addOutput(tv_M2);
- fusion.addOutput(tv_N);
-
- tv_avg->split(1, 4);
- auto rtvs = tvs.rFactor({2});
- tv1->computeAt(tv_avg, -1);
+ FusionExecutor fe;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor t0 = at::randn({M, N}, options);
- at::Tensor t_avg = at::empty({M}, options);
- at::Tensor t_var = at::empty({M}, options);
- at::Tensor t_N = at::empty({M}, options_int);
+ at::Tensor t0 = at::randn({x, y, z}, options);
+ at::Tensor t1 = at::randn({w, x, y, z}, options);
- FusionExecutor fe;
fe.compileFusion(&fusion);
- auto outputs = fe.runFusion({t0});
+ auto outputs = fe.runFusion({t0, t1});
- // by default Welford outputs sum of square diff so need to divide to get var
- outputs[1] /= N;
+ auto t3 = t0.add(1.0);
+ auto t4 = t3.add(t1);
- testValidate(
- &fusion,
- outputs,
- {t0},
- {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N},
- __LINE__,
- __FILE__);
+ TORCH_CHECK(t4.allclose(outputs[0]));
}
-TEST(NVFuserTest, FusionWelfordSchedule_CUDA) {
+TEST(NVFuserTest, FusionAdvancedIndexing2_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- int M = 64, N = 128;
+ int w = 3, x = 4, y = 7, z = 8;
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto tv0 = makeSymbolicTensor(2);
+ auto tv0 = makeDummyTensor(3);
+ auto tv1 = makeDummyTensor(4);
fusion.addInput(tv0);
- auto tv1 = mul(tv0, new Double(1));
- auto tvs = Welford(tv1, {1});
- auto tv_avg = tvs.avg;
- auto tv_M2 = tvs.var_sum;
- auto tv_N = tvs.n;
- fusion.addOutput(tv_avg);
- fusion.addOutput(tv_M2);
- fusion.addOutput(tv_N);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor t0 = at::randn({M, N}, options);
- // TODO: Why do we use launch params from here, but not scheduling???
- auto reduction_params = getReductionHeuristics(&fusion, {t0});
- scheduleReduction(&fusion, reduction_params.value());
+ fusion.addInput(tv1);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion({t0}, reduction_params.value().lparams);
+ auto tv2 = add(tv0, new Float(1.0));
+ auto tv3 = broadcast(tv2, {true, false, false, false});
+ auto tv4 = add(tv3, tv1);
- // by default Welford outputs sum of square diff so need to divide to get var
- outputs[1] /= N;
+ fusion.addOutput(tv4);
- auto at_avg = t0.mean({1});
- auto at_var = t0.var({1}, false);
- auto at_n = at::ones({M}, options_int) * N;
+ tv4->merge(-2);
+ tv4->merge(-2);
+ tv4->merge(-2);
- testValidate(
- &fusion,
- outputs,
- {t0},
- {at_avg, at_var, at_n},
- __LINE__,
- __FILE__,
- "validate welford",
- reduction_params.value().lparams);
-}
+ tv4->split(0, 128);
+ tv4->split(0, 4);
-namespace {
-void testWelford(DataType dtype, int red_axis, int odim, int rdim) {
- const int axis = red_axis;
- at::ScalarType aten_dtype = data_type_to_aten(dtype);
+ tv2->computeAt(tv4, 1);
- Fusion fusion;
- FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(2, dtype);
- bool is_fp16 = dtype == DataType::Half;
- TensorView* tv0_cast = tv0;
- if (is_fp16) {
- tv0_cast = castOp(DataType::Float, tv0);
- }
- fusion.addInput(tv0);
- auto tv1 = mul(tv0_cast, new Double(1));
- auto tvs = Welford(tv1, {axis});
- auto tv_avg = tvs.avg;
- auto tv_M2 = tvs.var_sum;
- auto tv_N = tvs.n;
-
- TensorView* avg_cast = tv_avg;
- TensorView* M2_cast = tv_M2;
-
- if (is_fp16) {
- avg_cast = castOp(DataType::Half, tv_avg);
- M2_cast = castOp(DataType::Half, tv_M2);
- }
+ tv4->axis(0)->parallelize(ParallelType::BIDx);
+ tv4->axis(1)->parallelize(ParallelType::Unroll);
+ tv4->axis(2)->parallelize(ParallelType::TIDx);
- fusion.addOutput(avg_cast);
- fusion.addOutput(M2_cast);
- fusion.addOutput(tv_N);
+ tv3->axis(1)->parallelize(ParallelType::Unroll);
+ tv3->axis(2)->parallelize(ParallelType::TIDx);
- auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
- auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
- at::manual_seed(0);
- std::vector<TensorView*> outputs_of_red;
- at::Tensor aten_input =
- (axis ? at::randn({odim, rdim}, options)
- : at::randn({rdim, odim}, options));
-
- if (is_fp16) {
- outputs_of_red.push_back(avg_cast);
- outputs_of_red.push_back(M2_cast);
- }
+ tv2->axis(1)->parallelize(ParallelType::Unroll);
+ tv2->axis(2)->parallelize(ParallelType::TIDx);
- auto reduction_params = getReductionHeuristics(&fusion, {aten_input});
- scheduleReduction(&fusion, reduction_params.value());
+ FusionExecutor fe;
- auto lparams = reduction_params.value().lparams;
+ at::Tensor t0 = at::randn({x, y, z}, options);
+ at::Tensor t1 = at::randn({w, x, y, z}, options);
- FusionExecutor fe;
fe.compileFusion(&fusion);
- auto outputs = fe.runFusion({aten_input}, reduction_params.value().lparams);
-
- // by default Welford outputs sum of square diff so need to divide to
- // get var
-
- outputs[1] /= rdim;
-
- auto at_avg = aten_input.mean({axis});
- auto at_var = aten_input.var({axis}, false);
- auto at_n =
- (axis ? at::ones({odim, rdim}, options)
- : at::ones({rdim, odim}, options));
- at_n = at_n.sum({axis});
-
- testValidate(
- &fusion,
- outputs,
- {aten_input},
- {at_avg, at_var, at_n},
- __LINE__,
- __FILE__,
- "validate welford",
- reduction_params.value().lparams);
-}
-} // namespace
+ auto outputs = fe.runFusion({t0, t1});
-TEST(NVFuserTest, FusionWelfordShmoo_CUDA) {
- std::vector<DataType> dtypes = {
- DataType::Double, DataType::Float, DataType::Half};
- std::vector<int> red_axis = {1, 0};
- std::vector<int> output_dims = {160, 320};
- std::vector<int> red_dims;
-
- // Tried to cut down the number iterations with just
- // doing every other power of 2.
- for (int i = 1; i <= 1024 * 1024; i <<= 2) {
- red_dims.push_back(i);
- }
+ auto t3 = t0.add(1.0);
+ auto t4 = t3.add(t1);
- for (auto dtype : dtypes) {
- for (auto& axis : red_axis) {
- for (auto& odim : output_dims) {
- for (auto& rdim : red_dims) {
- // TODO: original welford algorithm actually keeps a running sum of
- // squares, i.e. M_{2n} in the
- // cf:
- // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
- // algorithm notation, and it can reach inf for large numbers
- // with half precision. skipping too large volumes for half for
- // nwo might need further numerical experiments to re-design
- // this.
- if (rdim > 32768 && dtype == DataType::Half) {
- continue;
- }
- testWelford(dtype, axis, odim, rdim);
- }
- }
- }
- }
+ TORCH_CHECK(t4.allclose(outputs[0]));
}
-TEST(NVFuserTest, FusionTranspose1_CUDA) {
+TEST(NVFuserTest, FusionAdvancedIndexing3_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- constexpr int M = 10;
- constexpr int N = 20;
+ int w = 3, x = 4, y = 7, z = 8;
- auto tv0 = makeSymbolicTensor(2);
- auto tv1 = transpose(tv0, {{0, 1}});
+ auto tv0 = makeDummyTensor(3);
+ auto tv1 = makeDummyTensor(4);
fusion.addInput(tv0);
- fusion.addOutput(tv1);
+ fusion.addInput(tv1);
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv1->axis(1)->parallelize(ParallelType::TIDx);
+ auto tv2 = add(tv0, new Float(1.0));
+ auto tv3 = add(tv2, tv1);
+ fusion.addOutput(tv3);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor t0 = at::randn({M, N}, options);
- std::vector<IValue> aten_inputs = {t0};
+ at::Tensor t0 = at::randn({x, y, z}, options);
+ at::Tensor t1 = at::randn({w, x, y, z}, options);
+
+ scheduleFusion(&fusion, {t0, t1});
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(aten_inputs);
+ auto outputs = fe.runFusion({t0, t1});
- at::Tensor aten_output = t0.t();
+ auto t2 = t0.add(1.0);
+ auto t3 = t2.add(t1);
- testValidate(
- &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ TORCH_CHECK(t3.allclose(outputs[0]));
}
-TEST(NVFuserTest, FusionTranspose2_CUDA) {
+TEST(NVFuserTest, FusionAdvancedIndexing4_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- constexpr int M = 10;
- constexpr int N = 20;
-
- auto tv0 = makeSymbolicTensor(2);
- auto tv1 = transpose(tv0, {{0, 1}});
+ // Set up your input tensor views
+ TensorView* tv0 = makeConcreteTensor({10, 20});
fusion.addInput(tv0);
- fusion.addOutput(tv1);
-
- tv1->merge(0);
- tv1->split(0, 32);
+ TensorView* tv1 = makeConcreteTensor({10, 10, 20});
+ fusion.addInput(tv1);
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv1->axis(1)->parallelize(ParallelType::TIDx);
+ TensorView* tv2 = add(tv0, new Float(1));
+ TensorView* tv3 = broadcast(tv2, {true, false, false});
+ TensorView* tv4 = add(tv3, tv1);
+ fusion.addOutput(tv4);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor t0 = at::randn({M, N}, options);
- std::vector<IValue> aten_inputs = {t0};
+ at::Tensor t0 = at::randn({10, 20}, options);
+ at::Tensor t1 = at::randn({10, 10, 20}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(aten_inputs);
+ auto outputs = fe.runFusion({t0, t1});
- at::Tensor aten_output = t0.t();
+ auto t2 = t0.add(1.0);
+ auto t3 = t2.add(t1);
- testValidate(
- &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ TORCH_CHECK(t3.allclose(outputs[0]));
}
-TEST(NVFuserTest, FusionSimpleGemmTransposed_CUDA) {
+// Test a simple Gemm but also play around with fusion executor features
+TEST(NVFuserTest, FusionSimpleGemm_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
// Set up your input tensor views
-
- TensorView* tv0 = makeSymbolicTensor(2); // K, M
- TensorView* tv1 = makeSymbolicTensor(2); // N, K
+ TensorView* tv0 = makeDummyTensor(2); // M, K
+ TensorView* tv1 = makeDummyTensor(2); // K, N
fusion.addInput(tv0);
fusion.addInput(tv1);
- TensorView* tv0_t = transpose(tv0, {{0, 1}});
- TensorView* tv1_t = transpose(tv1, {{0, 1}});
-
- TensorView* tv2 = broadcast(tv0_t, {false, false, true});
+ TensorView* tv2 = broadcast(tv0, {false, false, true});
// tv2[I0, I1, B] = tv0[I0, I1]
- TensorView* tv3 = broadcast(tv1_t, {true, false, false});
+ TensorView* tv3 = broadcast(tv1, {true, false, false});
// tv3[B, I1, I2] = tv1[I1, I2]
// tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2]
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
- tv0_t->computeAt(tv5, -1);
- tv1_t->computeAt(tv5, -1);
+ tv0->computeAt(tv5, -1);
+ tv1->computeAt(tv5, -1);
// tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
// tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}]
// tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
- tv0_t->computeAt(tv6, -1);
- tv1_t->computeAt(tv6, -1);
+ tv0->computeAt(tv6, -1);
+ tv1->computeAt(tv6, -1);
// tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
// tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({K, M}, options);
- at::Tensor t1 = at::randn({N, K}, options);
+ at::Tensor t0 = at::randn({M, K}, options);
+ at::Tensor t1 = at::randn({K, N}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
// Lets specify a few bounds in launch params to make sure it works
fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4));
- // Don't specify any launch params
- auto cg_outputs = fe.runFusion({t0, t1});
+ // Make sure bad launch params throws
+ ASSERT_ANY_THROW(fe.runFusion({t0, t1}, LaunchParams(1, 2, 3, 4, 5, 6)));
- auto aten_output = t0.t().to(at::kDouble).matmul(t1.t().to(at::kDouble));
+ // Don't specify any launch params
+ auto outputs = fe.runFusion({t0, t1});
- testValidate(
- &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__);
+ auto t2 = t0.matmul(t1);
+ TORCH_CHECK(
+ t2.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ t2.sub(outputs[0]).abs().max());
}
-TEST(NVFuserTest, FusionSoftmax3DTransposed_CUDA) {
+// Softmax with a 1D tensor. Parallelized only with a single thread block.
+TEST(NVFuserTest, FusionSoftmax1D_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- const int tidx = 32;
- const int dimx = 32;
- const int dimy = 16;
- const int dimz = 130;
+ const int tidx = 128;
+ const int dimx = 1000;
// Set up your input tensor views
- TensorView* input_tv0 = makeSymbolicTensor(3);
+ TensorView* input_tv0 = makeDummyTensor(1);
fusion.addInput(input_tv0);
- TensorView* input_t = transpose(input_tv0, {{1, 2}});
-
- TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_t);
+ TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0);
TensorView* sum_exp_tv2 = sum(exp_tv1, {-1});
- TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {false, false, true});
+ TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {true});
// Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
// computed at sum_exp_rf_tv8.
- TensorView* input_t_copy = transpose(input_tv0, {{1, 2}});
- TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_t_copy);
+ TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0);
TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3);
fusion.addOutput(output_tv4);
- bcast_sum_tv3->split(-1, tidx);
+ bcast_sum_tv3->split(0, tidx);
sum_exp_tv2->split(-1, tidx);
TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2});
output_tv4->split(-1, tidx);
- input_t->computeAt(sum_exp_rf_tv5, -1);
- input_t_copy->computeAt(output_tv4, -1);
+ exp_tv1->computeAt(sum_exp_rf_tv5, -1);
+ exp_tv1_copy->computeAt(output_tv4, -1);
TensorView* tensors_to_parallelize[] = {
sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5};
for (auto tv : tensors_to_parallelize) {
- tv->axis(0)->parallelize(ParallelType::BIDx);
- tv->axis(1)->parallelize(ParallelType::BIDy);
tv->axis(-1)->parallelize(ParallelType::TIDx);
}
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({dimx, dimz, dimy}, options);
-
- at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options);
+ at::Tensor t0 = at::randn({dimx}, options);
+ at::Tensor cg_output = at::empty({dimx}, options);
+ at::Tensor t3_output = at::empty_like(cg_output, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- fe.runFusion({input}, {cg_output});
-
- auto aten_input_t = at::transpose(input, 1, 2);
- auto aten_output = at::_softmax(aten_input_t.to(at::kDouble), -1, false);
+ fe.runFusion({t0}, {cg_output});
- testValidate(
- &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
+ auto t2 = at::_softmax(t0, -1, false);
+ TORCH_CHECK(
+ t2.allclose(cg_output, 1e-5, 1e-5),
+ "Error of: ",
+ t2.sub(cg_output).abs().max());
}
-TEST(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) {
- // Case 1
- // tv1 = tv0 * 0.5
- // tv2 = tv1 * -1
- // tv3 = tv1 + 3
- // tv4 = tv1 * 2
- // tv5 = tv3 + tv2
- // tv6 = tv5 + tv4
- // tv7 = tv1 + tv4
+// Softmax with a 1D tensor with input normalization.
+TEST(NVFuserTest, FusionSoftmax1DNormalized_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
+ const int tidx = 128;
+ const int dimx = 1000;
- tv0 = transpose(tv0, {{0, 1}});
+ // Set up your input tensor views
+ TensorView* input_tv0 = makeDummyTensor(1);
+ fusion.addInput(input_tv0);
- TensorView* tv1 = mul(tv0, new Double(0.5));
- TensorView* tv2 = mul(tv1, new Double(-1.0));
- TensorView* tv3 = add(tv1, new Double(3.0));
- TensorView* tv4 = mul(tv1, new Double(2.0));
- TensorView* tv5 = add(tv3, tv2);
+ // Normalize with the max value before computing exp.
+ TensorView* max_val_tv1 =
+ reductionOp(BinaryOpType::Max, {-1}, new Float(0), input_tv0);
+ TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {true});
+ TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2);
+ TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3);
+ TensorView* sum_exp_tv5 = sum(exp_tv4, {-1});
+ TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {true});
- TensorView* tv6 = add(tv5, tv4);
- TensorView* tv7 = add(tv1, tv4);
+ // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
+ // computed at sum_exp_rf_tv8.
+ TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2);
+ TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy);
- fusion.addOutput(tv6);
- fusion.addOutput(tv7);
+ TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6);
- // Lets setup to actually run
- tv7->merge(0);
- tv7->split(0, 128);
- tv7->split(0, 4);
+ fusion.addOutput(output_tv7);
+ bcast_max_tv2->split(0, tidx);
+ bcast_sum_tv6->split(0, tidx);
- tv7->axis(0)->parallelize(ParallelType::BIDx);
+ max_val_tv1->split(-1, tidx);
+ TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2});
- tv0->computeAt(tv7, 1);
+ sum_exp_tv5->split(-1, tidx);
+ TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2});
- // The this-position of the last tensor should be zero.
- TORCH_CHECK(
- tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 &&
- tv7->getMaxProducerPosition() == 1);
- TORCH_CHECK(
- tv6->nDims() == 3 && tv6->getComputeAtPosition() == 0 &&
- tv6->getMaxProducerPosition() == 1);
- // The position of every other tensor should be 1.
- for (auto tv : {tv1, tv2, tv3, tv4, tv5}) {
- TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1);
- }
+ output_tv7->split(-1, tidx);
- for (Val* val : fusion.vals()) {
- if (!fusion.hasInput(val) &&
- val->getValType().value() == ValType::TensorView) {
- TensorView* tv = static_cast<TensorView*>(val);
- tv->axis(1)->parallelize(ParallelType::Unroll);
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
+ sub_tv3->computeAt(sum_exp_rf_tv9, -1);
+ sub_tv3_copy->computeAt(output_tv7, -1);
+
+ TensorView* tensors_to_parallelize[] = {
+ max_val_tv1,
+ bcast_max_tv2,
+ sum_exp_tv5,
+ bcast_sum_tv6,
+ output_tv7,
+ max_val_rf_tv8,
+ sum_exp_rf_tv9};
+
+ for (auto tv : tensors_to_parallelize) {
+ tv->axis(-1)->parallelize(ParallelType::TIDx);
}
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor aten_input = at::randn({129, 127}, options);
+ at::Tensor t0 = at::randn({dimx}, options);
+ at::Tensor t3_output = at::empty({dimx}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
-
- at::Tensor aten_input_t = aten_input.t();
-
- auto t1 = aten_input_t.mul({0.5});
- auto t2 = t1.mul({-1.0});
- auto t3 = t1.add({3.0});
- auto t4 = t1.mul({2.0});
- auto t5 = t3.add(t2);
- auto t6 = t5.add(t4);
- auto t7 = t1.add(t4);
-
- std::vector<at::Tensor> aten_outputs = {t6, t7};
+ auto outputs = fe.runFusion({t0});
- testValidate(
- &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__);
+ auto t2 = at::_softmax(t0, -1, false);
+ TORCH_CHECK(
+ t2.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ t2.sub(outputs[0]).abs().max());
}
-TEST(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) {
- // Case 2
- // tv1 = tv0 * -1
- // tv2 = tv0 + 3
- // tv3 = tv0 * 2
- // tv4 = tv2 + tv1
- // tv5 = tv4 + tv3
- // tv6 = tv5 + tv3
+// Softmax with a 3D tensor, where the inner-most 3rd dimension is
+// normalized. Pallelized with multiple thread blocks.
+TEST(NVFuserTest, FusionSoftmax3D_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
+ const int tidx = 32;
+ const int dimx = 32;
+ const int dimy = 16;
+ const int dimz = 130;
- tv0 = transpose(tv0, {{0, 1}});
+ // Set up your input tensor views
+ TensorView* input_tv0 = makeDummyTensor(3);
+ fusion.addInput(input_tv0);
- TensorView* tv1 = mul(tv0, new Double(-1.0));
- TensorView* tv2 = add(tv0, new Double(3.0));
- TensorView* tv3 = mul(tv0, new Double(2.0));
- TensorView* tv4 = add(tv2, tv1);
+ TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0);
+ TensorView* sum_exp_tv2 = sum(exp_tv1, {-1});
+ TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {false, false, true});
- TensorView* tv5 = add(tv4, tv3);
- TensorView* tv6 = add(tv5, tv3);
+ // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
+ // computed at sum_exp_rf_tv8.
+ TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0);
- fusion.addOutput(tv5);
- fusion.addOutput(tv6);
+ TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3);
- // Lets setup to actually run
- tv6->merge(0);
- tv6->split(0, 128);
- tv6->split(0, 4);
+ fusion.addOutput(output_tv4);
- tv6->axis(0)->parallelize(ParallelType::BIDx);
+ bcast_sum_tv3->split(-1, tidx);
- tv0->computeAt(tv6, 1);
+ sum_exp_tv2->split(-1, tidx);
+ TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2});
- for (Val* val : fusion.vals()) {
- if (!fusion.hasInput(val) &&
- val->getValType().value() == ValType::TensorView) {
- TensorView* tv = static_cast<TensorView*>(val);
+ output_tv4->split(-1, tidx);
- tv->axis(1)->parallelize(ParallelType::Unroll);
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
+ exp_tv1->computeAt(sum_exp_rf_tv5, -1);
+ exp_tv1_copy->computeAt(output_tv4, -1);
+
+ TensorView* tensors_to_parallelize[] = {
+ sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5};
+
+ for (auto tv : tensors_to_parallelize) {
+ tv->axis(0)->parallelize(ParallelType::BIDx);
+ tv->axis(1)->parallelize(ParallelType::BIDy);
+ tv->axis(-1)->parallelize(ParallelType::TIDx);
}
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({129, 127}, options);
-
+ at::Tensor t0 = at::randn({dimx, dimy, dimz}, options);
+ at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options);
+ at::Tensor t3_output = at::empty_like(cg_output, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input});
-
- auto input_t = input.t();
- auto t1 = input_t.mul({-1.0});
- auto t2 = input_t.add({3.0});
- auto t3 = input_t.mul({2.0});
- auto t4 = t2.add(t1);
- auto t5 = t4.add(t3);
- auto t6 = t5.add(t3);
-
- std::vector<at::Tensor> aten_outputs = {t5, t6};
+ fe.runFusion({t0}, {cg_output});
- testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__);
+ auto t2 = at::_softmax(t0, -1, false);
+ TORCH_CHECK(
+ t2.allclose(cg_output, 1e-5, 1e-5),
+ "Error of: ",
+ t2.sub(cg_output).abs().max());
}
-TEST(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) {
- // Case 3
- // T2 = T1 * 0.979361
- // T3 = T2 * T0
+// Softmax with a 3D tensor with input normalization.
+TEST(NVFuserTest, FusionSoftmax3DNormalized_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(4);
- fusion.addInput(tv0);
+ const int tidx = 32;
+ const int dimx = 32;
+ const int dimy = 16;
+ const int dimz = 130;
- tv0 = transpose(tv0, {{0, 1}, {1, 2}, {2, 3}, {3, 0}});
+ // Set up your input tensor views
+ TensorView* input_tv0 = makeDummyTensor(3);
+ fusion.addInput(input_tv0);
- TensorView* tv1 = makeSymbolicTensor(4);
- fusion.addInput(tv1);
+ // Normalize with the max value before computing exp.
+ TensorView* max_val_tv1 =
+ reductionOp(BinaryOpType::Max, {-1}, new Float(0), input_tv0);
+ TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true});
+ TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2);
+ TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3);
+ TensorView* sum_exp_tv5 = sum(exp_tv4, {-1});
+ TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {false, false, true});
- tv1 = transpose(tv1, {{0, 1}, {1, 2}, {2, 3}, {3, 0}});
+ // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be
+ // computed at sum_exp_rf_tv8.
+ TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2);
+ TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy);
- TensorView* tv2 = mul(tv1, new Double(.979361));
- TensorView* tv3 = mul(tv2, tv0);
+ TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6);
- fusion.addOutput(tv3);
+ fusion.addOutput(output_tv7);
- // Lets setup to actually run
- while (tv3->nDims() > 1)
- tv3->merge(0);
- tv3->split(0, 128);
- tv3->split(0, 4);
+ bcast_max_tv2->split(-1, tidx);
+ bcast_sum_tv6->split(-1, tidx);
- tv0->computeAt(tv3, 1);
- tv1->computeAt(tv3, 1);
+ max_val_tv1->split(-1, tidx);
+ TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2});
- tv3->axis(0)->parallelize(ParallelType::BIDx);
+ sum_exp_tv5->split(-1, tidx);
+ TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2});
- for (Val* val : fusion.vals()) {
- if (!fusion.hasInput(val) &&
- val->getValType().value() == ValType::TensorView) {
- TensorView* tv = static_cast<TensorView*>(val);
+ output_tv7->split(-1, tidx);
- tv->axis(1)->parallelize(ParallelType::Unroll);
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
+ sub_tv3->computeAt(sum_exp_rf_tv9, -1);
+ sub_tv3_copy->computeAt(output_tv7, -1);
+
+ TensorView* tensors_to_parallelize[] = {
+ max_val_tv1,
+ bcast_max_tv2,
+ sum_exp_tv5,
+ bcast_sum_tv6,
+ output_tv7,
+ max_val_rf_tv8,
+ sum_exp_rf_tv9};
+
+ for (auto tv : tensors_to_parallelize) {
+ tv->axis(0)->parallelize(ParallelType::BIDx);
+ tv->axis(1)->parallelize(ParallelType::BIDy);
+ tv->axis(-1)->parallelize(ParallelType::TIDx);
}
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({129, 127, 63, 65}, options);
- at::Tensor t1 = at::rand_like(t0, options);
-
- std::vector<IValue> aten_inputs = {t0, t1};
+ at::Tensor t0 = at::randn({dimx, dimy, dimz}, options);
+ at::Tensor t3_output = at::empty({dimx, dimy, dimz}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- auto t0_t = t0.permute({3, 0, 1, 2});
- auto t1_t = t1.permute({3, 0, 1, 2});
- auto t2 = t1_t.mul({0.979361});
- auto aten_output = t2.mul(t0_t);
+ auto outputs = fe.runFusion({t0});
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ auto t2 = at::_softmax(t0, -1, false);
+ TORCH_CHECK(
+ t2.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ t2.sub(outputs[0]).abs().max());
}
-TEST(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) {
- // Case 4
- // T4 = T2 - T3
- // T5 = T1 + T4
- // T6 = T5 - T0
+TEST(NVFuserTest, FusionSoftmaxComputeAt_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(4);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- tv0 = transpose(tv0, {{0, 1}, {1, 2}, {2, 3}, {3, 0}});
-
- TensorView* tv1 = makeSymbolicTensor(4);
- fusion.addInput(tv1);
-
- tv1 = transpose(tv1, {{0, 1}, {1, 2}, {2, 3}, {3, 0}});
-
- TensorView* tv2 = makeSymbolicTensor(4);
- fusion.addInput(tv2);
-
- tv2 = transpose(tv2, {{0, 1}, {1, 2}, {2, 3}, {3, 0}});
+ auto tv1 = sum(tv0, {1});
+ auto tv2 = broadcast(tv1, {false, true});
- TensorView* tv3 = makeSymbolicTensor(4);
- fusion.addInput(tv3);
+ auto tv3 = add(tv0, new Float(1.0));
- tv3 = transpose(tv3, {{0, 1}, {1, 2}, {2, 3}, {3, 0}});
+ auto tv4 = mul(tv2, tv3);
- TensorView* tv4 = sub(tv2, tv3);
- TensorView* tv5 = add(tv1, tv4);
- TensorView* tv6 = sub(tv5, tv0);
+ auto tv5 = sum(tv4, {1});
+ auto tv6 = broadcast(tv5, {false, true});
- fusion.addOutput(tv6);
+ auto tv7 = sub(tv6, tv4);
+ fusion.addOutput(tv7);
- // Lets setup to actually run
- while (tv6->nDims() > 1)
- tv6->merge(0);
- tv6->split(0, 128);
- tv6->split(0, 4);
+ tv1->computeAt(tv7, 1);
+ ASSERT_ANY_THROW(tv1->computeAt(tv7, -1));
+}
- tv0->computeAt(tv6, 1);
- tv1->computeAt(tv6, 1);
- tv2->computeAt(tv6, 1);
- tv3->computeAt(tv6, 1);
+// Similar to FusionReduction but uses grid reduction
+TEST(NVFuserTest, FusionGridReduction1_CUDA) {
+ const int gdimx = 32;
+ const int bdimx = 128;
- tv6->axis(0)->parallelize(ParallelType::BIDx);
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- for (Val* val : fusion.vals()) {
- if (!fusion.hasInput(val) &&
- val->getValType().value() == ValType::TensorView) {
- TensorView* tv = static_cast<TensorView*>(val);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
- tv->axis(1)->parallelize(ParallelType::Unroll);
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
- }
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ fusion.addOutput(tv1);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({129, 127, 63, 65}, options);
- at::Tensor t1 = at::rand_like(t0, options);
- at::Tensor t2 = at::rand_like(t0, options);
- at::Tensor t3 = at::rand_like(t0, options);
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
- std::vector<IValue> aten_inputs = {t0, t1, t2, t3};
+ tv1->split(1, bdimx);
+ // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+ tv1->split(1, gdimx);
+ // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ TensorView* tv2 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
+ // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
- auto t0_t = t0.permute({3, 0, 1, 2});
- auto t1_t = t1.permute({3, 0, 1, 2});
- auto t2_t = t2.permute({3, 0, 1, 2});
- auto t3_t = t3.permute({3, 0, 1, 2});
- auto t4 = t2_t.sub(t3_t);
- auto t5 = t1_t.add(t4);
- auto aten_output = t5.sub(t0_t);
+ // Incrementally, can print in between for debugging
+ tv0->computeAt(tv2, 1);
+ tv2->computeAt(tv1, 1);
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+ // Re do it all at once, because why not.
+ tv0->computeAt(tv1, 1);
-TEST(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) {
- // Case 5
- // tv2 = tv0 + 2.0
- // tv3 = tv1 * tv2
- Fusion fusion;
- FusionGuard fg(&fusion);
+ tv1->axis(0)->parallelize(ParallelType::BIDy);
+ tv1->axis(1)->parallelize(ParallelType::BIDx);
+ tv2->axis(2)->parallelize(ParallelType::BIDx);
- // Set up your input tensor views
- TensorView* tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- tv0 = transpose(tv0, {{0, 1}});
- TensorView* tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
- tv1 = transpose(tv1, {{0, 1}});
- TensorView* tv2 = add(tv0, new Double(2.0));
- TensorView* tv3 = mul(tv1, tv2);
- fusion.addOutput(tv3);
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->merge(0);
- tv3->split(-1, 8);
- tv3->split(-1, 4);
+ int numel_x = 10000;
+ int numel_y = 65000;
- tv0->computeAt(tv3, 1);
- tv1->computeAt(tv3, 1);
- tv3->axis(0)->parallelize(ParallelType::BIDx);
+ // fusion.printKernel();
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({63, 65}, options);
- at::Tensor t1 = at::rand_like(t0, options);
-
- std::vector<IValue> aten_inputs = {t0, t1};
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- auto t2 = t0.t().add(2.0);
- auto aten_output = t1.t().mul(t2);
+ fe.runFusion({input}, {cg_output});
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(cg_output));
}
-TEST(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) {
+// Same test as the above but uses BIDy and TIDx for reduction
+TEST(NVFuserTest, FusionGridReduction2_CUDA) {
+ const int gdimy = 32;
+ const int bdimx = 128;
+
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeSymbolicTensor(2);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- tv0 = transpose(tv0, {{0, 1}});
- TensorView* tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
- tv1 = transpose(tv1, {{0, 1}});
- TensorView* tv2 = add(tv0, new Double(2.0));
- TensorView* tv3 = mul(tv1, tv2);
- fusion.addOutput(tv3);
-
- tv2->merge(0);
- tv2->split(-1, 8);
- tv2->split(-1, 4);
- tv3->merge(0);
- tv3->split(-1, 8);
-
- tv0->computeAt(tv3, 1);
- tv1->computeAt(tv3, 1);
-
- tv3->axis(0)->parallelize(ParallelType::BIDx);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({63, 65}, options);
- at::Tensor t1 = at::rand_like(t0, options);
- std::vector<IValue> aten_inputs = {t0, t1};
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ fusion.addOutput(tv1);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
- auto t2 = t0.t().add(2.0);
- auto aten_output = t1.t().mul(t2);
+ tv1->split(1, bdimx);
+ // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+ tv1->split(1, gdimy);
+ // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+ TensorView* tv2 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
+ // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
-TEST(NVFuserTest, FusionSegmentReducePointwise_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+ // Incrementally, can print in between for debugging
+ tv0->computeAt(tv2, 1);
+ tv2->computeAt(tv1, 1);
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(1);
- TensorView* tv2 = makeSymbolicTensor(2);
+ // Re do it all at once, because why not.
+ tv0->computeAt(tv1, 1);
- fusion->addInput(tv0);
- fusion->addInput(tv1);
- fusion->addInput(tv2);
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
+ tv1->axis(1)->parallelize(ParallelType::BIDy);
+ tv2->axis(2)->parallelize(ParallelType::BIDy);
- TensorView* tv3 = add(tv0, new Double(1)); // Group 0
- TensorView* tv4 =
- max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues)
- TensorView* tv5 = add(tv4, tv1); // Group 0 (Non Broadcast after reduce,
- // keeps normalization scheduler away)
- TensorView* tv6 = add(tv5, tv2); // Group 1 (Broadcast after reduce)
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
- fusion->addOutput(tv6);
+ int numel_x = 10000;
+ int numel_y = 65000;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({128, 65}, options);
- at::Tensor t1 = at::randn({65}, options);
- at::Tensor t2 = at::randn({128, 65}, options);
-
- auto t3 = t0.add(1.0);
- auto t4 = std::get<0>(at::max(t3, 0));
- auto t5 = t4.add(t1);
- auto t6 = t5.add(t2);
-
- FusionExecutorCache executor_cache(std::move(fusion));
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
- auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2});
-
- TORCH_CHECK(
- executor_cache.getMostRecentKernelRuntime()->isSegmented(),
- "segmentation didn't happen");
- TORCH_CHECK(
- executor_cache.getMostRecentKernelRuntime()
- ->fusionSegments()
- ->groups()
- .size() == 2,
- "segmentation didn't happen as expected");
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({input});
- testValidate(
- executor_cache.fusion(), outputs, {t0, t1, t2}, {t6}, __LINE__, __FILE__);
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(outputs[0]));
}
-TEST(NVFuserTest, FusionMultipleVectorize_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+// Same test but uses BIDy and BIDz for reduction. No TID used.
+TEST(NVFuserTest, FusionGridReduction3dim1_CUDA) {
+ const int gdimz = 32;
+ const int gdimy = 128;
- TensorView* tv0 = makeContigTensor(1);
- TensorView* tv1 = makeContigTensor(1);
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- fusion->addInput(tv0);
- fusion->addInput(tv1);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
- TensorView* tv3 = add(tv0, tv1);
- fusion->addOutput(tv3);
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ fusion.addOutput(tv1);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({40960}, options);
- at::Tensor t1 = at::randn({40960}, options);
- auto t2 = t0 + t1;
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
- FusionExecutorCache executor_cache(std::move(fusion));
- executor_cache.profile(true);
+ tv1->split(1, gdimy);
+ // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+ tv1->split(1, gdimz);
+ // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
- auto outputs = executor_cache.runFusionWithInputs({t0, t1});
- auto runtime1 = executor_cache.getMostRecentKernelRuntime();
- auto log1 = executor_cache.getMostRecentExecutorInfo().pointwise_params;
- TORCH_CHECK(log1.has_value());
- TORCH_CHECK(log1->vectorize);
+ TensorView* tv2 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
+ // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
- testValidate(
- executor_cache.fusion(), outputs, {t0, t1}, {t2}, __LINE__, __FILE__);
+ // Incrementally, can print in between for debugging
+ tv0->computeAt(tv2, 1);
+ tv2->computeAt(tv1, 1);
- t0 = at::randn({40964}, options);
- t1 = at::randn({40964}, options);
- t2 = t0 + t1;
+ // Re do it all at once, because why not.
+ tv0->computeAt(tv1, 1);
- outputs = executor_cache.runFusionWithInputs({t0, t1});
- auto runtime2 = executor_cache.getMostRecentKernelRuntime();
- auto log2 = executor_cache.getMostRecentExecutorInfo().pointwise_params;
- TORCH_CHECK(log2.has_value());
- TORCH_CHECK(log2->vectorize);
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
+ tv1->axis(1)->parallelize(ParallelType::BIDz);
+ tv2->axis(2)->parallelize(ParallelType::BIDz);
- testValidate(
- executor_cache.fusion(), outputs, {t0, t1}, {t2}, __LINE__, __FILE__);
+ tv1->axis(-1)->parallelize(ParallelType::BIDy);
+ tv2->axis(-1)->parallelize(ParallelType::BIDy);
- t0 = at::randn({40962}, options);
- t1 = at::randn({40962}, options);
- t2 = t0 + t1;
+ int numel_x = 100;
+ int numel_y = 6500;
- outputs = executor_cache.runFusionWithInputs({t0, t1});
- auto runtime3 = executor_cache.getMostRecentKernelRuntime();
- auto log3 = executor_cache.getMostRecentExecutorInfo().pointwise_params;
- TORCH_CHECK(log3.has_value());
- TORCH_CHECK(log3->vectorize);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
- testValidate(
- executor_cache.fusion(), outputs, {t0, t1}, {t2}, __LINE__, __FILE__);
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ fe.runFusion({input}, {cg_output});
- TORCH_CHECK(runtime1 == runtime2);
- TORCH_CHECK(runtime1 != runtime3);
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(cg_output));
}
-TEST(NVFuserTest, FusionVectorizeSimple_CUDA) {
+// Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0
+TEST(NVFuserTest, FusionGridReduction3dim0_CUDA) {
+ const int rdim = 0;
+ const int gdimy = 128;
+ const int gdimz = 32;
+
Fusion fusion;
FusionGuard fg(&fusion);
- TensorView* tv0 = makeContigTensor(3);
-
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- auto tv1 = unaryOp(UnaryOpType::Sin, tv0);
-
+ // tv1[R0, I1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {rdim}, new Float(0), tv0);
fusion.addOutput(tv1);
- auto tv0_cache = tv0->cache_after();
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
- auto tv1_cache = tv1->cache_before();
+ tv1->split(rdim, gdimy);
+ // tv1[R0o, R0i{128}, I1] = tv0[I0, I1]
+ tv1->split(rdim, gdimz);
+ // tv1[R0oo, R0oi{32}, R0i{128}, I1] = tv0[I0, I1]
- tv1->merge(0);
- tv1->merge(0);
- tv1->split(0, 4);
- tv1->split(0, 128);
+ TensorView* tv2 = tv1->rFactor({rdim});
+ // tv2[R0oo, I0oi{32}, I0i{128}, I1] = tv0[I0, I1]
+ // tv1[ R0oi{32}, R0i{128}, I1] = tv2[R0oo, I0oi{32}, I0i{128}, I1]
- tv1->axis(0)->parallelize(ParallelType::BIDx);
- tv1->axis(1)->parallelize(ParallelType::TIDx);
+ // Note that computeAt isn't going to make anything better as there
+ // is no dynamically sized dimension.
- tv0->computeAt(tv1, 2);
+ // Map parallelism as [Serial, BIDz, BIDy, BIDx]
+ tv1->axis(-1)->parallelize(ParallelType::BIDx);
+ tv2->axis(-1)->parallelize(ParallelType::BIDx);
+ tv1->axis(-2)->parallelize(ParallelType::BIDy);
+ tv2->axis(-2)->parallelize(ParallelType::BIDy);
+ tv1->axis(-3)->parallelize(ParallelType::BIDz);
+ tv2->axis(-3)->parallelize(ParallelType::BIDz);
- tv0_cache->axis(2)->parallelize(ParallelType::Vectorize);
- tv1->axis(2)->parallelize(ParallelType::Vectorize);
+ int numel_x = 6500;
+ int numel_y = 100;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
-
- at::Tensor aten_input = at::empty({2, 6, 32}, options);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({aten_input});
-
- at::Tensor aten_output = aten_input.sin();
+ auto outputs = fe.runFusion({input});
- testValidate(
- &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__);
+ auto aten_output = input.sum({0});
+ TORCH_CHECK(aten_output.allclose(outputs[0]));
}
-TEST(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+// This is similar to the FusionReduction, but swaps BIDx and TIDx
+TEST(NVFuserTest, FusionGridReduction4_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
+
+ const int bdimx = 128;
+ const int gdimx = 1024;
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
+
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ fusion.addOutput(tv1);
+
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
- std::vector<int64_t> input_shape{32, 64, 8};
- const int kReductionAxis = 1;
+ tv1->split(1, gdimx);
+ // tv1[I0, R1o, R1i{1024}] = tv0[I0, I1]
+ tv1->split(1, 4);
+ // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1]
- auto tv0 = TensorViewBuilder()
- .ndims(input_shape.size())
- .dtype(DataType::Double)
- .build();
+ TensorView* tv2 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
+ // tv1[I0, R1oi{4}, R1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
- fusion->addInput(tv0);
+ TensorView* tv3 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
+ // tv3[I0, R1oi{4}, Ir1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
+ // tv1[I0, R1i{1024}] = tv3[I0, R1oi{4}, Ir1i{1024}]
- auto tv1 = add(tv0, new Double(1.0));
- auto tv2 = sum(tv1, {2}); // Group 0
+ // Incrementally, can print in between for debugging
+ tv0->computeAt(tv2, 1);
+ tv2->computeAt(tv3, 1);
+ tv3->computeAt(tv1, 1);
- auto output = softmax(tv2, kReductionAxis); // Group 1
- fusion->addOutput(output);
+ // Re do it all at once, because why not.
+ tv0->computeAt(tv1, 1);
- auto options = at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0);
- at::Tensor at_x = at::randn(input_shape, options);
+ tv2->axis(2)->parallelize(ParallelType::Unroll);
+ tv1->axis(0)->parallelize(ParallelType::TIDx);
- FusionExecutorCache executor_cache(std::move(fusion));
+ tv1->axis(-1)->parallelize(ParallelType::BIDx);
+ tv2->axis(-1)->parallelize(ParallelType::BIDx);
+ tv3->axis(-1)->parallelize(ParallelType::BIDx);
- auto outputs = executor_cache.runFusionWithInputs({at_x});
+ int numel_x = bdimx;
+ int numel_y = 65000;
- auto t1 = at_x.add(1.0);
- auto t2 = t1.sum({2});
- auto t3 = at::_softmax(t2.to(at::kDouble), -1, false);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
- auto optimized_fusion = executor_cache.getMostRecentKernelRuntime();
- TORCH_CHECK(optimized_fusion->isSegmented(), "segmentation didn't happen");
- TORCH_CHECK(
- optimized_fusion->fusionSegments()->groups().size() == 2,
- "segmentation didn't happen as expected");
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ fe.runFusion({input}, {cg_output});
- testValidate(
- executor_cache.fusion(), outputs, {at_x}, {t3}, __LINE__, __FILE__);
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(cg_output));
}
-TEST(NVFuserTest, FusionSwizzle1_CUDA) {
+// Grid reduction with 2D thread blocks but only TIDx and BIDx are
+// mapped to a reduction dim
+TEST(NVFuserTest, FusionGridReduction5_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
+ const int bdimx = 64;
+ const int bdimy = 16;
+ const int gdimx = 4;
+
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = mul(tv1, new Double(2));
- fusion.addOutput(tv2);
- tv2->split(0, 7);
- tv2->split(0, 9);
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ fusion.addOutput(tv1);
- tv0->computeAt(tv2, 1);
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
- tv2->axis(0)->parallelize(ParallelType::BIDx);
+ tv1->split(1, bdimx);
+ // tv1[I0, R1o, R1i{64}] = tv0[I0, I1]
+ tv1->split(1, gdimx);
+ // tv1[I0, R1oo, R1oi{4}, R1i{64}] = tv0[I0, I1]
+
+ TensorView* tv2 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}] = tv0[I0, I1]
+ // tv1[I0, R1oi{4}, R1i{64}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}]
+
+ tv0->computeAt(tv1, 1);
- tv1->setMemoryType(MemoryType::Shared);
- tv1->swizzle(SwizzleType::Transpose, {1, 2});
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv1->axis(1)->parallelize(ParallelType::TIDx);
- tv1->axis(2)->parallelize(ParallelType::TIDy);
+ tv1->axis(-2)->parallelize(ParallelType::BIDx);
+ tv2->axis(-2)->parallelize(ParallelType::BIDx);
- tv2->axis(1)->parallelize(ParallelType::TIDx);
- tv2->axis(2)->parallelize(ParallelType::TIDy);
+ tv1->axis(0)->parallelize(ParallelType::TIDy);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({100}, options);
+ int numel_x = bdimy;
+ int numel_y = 6500;
- std::vector<IValue> aten_inputs = {t0};
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ auto outputs = fe.runFusion({input});
- auto aten_output = (t0 + 1) * 2;
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(outputs[0]));
}
-TEST(NVFuserTest, FusionSwizzle2_CUDA) {
+// Similar to FusionGridReduction1 but with 3D tensors
+TEST(NVFuserTest, FusionGridReduction6_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(3);
fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = mul(tv1, new Double(2));
- fusion.addOutput(tv2);
- tv1->split(-1, 4);
- tv1->split(-2, 4);
+ // tv1[I0, R1, R2] = tv0[I0, I1, I2]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Float(0), tv0);
+ fusion.addOutput(tv1);
- tv2->split(-1, 4);
- tv2->split(-2, 4);
+ TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
- tv0->computeAt(tv2, 1);
+ // Splitting for TID
+ tv1->split(2, 128);
+ // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
- tv2->reorder({{-1, -2}});
+ // Splitting for BID
+ tv1->split(1, 128);
- tv1->setMemoryType(MemoryType::Shared);
- tv1->swizzle(SwizzleType::Transpose, {-2, -1});
+ // tv1[I0, R1o, R1i{128}, R2o, R2i{128}] = tv0[I0, I1, I2]
+
+ TensorView* tv2 = tv1->rFactor({3});
+ // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
+ // tv1[I0, R1o, R1i{128}, R2i{128}]
+
+ TensorView* tv3 = tv1->rFactor({1});
+ // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
+ // tv3[I0, R1o, I1i{128}, I2i{128}]
+ // tv1[I0, R1i{128}, R2i{128}]
+
+ tv3->computeAt(tv1, 1);
+ tv2->computeAt(tv3, 3);
+
+ tv1->axis(0)->parallelize(ParallelType::BIDy);
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-2)->parallelize(ParallelType::TIDy);
tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv1->axis(-2)->parallelize(ParallelType::TIDy);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({123}, options);
+ tv1->axis(-2)->parallelize(ParallelType::BIDx);
+ tv2->axis(-3)->parallelize(ParallelType::BIDx);
+ tv3->axis(-2)->parallelize(ParallelType::BIDx);
+
+ int numel_x = 6500;
+ int numel_y = 200;
+ int numel_z = numel_y;
- std::vector<IValue> aten_inputs = {t0};
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options);
+ at::Tensor cg_output = at::empty({numel_x}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- auto aten_output = (t0 + 1) * 2;
+ fe.runFusion({input}, {cg_output});
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ auto aten_output = input.sum({1, 2});
+ TORCH_CHECK(aten_output.allclose(cg_output));
}
-TEST(NVFuserTest, FusionTransposeWithSwizzle_CUDA) {
+TEST(NVFuserTest, FusionNonRedAxisBind_CUDA) {
+ int bid_x = 3;
+ int tid_x = 2;
+ int red_dim = 0;
+
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- auto tv1 = transpose(tv0, {{0, 1}});
- fusion.addOutput(tv1);
-
- // tv0: [I0, I1]
- // tv1: [I1, I0]
-
- const int BS = 32;
-
- // CTA tiling by BS*BS
- tv1->split(1, BS);
- tv1->split(0, BS);
- tv1->reorder({{1, 2}});
- // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)]
-
- // Create a smem buffer to cache each tile
- auto tv0_cache = tv0->cache_after();
- tv0_cache->setMemoryType(MemoryType::Shared);
-
- tv0->computeAt(tv1, 2);
- // tv0: [I0, I1]
- // tv0_cache: [I1/BS, I0/BS, BS(I1), BS(I0)]
- // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)]
-
- // Assign each thread block to a tile
- tv1->axis(0)->parallelize(ParallelType::BIDy);
- tv1->axis(1)->parallelize(ParallelType::BIDx);
- // Thread mapping for each tile. For both of the input and output
- // tiles, map TIDx to the fastest-changing dimension to facilitate
- // coalesced gmem accesses.
- tv1->axis(2)->parallelize(ParallelType::TIDy);
- tv1->axis(3)->parallelize(ParallelType::TIDx);
- // Note that the fastest-changing axis is next to the inner-most
- // axis since computeAt reorders the axes as the output tensor.
- tv0_cache->axis(2)->parallelize(ParallelType::TIDx);
- tv0_cache->axis(3)->parallelize(ParallelType::TIDy);
+ TensorView* tv1 =
+ reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0);
+ fusion.addOutput(tv1);
- // Swizzles the smem cache to avoid bank conflicts
- tv0_cache->swizzle(SwizzleType::Transpose, {3, 2});
+ tv1->split(-1, tid_x);
+ tv1->axis(-2)->parallelize(ParallelType::BIDx);
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const int bx = 100;
- const int by = 200;
- at::Tensor t0 = at::randn({bx, by}, options);
- std::vector<IValue> aten_inputs = {t0};
+ at::Tensor input = at::rand({16, bid_x * tid_x}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({input});
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- auto aten_output = t0.t();
+ auto aten_output = input.sum({red_dim});
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0]),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
}
-TEST(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) {
+TEST(NVFuserTest, FusionSplitBCast_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = transpose(tv0, {{0, 1}});
- fusion.addOutput(tv1);
-
- // tv0: [I0, I1]
- // tv1: [I1, I0]
-
- const int BS = 32;
- const int BDIM = 256;
+ // Set up your input tensor views
+ TensorView* input_tv0 = makeDummyTensor(3);
+ TensorView* input_tv1 = makeDummyTensor(3);
+ fusion.addInput(input_tv0);
+ fusion.addInput(input_tv1);
- // CTA tiling by BS*BS
- tv1->split(1, BS);
- tv1->split(0, BS);
- tv1->reorder({{1, 2}});
- // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)]
+ TensorView* sum_tv2 =
+ reductionOp(BinaryOpType::Add, {2}, new Float(0), input_tv0);
+ TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true});
+ TensorView* output_tv4 = div(input_tv1, bcast_tv3);
- // Create a smem buffer to cache each tile
- auto tv0_cache = tv0->cache_after();
- tv0_cache->setMemoryType(MemoryType::Shared);
+ sum_tv2->split(-1, 32);
+ TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2});
- tv0->computeAt(tv1, 2);
- // tv0: [I0, I1]
- // tv0_cache: [I1/BS, I0/BS, BS*BS/BDIM, BDIM]
- // tv1: [I1/BS, I0/BS, BS*BS/BDIM, BDIM]
+ bcast_tv3->split(-1, 32);
+ output_tv4->split(-1, 32);
- // Tranform the tile axes for 1D thread mapping
- tv1->merge(-2, -1);
- tv1->split(-1, BDIM);
- // tv1: [I1/BS, I0/BS, BS*BS/BDIM, BDIM]
+ sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx);
+ sum_tv2->axis(0)->parallelize(ParallelType::BIDx);
+ bcast_tv3->axis(0)->parallelize(ParallelType::BIDx);
+ output_tv4->axis(0)->parallelize(ParallelType::BIDx);
- // Transform the cache similarly but apply swizzle to the 2D tile axes.
- tv0_cache->reorder({{-2, -1}});
- tv0_cache->swizzle(SwizzleType::Transpose, {2, 3});
- tv0_cache->merge(-2, -1);
- tv0_cache->split(-1, BDIM);
- // tv0: [I1/BS, I0/BS, BS*BS/BDIM, BDIM]
+ sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy);
+ sum_tv2->axis(1)->parallelize(ParallelType::BIDy);
+ bcast_tv3->axis(1)->parallelize(ParallelType::BIDy);
+ output_tv4->axis(1)->parallelize(ParallelType::BIDy);
- // Assign each thread block to a tile
- tv1->axis(0)->parallelize(ParallelType::BIDy);
- tv1->axis(1)->parallelize(ParallelType::BIDx);
+ sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx);
+ sum_tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ bcast_tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ output_tv4->axis(-1)->parallelize(ParallelType::TIDx);
- // Thread mapping for each tile.
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv0_cache->axis(-1)->parallelize(ParallelType::TIDx);
+ fusion.addOutput(output_tv4);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const int bx = 100;
- const int by = 200;
- at::Tensor t0 = at::randn({bx, by}, options);
- std::vector<IValue> aten_inputs = {t0};
+ at::Tensor t0 = at::randn({32, 32, 128}, options);
+ at::Tensor t1 = at::randn({32, 32, 128}, options);
+ at::Tensor cg_output = at::empty({32, 32, 128}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
-
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- auto aten_output = t0.t();
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ fe.runFusion({t0, t1}, {cg_output});
}
-// Grid reduction can be executed only once in a kernel. Should result
-// in an error at the time of compilation.
-TEST(NVFuserTest, FusionGridReductionInLoop_CUDA) {
+TEST(NVFuserTest, FusionBCastInnerDim_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- auto tv1 = sum(tv0, {1});
- fusion.addOutput(tv1);
- tv1->axis(1)->parallelize(ParallelType::BIDx);
+ // reduce then broadcast
+ auto tv1 = sum(tv0, {0});
+ auto tv2 = broadcast(tv1, {false, true});
- FusionExecutor fe;
- ASSERT_ANY_THROW(fe.compileFusion(&fusion));
+ TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast());
}
-TEST(NVFuserTest, FusionIssue633_CUDA) {
+TEST(NVFuserTest, FusionBCastReduce_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- const int dx = 10;
- const int dy = 11;
- const int dz = 12;
-
- auto tv0 = makeConcreteTensor({dx, dy, dz});
- fusion.addInput(tv0);
- auto tv1 = makeConcreteTensor({dx, dy, 1});
- fusion.addInput(tv1);
- auto tv2 = add(tv0, tv1);
- fusion.addOutput(tv2);
-
- tv2->merge(1);
- tv2->merge(0);
- tv2->split(-1, 128);
-
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(1)->parallelize(ParallelType::TIDx);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({dx, dy, dz}, options);
- at::Tensor t1 = at::randn({dx, dy, 1}, options);
- std::vector<IValue> aten_inputs = {t0, t1};
-
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- auto aten_output = t0 + t1;
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ auto tv1 = broadcast(tv0, {true, false, false});
+ auto tv2 = sum(tv1, {1});
+ TORCH_CHECK(
+ tv2->axis(0)->isBroadcast() && tv2->axis(1)->isReduction() &&
+ !tv2->axis(2)->isBroadcast() && !tv2->axis(2)->isReduction());
}
-TEST(NVFuserTest, FusionKirScoping_CUDA) {
+// Multiple consumer reduction with computeAt
+// https://github.com/csarofeen/pytorch/issues/110
+TEST(NVFuserTest, FusionReductionMultiConsumer_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(2));
- fusion.addOutput(tv2);
+ auto tv1 = unaryOp(UnaryOpType::Exp, tv0);
+ auto tv2 = reductionOp(BinaryOpType::Max, {-1}, new Float(0), tv1);
+ auto tv3 = reductionOp(BinaryOpType::Min, {-1}, new Float(0), tv1);
+ auto tv4 = add(tv2, tv3);
+ fusion.addOutput(tv4);
+ tv1->computeAt(tv2, -1);
- tv2->merge(0);
- tv2->split(0, 4);
- tv0->computeAt(tv2, -1);
+ TORCH_CHECK(
+ (tv1->getComputeAtView() == tv2 || tv1->getComputeAtView() == tv3) &&
+ tv1->getThisComputeAtAxis() == 2 && tv1->getRelativeComputeAtAxis() == 2);
+}
- GpuLower gpulw(&fusion);
+TEST(NVFuserTest, FusionComputeAtExprOrder1_CUDA) {
+ for (int i = 0; i < 2; ++i) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- auto kir_tv1 = gpulw.lowerValue(tv1);
- auto tv1_scope = kir_tv1->definition()->scope();
- TORCH_CHECK(tv1_scope != nullptr);
- TORCH_CHECK(tv1_scope->owner()->as<kir::IfThenElse>());
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(1);
+ fusion.addInput(tv0);
- auto kir_tv2 = gpulw.lowerValue(tv2);
- auto tv2_scope = kir_tv2->definition()->scope();
- TORCH_CHECK(tv2_scope != nullptr);
- TORCH_CHECK(tv2_scope->owner()->as<kir::IfThenElse>());
+ auto tv1 = add(tv0, new Float(1));
+ auto tv2 = add(tv0, new Float(1));
+ TensorView* tv3 = add(tv1, tv2);
+ if (i == 0) {
+ tv1->computeAt(tv3, -1);
+ fusion.addOutput(tv2);
+ } else {
+ tv2->computeAt(tv3, -1);
+ fusion.addOutput(tv1);
+ }
+ fusion.addOutput(tv3);
- TORCH_CHECK(tv1_scope != tv2_scope);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({100}, options);
- // tv1 and tv2 should have the same inner-most ForLoop
- auto parent_scope = tv1_scope->owner()->scope();
- TORCH_CHECK(parent_scope == tv2_scope->owner()->scope());
- TORCH_CHECK(parent_scope->owner()->as<kir::ForLoop>());
- // There should be one more loop
- parent_scope = parent_scope->owner()->scope();
- TORCH_CHECK(parent_scope->owner()->as<kir::ForLoop>());
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({input});
- // scope() should return nullptr for top-level exprs
- auto top_level_scope = parent_scope->owner()->scope();
- TORCH_CHECK(top_level_scope == nullptr);
+ auto aten_output = (input + 1) * 2;
+ TORCH_CHECK(
+ aten_output.allclose(outputs[1]),
+ "Error of: ",
+ aten_output.sub(outputs[1]).abs().max());
+ }
}
-TEST(NVFuserTest, FusionBroadcastAcrossComputeAt_CUDA) {
+TEST(NVFuserTest, FusionComputeAtExprOrder2_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- std::vector<int64_t> shape{17, 19};
-
- auto tv0 = makeSymbolicTensor(1);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- auto tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
- auto tv2 = broadcast(tv0, {false, true});
- auto tv3 = add(tv1, tv2);
+
+ auto tv1 = add(tv0, new Float(1));
+ auto tv2 = add(tv0, new Float(1));
+ TensorView* tv3 = add(tv1, tv2);
fusion.addOutput(tv3);
- tv3->split(1, 128);
- tv0->computeAt(tv3, 2);
+ tv3->split(-1, 32);
- for (auto tv : {tv2, tv3}) {
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- }
+ tv1->computeAt(tv3, -1);
+ tv2->computeAt(tv3, -2);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({shape[0]}, options);
- at::Tensor t1 = at::randn(shape, options);
- std::vector<IValue> aten_inputs = {t0, t1};
+ at::Tensor input = at::rand({100, 100}, options);
+ at::Tensor output = at::empty_like(input, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ fe.runFusion({input}, {output});
- auto t3 = t0.unsqueeze(-1).expand(shape) + t1;
-
- testValidate(&fusion, cg_outputs, aten_inputs, {t3}, __LINE__, __FILE__);
+ auto aten_output = (input + 1) * 2;
+ TORCH_CHECK(
+ aten_output.allclose(output),
+ "Error of: ",
+ aten_output.sub(output).abs().max());
}
-TEST(NVFuserTest, FusionVectorizeMisalignedPointwise_CUDA) {
+TEST(NVFuserTest, FusionZeroDimComputeAt_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeContigTensor(2);
- auto tv1 = makeContigTensor(2);
+ TensorView* tv0 = makeDummyTensor(1);
fusion.addInput(tv0);
- fusion.addInput(tv1);
- auto tv2 = add(tv0, tv1);
+ auto tv1 = sum(tv0, {0});
+ auto tv2 = add(tv1, new Float(1));
fusion.addOutput(tv2);
-
- const int kTDX = 64;
- const int kVecSize = 4;
- const int kNumElems = kTDX * kVecSize;
-
- tv2->split(1, kNumElems);
-
- auto c0 = tv0->cache_after();
- auto c1 = tv1->cache_after();
- auto c2 = tv2->cache_before();
-
- tv2->split(-1, kVecSize);
-
- c0->computeAt(tv2, -2);
- c1->computeAt(tv2, -2);
-
- c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
- c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
-
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(-2)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+ TORCH_CHECK(tv2->nDims() == 0);
+ tv1->computeAt(tv2, 0);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const int bx = 128;
- const int by = 457;
- at::Tensor t0 = at::randn({bx, by}, options);
- at::Tensor t1 = at::randn({bx, by}, options);
-
- std::vector<IValue> aten_inputs = {t0, t1};
+ at::Tensor input = at::rand({100}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ auto outputs = fe.runFusion({input});
- auto aten_output = t0 + t1;
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ auto aten_output = input.sum() + 1;
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0]),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
}
-TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeContig_CUDA) {
+TEST(NVFuserTest, FusionZeroDimBroadcast_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeContigTensor(4);
- auto tv1 = makeContigTensor(4);
+ TensorView* tv0 = makeDummyTensor(0);
fusion.addInput(tv0);
- fusion.addInput(tv1);
- auto tv2 = add(tv0, tv1);
- fusion.addOutput(tv2);
-
- tv2->reorder({{0, 1}, {1, 0}});
- tv2->merge(-2);
-
- const int kTDX = 64;
- const int kVecSize = 2;
- const int kNumElems = kTDX * kVecSize;
-
- tv2->split(-1, kNumElems);
-
- auto c0 = tv0->cache_after();
- auto c1 = tv1->cache_after();
- auto c2 = tv2->cache_before();
-
- tv2->split(0, 128);
- tv2->split(-1, kVecSize);
+ auto tv1 = broadcast(tv0, {true, true});
+ TORCH_CHECK(tv1->nDims() == 2);
- c0->computeAt(tv2, -2);
- c1->computeAt(tv2, -2);
+ TensorView* tv2 = makeDummyTensor(2);
+ fusion.addInput(tv2);
- c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
- c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+ auto tv3 = add(tv1, tv2);
+ auto tv4 = sum(tv3, {0, 1});
+ fusion.addOutput(tv4);
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(1)->parallelize(ParallelType::BIDy);
- tv2->axis(-2)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+ tv3->computeAt(tv4, -1);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const int n = 32;
- const int c = 127;
- const int h = 51;
- const int w = 23;
- at::Tensor t0 = at::randn({n, c, h, w}, options);
- at::Tensor t1 = at::randn({n, c, h, w}, options);
-
- std::vector<IValue> aten_inputs = {t0, t1};
+ at::Tensor input1 = at::rand({}, options);
+ at::Tensor input2 = at::rand({10, 10}, options);
+ at::Tensor output = at::empty({}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ fe.runFusion({input1, input2}, {output});
- auto aten_output = t0 + t1;
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ auto aten_output =
+ (input1.unsqueeze(-1).unsqueeze(-1).expand({10, 10}) + input2).sum();
+ TORCH_CHECK(
+ aten_output.allclose(output),
+ "Error of: ",
+ aten_output.sub(output).abs().max());
}
-TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) {
+TEST(NVFuserTest, FusionZeroDimReduction_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- constexpr int kNumDims = 4;
- constexpr int kTDX = 64;
- constexpr int kVecSize = 2;
- constexpr int kNumElems = kTDX * kVecSize;
+ const int bdimx = 32;
+ const int gdimx = 32;
- auto tv0 = makeSymbolicTensor(kNumDims);
- auto tv1 = makeSymbolicTensor(kNumDims);
+ TensorView* tv0 = makeDummyTensor(1);
fusion.addInput(tv0);
- fusion.addInput(tv1);
- auto tv2 = add(tv0, tv1);
- fusion.addOutput(tv2);
-
- // Create caches for vectorization
- auto c0 = tv0->cache_after();
- auto c1 = tv1->cache_after();
- auto c2 = tv2->cache_before();
-
- // Merge all dimensions together except inner-most dim
- for (int idx = 0; idx < kNumDims - 2; ++idx) {
- tv2->merge(0);
- }
- // Split inner-most dim
- tv2->split(-1, kNumElems);
- tv2->split(-1, kVecSize);
- TransformPropagator::from(tv2);
-
- c0->computeAt(tv2, -2);
- c1->computeAt(tv2, -2);
+ auto tv1 = sum(tv0, {0});
+ fusion.addOutput(tv1);
- // Parallelization Strategy
- c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
- c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+ tv1->split(0, bdimx);
+ tv1->split(0, gdimx);
+ auto tv2 = tv1->rFactor({0});
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(2)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv1->axis(-2)->parallelize(ParallelType::BIDx);
+ tv2->axis(-2)->parallelize(ParallelType::BIDx);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const int n = 5;
- const int c = 3;
- const int h = 51;
- const int w = 257;
- at::Tensor t0 = at::randn({n, c, h, w}, options);
- at::Tensor t1 = at::randn({n, c, h, w}, options);
-
- std::vector<IValue> aten_inputs = {t0, t1};
+ at::Tensor input = at::rand({1000}, options);
+ at::Tensor output = at::empty({}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- auto aten_output = t0 + t1;
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicFail_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- constexpr int kNumDims = 4;
- constexpr int kTDX = 64;
- constexpr int kVecSize = 2;
- constexpr int kNumElems = kTDX * kVecSize;
- std::vector<int64_t> bcast_shape{1, 1, 1, -1};
-
- auto tv0 = makeContigTensor(kNumDims);
- auto tv1 = TensorViewBuilder().shape(bcast_shape).build();
- fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- auto tv2 = add(tv0, tv1);
- fusion.addOutput(tv2);
-
- // Create caches for vectorization
- auto c0 = tv0->cache_after();
- auto c1 = tv1->cache_after();
- auto c2 = tv2->cache_before();
-
- // Merge all dimensions together
- // Backward merge order is necessary for vectorize validation
- for (int idx = kNumDims - 1; idx > 0; --idx) {
- tv2->merge(idx - 1);
- }
- tv2->split(-1, kNumElems);
- tv2->split(-1, kVecSize);
- TransformPropagator::from(tv2);
-
- c0->computeAt(tv2, -2);
- c1->computeAt(tv2, -2);
+ fe.runFusion({input}, {output});
- // Parallelization Strategy
- c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
- c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
-
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const int n = 32;
- const int c = 128;
- const int h = 51;
- const int w = 23;
- at::Tensor t0 = at::randn({n, c, h, w}, options);
- at::Tensor t1 = at::randn({1, 1, 1, w}, options);
-
- std::vector<IValue> aten_inputs = {t0, t1};
-
- FusionExecutor fe;
- // TODO: throw assertion - cannot merge non-contiguous vectorization axes
- // Make sure compilation fails
- ASSERT_ANY_THROW(fe.compileFusion(&fusion));
+ auto aten_output = input.sum();
+ TORCH_CHECK(
+ aten_output.allclose(output),
+ "Error of: ",
+ aten_output.sub(output).abs().max());
}
-TEST(NVFuserTest, FusionVectorizeMisalignedRFactor_CUDA) {
+TEST(NVFuserTest, FusionBCastAfterReduce_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
+ const int tidx = 128;
- auto tv0 = makeContigTensor(2);
- auto tv1 = makeContigTensor(2);
-
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- auto tv2 = add(tv0, tv1);
- auto tv3 = sum(tv2, {-1});
-
- fusion.addOutput(tv3);
+ auto tv1 = sum(tv0, {1});
+ auto tv2 = broadcast(tv1, {false, true});
- auto c0 = tv0->cache_after();
- auto c1 = tv1->cache_after();
+ tv1->split(1, tidx);
+ auto tv3 = tv1->rFactor({-2});
- tv3->split(-1, 128 * 4);
- tv3->split(-1, 4);
- // Reduce outer dim first
- auto tv4 = tv3->rFactor({-3, -1});
- // Tv3 will reduce threads
+ TensorView* tv4 = makeDummyTensor(2);
+ fusion.addInput(tv4);
- tv0->computeAt(tv3, 1);
- tv1->computeAt(tv3, 1);
+ auto tv5 = add(tv2, tv4);
+ fusion.addOutput(tv5);
+ tv5->split(1, tidx);
- tv3->axis(0)->parallelize(ParallelType::BIDx);
+ tv3->computeAt(tv5, 1);
- tv0->computeAt(tv4, -2);
- tv1->computeAt(tv4, -2);
+ tv2->split(1, tidx);
- c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
- c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ tv5->axis(-1)->parallelize(ParallelType::TIDx);
- tv4->axis(-2)->parallelize(ParallelType::TIDx);
- tv3->axis(1)->parallelize(ParallelType::TIDx);
+ tv5->axis(0)->parallelize(ParallelType::BIDx);
- tv2->computeAt(tv4, -1);
+ int x = 63, y = 200;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const int bx = 128;
- const int by = 2050;
- at::Tensor t0 = at::randn({bx, by}, options);
- at::Tensor t1 = at::randn({bx, by}, options);
- std::vector<IValue> aten_inputs = {t0, t1};
+ at::Tensor t0 = at::randn({x, y}, options);
+ at::Tensor t4 = at::randn({x, y}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ auto outputs = fe.runFusion({t0, t4});
- auto aten_output = t0.add(t1).sum(1);
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ auto t3 = t0.sum({1}).unsqueeze(-1).expand({x, y});
+ auto t5 = t3.add(t4);
+
+ // Error is larger than the default threshold
+ TORCH_CHECK(t5.allclose(outputs[0], 1e-5, 1e-5));
}
-TEST(NVFuserTest, FusionVectorizeMisalignedWrongDimFail_CUDA) {
+TEST(NVFuserTest, FusionReductionScheduler_CUDA) {
+ constexpr int bid_x = 80;
+ constexpr int tid_x = 4096;
+ constexpr int red_dim = 1;
+
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeContigTensor(2);
- auto tv1 = makeContigTensor(2);
-
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- auto tv2 = add(tv0, tv1);
- fusion.addOutput(tv2);
-
- tv2->split(1, 16);
- tv2->split(1, 64);
-
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(2)->parallelize(ParallelType::TIDx);
- auto c0 = tv0->cache_after();
- auto c1 = tv1->cache_after();
- auto c2 = tv2->cache_before();
+ TensorView* tv1 =
+ reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0);
+ fusion.addOutput(tv1);
- c0->computeAt(tv2, -2);
- c1->computeAt(tv2, -2);
+ const auto options =
+ at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::randn({bid_x, tid_x}, options);
- std::vector<TensorView*> vectorized_tvs = {c0, c1, tv2};
- for (auto tv : vectorized_tvs) {
- tv->split(-1, 4);
- // Vectorize the wrong dimension
- tv->axis(-2)->parallelize(ParallelType::MisalignedVectorize);
- }
+ // Apply reduction heuristic
+ auto reduction_params = getReductionHeuristics(&fusion, {input}, tv1);
+ TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
+ scheduleReduction(&fusion, reduction_params.value(), tv1, {});
FusionExecutor fe;
- // Make sure compilation fails
- ASSERT_ANY_THROW(fe.compileFusion(&fusion));
+ fe.compileFusion(&fusion);
+ // no broadcasting needed, omitting the last optional argument;
+ auto outputs = fe.runFusion({input}, reduction_params.value().lparams);
+ auto aten_output = input.sum({red_dim});
+
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-04, 1e-04),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
}
-TEST(NVFuserTest, FusionVectorizeMisalignedStride_CUDA) {
+// Simple reduction parallelized on a symbolic size.
+TEST(NVFuserTest, FusionSymbolicReduction_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
- auto tv1 = makeSymbolicTensor(2);
-
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- auto tv2 = add(tv0, tv1);
- fusion.addOutput(tv2);
- const int kTDX = 64;
- const int kVecSize = 4;
- const int kNumElems = kTDX * kVecSize;
+ // tv1[I0, R1] = tv0[I0, I1]
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ fusion.addOutput(tv1);
- tv2->split(1, kNumElems);
+ // Interface should just be a direct split with a Parallel type. We can
+ // include the parallelize call if we do this.
+ tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
+ // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
- auto c0 = tv0->cache_after();
- auto c1 = tv1->cache_after();
+ TensorView* tv2 = tv1->rFactor({1});
+ // tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}] = tv0[I0, I1]
+ // tv1[I0, R1oi{4}, R1i{BIDx}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}]
- tv2->split(-1, kVecSize);
+ // Incrementally, can print in between for debugging
+ tv0->computeAt(tv2, 1);
+ tv2->computeAt(tv1, 1);
- c0->computeAt(tv2, -2);
- c1->computeAt(tv2, -2);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
- c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
- c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(-2)->parallelize(ParallelType::TIDx);
+ int numel_x = 65000;
+ int numel_y = 1025;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const int bx = 128;
- const int by = 2049;
- at::Tensor t0 = at::randn({bx, by}, options).index({"...", Slice(3)});
- at::Tensor t1 = at::randn({bx, by}, options).index({"...", Slice(3)});
- std::vector<IValue> aten_inputs = {t0, t1};
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+
+ // How many threads to use for the block reduction
+ int runtime_threadIdx_dim = 128;
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ auto outputs = fe.runFusion(
+ {input}, LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1));
- auto aten_output = t0 + t1;
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(aten_output.allclose(outputs[0]));
}
-TEST(NVFuserTest, FusionVectorizeMisalignedStrideFail_CUDA) {
+TEST(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) {
+ const std::vector<int> red_dims = {0, 2};
+ // Copy is because CodeGen requires int and Pytorch requires int64_t
+ // for a vector of reduction dimensions
+ const std::vector<int64_t> red_dims64 = {0, 2};
+ const std::vector<int64_t> tensor_dims_in = {5, 10, 15, 20};
+ const std::vector<int64_t> tensor_dims_out = {10, 20};
+
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
- auto tv1 = makeSymbolicTensor(2);
-
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(tensor_dims_in.size());
fusion.addInput(tv0);
- fusion.addInput(tv1);
- auto tv2 = add(tv0, tv1);
- fusion.addOutput(tv2);
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, red_dims, new Float(0), tv0);
+ fusion.addOutput(tv1);
+
+ const auto options =
+ at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::randn(tensor_dims_in, options);
+ at::Tensor cg_output = at::empty(tensor_dims_out, options);
+
+ // Apply reduction heuristic
+ auto reduction_params = getReductionHeuristics(&fusion, {input}, tv1);
+ TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
+ scheduleReduction(&fusion, reduction_params.value(), tv1, {});
+
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({input}, reduction_params.value().lparams);
- const int kTDX = 64;
- const int kVecSize = 4;
- const int kNumElems = kTDX * kVecSize;
+ auto aten_output = input.sum(red_dims64);
- tv2->split(1, kNumElems);
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-04, 1e-04),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
+}
- auto c0 = tv0->cache_after();
- auto c1 = tv1->cache_after();
- auto c2 = tv2->cache_before();
+TEST(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) {
+ const std::vector<int> red_dims = {1, 3};
+ // Copy is because CodeGen requires int and Pytorch requires int64_t
+ // for a vector of reduction dimensions
+ const std::vector<int64_t> red_dims64 = {1, 3};
+ const std::vector<int64_t> tensor_dims_in = {5, 10, 15, 20};
+ const std::vector<int64_t> tensor_dims_out = {5, 15};
- tv2->split(-1, kVecSize);
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- c0->computeAt(tv2, -2);
- c1->computeAt(tv2, -2);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(tensor_dims_in.size());
+ fusion.addInput(tv0);
- c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
- c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, red_dims, new Float(0), tv0);
+ fusion.addOutput(tv1);
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(-2)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize);
+ const auto options =
+ at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::randn(tensor_dims_in, options);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const int bx = 128;
- const int by = 2049;
- at::Tensor t0 = at::randn({bx, by}, options).index({"...", Slice(3)});
- at::Tensor t1 = at::randn({bx, by}, options).index({"...", Slice(3)});
- std::vector<IValue> aten_inputs = {t0, t1};
+ auto reduction_params = getReductionHeuristics(&fusion, {input}, tv1);
+ TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
+ scheduleReduction(&fusion, reduction_params.value(), tv1, {});
FusionExecutor fe;
fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({input}, reduction_params.value().lparams);
+
+ auto aten_output = input.sum(red_dims64);
- // Failure because the input + output tensors do not have the same stride
- ASSERT_ANY_THROW(fe.runFusion(aten_inputs));
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-05, 1e-05),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
}
-TEST(NVFuserTest, FusionVectorization1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+TEST(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) {
+ std::vector<bool> fp16_usage = {true, false};
+ std::vector<int> red_axis = {1, 0};
+ std::vector<int> output_dims = {320, 640};
+ std::vector<int> red_dims;
- auto tv0 = makeSymbolicTensor(2);
+ // Making sure we get deterministic results
+ // (see https://github.com/csarofeen/pytorch/issues/399)
+ at::manual_seed(0);
- auto tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
+ // Tried to cut down the number iterations with just
+ // doing every other power of 2.
+ for (int i = 1; i <= 1024 * 1024; i <<= 2) {
+ red_dims.push_back(i);
+ }
- auto tv2 = add(tv0, tv1);
- fusion.addOutput(tv2);
+ for (auto fp16 : fp16_usage) {
+ for (auto& axis : red_axis) {
+ for (auto& odim : output_dims) {
+ for (auto& rdim : red_dims) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- tv2->split(1, 16);
- tv2->split(1, 64);
+ TensorView* tv0 =
+ makeDummyTensor(2, (fp16 ? DataType::Half : DataType::Float));
+ fusion.addInput(tv0);
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(2)->parallelize(ParallelType::TIDx);
+ Val* tv0_cast = nullptr;
+ if (fp16) {
+ tv0_cast = castOp(DataType::Float, tv0);
+ }
- auto c0 = tv0->cache_after();
- auto c1 = tv1->cache_after();
- auto c2 = tv2->cache_before();
+ TensorView* tv1 = reductionOp(
+ BinaryOpType::Add,
+ {axis},
+ new Float(0),
+ (fp16 ? tv0_cast->as<TensorView>() : tv0));
- c0->computeAt(tv2, -2);
- c1->computeAt(tv2, -2);
+ TensorView* tv1_cast = nullptr;
+ if (fp16) {
+ tv1_cast = castOp(DataType::Half, tv1);
+ }
- std::vector<TensorView*> vectorized_tvs = {c0, c1, tv2};
- for (auto tv : vectorized_tvs) {
- tv->split(-1, 4);
- tv->axis(-1)->parallelize(ParallelType::Vectorize);
- }
+ fusion.addOutput((fp16 ? tv1_cast : tv1));
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const int bx = 128;
- const int by = 2048;
- at::Tensor t0 = at::randn({bx, by}, options);
- at::Tensor t1 = at::randn({bx, by}, options);
+ auto options = at::TensorOptions()
+ .dtype((fp16 ? at::kHalf : at::kFloat))
+ .device(at::kCUDA, 0);
+ at::Tensor input =
+ (axis ? at::randn({odim, rdim}, options)
+ : at::randn({rdim, odim}, options));
- std::vector<IValue> aten_inputs = {t0, t1};
+ std::vector<TensorView*> outputs_of_red;
+ if (fp16) {
+ outputs_of_red.push_back(tv1_cast);
+ }
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
+ auto reduction_params = getReductionHeuristics(&fusion, {input}, tv1);
+ TORCH_CHECK(reduction_params.has_value(), "Reduction is not found!");
+ scheduleReduction(
+ &fusion, reduction_params.value(), tv1, outputs_of_red);
+
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
- auto aten_output = t0 + t1;
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ auto outputs =
+ fe.runFusion({input}, reduction_params.value().lparams);
+ auto aten_output = input.sum({axis});
+
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-03, 1e-03),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
+ }
+ }
+ }
+ }
}
-TEST(NVFuserTest, FusionVectorization2_CUDA) {
+TEST(NVFuserTest, FusionCacheBefore_CUDA) {
+ // TVM Cache Write
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
-
- auto tv1 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = add(tv0, new Float(1.0));
+ TensorView* tv2 = mul(tv1, new Float(3.0));
fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- auto tv2 = add(tv0, tv1);
fusion.addOutput(tv2);
+ // Before: TV2 = TV1 * 3
+ // After: TV3 = TV1 * 3;
+ // TV2 = TV3;
+
+ constexpr int BSX = 32;
+ tv2->split(-1, BSX);
+ tv0->computeAt(tv2, -1);
- tv2->split(1, 16);
- tv2->split(1, 64);
+ // cache_before automatically applies ComputeAt to the cache TensorView
+ tv2->cache_before();
+ // Thread and Block binding
tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(2)->parallelize(ParallelType::TIDx);
-
- auto c0 = tv0->cache_after();
- auto c1 = tv1->cache_after();
- auto c2 = tv2->cache_before();
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
- c0->computeAt(tv2, -2);
- c1->computeAt(tv2, -2);
+ constexpr int M = 32, N = 750;
- std::vector<TensorView*> vectorized_tvs = {c0, c1, tv2};
- for (auto tv : vectorized_tvs) {
- tv->split(-1, 4);
- // Vectorize the wrong dimension
- tv->axis(-2)->parallelize(ParallelType::Vectorize);
- }
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({M, N}, options);
FusionExecutor fe;
- // Make sure compilation fails
- ASSERT_ANY_THROW(fe.compileFusion(&fusion));
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({input});
+
+ at::Tensor aten_output = (input + 1.0) * 3.0;
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().sum());
}
-TEST(NVFuserTest, FusionVectorization3_CUDA) {
+TEST(NVFuserTest, FusionCacheAfter_CUDA) {
+ // TVM Cache Read
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
-
- auto tv1 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = add(tv0, new Float(1.0));
+ TensorView* tv2 = mul(tv1, new Float(3.0));
fusion.addInput(tv0);
- fusion.addInput(tv1);
-
- auto tv2 = add(tv0, tv1);
fusion.addOutput(tv2);
+ // Before: TV1 = TV0 + 1
+ // After: TV3 = TV0;
+ // TV1 = TV3 + 1
- tv2->split(1, 16);
- tv2->split(1, 64);
-
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(2)->parallelize(ParallelType::TIDx);
+ constexpr int BSX = 32;
+ tv2->split(-1, BSX);
+ tv0->computeAt(tv2, -1);
- auto c0 = tv0->cache_after();
- auto c1 = tv1->cache_after();
- auto c2 = tv2->cache_before();
+ // cache_after automatically applies ComputeAt to the cache TensorView
+ tv0->cache_after();
- c0->computeAt(tv2, -2);
- c1->computeAt(tv2, -2);
+ // Thread and Block binding
+ tv2->axis(0)->parallelize(ParallelType::BIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
- std::vector<TensorView*> vectorized_tvs = {c0, c1, tv2};
- for (auto tv : vectorized_tvs) {
- tv->split(-1, 4);
- tv->axis(-1)->parallelize(ParallelType::Vectorize);
- }
+ constexpr int M = 32, N = 457;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const int bx = 128;
- const int by = 2049;
- at::Tensor t0 = at::randn({bx, by}, options);
- at::Tensor t1 = at::randn({bx, by}, options);
+ at::Tensor input = at::rand({M, N}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({input});
- std::vector<IValue> aten_inputs = {t0, t1};
- ASSERT_ANY_THROW(fe.runFusion(aten_inputs));
-
- aten_inputs[0] = t0.index({"...", Slice(1)});
- aten_inputs[1] = t1.index({"...", Slice(1)});
- ASSERT_ANY_THROW(fe.runFusion(aten_inputs));
-
- t0 = at::randn({bx, 2048}, options).index({"...", Slice(4)});
- t1 = at::randn({bx, 2048}, options).index({"...", Slice(4)});
- aten_inputs = {t0, t1};
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- auto aten_output = t0 + t1;
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ at::Tensor aten_output = (input + 1.0) * 3.0;
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().sum());
}
-TEST(NVFuserTest, FusionVectorizationRFactor_CUDA) {
+TEST(NVFuserTest, FusionCacheIndirect_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
-
- auto tv1 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
+ TensorView* tv2 = makeDummyTensor(2);
+ TensorView* tv3 = makeDummyTensor(2);
+ TensorView* tv4 = sub(tv2, tv3);
+ TensorView* tv5 = add(tv1, tv4);
+ TensorView* tv6 = sub(tv5, tv0);
fusion.addInput(tv0);
fusion.addInput(tv1);
+ fusion.addInput(tv2);
+ fusion.addInput(tv3);
+ fusion.addOutput(tv6);
+ // t6 = ((t1 + (t2 - t3)) - t0)
- auto tv2 = add(tv0, tv1);
-
- auto tv3 = sum(tv2, {-1});
-
- fusion.addOutput(tv3);
-
- tv3->split(-1, 128 * 4);
- tv3->split(-1, 4);
- // Reduce outer dim first
- auto tv4 = tv3->rFactor({-3, -1});
- // Tv3 will reduce threads
-
- auto tv6 = tv0->cache_after();
- auto tv7 = tv1->cache_after();
-
- tv0->computeAt(tv3, 1);
- tv1->computeAt(tv3, 1);
-
- tv3->axis(0)->parallelize(ParallelType::BIDx);
+ // cache_after on inputs placed before schedule
+ constexpr int BSX = 32;
+ tv6->split(-1, BSX);
+ tv2->computeAt(tv6, -1);
- tv0->computeAt(tv4, -2);
- tv1->computeAt(tv4, -2);
+ tv5->cache_after();
+ tv5->cache_before();
- tv6->axis(-1)->parallelize(ParallelType::Vectorize);
- tv7->axis(-1)->parallelize(ParallelType::Vectorize);
+ // Thread and Block binding
+ tv6->axis(0)->parallelize(ParallelType::BIDx);
+ tv6->axis(-1)->parallelize(ParallelType::TIDx);
- tv4->axis(-2)->parallelize(ParallelType::TIDx);
- tv3->axis(1)->parallelize(ParallelType::TIDx);
+ constexpr int M = 32, N = 810;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- const int bx = 128;
- const int by = 2048;
- at::Tensor t0 = at::randn({bx, by}, options);
- at::Tensor t1 = at::randn({bx, by}, options);
-
- std::vector<IValue> aten_inputs = {t0, t1};
+ at::Tensor in0 = at::rand({M, N}, options);
+ at::Tensor in1 = at::rand({M, N}, options);
+ at::Tensor in2 = at::rand({M, N}, options);
+ at::Tensor in3 = at::rand({M, N}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
-
- auto aten_output = t0.add(t1).sum(1);
- testValidate(
- &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
+ auto outputs = fe.runFusion({in0, in1, in2, in3});
- auto t3 = t0.add(t1).sum(1);
-
- testValidate(&fusion, cg_outputs, aten_inputs, {t3}, __LINE__, __FILE__);
+ at::Tensor aten_output = (in1 + (in2 - in3)) - in0;
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().sum());
}
-// Unswitched loops with extent one may omit else clause.
-TEST(NVFuserTest, FusionSizeOneLoop1_CUDA) {
+TEST(NVFuserTest, FusionCacheBcast_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- // Progressively broadcast tensors
- TensorView* tv0 = makeSymbolicTensor(1);
+ // Algorithm
+ TensorView* tv0 = makeDummyTensor(1); // (M, 1)
+ TensorView* tv1 = broadcast(tv0, {false, true});
+ TensorView* tv2 = makeDummyTensor(1); // (1, N)
+ TensorView* tv3 = broadcast(tv2, {true, false});
+ TensorView* tv4 = mul(tv1, tv3);
fusion.addInput(tv0);
- TensorView* tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
- TensorView* tv2 = makeSymbolicTensor(3);
fusion.addInput(tv2);
+ fusion.addOutput(tv4);
- TensorView* tv3 = broadcast(tv0, {false, true});
- TensorView* tv4 = add(tv3, tv1);
- TensorView* tv5 = add(tv4, tv2);
+ constexpr int BSX = 128;
+ tv4->split(0, BSX);
+ tv4->split(-1, BSX);
+ tv4->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
+ // M/BSX, N/BSY, BSX, BSY
+ tv0->computeAt(tv4, 2);
+ tv2->computeAt(tv4, 2);
+ // 0, 1 | 2, 3, 4
- fusion.addOutput(tv5);
+ // Case 1
+ tv0->cache_after();
- // Split inner dimension
- tv5->split(1, 8);
- // Merge middle dims with outer dimensions
- tv5->merge(2);
- tv5->merge(0);
+ // Case 2
+ tv1->cache_before();
- // tv5[I0*I1o, I1i*I2]
- // Get a dim of size 1 to unswitch
- tv5->split(0, 1, false);
+ // Case 3
+ tv1->cache_after();
- // Compute everything inline
- tv0->computeAt(tv5, -1);
+ // Case 4
+ TensorView* tv8 = tv4->cache_before();
- tv5->axis(0)->parallelize(ParallelType::Unswitch);
- tv5->axis(1)->parallelize(ParallelType::BIDx);
- tv5->axis(2)->parallelize(ParallelType::TIDx);
+ tv4->axis(0)->parallelize(ParallelType::BIDx);
+ tv4->axis(1)->parallelize(ParallelType::BIDy);
+ tv4->axis(-1)->parallelize(ParallelType::TIDx);
+ // Manual Replay on TV3
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ tv8->axis(-1)->parallelize(ParallelType::TIDx);
- // Make sure the unswitched loop does not have an else clause.
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto fl = dynamic_cast<kir::ForLoop*>(kir_node.get())) {
- if (fl->iter_domain()->parallelType() != ParallelType::Unswitch) {
- continue;
- }
- if (auto pred = dynamic_cast<kir::IfThenElse*>(fl->parentScope())) {
- TORCH_CHECK(!pred->hasElse());
- }
- }
- }
+ constexpr int M = 92, N = 500;
- const int x = 11;
- const int y = 12;
- const int z = 13;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({x}, options);
- at::Tensor t1 = at::randn({x, y}, options);
- at::Tensor t2 = at::randn({z, x, y}, options);
- std::vector<IValue> aten_inputs = {t0, t1, t2};
+ at::Tensor t0 = at::randn({M}, options);
+ at::Tensor t1 = at::randn({N}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
- auto t6 = (t0.unsqueeze(-1) + t1).unsqueeze(0) + t2;
+ auto outputs = fe.runFusion({t0, t1});
- testValidate(&fusion, cg_outputs, aten_inputs, {t6}, __LINE__, __FILE__);
+ at::Tensor aten_output = t0.unsqueeze(1).matmul(t1.unsqueeze(0));
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
}
-// The unswitched loop has extent one but inner loops don't. The else
-// part should not be omitted.
-TEST(NVFuserTest, FusionSizeOneLoop2_CUDA) {
+TEST(NVFuserTest, FusionCacheComplex_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- const int x = 15;
- auto tv0 = makeConcreteTensor({x});
+ TensorView* tv0 = makeDummyTensor(2); // (N, N)
+ TensorView* tv1 = makeDummyTensor(1); // (N)
+ TensorView* tv2 = sum(tv0, {1}); // (N)
+ TensorView* tv3 = broadcast(tv2, {false, true}); // (N, 1)
+ TensorView* tv4 = broadcast(tv1, {true, false}); // (1, N)
+ TensorView* tv5 = mul(tv3, tv4); // (N, N)
fusion.addInput(tv0);
+ fusion.addInput(tv1);
+ fusion.addOutput(tv5);
- auto tv1 = add(tv0, new Double(1));
- fusion.addOutput(tv1);
+ // Exception: Cache-Before on reduction Op
+ // TensorView* tv9 = tv2->cache_before();
+
+ constexpr int BSX = 128;
+ tv5->split(0, BSX);
+ tv5->split(-1, BSX);
+ // M/BSX, BSX, N/BSX, BSX
+ tv5->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
+ // M/BSX, N/BSY, BSX, BSY
+ tv0->computeAt(tv5, 2);
+ tv1->computeAt(tv5, 2);
+ // 0, 1 | 2, 3, 4
+
+ tv2->cache_after();
+ TensorView* tv7 = tv5->cache_before();
- tv1->split(-1, 4);
- tv1->split(-2, 1);
+ tv5->axis(0)->parallelize(ParallelType::BIDx);
+ tv5->axis(1)->parallelize(ParallelType::BIDy);
+ tv5->axis(-1)->parallelize(ParallelType::TIDx);
- tv1->axis(-2)->parallelize(ParallelType::Unswitch);
+ tv4->axis(-1)->parallelize(ParallelType::TIDx);
+ tv7->axis(-1)->parallelize(ParallelType::TIDx);
- // Make sure the size-one unswitched loop does not omit the else clause.
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto fl = dynamic_cast<kir::ForLoop*>(kir_node.get())) {
- if (fl->iter_domain()->parallelType() != ParallelType::Unswitch) {
- continue;
- }
- if (auto pred = dynamic_cast<kir::IfThenElse*>(fl->parentScope())) {
- TORCH_CHECK(pred->hasElse());
- }
- }
- }
+ constexpr int N = 800;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({x}, options);
- std::vector<IValue> aten_inputs = {t0};
+ at::Tensor input1 = at::rand({N, N}, options);
+ at::Tensor input2 = at::rand({N}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion(aten_inputs);
- auto t1 = t0 + 1;
+ auto outputs = fe.runFusion({input1, input2});
- testValidate(&fusion, cg_outputs, aten_inputs, {t1}, __LINE__, __FILE__);
+ at::Tensor aten_output =
+ matmul(sum(input1, 1).unsqueeze(1), input2.unsqueeze(0));
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().sum());
}
-TEST(NVFuserTest, FusionValidateParallelize1_CUDA) {
+TEST(NVFuserTest, FusionCacheMultiConsumer_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
+ TensorView* tv0 = makeDummyTensor(1);
+ TensorView* tv1 = add(tv0, new Float(1));
+ TensorView* tv2 = add(tv1, new Float(2));
+ TensorView* tv3 = add(tv0, new Float(1));
+ TensorView* tv4 = add(tv3, new Float(2));
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(1));
+ fusion.addInput(tv0);
fusion.addOutput(tv2);
+ fusion.addOutput(tv4);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDy);
-
- // Invalid as tv1 and tv2 do have the same ParallelType
- FusionExecutor fe;
- ASSERT_ANY_THROW(fe.compileFusion(&fusion));
-}
+ tv1->computeAt(tv2, -1);
+ tv3->computeAt(tv4, -1);
-TEST(NVFuserTest, FusionValidateParallelize2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ auto tv5 = tv1->cache_before();
+ auto tv6 = tv3->cache_before();
+ tv5->setMemoryType(MemoryType::Shared);
+ tv6->setMemoryType(MemoryType::Shared);
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
+ // Fails because tensor must be recomputed twice
+ // auto tv7 = tv0->cache_after();
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(1));
- fusion.addOutput(tv2);
+ constexpr int N = 800;
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDy);
- tv1->setMemoryType(MemoryType::Shared);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({N}, options);
- // tv1 and tv2 do have the same ParallelType, but tv1 is on shared
- // memory, so it is valid
FusionExecutor fe;
fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({input});
+
+ auto aten_output = (input + 1) + 2;
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().sum());
+ TORCH_CHECK(
+ aten_output.allclose(outputs[1], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[1]).abs().sum());
}
-TEST(NVFuserTest, FusionValidateParallelize3_CUDA) {
+TEST(NVFuserTest, FusionSmem_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
+ // Algorithm
+ TensorView* tv0 = makeDummyTensor(2); // (M, N)
+ TensorView* tv1 = makeDummyTensor(2); // (M, N)
+ TensorView* tv2 = mul(tv0, tv1);
fusion.addInput(tv0);
-
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(1));
+ fusion.addInput(tv1);
fusion.addOutput(tv2);
- tv1->split(-1, 4);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->split(-1, 4);
+ // Schedule
+ TensorView* tv3 = tv0->cache_after();
+ TensorView* tv4 = tv1->cache_after();
+ tv3->setMemoryType(MemoryType::Shared);
+ tv4->setMemoryType(MemoryType::Shared);
+
+ constexpr int BSY = 32;
+ constexpr int BSX = 128;
+ tv2->split(0, BSY);
+ tv2->split(2, BSX);
+ // M/BSX, BSX, N/BSX, BSX
+ tv2->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
+ // M/BSX, N/BSX, BSX, BSX
+
+ tv0->computeAt(tv2, 2);
+ tv1->computeAt(tv2, 2);
+
+ // Thread and Block binding
+ tv2->axis(0)->parallelize(ParallelType::BIDx);
+ tv2->axis(1)->parallelize(ParallelType::BIDy);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ // Manual Binding
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ tv4->axis(-1)->parallelize(ParallelType::TIDx);
+
+ constexpr int M = 128, N = 10240;
- tv1->setMemoryType(MemoryType::Global);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor t0 = at::randn({M, N}, options);
+ at::Tensor t1 = at::randn({M, N}, options);
- // tv1 and tv2 have the same shape and ParallelType
FusionExecutor fe;
fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({t0, t1});
+
+ at::Tensor aten_output = mul(t0, t1);
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
+ TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
}
-TEST(NVFuserTest, FusionValidateParallelize4_CUDA) {
+TEST(NVFuserTest, FusionSmemReduce_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
+ // Algorithm
+ TensorView* tv0 = makeDummyTensor(3); // M, K, N
+ TensorView* tv1 = sum(tv0, {1}); // M, R, N
fusion.addInput(tv0);
+ fusion.addOutput(tv1);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(1));
- fusion.addOutput(tv2);
+ TensorView* tv2 = tv0->cache_after();
+ tv2->setMemoryType(MemoryType::Shared);
+
+ // Schedule
+ constexpr int BSX = 32;
+ tv1->split(2, BSX);
+ tv1->split(1, 128);
+ tv1->split(0, BSX);
+ // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
+ tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}});
+ TensorView* tv3 = tv1->rFactor({-2});
+
+ tv0->computeAt(tv1, -2);
+ tv0->computeAt(tv3, -2);
- tv1->split(-1, 4);
+ // Thread and Block binding
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
+ tv1->axis(1)->parallelize(ParallelType::BIDy);
tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->split(-1, 8);
+ // Manual Binding
tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+
+ constexpr int M = 154, K = 45, N = 1524;
- tv1->setMemoryType(MemoryType::Global);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor t0 = at::randn({M, K, N}, options);
- // tv1 and tv2 do not have the same shape
FusionExecutor fe;
- ASSERT_ANY_THROW(fe.compileFusion(&fusion));
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({t0});
+
+ at::Tensor aten_output = sum(t0, {1});
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
+ TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 1);
+ TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(24) == 1);
}
-TEST(NVFuserTest, FusionValidateParallelize5_CUDA) {
+TEST(NVFuserTest, FusionSmemBlockGemm_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
+ // Algorithm
+ TensorView* tv0 = makeDummyTensor(2); // (M, K)
+ TensorView* tv1 = makeDummyTensor(2); // (K, N)
+ TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
+ TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
+ TensorView* tv4 = mul(tv2, tv3); // M, K, N
+ TensorView* tv5 = sum(tv4, {1}); // M, R, N
fusion.addInput(tv0);
+ fusion.addInput(tv1);
+ fusion.addOutput(tv5);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(1));
- fusion.addOutput(tv2);
+ // Schedule
+ constexpr int BSX = 16;
+ tv5->split(2, BSX);
+ tv5->split(1, BSX);
+ tv5->split(0, BSX);
+ // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
+ tv5->reorder({{0, 0}, {1, 3}, {2, 2}, {3, 5}, {4, 1}, {5, 4}});
+ // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX
+ TensorView* tv6 = tv5->rFactor({-1});
- tv1->split(-1, 4);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv1->setMemoryType(MemoryType::Shared);
+ tv2->setMemoryType(MemoryType::Shared);
+ tv3->setMemoryType(MemoryType::Shared);
+ tv4->setMemoryType(MemoryType::Shared);
+ tv6->setMemoryType(MemoryType::Shared);
- tv2->split(-1, 8);
+ tv0->computeAt(tv5, 3);
+ tv1->computeAt(tv5, 3);
+
+ // Thread and Block binding
+ tv5->axis(0)->parallelize(ParallelType::BIDx);
+ tv5->axis(1)->parallelize(ParallelType::BIDy);
+ tv5->axis(-2)->parallelize(ParallelType::TIDy);
+ tv5->axis(-1)->parallelize(ParallelType::TIDx);
+ // Manual Binding
tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ tv4->axis(-1)->parallelize(ParallelType::TIDx);
+ tv6->axis(-3)->parallelize(ParallelType::TIDy);
+ tv6->axis(-2)->parallelize(ParallelType::TIDx);
+
+ constexpr int M = 154, K = 45, N = 1524;
+
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor t0 = at::randn({M, K}, options);
+ at::Tensor t1 = at::randn({K, N}, options);
- // tv1 and tv2 do not have the same shape, but tv1 is on shared
- // memory, so it is valid
FusionExecutor fe;
fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({t0, t1});
+
+ at::Tensor aten_output = matmul(t0, t1);
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
+ TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
}
-// See issue #995
-TEST(NVFuserTest, FusionValidateParallelize6_CUDA) {
+TEST(NVFuserTest, FusionSmemBlockGemmCache_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(3);
- auto tv1 = makeSymbolicTensor(4);
+ // Algorithm
+ TensorView* tv0 = makeDummyTensor(2); // (M, K)
+ TensorView* tv1 = makeDummyTensor(2); // (K, N)
+ TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
+ TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
+ TensorView* tv4 = mul(tv2, tv3); // M, K, N
+ TensorView* tv5 = sum(tv4, {1}); // M, R, N
fusion.addInput(tv0);
fusion.addInput(tv1);
+ fusion.addOutput(tv5);
- auto tv2 = add(tv0, new Double(1));
- auto tv3 = broadcast(tv2, {true, false, false, false});
- auto tv4 = add(tv3, tv1);
- fusion.addOutput(tv4);
-
- tv4->merge(0);
- tv4->merge(0);
- tv4->merge(0);
- tv4->split(0, 128);
- tv4->split(0, 1);
- tv4->split(0, 1);
-
- TransformPropagator::from(tv4);
+ // Schedule
+ // Remove reduction axis from tv5
+ // tv6 = (M, R, N)
+ // tv5 = (M, N)
+ TensorView* tv6 = tv5->cache_before();
- tv0->computeAt(tv2, 2);
- tv3->computeAt(tv4, 2);
+ constexpr int BSX = 16;
+ tv5->split(1, BSX);
+ tv5->split(0, BSX);
+ // M/BSX, BSX, N/BSX, BSX
+ tv5->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}});
+ // tv5 = M/BSX, N/BSX, MSX, NSX
- tv4->axis(0)->parallelize(ParallelType::BIDx);
- tv4->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(0)->parallelize(ParallelType::BIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv6->computeAt(tv5, 2);
+ tv6->computeAt(tv5, 2);
- // Validation should throw an exception saying the first axes of tv2
- // and tv3 have incompatible parallelization. See also issue #995.
- ASSERT_ANY_THROW(fusion.printKernel());
-}
+ tv6->split(-1, BSX);
+ // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
+ tv6->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 4}, {4, 2}, {5, 5}});
+ // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX
+ TensorView* tv7 = tv6->rFactor({-1});
+ // tv7 = M/BSX, N/BSX, K/BSXrf, MSX, NSX, KSXr
+ // tv6 = M/BSX, N/BSX, K/BSXr, MSX, NSX
-TEST(NVFuserTest, FusionDAGMerging_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ tv0->computeAt(tv6, 3);
+ tv1->computeAt(tv6, 3);
- auto tv0 = makeSymbolicTensor(5);
- auto tv1 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- fusion.addInput(tv1);
+ tv0->computeAt(tv7, 3);
+ tv1->computeAt(tv7, 3);
- // Branch 0
- auto tv2 = sum(tv0, {0}); // 0
- auto tv3 = sum(tv2, {0}); // 1
- auto tv4 = sum(tv3, {0}); // 2
- auto tv5 = sum(tv4, {0}); // 3
+ tv2->setMemoryType(MemoryType::Shared);
+ tv3->setMemoryType(MemoryType::Shared);
+ tv4->setMemoryType(MemoryType::Shared);
+ tv6->setMemoryType(MemoryType::Shared);
+ tv7->setMemoryType(MemoryType::Shared);
+ // Memory Type
- // Branch 1
- auto tv6 = add(tv1, new Double(1)); // 4
+ // Thread and Block binding
+ tv5->axis(0)->parallelize(ParallelType::BIDx);
+ tv5->axis(1)->parallelize(ParallelType::BIDy);
+ tv5->axis(-2)->parallelize(ParallelType::TIDy);
+ tv5->axis(-1)->parallelize(ParallelType::TIDx);
+ // Manual Binding
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ tv4->axis(-1)->parallelize(ParallelType::TIDx);
- // Merge
- auto tv7 = add(tv6, tv5); // 5
+ tv7->axis(-3)->parallelize(ParallelType::TIDy);
+ tv7->axis(-2)->parallelize(ParallelType::TIDx);
- // Maximum expected output groups (can improve overtime):
- // {0}, {1}, {2}, {3,4,5}
- // without final merge would have been {0}, {1}, {2}, {3,4}, {5}
+ tv6->axis(-2)->parallelize(ParallelType::TIDy);
+ tv6->axis(-1)->parallelize(ParallelType::TIDx);
- fusion.addOutput(tv7);
+ constexpr int M = 154, K = 45, N = 1524;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({2, 2, 2, 2, 2}, options);
- at::Tensor t1 = at::randn({2}, options);
+ at::Tensor t0 = at::randn({M, K}, options);
+ at::Tensor t1 = at::randn({K, N}, options);
+
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({t0, t1});
- auto fusion_segments = fusion.segment({t0, t1});
- TORCH_CHECK(fusion_segments->groups().size() <= 4);
+ at::Tensor aten_output = matmul(t0, t1);
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
+ TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
}
-TEST(NVFuserTest, FusionDAGScalarMerging_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+TEST(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
+
+ TensorView* x = makeDummyTensor(2);
+ fusion.addInput(x);
+ TensorView* max_val =
+ reductionOp(BinaryOpType::Max, {-1}, new Float(FLT_MIN), x); // (M)
+ TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B)
+ TensorView* x_max_sub = sub(x, bcast_max); // (M, N)
+ TensorView* exp = unaryOp(UnaryOpType::Exp, x_max_sub); // (M, N)
+ TensorView* sum_exp = sum(exp, {-1}); // (M, R)
+ TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B)
+ TensorView* softmax = div(exp, bcast_sum); // (M, N)
+ fusion.addOutput(softmax);
- auto tv0 = makeSymbolicTensor(3);
- auto i0 = new Double();
+ // Read Input into Shared Memory
+ // Load Input + Pwise into shared memory
+ auto cache_x = x->cache_after();
+ cache_x->setMemoryType(MemoryType::Shared);
+ exp->setMemoryType(MemoryType::Shared);
- fusion->addInput(tv0);
- fusion->addInput(i0);
+ std::vector<TensorView*> all_tensors(
+ {x,
+ cache_x,
+ max_val,
+ bcast_max,
+ x_max_sub,
+ exp,
+ sum_exp,
+ bcast_sum,
+ softmax});
- auto i1 = add(i0, new Double(1.0));
- auto i2 = mul(i1, i1);
- auto i3 = add(i2, i1);
+ auto tidx = new Int();
+ fusion.addInput(tidx);
- // Branch 0
- auto tv1 = sum(tv0, {0}); // 0
- auto tv2 = add(tv1, i2);
- // Branch 1
- auto tv3 = sum(tv2, {0}); // 1
- auto tv4 = add(tv3, i3);
+ for (auto tensor : all_tensors) {
+ tensor->split(-1, tidx);
+ }
- auto tv5 = add(tv4, i0);
+ auto sum_exp_rf = sum_exp->rFactor({1});
+ all_tensors.push_back(sum_exp_rf);
- fusion->addOutput(tv5);
+ // computeAt
+ x->computeAt(x_max_sub, 1);
+ exp->computeAt(softmax, 1);
+ x_max_sub->computeAt(exp, 2);
- FusionExecutorCache executor_cache(std::move(fusion));
+ softmax->axis(0)->parallelize(ParallelType::BIDx);
+ for (auto tensor : all_tensors) {
+ tensor->axis(-1)->parallelize(ParallelType::TIDx);
+ }
+ const size_t dimx = 1024;
+ const size_t dimy = 4096;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({16, 16, 16}, options);
- double s0 = 0.5;
-
- auto s1 = s0 + 1.0;
- auto s2 = s1 * s1;
- auto s3 = s2 + s1;
- auto t1 = t0.sum({0});
- auto t2 = t1 + s2;
- auto t3 = sum(t2, {0});
- auto t4 = t3 + s3;
- auto t5 = t4 + s0;
+ at::Tensor t0 = at::randn({dimx, dimy}, options);
- auto outputs = executor_cache.runFusionWithInputs({t0, s0});
+ torch::jit::fuser::cuda::FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({t0, 128});
+ auto t1 = at::_softmax(t0, -1, false);
TORCH_CHECK(
- executor_cache.getMostRecentKernelRuntime()->isSegmented(),
- "segmentation didn't happen");
- TORCH_CHECK(
- executor_cache.getMostRecentKernelRuntime()
- ->fusionSegments()
- ->groups()
- .size() == 2,
- "segmentation didn't happen as expected");
-
- testValidate(
- executor_cache.fusion(), outputs, {t0, s0}, {t5}, __LINE__, __FILE__);
+ t1.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ t1.sub(outputs[0]).abs().max());
}
-TEST(NVFuserTest, FusionBlockReduceInSerialLoop_CUDA) {
+TEST(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- constexpr int M = 10;
- constexpr int N = 20;
- constexpr int K = 20;
+ const int pixels_per_thread = 64;
+ const int TIDX = 128;
+ const int static_size = pixels_per_thread * TIDX;
- auto tv0 = makeSymbolicTensor(3);
- auto tv1 = sum(tv0, {{1, 2}});
- fusion.addInput(tv0);
- fusion.addOutput(tv1);
+ TensorView* sx = makeConcreteTensor({-1, static_size});
+ TensorView* dx = makeDummyTensor(2);
+ fusion.addInput(sx);
+ fusion.addInput(dx);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv1->axis(0)->parallelize(ParallelType::BIDx);
+ TensorView* max_sx =
+ reductionOp(BinaryOpType::Max, {-1}, new Float(FLT_MIN), sx); // (M)
+ TensorView* max_dx =
+ reductionOp(BinaryOpType::Max, {-1}, new Float(FLT_MIN), dx); // (M)
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor t0 = at::randn({M, N, K}, options);
- std::vector<IValue> aten_inputs = {t0};
+ // Reduction => merge local and shared memory TensorViews
+ TensorView* max_val = binaryOp(BinaryOpType::Max, max_sx, max_dx);
+ TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B)
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(aten_inputs);
- at::Tensor aten_output = t0.sum({1, 2});
- testValidate(
- &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
-}
+ TensorView* sx_max_sub = sub(sx, bcast_max); // (M, N)
+ TensorView* dx_max_sub = sub(dx, bcast_max); // (M, N)
-TEST(NVFuserTest, FusionBlockWelfordInSerialLoop_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ TensorView* sx_exp = unaryOp(UnaryOpType::Exp, sx_max_sub); // (M, N)
+ TensorView* dx_exp = unaryOp(UnaryOpType::Exp, dx_max_sub); // (M, N)
- constexpr int M = 10;
- constexpr int N = 20;
- constexpr int K = 20;
+ TensorView* sx_sum_exp = sum(sx_exp, {-1}); // (M, R)
+ TensorView* dx_sum_exp = sum(dx_exp, {-1}); // (M, R)
- auto tv0 = makeSymbolicTensor(3);
- auto tvs = Welford(tv0, {{1, 2}});
- fusion.addInput(tv0);
- auto tv_avg = tvs.avg;
- auto tv_M2 = tvs.var_sum;
- auto tv_N = tvs.n;
- fusion.addOutput(tv_avg);
- fusion.addOutput(tv_M2);
+ // Reduction => merge local and shared memory TensorViews
+ TensorView* sum_exp = binaryOp(BinaryOpType::Add, sx_sum_exp, dx_sum_exp);
+ TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B)
- tv_avg->axis(-1)->parallelize(ParallelType::TIDx);
- tv_avg->axis(0)->parallelize(ParallelType::BIDx);
+ TensorView* sx_softmax = div(sx_exp, bcast_sum); // (M, N)
+ TensorView* dx_softmax = div(dx_exp, bcast_sum); // (M, N)
+ fusion.addOutput(sx_softmax);
+ fusion.addOutput(dx_softmax);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor t0 = at::randn({M, N, K}, options);
- std::vector<IValue> aten_inputs = {t0};
+ auto sx_cache = sx->cache_after();
+ auto dx_cache = dx->cache_after();
+ dx_cache->setMemoryType(MemoryType::Shared);
+ dx_exp->setMemoryType(MemoryType::Shared);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(aten_inputs);
- at::Tensor aten_avg = t0.mean({1, 2});
- at::Tensor aten_M2 = t0.var({1, 2}, false) * N * K;
- testValidate(
- &fusion, outputs, aten_inputs, {aten_avg, aten_M2}, __LINE__, __FILE__);
-}
+ // Reduction and Broadcast Tensors common to both memory TVs
+ std::vector<TensorView*> common_tensors(
+ {max_val, sum_exp, bcast_max, bcast_sum});
-// See Issue #716
-TEST(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ // Static Local Memory TVs
+ std::vector<TensorView*> static_tensors(
+ {sx, sx_cache, max_sx, sx_max_sub, sx_exp, sx_sum_exp, sx_softmax});
- constexpr int M = 10;
- constexpr int N = 11;
+ // Dynamic Local Memory TVs
+ std::vector<TensorView*> dynamic_tensors(
+ {dx, dx_cache, max_dx, dx_max_sub, dx_exp, dx_sum_exp, dx_softmax});
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
+ std::vector<TensorView*> all_tensors;
+ all_tensors.insert(
+ all_tensors.end(), common_tensors.begin(), common_tensors.end());
+ all_tensors.insert(
+ all_tensors.end(), static_tensors.begin(), static_tensors.end());
+ all_tensors.insert(
+ all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end());
+
+ // M => M
+ // M, N => M, N/128, 128
+ for (auto tensor : all_tensors) {
+ if (tensor->nDims() > 1) {
+ tensor->split(-1, TIDX);
+ }
+ }
+
+ auto sx_sum_exp_rf = sx_sum_exp->rFactor({1});
+ auto dx_sum_exp_rf = dx_sum_exp->rFactor({1});
+ all_tensors.push_back(sx_sum_exp_rf);
+ all_tensors.push_back(dx_sum_exp_rf);
- std::vector<int> reduction_axes = {1};
- std::vector<bool> broadcast_mask = {false, true};
+ // computeAt
+ sx->computeAt(sx_max_sub, 1);
+ dx->computeAt(dx_max_sub, 1);
+
+ sx_exp->computeAt(sx_softmax, 1);
+ dx_exp->computeAt(dx_softmax, 1);
+
+ sx_max_sub->computeAt(sx_exp, 2);
+ dx_max_sub->computeAt(dx_exp, 2);
- auto tv0_bcast = broadcast(tv0, broadcast_mask);
- auto path1_bcast = add(tv0_bcast, new Double(1.0));
- auto path1 = sum(path1_bcast, reduction_axes);
- fusion.addOutput(path1);
+ sx_softmax->axis(0)->parallelize(ParallelType::BIDx);
+ dx_softmax->axis(0)->parallelize(ParallelType::BIDx);
+ for (auto tensor : all_tensors) {
+ if (tensor->nDims() > 1) {
+ tensor->axis(-1)->parallelize(ParallelType::TIDx);
+ }
+ }
- auto p = path1->split(1, 1);
- path1->rFactor({1});
- path1->axis(0)->parallelize(ParallelType::BIDx);
- tv0->computeAt(path1, 1);
+ const size_t dimx = 1024;
+ const size_t dimy = 16384;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor t0 = at::randn({M}, options);
- at::Tensor t0_ref = t0.clone();
- std::vector<IValue> aten_inputs = {t0};
+ at::Tensor in = at::randn({dimx, dimy}, options);
+ at::Tensor static_in = in.narrow(1, 0, static_size);
+ at::Tensor dynamic_in = in.narrow(1, static_size, dimy - static_size);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
+ at::Tensor out = at::zeros({dimx, dimy}, options);
+ at::Tensor static_out = out.narrow(1, 0, static_size);
+ at::Tensor dynamic_out = out.narrow(1, static_size, dimy - static_size);
- // inplace op, we are adding t0 to itself
- auto outputs = fe.runFusion(aten_inputs, {t0});
+ torch::jit::fuser::cuda::FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs =
+ fe.runFusion({static_in, dynamic_in}, {static_out, dynamic_out});
- TORCH_CHECK(outputs[0].allclose(t0_ref.add(1)));
+ auto t1 = at::_softmax(in, -1, false);
+ TORCH_CHECK(
+ t1.allclose(out, 1e-5, 1e-5), "Error of: ", t1.sub(out).abs().max());
}
-TEST(NVFuserTest, FusionReductionPredicate_CUDA) {
+TEST(NVFuserTest, FusionPersistentBatchNormLocalShared_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = sum(tv0, {0});
- fusion.addOutput(tv1);
+ const int pixels_per_thread = 64;
+ const int TIDX = 128;
+ const int static_size = pixels_per_thread * TIDX;
- auto tv2 = tv0->cache_after();
+ TensorView* sx = makeConcreteTensor({-1, static_size});
+ TensorView* dx = makeDummyTensor(2);
+ fusion.addInput(sx);
+ fusion.addInput(dx);
- const int bdimx = 128;
- tv1->split(1, bdimx);
- tv1->split(1, 4);
- tv1->split(1, 1);
+ Float* gamma = new Float();
+ Float* beta = new Float();
+ Float* eps = new Float();
+ Int* N = new Int();
+ fusion.addInput(gamma);
+ fusion.addInput(beta);
+ fusion.addInput(eps);
+ fusion.addInput(N);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv1->axis(2)->parallelize(ParallelType::Unroll);
- tv1->split(0, 10);
- tv0->computeAt(tv1, 4);
+ // Reduction
+ auto sx_sum = sum(sx, {-1}); // (M, R)
+ auto dx_sum = sum(dx, {-1}); // (M, R)
+ // Reduction => merge local and shared memory TensorViews
+ auto x_sum = binaryOp(BinaryOpType::Add, sx_sum, dx_sum);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ // Broadcast
+ auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B)
+ // Pwise
+ auto x_mean = div(x_sum_bcast, N); // (M, B)
- int numel_x = 650;
- int numel_y = 102;
+ auto sx_mean_sub = sub(sx, x_mean); // (M, N)
+ auto dx_mean_sub = sub(dx, x_mean); // (M, N)
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({numel_x, numel_y}, options);
- at::Tensor cg_output = at::empty({numel_y}, options);
+ auto sx_mean_sub_pow = mul(sx_mean_sub, sx_mean_sub); // (M, N)
+ auto dx_mean_sub_pow = mul(dx_mean_sub, dx_mean_sub); // (M, N)
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- fe.runFusion({input}, {cg_output});
+ // Reduction
+ auto sx_var_sum = sum(sx_mean_sub_pow, {-1}); // (M, R)
+ auto dx_var_sum = sum(dx_mean_sub_pow, {-1}); // (M, R)
+ // Reduction => merge local and shared memory TensorViews
+ auto var_sum = binaryOp(BinaryOpType::Add, sx_var_sum, dx_var_sum);
- auto aten_output = input.to(at::kDouble).sum({0});
+ // Broadcast
+ auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B)
+ // Pwise
+ auto var = div(var_sum_bcast, N); // (M, B)
+ auto var_eps = add(var, eps); // (M, B)
+ auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B)
- testValidate(
- &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
-}
+ auto sx_norm = mul(sx_mean_sub, rvar);
+ auto dx_norm = mul(dx_mean_sub, rvar);
-TEST(NVFuserTest, FusionIssue728_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ auto sx_norm_gamma = mul(sx_norm, gamma);
+ auto dx_norm_gamma = mul(dx_norm, gamma);
- auto tv0 = makeSymbolicTensor(1);
- fusion.addOutput(tv0);
- auto tv1 = makeSymbolicTensor(1);
- fusion.addOutput(tv1);
- auto tv2 = makeSymbolicTensor(1);
- fusion.addOutput(tv2);
+ auto sx_norm_gamma_beta = add(sx_norm_gamma, beta);
+ auto dx_norm_gamma_beta = add(dx_norm_gamma, beta);
+ fusion.addOutput(sx_norm_gamma_beta);
+ fusion.addOutput(dx_norm_gamma_beta);
- auto tv3 = add(tv0, new Double(1));
- auto tv4 = add(tv3, tv1);
- auto tv5 = add(tv4, new Double(1));
- auto tv6 = add(tv2, new Double(1));
- fusion.addOutput(tv5);
- fusion.addOutput(tv6);
+ // Read Input into Shared Memory
+ // Read Input minus Input_Mean into Shared Memory
+ auto sx_cache = sx->cache_after();
+ auto dx_cache = dx->cache_after();
+ dx_cache->setMemoryType(MemoryType::Shared);
+ dx_mean_sub->setMemoryType(MemoryType::Shared);
- // tv0 -> tv3 -+
- // tv1 --------+-> tv4 -> tv5
- //
- // tv2 -> tv6
+ std::vector<TensorView*> common_tensors(
+ {x_sum, x_sum_bcast, x_mean, var_sum, var_sum_bcast, var, var_eps, rvar});
+
+ std::vector<TensorView*> static_tensors(
+ {sx,
+ sx_cache,
+ sx_sum,
+ sx_mean_sub,
+ sx_mean_sub_pow,
+ sx_var_sum,
+ sx_norm,
+ sx_norm_gamma,
+ sx_norm_gamma_beta});
+
+ std::vector<TensorView*> dynamic_tensors(
+ {dx,
+ dx_cache,
+ dx_sum,
+ dx_mean_sub,
+ dx_mean_sub_pow,
+ dx_var_sum,
+ dx_norm,
+ dx_norm_gamma,
+ dx_norm_gamma_beta});
+
+ std::vector<TensorView*> all_tensors;
+ all_tensors.insert(
+ all_tensors.end(), common_tensors.begin(), common_tensors.end());
+ all_tensors.insert(
+ all_tensors.end(), static_tensors.begin(), static_tensors.end());
+ all_tensors.insert(
+ all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end());
- auto all_vals_under_tv3 =
- DependencyCheck::getAllValsBetween({tv3}, fusion.outputs());
- std::unordered_set<Val*> included_tensors({tv3, tv4, tv5});
- for (auto tv : included_tensors) {
- TORCH_CHECK(
- std::find(all_vals_under_tv3.begin(), all_vals_under_tv3.end(), tv) !=
- all_vals_under_tv3.end(),
- "TV",
- tv->name(),
- " not found");
- }
- for (auto tv : ir_utils::filterByType<TensorView>(fusion.vals())) {
- if (included_tensors.find(tv) == included_tensors.end()) {
- TORCH_CHECK(
- std::find(all_vals_under_tv3.begin(), all_vals_under_tv3.end(), tv) ==
- all_vals_under_tv3.end(),
- "TV",
- tv->name(),
- " should not be found");
+ // M => M
+ // M, N => M, N/128, 128
+ for (auto tensor : all_tensors) {
+ if (tensor->nDims() > 1) {
+ tensor->split(-1, TIDX);
}
}
- auto no_dependency = DependencyCheck::getAllValsBetween({}, fusion.outputs());
- TORCH_CHECK(no_dependency.empty(), "No val should be returned");
-
- auto no_dep_path = DependencyCheck::getAllValsBetween({tv0, tv1}, {tv6});
- TORCH_CHECK(no_dep_path.empty(), "No val should be returned");
+ // Local Sum => Block Broadcast
+ TensorView* sx_sum_rf = sx_sum->rFactor({1});
+ TensorView* sx_var_sum_rf = sx_var_sum->rFactor({1});
+ TensorView* dx_sum_rf = dx_sum->rFactor({1});
+ TensorView* dx_var_sum_rf = dx_var_sum->rFactor({1});
+ all_tensors.push_back(sx_sum_rf);
+ all_tensors.push_back(sx_var_sum_rf);
+ all_tensors.push_back(dx_sum_rf);
+ all_tensors.push_back(dx_var_sum_rf);
- auto no_dep_path2 = DependencyCheck::getAllValsBetween({tv2}, {tv5});
- TORCH_CHECK(no_dep_path2.empty(), "No val should be returned");
+ // ComputeAt
+ sx->computeAt(sx_mean_sub_pow, 1);
+ dx->computeAt(dx_mean_sub_pow, 1);
- auto just_tv3 = DependencyCheck::getAllValsBetween({tv3}, {tv3});
- TORCH_CHECK(
- just_tv3.size() == 1 && *(just_tv3.begin()) == tv3,
- "Only tv3 should be included");
-}
+ var_sum->computeAt(rvar, 1);
-TEST(NVFuserTest, FusionIssue757_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ sx_mean_sub_pow->computeAt(sx_var_sum_rf, 2);
+ dx_mean_sub_pow->computeAt(dx_var_sum_rf, 2);
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = sum(tv0, {1});
- auto tv2 = broadcast(tv1, {false, true});
- auto tv3 = makeSymbolicTensor(2);
- fusion.addInput(tv3);
- auto tv4 = add(tv2, tv3);
- fusion.addOutput(tv4);
+ sx_norm->computeAt(sx_norm_gamma_beta, 2);
+ dx_norm->computeAt(dx_norm_gamma_beta, 2);
- tv1->computeAt(tv4, -1);
+ sx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx);
+ dx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx);
+ for (auto tensor : all_tensors) {
+ if (tensor->nDims() > 1) {
+ tensor->axis(-1)->parallelize(ParallelType::TIDx);
+ }
+ }
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv4->axis(-1)->parallelize(ParallelType::TIDx);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ const int dimx = 1024;
+ const int dimy = 16384;
+ const float kGamma = 1.0f;
+ const float kBeta = 0.0f;
+ const float kEps = 1e-5;
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- int numel_x = 650;
- int numel_y = 102;
+ at::Tensor in = at::randn({dimx, dimy}, options);
+ at::Tensor static_in = in.narrow(1, 0, static_size);
+ at::Tensor dynamic_in = in.narrow(1, static_size, dimy - static_size);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- at::Tensor t3 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0, t3};
+ at::Tensor out = at::zeros({dimx, dimy}, options);
+ at::Tensor static_out = out.narrow(1, 0, static_size);
+ at::Tensor dynamic_out = out.narrow(1, static_size, dimy - static_size);
- FusionExecutor fe;
+ torch::jit::fuser::cuda::FusionExecutor fe;
fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0.sum({1});
- auto t2 = t1.unsqueeze(-1).expand({numel_x, numel_y});
- auto t4 = t2 + t3;
+ auto outputs = fe.runFusion(
+ {static_in, dynamic_in, kGamma, kBeta, kEps, dimy},
+ {static_out, dynamic_out});
- testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__);
+ auto at_mu = at::mean(in, -1).unsqueeze(1);
+ auto at_var = at::var(in, -1).unsqueeze(1);
+ auto at_rvar = at::rsqrt(at::add(at_var, kEps));
+ auto at_norm = at::mul(at::sub(in, at_mu), at_rvar);
+ auto at_norm_gamma_beta = at::add(at::mul(at_norm, kGamma), kBeta);
+ TORCH_CHECK(
+ at_norm_gamma_beta.allclose(out, 1e-3, 1e-3),
+ "Error of: ",
+ at_norm_gamma_beta.sub(out).abs().max());
}
-// See issue #759
-TEST(NVFuserTest, FusionPredicatedBlockBroadcast_CUDA) {
+TEST(NVFuserTest, FusionSmemDynamicPersistentBatchNorm_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = sum(tv0, {1});
- auto tv2 = broadcast(tv1, {false, true});
- auto tv3 = makeSymbolicTensor(2);
- fusion.addInput(tv3);
- auto tv4 = add(tv2, tv3);
- fusion.addOutput(tv4);
-
- tv4->split(0, 4);
- tv1->computeAt(tv4, -1);
+ // Set up your input tensor views
+ auto x = makeDummyTensor(2);
+ Float* gamma = new Float();
+ Float* beta = new Float();
+ Float* eps = new Float();
+ Int* N = new Int();
+ fusion.addInput(x);
+ fusion.addInput(gamma);
+ fusion.addInput(beta);
+ fusion.addInput(eps);
+ fusion.addInput(N);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(1)->parallelize(ParallelType::TIDy);
- tv4->axis(-1)->parallelize(ParallelType::TIDx);
- tv4->axis(1)->parallelize(ParallelType::TIDy);
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ // Reduction
+ auto x_sum = sum(x, {-1}); // (M, R)
+ // Broadcast
+ auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B)
+ // Pwise
+ auto x_mean = div(x_sum_bcast, N); // (M, B)
+ auto x_mean_sub = sub(x, x_mean); // (M, N)
+ auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub); // (M, N)
+ // Reduction
+ auto var_sum = sum(x_mean_sub_pow, {-1}); // (M, R)
+ // Broadcast
+ auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B)
+ // Pwise
+ auto var = div(var_sum_bcast, N); // (M, B)
+ auto var_eps = add(var, eps); // (M, B)
+ auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B)
+ auto norm = mul(x_mean_sub, rvar);
+ auto norm_gamma = mul(norm, gamma);
+ auto norm_gamma_beta = add(norm_gamma, beta);
+ fusion.addOutput(norm_gamma_beta);
- int numel_x = 100;
- int numel_y = 101;
+ // Read Input into Shared Memory
+ // Read Input minus Input_Mean into Shared Memory
+ auto cache_x = x->cache_after();
+ cache_x->setMemoryType(MemoryType::Shared);
+ x_mean_sub->setMemoryType(MemoryType::Shared);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- at::Tensor t3 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0, t3};
+ std::vector<TensorView*> all_tensors(
+ {x_sum,
+ x_mean,
+ cache_x,
+ x_sum_bcast,
+ x_mean_sub,
+ x_mean_sub_pow,
+ var_sum,
+ var_sum_bcast,
+ var,
+ var_eps,
+ rvar,
+ norm,
+ norm_gamma,
+ norm_gamma_beta});
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(inputs);
+ auto tidx = new Int();
+ fusion.addInput(tidx);
- auto t1 = t0.sum({1});
- auto t2 = t1.unsqueeze(-1).expand({numel_x, numel_y});
- auto t4 = t2 + t3;
+ for (auto tensor : all_tensors) {
+ tensor->split(-1, tidx);
+ }
+ norm_gamma->split(1, 1);
+ norm_gamma_beta->split(1, 1);
- testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__);
-}
+ // Local Sum => Block Broadcast
+ TensorView* x_sum_rf = x_sum->rFactor({1});
+ TensorView* var_sum_rf = var_sum->rFactor({1});
+ all_tensors.push_back(x_sum_rf);
+ all_tensors.push_back(var_sum_rf);
-TEST(NVFuserTest, FusionSegmentVerticalMerge_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+ // ComputeAt
+ x->computeAt(x_mean_sub_pow, 1);
+ var_sum->computeAt(rvar, 1);
+ x_mean_sub_pow->computeAt(var_sum_rf, 2);
+ norm->computeAt(norm_gamma_beta, 2);
- auto tv0 = makeSymbolicTensor(3);
+ for (auto tv : all_tensors) {
+ tv->axis(0)->parallelize(ParallelType::BIDx);
+ tv->axis(-1)->parallelize(ParallelType::TIDx);
+ }
- fusion->addInput(tv0);
- // {first kernel}
- auto tv1 = sum(tv0, {0});
- auto tv2 = add(tv1, tv0);
- auto tv3 = sum(tv2, {0});
- auto tv4 = add(tv3, tv0);
- auto tv5 = sum(tv4, {0});
- auto tv6 = sum(tv5, {0});
- // {second kernel}
- auto tv7 = add(tv6, tv5);
- auto tv8 = add(tv7, tv5);
- auto tv9 = sum(tv8, {0});
-
- fusion->addOutput(tv9);
-
- SegmentCandidateFinderOptions segment_options;
- segment_options.run_herrmann_merge = false;
- segment_options.run_final_merge = false;
+ const int dimx = 128;
+ const int dimy = 2048;
+ const float kGamma = 1.0f;
+ const float kBeta = 0.0f;
+ const float kEps = 1e-5;
+ const int TIDX = 128;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({2, 2, 2}, options);
+ at::Tensor t0 = at::randn({dimx, dimy}, options);
- auto segmented_fusion =
- SegmentCandidateFinder::segment(fusion.get(), {t0}, segment_options);
+ torch::jit::fuser::cuda::FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({t0, kGamma, kBeta, kEps, dimy, TIDX});
- TORCH_CHECK(segmented_fusion->groups().size() == 2);
+ auto at_mu = at::mean(t0, -1).unsqueeze(1);
+ auto at_var = at::var(t0, -1).unsqueeze(1);
+ auto at_rvar = at::rsqrt(at::add(at_var, kEps));
+ auto at_norm = at::mul(at::sub(t0, at_mu), at_rvar);
+ auto at_norm_gamma_beta = at::add(at::mul(at_norm, kGamma), kBeta);
+ TORCH_CHECK(
+ at_norm_gamma_beta.allclose(outputs[0], 1e-3, 1e-3),
+ "Error of: ",
+ at_norm_gamma_beta.sub(outputs[0]).abs().max());
}
-TEST(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
-
- auto tv0 = makeSymbolicTensor(3);
- auto i0 = new Double();
+TEST(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- fusion->addInput(tv0);
- fusion->addInput(i0);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ fusion.addInput(tv0);
+ fusion.addOutput(tv1);
+ // tv1[I0, R1] = tv0[I0, I1]
- // Branch 0 {first kernel}
- auto tv1 = sum(tv0, {0});
- auto tv2 = add(tv0, i0);
- auto tv3 = unaryOp(UnaryOpType::Rsqrt, tv2);
- auto tv4 = sum(tv3, {0});
+ // Interface should just be a direct split with a Parallel type. We can
+ // include the parallelize call if we do this.
+ tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
+ // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
- // Branch 1 {first kernel}
- auto tv5 = unaryOp(UnaryOpType::Rsqrt, tv3);
- auto tv6 = sum(tv5, {0});
+ TensorView* tv2 = tv1->rFactor({2});
+ tv2->setMemoryType(MemoryType::Shared);
+ // tv2[I0, R1oo, Ir1i{BIDx}] = tv0[I0, I1]
+ // tv1[I0, R1i{BIDx}] = tv2[I0, R1oo, Ir1i{BIDx}]
- // Incompatible {second kernel}
- auto tv7 = sum(tv6, {0});
+ tv0->computeAt(tv1, 1);
- fusion->addOutput(tv1);
- fusion->addOutput(tv4);
- fusion->addOutput(tv7);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
- SegmentCandidateFinderOptions segment_options;
- segment_options.run_herrmann_merge = false;
- segment_options.run_final_merge = false;
+ constexpr int numel_x = 65000, numel_y = 1024;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({2, 2, 2}, options);
-
- auto segmented_fusion =
- SegmentCandidateFinder::segment(fusion.get(), {t0, 1.0}, segment_options);
-
- TORCH_CHECK(segmented_fusion->groups().size() == 2);
-}
-
-TEST(NVFuserTest, FusionSegmentMixReduction_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
- auto tv0 = makeSymbolicTensor(3);
-
- fusion->addInput(tv0);
-
- // def of tv1 in kernel 1 through horizontal
- auto tv1 = sum(tv0, {0, 1});
- // kernel 2
- auto tv2 = sum(tv0, {2});
- auto tv3 = broadcast(tv2, {false, false, true});
- auto tv4 = add(tv0, tv3);
- auto tv5 = sum(tv4, {2});
- // end of kernel 2
- // kernel 1
- auto tv6 = unaryOp(UnaryOpType::Rsqrt, tv0);
- auto tv7 = sum(tv6, {0, 1});
- auto tv8 = sum(tv6, {0, 1});
-
- fusion->addOutput(tv1);
- fusion->addOutput(tv5);
- fusion->addOutput(tv7);
- fusion->addOutput(tv8);
-
- SegmentCandidateFinderOptions segment_options;
- segment_options.run_herrmann_merge = false;
- segment_options.run_final_merge = false;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({2, 2, 2}, options);
+ // How many threads to use for the block reduction
+ constexpr int runtime_threadIdx_dim = 128;
- auto segmented_fusion =
- SegmentCandidateFinder::segment(fusion.get(), {t0}, segment_options);
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion(
+ {input}, LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1));
- TORCH_CHECK(segmented_fusion->groups().size() <= 2);
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
+ TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 0);
}
-TEST(NVFuserTest, FusionSBAR_CUDA) {
+TEST(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- // N, H, W, C format
- std::vector<int64_t> input_shape{656, 7, 7, 64};
-
- auto x = makeContigTensor(4);
- auto y = makeContigTensor(4);
- auto weight = makeContigTensor(1);
- auto bias = makeContigTensor(1);
-
- fusion.addInput(x);
- fusion.addInput(y);
- fusion.addInput(weight);
- fusion.addInput(bias);
-
- const size_t kNumberOfDims = x->nDims();
- std::vector<bool> broadcast_mask(kNumberOfDims, false);
- for (size_t axis = 0; axis < kNumberOfDims - 1; ++axis) {
- broadcast_mask[axis] = true;
- }
+ // Algorithm
+ Int* sym_bsx = new Int();
+ TensorView* tv0 = makeDummyTensor(3); // M, K, N
+ fusion.addInput(tv0);
+ fusion.addInput(sym_bsx);
- auto weight_bcast = broadcast(weight, broadcast_mask);
- auto scale = mul(x, weight_bcast);
- auto bias_bcast = broadcast(bias, broadcast_mask);
- auto scale_bias = add(scale, bias_bcast);
- auto scale_bias_add = add(scale_bias, y);
- auto scale_bias_add_relu = unaryOp(UnaryOpType::Relu, scale_bias_add);
+ TensorView* tv1 = sum(tv0, {1}); // M, R, N
+ fusion.addOutput(tv1);
- fusion.addOutput(scale_bias_add_relu);
+ TensorView* tv2 = tv0->cache_after();
+ tv2->setMemoryType(MemoryType::Shared);
- // inputs
- at::manual_seed(0);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor at_x = at::randn(input_shape, options);
- at::Tensor at_y = at::randn(input_shape, options);
- at::Tensor at_weight = at::ones({input_shape[3]}, options);
- at::Tensor at_bias = at::zeros({input_shape[3]}, options);
+ // Schedule
+ constexpr int BSX = 32;
+ tv1->split(2, BSX);
+ tv1->split(1, sym_bsx);
+ tv1->split(0, BSX);
+ // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
+ tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}});
+ TensorView* tv3 = tv1->rFactor({-2});
- // inputs
- std::vector<c10::IValue> inputs = {at_x, at_y, at_weight, at_bias};
+ tv0->computeAt(tv1, -2);
+ tv0->computeAt(tv3, -2);
- // outputs
- std::vector<at::Tensor> outputs;
+ // Thread and Block binding
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
+ tv1->axis(1)->parallelize(ParallelType::BIDy);
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ // Manual Binding
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
- auto lparams = schedulePointwise(&fusion, c10::ArrayRef<c10::IValue>(inputs));
+ constexpr int M = 154, K = 45, N = 1524;
- FusionExecutor executor;
- executor.compileFusion(&fusion);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor t0 = at::randn({M, K, N}, options);
- outputs = executor.runFusion(c10::ArrayRef<c10::IValue>(inputs), lparams);
+ // How many threads to use for the block reduction
+ constexpr int runtime_threadIdx_dim = 128;
- auto at_scale = at::mul(at_x, at_weight);
- auto at_scale_bias = at::add(at_scale, at_bias);
- auto pwise_add = at::add(at_scale_bias, at_y);
- auto output = at::relu(pwise_add);
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion(
+ {t0, runtime_threadIdx_dim},
+ LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1));
- testValidate(&fusion, outputs, inputs, {output}, __LINE__, __FILE__);
+ at::Tensor aten_output = sum(t0, {1});
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
+ TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 1);
+ TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(24) == 1);
}
-TEST(NVFuserTest, FusionSingleElement_CUDA) {
+TEST(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(0);
- fusion.addInput(tv0);
+ Int* sym_bsx = new Int();
+ TensorView* tv0 = makeDummyTensor(2); // (M, K)
+ TensorView* tv1 = makeDummyTensor(2); // (K, N)
+ TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
+ TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
+ TensorView* tv4 = mul(tv2, tv3); // M, K, N
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
+ fusion.addInput(sym_bsx);
+ fusion.addOutput(tv4);
+ // Algorithm
+
+ tv2->setMemoryType(MemoryType::Shared);
+ tv3->setMemoryType(MemoryType::Shared);
+
+ constexpr int BSX = 32;
+ tv4->split(2, BSX);
+ tv4->split(1, sym_bsx);
+ tv4->split(0, BSX);
+ // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX
+ tv4->reorder({{0, 0}, {1, 3}, {2, 1}, {3, 4}, {4, 2}, {5, 5}});
+ // M/BSX, K/BSX, N/BSX, MSX, KSX, NSX
+
+ tv0->computeAt(tv4, 3);
+ tv1->computeAt(tv4, 3);
+ // Schedule
- auto tv1 = add(tv0, new Double(2.5));
+ tv4->axis(0)->parallelize(ParallelType::BIDx);
+ tv4->axis(2)->parallelize(ParallelType::BIDy);
+ // Manual Binding
+ tv2->axis(-2)->parallelize(ParallelType::TIDx);
+ tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ // Thread and Block binding
- auto tv2 = add(tv1, new Double(3.5));
- fusion.addOutput(tv2);
+ constexpr int M = 128, K = 457, N = 1024;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input = at::randn({}, options);
-
- at::Tensor cg_output = at::empty({}, options);
-
- auto lparams = schedulePointwise(&fusion, {input});
+ at::Tensor t0 = at::randn({M, K}, options);
+ at::Tensor t1 = at::randn({K, N}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
- fe.runFusion({input}, {cg_output}, lparams);
+ auto outputs =
+ fe.runFusion({t0, t1, BSX}, LaunchParams(-1, -1, -1, BSX, -1, -1));
- auto aten_output = input.add(2.5).add(3.5);
-
- testValidate(
- &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__);
+ at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0));
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
+ TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 1);
+ TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(22) == 1);
}
-TEST(NVFuserTest, FusionBNBackwardRepro_CUDA) {
- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
- Fusion& fusion = *fusion_ptr.get();
+TEST(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) {
+ Fusion fusion;
FusionGuard fg(&fusion);
- int batch = 4;
- int c = 4;
- int h = 4;
- int w = 4;
- int numDims = 4;
-
- auto input = makeSymbolicTensor(numDims);
- fusion.addInput(input);
- auto weight = makeSymbolicTensor(1);
- fusion.addInput(weight);
- auto running_mean = makeSymbolicTensor(1);
- fusion.addInput(running_mean);
- auto running_var = makeSymbolicTensor(1);
- fusion.addInput(running_var);
- auto save_mean = makeSymbolicTensor(1);
- fusion.addInput(save_mean);
- auto save_invstd = makeSymbolicTensor(1);
- fusion.addInput(save_invstd);
-
- auto grad_out_prev = makeSymbolicTensor(numDims);
- fusion.addInput(grad_out_prev);
- auto gt_0 =
- makeSymbolicTensor(numDims); // single tensor broadcasted is dangerous.
- fusion.addInput(gt_0);
-
- auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, new Int(1));
- auto gt_float = castOp(DataType::Float, gt_bool);
-
- auto grad_out = mul(grad_out_prev, gt_float);
-
- Val* eps_ptr = new Double(1e-5);
-
- auto grads = batch_norm_backward(
- input,
- grad_out,
- weight,
- running_mean,
- running_var,
- save_mean,
- save_invstd,
- true,
- eps_ptr,
- {true, true, true});
-
- fusion.addOutput(grads.grad_input);
- fusion.addOutput(grads.grad_weight);
- fusion.addOutput(grads.grad_bias);
+ // Symbolic integers we will use for runtime tiling
+ Int* symbolic_m_tile_dim = new Int(); // bound to threadIdx.z
+ Int* symbolic_split_k_tile_dim = new Int(); // bound to blockIdx.x
+ Int* symbolic_block_k_tile_dim = new Int(); // bound to threadIdx.x
+ // Compile-time integer for tiling
+ int n_smem_tile = 8; // bound to threadIdx.y
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input0 = at::randn({batch, c, h, w}, options);
- at::Tensor input1 = at::randn({c}, options);
- at::Tensor input2 = at::randn_like(input1);
- at::Tensor input3 = at::randn_like(input1);
- at::Tensor input4 = at::randn_like(input1);
- at::Tensor input5 = at::randn_like(input1);
- at::Tensor input6 = at::randn_like(input0);
- at::Tensor input7 = at::randn_like(input0);
-
- FusionExecutorCache fec(std::move(fusion_ptr));
- std::vector<IValue> inputs = {
- input0, input1, input2, input3, input4, input5, input6, input7};
- auto outputs = fec.runFusionWithInputs(inputs);
-}
-
-// TODO: We only changed inputs, merge this with the test above.
-TEST(NVFuserTest, FusionBNBackwardRepro2_CUDA) {
- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
- Fusion& fusion = *fusion_ptr.get();
- FusionGuard fg(&fusion);
+ // Symbolic 2D tensors TV0[M, K], TV1[K, N]
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
- int batch = 2;
- int c = 81;
- int h = 1;
- int w = 1;
- int numDims = 4;
-
- // auto input = makeSymbolicTensor(numDims);
- auto input = makeConcreteTensor({-1, -1, 1, 1});
- fusion.addInput(input);
- auto weight = makeSymbolicTensor(1);
- fusion.addInput(weight);
- auto running_mean = makeSymbolicTensor(1);
- fusion.addInput(running_mean);
- auto running_var = makeSymbolicTensor(1);
- fusion.addInput(running_var);
- auto save_mean = makeSymbolicTensor(1);
- fusion.addInput(save_mean);
- auto save_invstd = makeSymbolicTensor(1);
- fusion.addInput(save_invstd);
-
- // auto grad_out_prev = makeSymbolicTensor(numDims);
- auto grad_out_prev = makeConcreteTensor({-1, -1, 1, 1});
- fusion.addInput(grad_out_prev);
- // auto gt_0 =
- // makeSymbolicTensor(numDims); // single tensor broadcasted is dangerous.
- auto gt_0 = makeConcreteTensor({-1, -1, 1, 1});
- fusion.addInput(gt_0);
-
- auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, new Int(1));
- auto gt_float = castOp(DataType::Float, gt_bool);
-
- auto grad_out = mul(grad_out_prev, gt_float);
-
- Val* eps_ptr = new Double(1e-5);
-
- auto grads = batch_norm_backward(
- input,
- grad_out,
- weight,
- running_mean,
- running_var,
- save_mean,
- save_invstd,
- true,
- eps_ptr,
- {true, true, true});
-
- fusion.addOutput(grads.grad_input);
- fusion.addOutput(grads.grad_weight);
- fusion.addOutput(grads.grad_bias);
+ // Broadcast tv0 to [M, K, *]
+ TensorView* tv2 = broadcast(tv0, {false, false, true});
+ // Broadcast tv1 to [*, K, N]
+ TensorView* tv3 = broadcast(tv1, {true, false, false});
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input0 = at::randn({batch, c, h, w}, options);
- at::Tensor input1 = at::randn({c}, options);
- at::Tensor input2 = at::randn_like(input1);
- at::Tensor input3 = at::randn_like(input1);
- at::Tensor input4 = at::randn_like(input1);
- at::Tensor input5 = at::randn_like(input1);
- at::Tensor input6 = at::randn_like(input0);
- at::Tensor input7 = at::randn_like(input0);
+ // Pointwise multiplication resulting in tv3[M, K, N]
+ TensorView* tv4 = mul(tv2, tv3);
- FusionExecutorCache fec(std::move(fusion_ptr));
- std::vector<IValue> inputs = {
- input0, input1, input2, input3, input4, input5, input6, input7};
- auto outputs = fec.runFusionWithInputs(inputs);
-}
+ // Turn the K-dimension of tv4 into a reduction dimension
+ TensorView* tv5 = sum(tv4, {1});
-TEST(NVFuserTest, FusionBNRepro_CUDA) {
- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
- Fusion& fusion = *fusion_ptr.get();
- FusionGuard fg(&fusion);
+ // Register inputs and outputs
+ fusion.addInput(tv0);
+ fusion.addInput(tv1);
+ fusion.addOutput(tv5);
- const bool kTraining = true;
- const float kMomentum = 0.1;
- const float kEps = 1e-5;
+ // Register runtime tile dims as inputs
+ fusion.addInput(symbolic_m_tile_dim);
+ fusion.addInput(symbolic_split_k_tile_dim);
+ fusion.addInput(symbolic_block_k_tile_dim);
- int batch = 14;
- int c = 65;
- int h = 7;
- int w = 7;
- int numDims = 4;
-
- auto input = makeSymbolicTensor(numDims);
- fusion.addInput(input);
- auto weight = makeSymbolicTensor(1);
- fusion.addInput(weight);
- auto bias = makeSymbolicTensor(1);
- fusion.addInput(bias);
- auto running_mean = makeSymbolicTensor(1);
- fusion.addInput(running_mean);
- auto running_var = makeSymbolicTensor(1);
- fusion.addInput(running_var);
-
- auto momentum_ptr = new Double(kMomentum);
- auto eps_ptr = new Double(kEps);
-
- auto result = batch_norm(
- input,
- weight,
- bias,
- running_mean,
- running_var,
- kTraining,
- momentum_ptr,
- eps_ptr);
-
- fusion.addOutput(result.output);
- fusion.addOutput(result.mean);
- fusion.addOutput(result.invstd);
+ // Make a 3D tile, mix of symbolic and constant, do in reverse order because
+ // dims are inserted
+ tv5->split(2, n_smem_tile);
+ tv5->split(1, symbolic_block_k_tile_dim);
+ tv5->split(1, symbolic_split_k_tile_dim);
+ tv5->split(0, symbolic_m_tile_dim);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input1 = at::randn({batch, c, h, w}, options);
- at::Tensor input2 = at::randn({c}, options);
- at::Tensor input3 = at::randn_like(input2);
- at::Tensor input4 = at::randn_like(input2);
- at::Tensor input5 = at::randn_like(input2);
-
- auto input1_ref = input1.clone();
- auto input2_ref = input2.clone();
- auto input3_ref = input3.clone();
- auto input4_ref = input4.clone();
- auto input5_ref = input5.clone();
-
- FusionExecutorCache fec(std::move(fusion_ptr));
- std::vector<IValue> aten_inputs = {input1, input2, input3, input4, input5};
- auto cg_outputs = fec.runFusionWithInputs(aten_inputs);
-
- auto at_results = at::native_batch_norm(
- input1_ref,
- input2_ref,
- input3_ref,
- input4_ref,
- input5_ref,
- kTraining,
- kMomentum,
- kEps);
-
- auto at_output = std::get<0>(at_results);
- auto at_mean = std::get<1>(at_results);
- auto at_invstd = std::get<2>(at_results);
-
- std::vector<at::Tensor> aten_outputs = {
- input4_ref, input5_ref, at_output, at_mean, at_invstd};
-
- testValidate(
- &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionBNRepro2_CUDA) {
- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
- Fusion& fusion = *fusion_ptr.get();
- FusionGuard fg(&fusion);
+ // Reorder so all outer tiles are in the leftmost 3 positions
+ tv5->reorder({{1, 5}, {5, 1}});
- const bool kTraining = true;
- const float kMomentum = 0.1;
- const float kEps = 1e-5;
+ // Factor out the outer reduction IterDomain, then run the inter-cta
+ // reduction, and intra-cta reduction
+ auto tv6 = tv5->rFactor({2});
- int batch = 2;
- int c = 4;
- int h = 17;
- int w = 17;
- int numDims = 4;
+ // Scope computations
+ tv6->computeAt(tv5, 2);
- auto input = makeSymbolicTensor(numDims);
- fusion.addInput(input);
+ // RFactor moves reduction axes around, reorder to match ordering of tv5
+ tv6->reorder({
+ {2, -2},
+ {3, -1},
+ {4, 2},
+ {5, 3},
+ {6, 4},
+ });
- Val* momentum_ptr = new Double(kMomentum);
- Val* eps_ptr = new Double(kEps);
+ // Setup compute at schedule
+ tv0->computeAt(tv6, 3);
+ tv1->computeAt(tv6, 3);
+ tv4->computeAt(tv6, -1);
+ //
+ // T2[Mo, bNo, Koo, Koi, Kii, Mi, bNi] CA(4, 3)
+ // T3[bMo, No, Koo, Koi, Kii, bMi, Ni] CA(4, 3)
+ // T4[ Mo, No, Koo, Koi, Kii, Mi, Ni]
+ // T6[ Mo, No, rKoo, Koi, Kii, Mi, Ni]
+ // T5[ Mo, No, rKoi, rKii, Mi, Ni]
- auto result = batch_norm(
- input,
- nullptr,
- nullptr,
- nullptr,
- nullptr,
- kTraining,
- momentum_ptr,
- eps_ptr);
+ // Cache smem tiles
+ tv2->setMemoryType(MemoryType::Shared);
+ tv3->setMemoryType(MemoryType::Shared);
+ tv4->setMemoryType(MemoryType::Local);
+ tv6->setMemoryType(MemoryType::Local);
- fusion.addOutput(result.output);
- fusion.addOutput(result.mean);
- fusion.addOutput(result.invstd);
+ tv5->axis(0)->parallelize(ParallelType::BIDz);
+ tv5->axis(1)->parallelize(ParallelType::BIDy);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input1 = at::randn({batch, c, h, w}, options);
+ std::vector<TensorView*> tv_list = {tv2, tv3, tv4, tv5, tv6};
+ for (auto tv : tv_list) {
+ tv->axis(-2)->parallelize(ParallelType::TIDz);
+ tv->axis(-1)->parallelize(ParallelType::TIDy);
+ }
+ tv2->axis(3)->parallelize(ParallelType::TIDx);
+ tv3->axis(3)->parallelize(ParallelType::TIDx);
+ tv4->axis(3)->parallelize(ParallelType::TIDx);
+ tv6->axis(3)->parallelize(ParallelType::TIDx);
+ tv5->axis(2)->parallelize(ParallelType::TIDx);
- auto input1_ref = input1.clone();
- at::Tensor r_m;
- at::Tensor r_v;
- at::Tensor weight;
- at::Tensor bias;
+ tv2->axis(4)->parallelize(ParallelType::BIDx);
+ tv3->axis(4)->parallelize(ParallelType::BIDx);
+ tv4->axis(4)->parallelize(ParallelType::BIDx);
+ tv6->axis(4)->parallelize(ParallelType::BIDx);
+ tv5->axis(3)->parallelize(ParallelType::BIDx);
+
+ constexpr int M = 31, K = 65, N = 33;
- FusionExecutorCache fec(std::move(fusion_ptr));
- std::vector<IValue> aten_inputs = {input1};
- auto cg_outputs = fec.runFusionWithInputs(aten_inputs);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor A = at::randn({M, K}, options);
+ at::Tensor B = at::randn({K, N}, options);
- auto at_results = at::native_batch_norm(
- input1_ref, r_m, r_v, weight, bias, kTraining, kMomentum, kEps);
+ FusionExecutor fe;
+ // Generate CUDA and compile with nvRTC
+ fe.compileFusion(&fusion);
- auto at_output = std::get<0>(at_results);
- auto at_mean = std::get<1>(at_results);
- auto at_invstd = std::get<2>(at_results);
+ // Runtime tiling
+ int m_tile = 4; // bound to threadIdx.z
+ int split_k = 7; // bound to blockIdx.x
+ int intra_cta = 8; // bound to threadIdx.x
- std::vector<at::Tensor> aten_outputs = {at_output, at_mean, at_invstd};
+ auto fuser_outputs = fe.runFusion({A, B, m_tile, split_k, intra_cta});
+ auto C_fuser = fuser_outputs[0];
- testValidate(
- &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__);
+ at::Tensor aten_C = mul(A.unsqueeze(2), B.unsqueeze(0)).sum(1);
+ TORCH_CHECK(
+ aten_C.allclose(C_fuser, 1e-5, 1e-5),
+ "Error of: ",
+ aten_C.sub(C_fuser).abs().max());
+ TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.size() == 1);
+ TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs.count(41) == 1);
}
-TEST(NVFuserTest, FusionZeroSizeTensorPW_CUDA) {
+TEST(NVFuserTest, FusionGlobalIntermediate_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
fusion.addInput(tv0);
+ fusion.addOutput(tv1);
+ // tv1[I0, R1] = tv0[I0, I1]
- auto tv1 = makeConcreteTensor({0});
- fusion.addInput(tv1);
+ // Interface should just be a direct split with a Parallel type. We can
+ // include the parallelize call if we do this.
+ tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
+ // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1]
- auto tv2 = add(tv0, new Double(2.5));
- fusion.addOutput(tv2);
+ TensorView* tv2 = tv1->rFactor({2});
+ tv2->setMemoryType(MemoryType::Global);
+ // tv2[I0, R1oo, Ir1i{BIDx}] = tv0[I0, I1]
+ // tv1[I0, R1i{BIDx}] = tv2[I0, R1oo, Ir1i{BIDx}]
- auto tv3 = makeConcreteTensor({0});
- fusion.addOutput(tv3);
+ tv0->computeAt(tv1, 1);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ tv1->axis(0)->parallelize(ParallelType::BIDx);
+
+ constexpr int numel_x = 65000, numel_y = 1024;
- at::Tensor input0 = at::randn({2}, options);
- at::Tensor input1 = at::randn({0}, options);
- at::Tensor cg_output2 = at::empty({2}, options);
- at::Tensor cg_output3 = at::empty({0}, options);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
- auto lparams = schedulePointwise(&fusion, {input0, input1});
+ // How many threads to use for the block reduction
+ constexpr int runtime_threadIdx_dim = 128;
FusionExecutor fe;
fe.compileFusion(&fusion);
- fe.runFusion({input0, input1}, {cg_output2, cg_output3}, lparams);
+ auto outputs = fe.runFusion(
+ {input}, LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1));
- auto aten_output2 = input0.add(2.5);
- at::Tensor aten_output3 = at::empty({0}, options);
-
- testValidate(
- &fusion,
- {cg_output2, cg_output3},
- {input0, input1},
- {aten_output2, aten_output3},
- __LINE__,
- __FILE__);
+ auto aten_output = input.sum({1});
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
}
-TEST(NVFuserTest, FusionZeroSizeTensorReduction_CUDA) {
+TEST(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
+ TensorView* tv1 = makeDummyTensor(2);
+ TensorView* tv2 = makeDummyTensor(2);
+ TensorView* tv3 = makeDummyTensor(2);
+ TensorView* tv4 = sub(tv2, tv3);
+ TensorView* tv5 = add(tv1, tv4);
+ TensorView* tv6 = sub(tv5, tv0);
fusion.addInput(tv0);
-
- auto tv1 = makeConcreteTensor({0});
fusion.addInput(tv1);
+ fusion.addInput(tv2);
+ fusion.addInput(tv3);
+ fusion.addOutput(tv6);
+ // t6 = ((t1 + (t2 - t3)) - t0)
- auto tv2 = sum(tv0, {1});
- fusion.addOutput(tv2);
-
- auto tv3 = makeConcreteTensor({0});
- fusion.addOutput(tv3);
+ tv4->setMemoryType(MemoryType::Global);
+ tv5->setMemoryType(MemoryType::Global);
+ tv6->setMemoryType(MemoryType::Global);
+ constexpr int M = 32, N = 810;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor in0 = at::rand({M, N}, options);
+ at::Tensor in1 = at::rand({M, N}, options);
+ at::Tensor in2 = at::rand({M, N}, options);
+ at::Tensor in3 = at::rand({M, N}, options);
- at::Tensor input0 = at::randn({2, 4}, options);
- at::Tensor input1 = at::randn({0}, options);
- at::Tensor cg_output2 = at::empty({2}, options);
- at::Tensor cg_output3 = at::empty({0}, options);
-
- auto reduction_params = getReductionHeuristics(&fusion, {input0, input1});
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
- scheduleReduction(&fusion, reduction_params.value());
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
-
- auto lparams = reduction_params.value().lparams;
FusionExecutor fe;
fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input0, input1}, lparams);
- auto aten_output2 = input0.sum({1});
- at::Tensor aten_output3 = at::empty({0}, options);
-
- testValidate(
- &fusion,
- cg_outputs,
- {input0, input1},
- {aten_output2, aten_output3},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-TEST(NVFuserTest, FusionZeroSizeTensorNormalization_CUDA) {
+ auto outputs = fe.runFusion({in0, in1, in2, in3});
+
+ at::Tensor aten_output = (in1 + (in2 - in3)) - in0;
+ TORCH_CHECK(
+ aten_output.allclose(outputs[0], 1e-5, 1e-5),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().sum());
+}
+
+TEST(NVFuserTest, FusionConstCheck_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- auto tv1 = makeConcreteTensor({0});
- fusion.addInput(tv1);
+ auto one = new Int(1);
+ TORCH_CHECK(one->isConstScalar());
- auto tv2 = sum(tv0, {0});
- auto tv3 = broadcast(tv2, {true, false});
- auto tv4 = add(tv0, tv3);
- fusion.addOutput(tv4);
+ auto one_x2 = mul(one, one);
+ TORCH_CHECK(one_x2->isConstScalar());
- auto tv5 = makeConcreteTensor({0});
- fusion.addOutput(tv5);
+ auto one_x3 = mul(one_x2, one);
+ TORCH_CHECK(one_x3->isConstScalar());
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ auto one_x4 = mul(one_x3, one);
+ TORCH_CHECK(one_x4->isConstScalar());
+}
- at::Tensor input0 = at::randn({2, 4}, options);
- at::Tensor input1 = at::randn({0}, options);
- at::Tensor cg_output2 = at::empty({2, 4}, options);
- at::Tensor cg_output3 = at::empty({0}, options);
+TEST(NVFuserTest, FusionUnrollWithAlloc_CUDA) {
+ const std::vector<int64_t> tensor_dims_in = {128, 128};
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- auto reduction_params = getNormalizationHeuristics(&fusion, {input0, input1});
- TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
- scheduleNormalization(&fusion, reduction_params.value());
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(tensor_dims_in.size());
+ fusion.addInput(tv0);
- auto lparams = reduction_params.value().lparams;
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto cg_outputs = fe.runFusion({input0, input1}, lparams);
- auto aten_output2 = input0.sum({0}).add(input0);
- at::Tensor aten_output3 = at::empty({0}, options);
-
- testValidate(
- &fusion,
- cg_outputs,
- {input0, input1},
- {aten_output2, aten_output3},
- __LINE__,
- __FILE__,
- "",
- lparams);
-}
-
-TEST(NVFuserTest, FusionSegmentIoAlias_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+ TensorView* tv1 = add(tv0, new Float(0));
+ TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv1);
+ fusion.addOutput(tv2);
- TensorView* tv0 = makeSymbolicTensor(2);
- TensorView* tv1 = makeSymbolicTensor(1);
- TensorView* tv2 = makeSymbolicTensor(2);
+ const auto options =
+ at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand(tensor_dims_in, options);
+ at::Tensor cg_output = at::empty({tensor_dims_in[0]}, options);
- fusion->addInput(tv0);
- fusion->addInput(tv1);
- fusion->addInput(tv2);
+ // const at::ArrayRef<c10::IValue> inputs({input});
- TensorView* tv3 = add(tv0, new Double(1)); // Group 0
- TensorView* tv4 =
- max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues)
- TensorView* tv5 = add(tv4, tv1); // Group 0 (Non Broadcast after reduce,
- // keeps normalization scheduler away)
- TensorView* tv6 = add(tv5, tv2); // Group 1 (Broadcast after reduce)
+ // Schedule
+ tv2->split(1, 32);
+ tv2->split(1, 4); // unroll
- fusion->addOutput(tv6);
- // Note: test alias;
- fusion->aliasOutputToInput(tv6, tv0);
+ auto tv2_rf = tv2->rFactor({-3, -2});
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({128, 65}, options);
- at::Tensor t1 = at::randn({65}, options);
- at::Tensor t2 = at::randn({128, 65}, options);
+ tv2->axis(0)->parallelize(ParallelType::BIDx);
+ tv2->axis(-1)->parallelize(ParallelType::TIDx);
- auto t3 = t0.add(1.0);
- auto t4 = std::get<0>(at::max(t3, 0));
- auto t5 = t4.add(t1);
- auto t6 = t5.add(t2);
+ tv2_rf->axis(0)->parallelize(ParallelType::BIDx);
+ tv2_rf->axis(-1)->parallelize(ParallelType::TIDx);
+ tv2_rf->axis(-2)->parallelize(ParallelType::Unroll);
- FusionExecutorCache executor_cache(std::move(fusion));
+ tv1->computeAt(tv2_rf, -1);
- auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2});
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({input});
- // validating aliasing
- TORCH_INTERNAL_ASSERT(outputs[0].data_ptr() == t0.data_ptr());
+ auto aten_output = (input + 0).sum(1);
TORCH_CHECK(
- executor_cache.getMostRecentKernelRuntime()->isSegmented(),
- "segmentation didn't happen");
- TORCH_CHECK(
- executor_cache.getMostRecentKernelRuntime()
- ->fusionSegments()
- ->groups()
- .size() == 2,
- "segmentation didn't happen as expected");
-
- testValidate(
- executor_cache.fusion(), outputs, {t0, t1, t2}, {t6}, __LINE__, __FILE__);
+ aten_output.allclose(outputs[0]),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
}
-TEST(NVFuserTest, FusionWelford1Output_CUDA) {
- auto fusion_ptr = std::make_unique<Fusion>();
- auto fusion = fusion_ptr.get();
- FusionGuard fg(fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion->addInput(tv0);
-
- auto tvs = Welford(tv0, {1});
- fusion->addOutput(tvs.var_sum);
- FusionExecutorCache executor_cache(std::move(fusion_ptr));
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({128, 65}, options);
- auto outputs = executor_cache.runFusionWithInputs({t0});
+// Test isZeroInt
+TEST(NVFuserTest, FusionIsZeroInt_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- auto t1 = t0.var({1}, false) * 65;
- testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__);
+ Int* x = new Int(0);
+ Int* y = new Int(1);
+ Val* z = mul(x, y);
+ TORCH_CHECK(x->isZeroInt());
+ TORCH_CHECK(!y->isZeroInt());
+ TORCH_CHECK(!z->isZeroInt());
}
-TEST(NVFuserTest, FusionTranslate1Welford_CUDA) {
- auto fusion_ptr = std::make_unique<Fusion>();
- auto fusion = fusion_ptr.get();
- FusionGuard fg(fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion->addInput(tv0);
-
- auto tvs = Welford(tv0, {1});
- fusion->addOutput(tvs.var_sum);
- FusionExecutorCache executor_cache(std::move(fusion_ptr));
-
- auto run_test = [&executor_cache,
- fusion](auto inner_size) -> FusionKernelRuntime* {
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({128, inner_size}, options);
- auto outputs = executor_cache.runFusionWithInputs({t0});
- // Square sums does not fit well in the testValidate assumptions,
- // so we just compare the divided output here.
- outputs[0] /= inner_size;
- auto t1 = t0.var({1}, false);
- testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__);
-
- return executor_cache.getMostRecentKernelRuntime();
- };
-
- // Run a translated welford
- auto runtime1 = run_test(64);
- // Check it was translated
- TORCH_CHECK(runtime1->singleKernelFusion()->unordered_exprs().size() > 2);
- TORCH_CHECK(
- runtime1->schedulerHeuristics()->singleKernelHeuristics()->heuristc() ==
- ScheduleHeuristic::Normalization);
+// Test isOneInt
+TEST(NVFuserTest, FusionIsOneInt_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- // Run an un-translated welford
- auto runtime2 = run_test(65536);
- // Check it was not translated
- TORCH_CHECK(runtime2->singleKernelFusion()->unordered_exprs().size() == 1);
- TORCH_CHECK(
- runtime2->schedulerHeuristics()->singleKernelHeuristics()->heuristc() ==
- ScheduleHeuristic::Reduction);
+ Int* x = new Int(1);
+ Int* y = new Int(1);
+ Val* z = mul(x, y);
+ TORCH_CHECK(x->isOneInt());
+ TORCH_CHECK(y->isOneInt());
+ TORCH_CHECK(!z->isOneInt());
}
-TEST(NVFuserTest, FusionTranslate2Welford_CUDA) {
- auto fusion_ptr = std::make_unique<Fusion>();
- auto fusion = fusion_ptr.get();
- FusionGuard fg(fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion->addInput(tv0);
-
- auto tvs1 = Welford(tv0, {1});
- auto tvs2 = Welford(tv0, {1});
-
- fusion->addOutput(tvs1.var_sum);
- fusion->addOutput(tvs2.var_sum);
-
- FusionExecutorCache executor_cache(std::move(fusion_ptr));
-
- auto run_test = [&executor_cache,
- fusion](auto inner_size) -> FusionKernelRuntime* {
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({128, inner_size}, options);
- auto outputs = executor_cache.runFusionWithInputs({t0});
-
- // Square sums does not fit well in the testValidate assumptions,
- // so we just compare the divided output here.
- outputs[0] /= inner_size;
- outputs[1] /= inner_size;
- auto t1 = t0.var({1}, false);
- testValidate(fusion, outputs, {t0}, {t1, t1}, __LINE__, __FILE__);
-
- return executor_cache.getMostRecentKernelRuntime();
- };
+// This is to verify no cycle of computeAt is created. A more complex
+// variation of this pattern appears in one of the Python tests
+// (test_random_topo).
+TEST(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- // Run a translated welford
- auto runtime1 = run_test(64);
- // Check it was translated
- TORCH_CHECK(runtime1->singleKernelFusion()->unordered_exprs().size() > 4);
- TORCH_CHECK(
- runtime1->schedulerHeuristics()->singleKernelHeuristics()->heuristc() ==
- ScheduleHeuristic::Normalization);
+ TensorView* tv0 = makeDummyTensor(1);
+ fusion.addInput(tv0);
- // Run an un-translated welford
- auto runtime2 = run_test(65536);
- // // Check it was not translated
- TORCH_CHECK(runtime2->singleKernelFusion()->unordered_exprs().size() == 2);
-}
+ // Common intermediate tensor
+ auto tv1 = add(tv0, new Float(1));
+ // tv1 -> tv2
+ auto tv2 = add(tv1, new Float(2));
+ // tv1 -> tv3 -> tv4
+ auto tv3 = add(tv1, new Float(3));
+ auto tv4 = add(tv3, new Float(4));
-TEST(NVFuserTest, FusionLargeWelfordNormalization_CUDA) {
- auto fusion_ptr = std::make_unique<Fusion>();
- auto fusion = fusion_ptr.get();
- FusionGuard fg(fusion);
+ // NOTE: This should no longer occur as of PR #201.
+ // The order of adding outputs matters. If tv3 is added before tv4,
+ // it should be fine. However, if tv4 is added before tv3, there
+ // will be a cycle of tv3->tv4 and tv4->tv3. tv3->tv4 is created
+ // first, and then tv4->tv3 is created at the final phase of
+ // computeAt (ComputeAt::setupOutputs).
+ fusion.addOutput(tv2);
+ fusion.addOutput(tv4);
+ fusion.addOutput(tv3);
- auto tv0 = makeSymbolicTensor(2);
- fusion->addInput(tv0);
+ tv0->computeAt(tv2, -1);
- auto tvs1 = Welford(tv0, {1});
- auto sum_of_tv0 = sum(tv0, {1});
- auto sum_plus_avg = add(tvs1.avg, sum_of_tv0);
+ TORCH_CHECK(
+ !(tv3->getComputeAtView() == tv4 && tv4->getComputeAtView() == tv3),
+ "ComputeAt cycle detected between tv3 and tv4");
- fusion->addOutput(sum_plus_avg);
+ const auto options =
+ at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand(100, options);
- FusionExecutorCache executor_cache(std::move(fusion_ptr));
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion({input});
- auto run_test = [&executor_cache,
- fusion](auto inner_size) -> FusionKernelRuntime* {
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({128, inner_size}, options);
- auto outputs = executor_cache.runFusionWithInputs({t0});
+ auto& output_tv2 = outputs[0];
+ auto& output_tv4 = outputs[1];
+ auto& output_tv3 = outputs[2];
- auto t1 = t0.mean({1}) + t0.sum({1});
- testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__);
+ auto aten_t1 = input + 1;
+ auto aten_t2 = aten_t1 + 2;
+ auto aten_t3 = aten_t1 + 3;
+ auto aten_t4 = aten_t3 + 4;
- return executor_cache.getMostRecentKernelRuntime();
- };
+ TORCH_CHECK(
+ aten_t2.allclose(output_tv2),
+ "Error of: ",
+ aten_t2.sub(output_tv2).abs().max());
+ TORCH_CHECK(
+ aten_t3.allclose(output_tv3),
+ "Error of: ",
+ aten_t3.sub(output_tv3).abs().max());
+ TORCH_CHECK(
+ aten_t4.allclose(output_tv4),
+ "Error of: ",
+ aten_t4.sub(output_tv4).abs().max());
- auto runtime = run_test(65536);
- TORCH_CHECK(!runtime->isSegmented());
+ return;
}
-TEST(NVFuserTest, FusionWelfordOtherPersistence_CUDA) {
- auto fusion_ptr = std::make_unique<Fusion>();
- auto fusion = fusion_ptr.get();
- FusionGuard fg(fusion);
+TEST(NVFuserTest, FusionTraversalOrder1_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
- fusion->addInput(tv0);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
- auto tvs1 = Welford(tv0, {1});
- auto sum_of_tv0 = sum(tv0, {1});
- auto sum_bcasted = broadcast(sum_of_tv0, {false, true});
- auto avg_bcasted = broadcast(tvs1.avg, {false, true});
- auto tv0_plus_sum = add(tv0, sum_bcasted);
- auto tv0_plus_avg = add(tv0, avg_bcasted);
+ TensorView* tv1 = add(tv0, new Float(1));
+ TensorView* tv2 = add(tv0, new Float(2));
+ TensorView* tv3 = add(tv1, new Float(3));
+ TensorView* tv4 = add(tv1, new Float(4));
- fusion->addOutput(tv0_plus_sum);
- fusion->addOutput(tv0_plus_avg);
+ fusion.addOutput(tv2);
+ fusion.addOutput(tv3);
+ fusion.addOutput(tv4);
- FusionExecutorCache executor_cache(std::move(fusion_ptr));
+ tv1->computeAt(tv3, -1);
- auto run_test = [&executor_cache,
- fusion](auto inner_size) -> FusionKernelRuntime* {
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({128, inner_size}, options);
- auto outputs = executor_cache.runFusionWithInputs({t0});
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
- auto t1 = t0.mean({1}).unsqueeze(1) + t0;
- auto t2 = t0.sum({1}).unsqueeze(1) + t0;
- testValidate(fusion, outputs, {t0}, {t2, t1}, __LINE__, __FILE__);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({10, 10}, options);
+ at::Tensor cg_output_tv2 = at::empty_like(input, options);
+ at::Tensor cg_output_tv3 = at::empty_like(input, options);
+ at::Tensor cg_output_tv4 = at::empty_like(input, options);
+ fe.runFusion({input}, {cg_output_tv2, cg_output_tv3, cg_output_tv4});
- return executor_cache.getMostRecentKernelRuntime();
- };
+ auto t1 = input + 1;
+ auto t2 = input + 2;
+ auto t3 = t1 + 3;
+ auto t4 = t1 + 4;
- for (auto inner_size : {4096, 8192, 32768}) {
- auto runtime = run_test(4096);
- TORCH_CHECK(!runtime->isSegmented());
- }
+ TORCH_CHECK(
+ t2.allclose(cg_output_tv2),
+ "tv2 error of: ",
+ t2.sub(cg_output_tv2).abs().max());
+ TORCH_CHECK(
+ t3.allclose(cg_output_tv3),
+ "tv5 error of: ",
+ t3.sub(cg_output_tv3).abs().max());
+ TORCH_CHECK(
+ t4.allclose(cg_output_tv4),
+ "tv4 error of: ",
+ t4.sub(cg_output_tv4).abs().max());
}
-TEST(NVFuserTest, TestSegmentIslands_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
-
- auto tv0 = makeSymbolicTensor(2);
- auto tv1 = makeSymbolicTensor(2);
- fusion->addInput(tv0);
- fusion->addInput(tv1);
-
- auto tv2 = sum(tv0, {0});
- auto tv3 = sum(tv1, {1});
- fusion->addOutput(tv2);
- fusion->addOutput(tv3);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({16, 16}, options);
- at::Tensor t1 = at::randn({16, 16}, options);
+TEST(NVFuserTest, FusionTraversalOrder2_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- FusionExecutorCache fusion_executor_cache(std::move(fusion));
- fusion_executor_cache.runFusionWithInputs({t0, t1});
-}
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(2);
+ fusion.addInput(tv0);
-TEST(NVFuserTest, TestBackOffInnerBroadcast_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+ TensorView* tv1 = add(tv0, new Float(1));
+ TensorView* tv2 = add(tv1, new Float(2));
- auto tv0 = makeSymbolicTensor(1);
- auto tv1 = makeSymbolicTensor(2);
- auto tv2 = makeSymbolicTensor(4);
- fusion->addInput(tv0);
- fusion->addInput(tv1);
+ TensorView* tv3 = add(tv0, new Float(3));
+ TensorView* tv4 = add(tv3, new Float(4));
- auto tv3 = broadcast(tv0, {false, true, true, true});
- auto tv4 = broadcast(tv1, {false, false, true, true});
- auto tv5 = unaryOp(UnaryOpType::Rsqrt, tv2);
+ TensorView* tv5 = add(tv1, tv3);
- auto tv6 = add(tv3, tv5);
- auto tv7 = add(tv4, tv5);
- auto tv8 = add(tv3, tv4);
+ fusion.addOutput(tv2);
+ fusion.addOutput(tv4);
+ fusion.addOutput(tv5);
- auto tv9 = add(tv6, tv7);
- auto tv10 = add(tv9, tv8);
+ tv1->computeAt(tv5, -1);
+ tv3->computeAt(tv5, -1);
- fusion->addOutput(tv10);
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
- tv0->computeAt(tv10, -2);
- tv1->computeAt(tv10, -2);
- tv2->computeAt(tv10, -2);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({10, 10}, options);
+ at::Tensor cg_output_tv2 = at::empty_like(input, options);
+ at::Tensor cg_output_tv4 = at::empty_like(input, options);
+ at::Tensor cg_output_tv5 = at::empty_like(input, options);
+ fe.runFusion({input}, {cg_output_tv2, cg_output_tv4, cg_output_tv5});
- TORCH_CHECK(tv3->getComputeAtPosition() == 1);
- TORCH_CHECK(tv4->getComputeAtPosition() == 2);
- TORCH_CHECK(tv5->getComputeAtPosition() == 3);
+ auto t1 = input + 1;
+ auto t2 = t1 + 2;
+ auto t3 = input + 3;
+ auto t4 = t3 + 4;
+ auto t5 = t1 + t3;
- TORCH_CHECK(tv6->getMaxProducerPosition() == 3);
- TORCH_CHECK(tv7->getMaxProducerPosition() == 3);
- TORCH_CHECK(tv8->getMaxProducerPosition() == 2);
+ TORCH_CHECK(
+ t2.allclose(cg_output_tv2),
+ "tv2 error of: ",
+ t2.sub(cg_output_tv2).abs().max());
+ TORCH_CHECK(
+ t4.allclose(cg_output_tv4),
+ "tv4 error of: ",
+ t4.sub(cg_output_tv4).abs().max());
+ TORCH_CHECK(
+ t5.allclose(cg_output_tv5),
+ "tv5 error of: ",
+ t5.sub(cg_output_tv5).abs().max());
}
-TEST(NVFuserTest, TestBackOffInnerBroadcast2_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+TEST(NVFuserTest, FusionTraversalOrder3_CUDA) {
+ for (int i = 0; i < 2; ++i) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
- auto tv1 = makeSymbolicTensor(3);
- fusion->addInput(tv0);
- fusion->addInput(tv1);
- auto tv2 = broadcast(tv0, {false, false, true});
- auto tv3 = add(tv2, tv1);
+ TensorView* tv0 = makeDummyTensor(1);
+ fusion.addInput(tv0);
- fusion->addOutput(tv3);
- tv3->split(-2, 4);
- tv3->reorder({{-1, -2}});
- tv0->computeAt(tv3, -2);
- tv1->computeAt(tv3, -2);
- TORCH_CHECK(tv2->getComputeAtPosition() == 2);
- TORCH_CHECK(tv3->getMaxProducerPosition() == 2);
-}
+ TensorView* tv1 = add(tv0, new Float(1));
+ TensorView* tv2 = add(tv1, new Float(2));
-TEST(NVFuserTest, TestBackOffInnerBroadcast3_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+ TensorView* tv3 = add(tv0, new Float(3));
+ TensorView* tv4 = add(tv3, new Float(4));
- auto tv0 = makeSymbolicTensor(2);
- auto tv1 = makeSymbolicTensor(4);
- fusion->addInput(tv0);
- fusion->addInput(tv1);
- auto tv2 = broadcast(tv0, {false, false, true});
- auto tv3 = broadcast(tv2, {false, true, false, false});
- auto tv4 = add(tv3, tv1);
+ TensorView* tv5 = add(tv1, tv3);
- fusion->addOutput(tv4);
- tv0->computeAt(tv4, -1);
- tv1->computeAt(tv4, -1);
- TORCH_CHECK(tv2->getComputeAtPosition() == 2);
- TORCH_CHECK(tv3->getMaxProducerPosition() == 3);
-}
+ fusion.addOutput(tv2);
+ fusion.addOutput(tv4);
+ fusion.addOutput(tv5);
-TEST(NVFuserTest, FusionSegfaultReduction_CUDA) {
- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
- Fusion& fusion = *fusion_ptr.get();
- FusionGuard fg(&fusion);
+ const int tile = 32;
+
+ tv1->split(-1, tile);
+ tv2->split(-1, tile);
+ tv3->split(-1, tile);
+ tv4->split(-1, tile);
+ tv5->split(-1, tile);
- int batch = 2;
- int c = 1;
- int h = 1;
- int w = 1;
- int numDims = 4;
-
- auto input = makeConcreteTensor({-1, 1, 1, 1});
- fusion.addInput(input);
- auto bcast_bias = makeConcreteTensor({-1, 1, 1, 1});
- fusion.addInput(bcast_bias);
-
- std::vector<int64_t> at_sum_axes;
- std::vector<int> outer_reduction_axes;
- std::vector<bool> outer_broadcast_mask(numDims, false);
- Val* N = new Double(1);
- for (size_t axis = 0; axis < numDims; ++axis) {
- if (axis != 1) {
- outer_reduction_axes.push_back(axis);
- at_sum_axes.push_back(axis);
- outer_broadcast_mask[axis] = true;
- N = mul(N, input->domain()->domain()[axis]->extent());
+ auto compute_at_outer = tv1;
+ auto compute_at_inner = tv3;
+ if (i == 1) {
+ std::swap(compute_at_inner, compute_at_outer);
}
- }
- auto output0 = mul(input, bcast_bias);
- fusion.addOutput(output0);
- auto output1 = sum(output0, outer_reduction_axes);
- fusion.addOutput(output1);
+ compute_at_outer->computeAt(tv5, -2);
+ compute_at_inner->computeAt(tv5, -1);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input0 = at::randn({batch, c, h, w}, options);
- at::Tensor input1 = at::randn({batch, c, h, w}, options);
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
- auto at_output0 = input0.mul(input1);
- auto at_output1 = at_output0.sum(at_sum_axes);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor input = at::rand({100}, options);
+ at::Tensor cg_output_tv2 = at::empty_like(input, options);
+ at::Tensor cg_output_tv4 = at::empty_like(input, options);
+ at::Tensor cg_output_tv5 = at::empty_like(input, options);
+ fe.runFusion({input}, {cg_output_tv2, cg_output_tv4, cg_output_tv5});
- FusionExecutorCache fec(std::move(fusion_ptr));
- std::vector<IValue> inputs = {input0, input1};
- auto outputs = fec.runFusionWithInputs(inputs);
+ auto t1 = input + 1;
+ auto t2 = t1 + 2;
+ auto t3 = input + 3;
+ auto t4 = t3 + 4;
+ auto t5 = t1 + t3;
- testValidate(
- &fusion, outputs, inputs, {at_output0, at_output1}, __LINE__, __FILE__);
+ TORCH_CHECK(
+ t2.allclose(cg_output_tv2),
+ "tv2 error of: ",
+ t2.sub(cg_output_tv2).abs().max());
+ TORCH_CHECK(
+ t4.allclose(cg_output_tv4),
+ "tv4 error of: ",
+ t4.sub(cg_output_tv4).abs().max());
+ TORCH_CHECK(
+ t5.allclose(cg_output_tv5),
+ "tv5 error of: ",
+ t5.sub(cg_output_tv5).abs().max());
+ }
}
-TEST(NVFuserTest, FusionPredicateElimination_CUDA) {
+TEST(NVFuserTest, FusionTraversalOrder4_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
+ // First tree
+ TensorView* tv0 = makeDummyTensor(1);
fusion.addInput(tv0);
+ TensorView* tv1 = add(tv0, new Float(1));
+ TensorView* tv2 = add(tv1, new Float(2));
+ TensorView* tv3 = add(tv1, new Float(3));
+ fusion.addOutput(tv2);
+ fusion.addOutput(tv3);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(2));
- auto tv3 = add(tv2, new Double(3));
+ // Second tree
+ TensorView* tv4 = makeDummyTensor(1);
+ fusion.addInput(tv4);
+ TensorView* tv5 = add(tv4, new Float(5));
+ TensorView* tv6 = add(tv5, new Float(6));
+ TensorView* tv7 = add(tv5, new Float(7));
+ fusion.addOutput(tv6);
+ fusion.addOutput(tv7);
- fusion.addOutput(tv3);
+ tv1->computeAt(tv2, -1);
+ tv5->computeAt(tv6, -1);
- tv3->split(0, 32);
- tv0->computeAt(tv3, 1);
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
- tv2->axis(1)->parallelize(ParallelType::Unswitch);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor t0 = at::rand({100}, options);
+ at::Tensor t4 = at::rand_like(t0, options);
+ at::Tensor cg_output_tv2 = at::empty_like(t0, options);
+ at::Tensor cg_output_tv3 = at::empty_like(t0, options);
+ at::Tensor cg_output_tv6 = at::empty_like(t0, options);
+ at::Tensor cg_output_tv7 = at::empty_like(t0, options);
- {
- GpuLower gpulw(&fusion);
- TORCH_CHECK(!isPredicated(tv2, gpulw));
- }
+ fe.runFusion(
+ {t0, t4}, {cg_output_tv2, cg_output_tv3, cg_output_tv6, cg_output_tv7});
- tv2->axis(1)->parallelize(ParallelType::Serial);
- tv2->split(1, 5);
+ auto t1 = t0 + 1;
+ auto t2 = t1 + 2;
+ auto t3 = t1 + 3;
+ auto t5 = t4 + 5;
+ auto t6 = t5 + 6;
+ auto t7 = t5 + 7;
- {
- GpuLower gpulw(&fusion);
- TORCH_CHECK(isPredicated(tv2, gpulw));
- }
+ TORCH_CHECK(
+ t2.allclose(cg_output_tv2),
+ "tv2 error of: ",
+ t2.sub(cg_output_tv2).abs().max());
+ TORCH_CHECK(
+ t3.allclose(cg_output_tv3),
+ "tv3 error of: ",
+ t3.sub(cg_output_tv3).abs().max());
+ TORCH_CHECK(
+ t6.allclose(cg_output_tv6),
+ "tv6 error of: ",
+ t6.sub(cg_output_tv6).abs().max());
+ TORCH_CHECK(
+ t7.allclose(cg_output_tv7),
+ "tv7 error of: ",
+ t7.sub(cg_output_tv7).abs().max());
}
-TEST(NVFuserTest, ForceFp16Simple_CUDA) {
- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
- auto fusion = fusion_ptr.get();
- FusionGuard fg(fusion);
+TEST(NVFuserTest, FusionTraversalOrder5_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
- auto tv1 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(1);
+ fusion.addInput(tv0);
+ TensorView* tv1 = add(tv0, new Float(1));
+ TensorView* tv2 = add(tv1, new Float(2));
+ TensorView* tv3 = add(tv0, new Float(3));
+ TensorView* tv4 = add(tv3, new Float(4));
+ TensorView* tv5 = add(tv2, tv4);
- fusion->addInput(tv0);
- fusion->addInput(tv1);
+ fusion.addOutput(tv1);
+ fusion.addOutput(tv3);
+ fusion.addOutput(tv5);
- // Group 1
- auto tv2 = sum(tv0, {1});
- auto tv3 = broadcast(tv2, {false, true});
+ tv2->computeAt(tv5, -1);
+ tv4->computeAt(tv5, -1);
- // Group 2
- auto tv4 = add(tv3, tv1); // Edge: tv3: expect cast
- auto tv5 = castOp(DataType::Half, tv4);
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
- fusion->addOutput(tv5);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor t0 = at::rand({100}, options);
+ at::Tensor cg_output_tv1 = at::empty_like(t0, options);
+ at::Tensor cg_output_tv3 = at::empty_like(t0, options);
+ at::Tensor cg_output_tv5 = at::empty_like(t0, options);
- FusionExecutorCache fec(std::move(fusion_ptr));
+ fe.runFusion({t0}, {cg_output_tv1, cg_output_tv3, cg_output_tv5});
- std::vector<int64_t> shape{15, 16};
+ auto t1 = t0 + 1;
+ auto t2 = t1 + 2;
+ auto t3 = t0 + 3;
+ auto t4 = t3 + 4;
+ auto t5 = t2 + t4;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto in0 = at::randn(shape, options);
- auto in1 = at::randn(shape, options);
- fec.runFusionWithInputs({in0, in1});
-
- // Check the segmented edge is fp16
- auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments();
- for (auto edge : segmented_fusion->edges()) {
- auto edge_tv = edge->val->as<TensorView>();
- TORCH_CHECK(edge_tv->getDataType() == DataType::Half);
- }
+ TORCH_CHECK(
+ t1.allclose(cg_output_tv1),
+ "tv1 error of: ",
+ t1.sub(cg_output_tv1).abs().max());
+ TORCH_CHECK(
+ t3.allclose(cg_output_tv3),
+ "tv3 error of: ",
+ t3.sub(cg_output_tv3).abs().max());
+ TORCH_CHECK(
+ t5.allclose(cg_output_tv5),
+ "tv5 error of: ",
+ t5.sub(cg_output_tv5).abs().max());
}
-TEST(NVFuserTest, ForceFp16NotAllCast_CUDA) {
- std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
- auto fusion = fusion_ptr.get();
- FusionGuard fg(fusion);
+TEST(NVFuserTest, FusionTraversalOrder6_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(3);
- auto tv1 = makeSymbolicTensor(3);
+ TensorView* tv0 = makeDummyTensor(1);
+ fusion.addInput(tv0);
+ TensorView* tv1 = add(tv0, new Float(1));
+ TensorView* tv2 = add(tv0, new Float(2));
+ TensorView* tv3 = add(tv1, tv2);
+ TensorView* tv4 = add(tv3, new Float(4));
- fusion->addInput(tv0);
- fusion->addInput(tv1);
+ fusion.addOutput(tv4);
- // Group 1
- auto tv3 = sum(tv0, {1});
- auto tv4 = broadcast(tv3, {false, true, false});
- auto tv5 = sum(tv0, {1});
+ tv1->split(0, 32);
+ tv2->split(0, 32);
+ tv3->split(0, 32);
+ tv4->split(0, 32);
- // Group 2
- auto tv6 = add(tv4, tv1); // edge tv4, expect cast
- auto tv7 = castOp(DataType::Half, tv6);
+ tv3->computeAt(tv4, -2);
+ tv1->computeAt(tv3, -1);
+ tv2->computeAt(tv3, -2);
- // Group 3
- auto tv8 = sum(tv5, {1}); // edge tv5, don't expect cast
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
- fusion->addOutput(tv7);
- fusion->addOutput(tv8);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+ at::Tensor t0 = at::rand({100}, options);
+ at::Tensor cg_output_tv4 = at::empty_like(t0, options);
- FusionExecutorCache fec(std::move(fusion_ptr));
+ fe.runFusion({t0}, {cg_output_tv4});
- std::vector<int64_t> shape{16, 16, 16};
+ auto t1 = t0 + 1;
+ auto t2 = t0 + 2;
+ auto t3 = t1 + t2;
+ auto t4 = t3 + 4;
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto in0 = at::randn(shape, options);
- auto in1 = at::randn(shape, options);
- fec.runFusionWithInputs({in0, in1});
-
- auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments();
- auto complete_fusion = segmented_fusion->completeFusion();
-
- // Check that the edge that wasn't fp16 is the producer of the
- // reduction op, i.e. tv8 = sum(tv5,{1});.
- for (auto edge : segmented_fusion->edges()) {
- auto edge_tv = edge->val->as<TensorView>();
- if (edge_tv->getDataType() == DataType::Float) {
- auto consumer = *(complete_fusion->unordered_uses(edge_tv).begin());
- TORCH_CHECK(consumer->isA<ReductionOp>());
- }
- }
+ TORCH_CHECK(
+ t4.allclose(cg_output_tv4),
+ "tv4 error of: ",
+ t4.sub(cg_output_tv4).abs().max());
}
-TEST(NVFuserTest, FusionIssue970_CUDA) {
+TEST(NVFuserTest, FusionTraversalOrder7_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
- const int nelm = 10;
-
- // tv3 = tv0 + sum(tv0)
- auto tv0 = makeConcreteTensor({nelm, nelm});
+ TensorView* tv0 = makeDummyTensor(1);
fusion.addInput(tv0);
- auto tv1 = sum(tv0, {1});
- auto tv2 = broadcast(tv1, {false, true});
- auto tv3 = add(tv2, tv0);
- fusion.addOutput(tv3);
+ TensorView* tv1 = add(tv0, new Float(1));
+ TensorView* tv2 = add(tv1, new Float(2));
+ TensorView* tv3 = add(tv0, new Float(3));
+ TensorView* tv4 = add(tv3, new Float(4));
+ TensorView* tv5 = add(tv2, tv4);
- tv1->split(1, 4);
+ fusion.addOutput(tv5);
+
+ TensorView* tvs[] = {tv1, tv2, tv3, tv4, tv5};
+ for (auto tv : tvs) {
+ tv->split(0, 2);
+ tv->split(0, 4);
+ tv->split(0, 8);
+ }
+
+ // computeAt into inner loop nests
+ tv1->computeAt(tv2, -1);
+ tv3->computeAt(tv4, -2);
+
+ tv2->computeAt(tv5, -4);
+ tv4->computeAt(tv5, -3);
FusionExecutor fe;
fe.compileFusion(&fusion);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor t0 = at::randn({nelm, nelm}, options);
-
- auto outputs = fe.runFusion({t0});
+ at::Tensor t0 = at::rand({100}, options);
+ at::Tensor cg_output_tv5 = at::empty_like(t0, options);
+ fe.runFusion({t0}, {cg_output_tv5});
- auto ref = sum(t0, {1}).unsqueeze(-1).expand({nelm, nelm}) + t0;
+ auto t1 = t0 + 1;
+ auto t2 = t1 + 2;
+ auto t3 = t0 + 3;
+ auto t4 = t3 + 4;
+ auto t5 = t2 + t4;
- testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__);
+ TORCH_CHECK(
+ t5.allclose(cg_output_tv5),
+ "tv5 error of: ",
+ t5.sub(cg_output_tv5).abs().max());
}
-// Reproducer of #1016
-TEST(NVFuserTest, FusionIssue1016_CUDA) {
+// Test predication of grid reduction
+TEST(NVFuserTest, FusionThreadPredicate_CUDA) {
+ const int gdimx = 4;
+ const int bdimx = 128;
+
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(2);
+ TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(2));
+ TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+ TensorView* tv2 = unaryOp(UnaryOpType::Neg, tv1);
+ TensorView* tv3 = add(tv0, new Float(2));
+ fusion.addOutput(tv3);
fusion.addOutput(tv2);
- tv1->setMemoryType(MemoryType::Shared);
+ tv1->split(1, bdimx);
+ tv1->split(1, gdimx);
+ tv3->split(1, bdimx);
+ tv3->split(1, gdimx);
+
+ TensorView* tv1_rf = tv1->rFactor({1});
- tv2->split(-1, 8);
+ tv1->computeAt(tv2, -1);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
+ tv1->axis(0)->parallelize(ParallelType::BIDy);
+ tv1_rf->axis(0)->parallelize(ParallelType::BIDy);
+ tv2->axis(0)->parallelize(ParallelType::BIDy);
+ tv1->axis(-2)->parallelize(ParallelType::BIDx);
+ tv1_rf->axis(-2)->parallelize(ParallelType::BIDx);
+ tv1->axis(-1)->parallelize(ParallelType::TIDx);
+ tv1_rf->axis(-1)->parallelize(ParallelType::TIDx);
+
+ tv3->axis(3)->parallelize(ParallelType::TIDx);
+ tv3->axis(2)->parallelize(ParallelType::BIDx);
+ tv3->axis(0)->parallelize(ParallelType::BIDy);
- int numel_x = 10;
- int numel_y = 11;
+ int numel_x = 100;
+ int numel_y = 1000;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
+ at::Tensor input = at::rand({numel_x, numel_y}, options);
+ at::Tensor cg_output_tv2 = at::empty({numel_x}, options);
+ at::Tensor cg_output_tv3 = at::empty_like(input, options);
- auto ref = t0 + 1 + 2;
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ fe.runFusion({input}, {cg_output_tv3, cg_output_tv2});
- testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__);
+ auto aten_output_tv2 = -input.sum({1});
+ TORCH_CHECK(aten_output_tv2.allclose(cg_output_tv2));
+ auto aten_output_tv3 = input + 2.0;
+ TORCH_CHECK(aten_output_tv3.allclose(cg_output_tv3));
}
-// Reproducer of #1021
-TEST(NVFuserTest, FusionIssue1021_CUDA) {
+TEST(NVFuserTest, FusionLSTMCell_CUDA) {
+ const int hidden_features = 512;
+ const int batch_size = 64;
+
Fusion fusion;
FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = broadcast(tv1, {false, true});
- fusion.addOutput(tv2);
-
- auto tv3 = tv2->cache_before();
-
- tv2->split(0, 2);
-
- tv1->computeAt(tv2, 1);
-
- tv2->axis(0)->parallelize(ParallelType::TIDx);
- tv2->axis(1)->parallelize(ParallelType::Vectorize);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({10}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto ref = (t0 + 1).unsqueeze(-1);
+ TensorView* tvs[16];
+ for (auto& tv : tvs) {
+ tv = makeDummyTensor(2);
+ fusion.addInput(tv);
+ }
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
+ auto ingate = unaryOp(
+ UnaryOpType::Sigmoid, add(add(add(tvs[0], tvs[1]), tvs[2]), tvs[3]));
-// Reproducer of issue #1053
-TEST(NVFuserTest, FusionNonUniqueThreadDim_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+ auto forgetgate = unaryOp(
+ UnaryOpType::Sigmoid, add(add(add(tvs[4], tvs[5]), tvs[6]), tvs[7]));
- auto tv0 = makeSymbolicTensor(1);
- fusion->addInput(tv0);
- auto tv1 = sum(tv0, {0});
- fusion->addOutput(tv1);
+ auto cellgate = unaryOp(
+ UnaryOpType::Tanh, add(add(add(tvs[8], tvs[9]), tvs[10]), tvs[11]));
- auto tv2 = add(tv0, new Double(1));
- fusion->addOutput(tv2);
+ auto outgate = unaryOp(
+ UnaryOpType::Sigmoid, add(add(add(tvs[12], tvs[13]), tvs[14]), tvs[15]));
- tv1->split(0, 8);
- auto tv1_rf = tv1->rFactor({-1});
+ auto cx = makeContigTensor(2);
+ fusion.addInput(cx);
- tv1_rf->computeAt(tv1, 1);
+ auto cy = add(mul(forgetgate, cx), mul(ingate, cellgate));
- tv1_rf->axis(-1)->parallelize(ParallelType::TIDx);
+ auto hy = mul(outgate, unaryOp(UnaryOpType::Tanh, cy));
- tv2->axis(0)->parallelize(ParallelType::TIDx);
+ fusion.addOutput(cy);
+ fusion.addOutput(hy);
+ std::vector<c10::IValue> inputs;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input1 = at::randn({32}, options);
-
- auto at_tv1 = (input1).sum({0});
- auto at_tv2 = input1 + 1;
+ at::Tensor large_tensor0 =
+ at::randn({batch_size, hidden_features * 4}, options);
+ at::Tensor large_tensor1 =
+ at::randn({batch_size, hidden_features * 4}, options);
+ at::Tensor large_tensor2 =
+ at::randn({batch_size, hidden_features * 4}, options);
+ at::Tensor large_tensor3 =
+ at::randn({batch_size, hidden_features * 4}, options);
- FusionExecutor fe;
- fe.compileFusion(fusion.get());
- auto outputs = fe.runFusion({input1});
- testValidate(
- fusion.get(), outputs, {input1}, {at_tv1, at_tv2}, __LINE__, __FILE__);
-}
+ auto chunked0 = large_tensor0.chunk(4, 1);
+ auto chunked1 = large_tensor1.chunk(4, 1);
+ auto chunked2 = large_tensor2.chunk(4, 1);
+ auto chunked3 = large_tensor3.chunk(4, 1);
-TEST(NVFuserTest, FusionParallelDimensionMap1_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+ inputs.insert(inputs.end(), chunked0.begin(), chunked0.end());
+ inputs.insert(inputs.end(), chunked1.begin(), chunked1.end());
+ inputs.insert(inputs.end(), chunked2.begin(), chunked2.end());
+ inputs.insert(inputs.end(), chunked3.begin(), chunked3.end());
- auto tv0 = makeSymbolicTensor(1);
- fusion->addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv0, new Double(1));
- fusion->addOutput(tv1);
- fusion->addOutput(tv2);
-
- tv1->split(0, 8, false);
- tv1->axis(1)->parallelize(ParallelType::TIDx);
- tv2->split(0, 8, false);
- tv2->axis(1)->parallelize(ParallelType::TIDx);
-
- // The extents of tv1 and tv2 axes are equal even though their
- // actual values are not statically known
- GpuLower gpulw(fusion.get());
- const auto& pdmap = gpulw.parallelDimensionMap();
- auto kir_tv1 = gpulw.lowerValue(tv1)->as<kir::TensorView>();
- auto kir_tv2 = gpulw.lowerValue(tv2)->as<kir::TensorView>();
- for (size_t i = 0; i < kir_tv1->domain()->domain().size(); ++i) {
- auto dom1 = kir_tv1->domain()->domain()[i];
- auto dom2 = kir_tv2->domain()->domain()[i];
- TORCH_INTERNAL_ASSERT(pdmap.equalDim(dom1->extent(), dom2->extent()));
- }
+ auto at_ingate =
+ chunked0[0].add(chunked0[1]).add(chunked0[2]).add(chunked0[3]).sigmoid();
+ auto at_forgetgate =
+ chunked1[0].add(chunked1[1]).add(chunked1[2]).add(chunked1[3]).sigmoid();
+ auto at_cellgate =
+ chunked2[0].add(chunked2[1]).add(chunked2[2]).add(chunked2[3]).tanh();
+ auto at_outgate =
+ chunked3[0].add(chunked3[1]).add(chunked3[2]).add(chunked3[3]).sigmoid();
- TORCH_CHECK(pdmap.isExact(ParallelType::TIDx));
- TORCH_CHECK(
- pdmap.get(ParallelType::TIDx)->isA<kir::NamedScalar>() &&
- pdmap.get(ParallelType::TIDx)->as<kir::NamedScalar>()->name() ==
- "blockDim.x");
+ auto at_cx = at::randn({batch_size, hidden_features}, options);
+ inputs.push_back(at_cx);
+ auto at_cy = at_forgetgate.mul(at_cx).add(at_ingate.mul(at_cellgate));
+ auto at_hy = at_outgate.mul(at_cy.tanh());
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input1 = at::randn({32}, options);
+ scheduleFusion(&fusion, c10::ArrayRef<c10::IValue>(inputs));
FusionExecutor fe;
- fe.compileFusion(fusion.get());
- auto outputs = fe.runFusion({input1});
+ fe.compileFusion(&fusion);
+ auto outputs = fe.runFusion(c10::ArrayRef<c10::IValue>(inputs));
- testValidate(
- fusion.get(),
- outputs,
- {input1},
- {input1 + 1, input1 + 1},
- __LINE__,
- __FILE__);
+ TORCH_CHECK(at_cy.allclose(outputs[0], 1e-4, 1e-7));
+ TORCH_CHECK(at_hy.allclose(outputs[1], 1e-4, 1e-7));
}
-TEST(NVFuserTest, FusionParallelDimensionMap2_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+TEST(NVFuserTest, FusionComputeAtMultiBCast_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- auto tv0 = makeSymbolicTensor(1);
- fusion->addInput(tv0);
- auto tv1 = makeSymbolicTensor(2);
- fusion->addInput(tv1);
- auto tv2 = broadcast(tv0, {false, true});
- auto tv3 = add(tv1, tv2);
- fusion->addOutput(tv3);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(1);
+ fusion.addInput(tv0);
- tv3->split(-1, 8, false);
- tv2->computeAt(tv3, -1);
+ TensorView* tv1 = mul(tv0, new Float(0.5));
+ TensorView* tv2 = broadcast(tv1, {true, false});
+ TensorView* tv3 = broadcast(tv1, {false, true});
+ TensorView* tv4 = add(tv2, tv3);
+ fusion.addOutput(tv4);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
+ // This is not supported and should throw an exception.
+ ASSERT_ANY_THROW(tv1->computeAt(tv3, -1));
+}
- GpuLower gpulw(fusion.get());
- const auto& pdmap = gpulw.parallelDimensionMap();
- TORCH_CHECK(pdmap.isExact(ParallelType::TIDx));
- TORCH_CHECK(
- pdmap.get(ParallelType::TIDx)->isA<kir::NamedScalar>() &&
- pdmap.get(ParallelType::TIDx)->as<kir::NamedScalar>()->name() ==
- "blockDim.x");
+TEST(NVFuserTest, FusionReductionHalf_CUDA) {
+ Fusion fusion;
+ FusionGuard fg(&fusion);
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input1 = at::randn({11}, options);
- at::Tensor input2 = at::randn({11, 13}, options);
+ // Set up your input tensor views
+ TensorView* tv0 = makeDummyTensor(3, DataType::Half);
+ fusion.addInput(tv0);
- FusionExecutor fe;
- fe.compileFusion(fusion.get());
- auto outputs = fe.runFusion({input1, input2});
+ auto tv1 = castOp(DataType::Float, tv0);
+ auto tv2 = add(tv1, new Float(1.0));
+ auto tv3 = sum(tv2, {2});
+ auto tv4 = castOp(DataType::Half, tv3);
- auto ref = input1.unsqueeze(-1) + input2;
+ fusion.addOutput(tv4);
- testValidate(
- fusion.get(), outputs, {input1, input2}, {ref}, __LINE__, __FILE__);
-}
+ const auto options =
+ at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
+ at::Tensor input = at::randn({8, 8, 16}, options);
-// Mix symbolic and concrete tensors
-TEST(NVFuserTest, FusionParallelDimensionMap3_CUDA) {
- auto fusion = std::make_unique<Fusion>();
- FusionGuard fg(fusion.get());
+ auto reduction_tv = tv3;
- auto tv0 = makeSymbolicTensor(1);
- fusion->addInput(tv0);
+ auto outputsOfReduction = DependencyCheck::getAllOutputsOf({reduction_tv});
- auto tv2 = add(tv0, new Double(1));
- fusion->addOutput(tv2);
- auto tv3 = add(tv0, new Double(1));
- fusion->addOutput(tv3);
+ // Grab only tensor views, though there shouldn't be any other type
+ auto tv_entries = ir_utils::filterByType<TensorView>(outputsOfReduction);
- tv2->split(0, 10);
- tv3->split(0, 20);
+ std::vector<TensorView*> tvOutputsOfReduction(
+ tv_entries.begin(), tv_entries.end());
- auto tv4 = add(tv0, new Double(1));
- fusion->addOutput(tv4);
- auto tv5 = add(tv0, new Double(1));
- fusion->addOutput(tv5);
+ auto reduction_params =
+ getReductionHeuristics(&fusion, {input}, reduction_tv);
+ TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
+ scheduleReduction(
+ &fusion, reduction_params.value(), reduction_tv, tvOutputsOfReduction);
- // Not mapped but equal extent
- tv4->split(0, 10);
- tv5->split(0, 10);
+ TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
+ FusionExecutor fe;
+ fe.compileFusion(&fusion);
+ // no broadcasting needed, omitting the last optional argument;
+ auto outputs = fe.runFusion({input}, reduction_params.value().lparams);
- tv4->axis(-1)->parallelize(ParallelType::TIDy);
- tv5->axis(-1)->parallelize(ParallelType::TIDy);
+ auto aten_output = input.to(c10::ScalarType::Float)
+ .add(1.0)
+ .sum({2})
+ .to(c10::ScalarType::Half);
- GpuLower gpulw(fusion.get());
- const auto& pdmap = gpulw.parallelDimensionMap();
- TORCH_CHECK(!pdmap.isExact(ParallelType::TIDx));
- TORCH_CHECK(
- pdmap.get(ParallelType::TIDx)->isA<kir::NamedScalar>() &&
- pdmap.get(ParallelType::TIDx)->as<kir::NamedScalar>()->name() ==
- "blockDim.x");
- TORCH_CHECK(pdmap.isExact(ParallelType::TIDy));
TORCH_CHECK(
- pdmap.get(ParallelType::TIDy)->isConst() &&
- pdmap.get(ParallelType::TIDy)->as<kir::Int>()->value().value() == 10);
+ aten_output.allclose(outputs[0], 1e-04, 1e-04),
+ "Error of: ",
+ aten_output.sub(outputs[0]).abs().max());
+}
+TEST(NVFuserTest, FusionInputsIdLookup_CUDA) {
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input1 = at::randn({13}, options);
+ at::Tensor t0 = at::randn({16, 8, 8}, options);
+ at::Tensor t1 = at::randn({8, 8}, options);
+ at::Tensor t2 = at::randn({6, 4}, options);
- FusionExecutor fe;
- fe.compileFusion(fusion.get());
- auto outputs = fe.runFusion({input1});
+ // create a cache with max size 2;
+ auto inputs_id_lookup = InputsIdLookup(2);
- testValidate(
- fusion.get(),
- outputs,
- {input1},
- {input1 + 1, input1 + 1, input1 + 1, input1 + 1},
- __LINE__,
- __FILE__);
-}
+ // testing basic function, same encoding for identical inputs
+ auto id_0 = inputs_id_lookup.lookupId({t0, t1, 5.0});
+ auto id_0_lookup = inputs_id_lookup.lookupId({t0, t1, 2.5});
+ TORCH_CHECK(id_0.id == id_0_lookup.id);
+ TORCH_CHECK(inputs_id_lookup.size() == 1);
+ TORCH_CHECK(id_0.eviction == false);
-// Parallelizing merged broadcast domains
-TEST(NVFuserTest, FusionParallelDimensionMap4_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+ // new input (even tho same shape, but we have different signature because of
+ // missing scalar input
+ auto id_1 = inputs_id_lookup.lookupId({t0, t1});
+ auto id_1_lookup = inputs_id_lookup.lookupId({t0, t1});
+ TORCH_CHECK(id_1.id == id_1_lookup.id);
+ TORCH_CHECK(inputs_id_lookup.size() == 2);
+ TORCH_CHECK(id_1.eviction == false);
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- auto tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
- auto tv2 = add(tv0, new Double(1));
- auto tv3 = broadcast(tv2, {true, false});
- auto tv4 = add(tv3, tv1);
- fusion.addOutput(tv4);
+ // eviction should happen at this point
+ auto id_2 = inputs_id_lookup.lookupId({t2, t1});
+ TORCH_CHECK(id_2.id != id_0.id);
+ TORCH_CHECK(id_2.id != id_1.id);
+ TORCH_CHECK(inputs_id_lookup.size() == 2);
+ TORCH_CHECK(id_2.eviction == true);
+ TORCH_CHECK(id_2.evict_id == id_0.id);
- tv4->split(1, 4);
- tv4->reorder({{1, 2}, {2, 1}});
- tv4->merge(0);
- tv0->computeAt(tv4, 1);
- tv1->computeAt(tv4, 1);
+ // look at input 1 again
+ auto id_1_relook = inputs_id_lookup.lookupId({t0, t1});
+ TORCH_CHECK(id_1_relook.id == id_1.id);
+ TORCH_CHECK(id_1_relook.eviction == false);
+}
- // TIDx is mapped to tv4.axis(0) as well as tv2.axis(0), so it's not
- // exact.
- tv4->axis(0)->parallelize(ParallelType::TIDx);
+TEST(NVFuserTest, FusionGroupGuardSimpleTensor_CUDA) {
+ std::vector<int64_t> sizes_vec({16, 8, 8});
+ std::vector<int64_t> strides_vec({64, 8, 1});
+ auto tensor_type = TensorType::create(
+ at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- tv2->setMemoryType(MemoryType::Shared);
- tv3->setMemoryType(MemoryType::Shared);
+ // pass with identical shape
+ auto t0 = at::randn({16, 8, 8}, options);
+ TORCH_CHECK(complyWith(t0, tensor_type));
- GpuLower gpulw(&fusion);
- const auto& pdmap = gpulw.parallelDimensionMap();
- TORCH_CHECK(!pdmap.isExact(ParallelType::TIDx));
- TORCH_CHECK(
- pdmap.get(ParallelType::TIDx)->isA<kir::NamedScalar>() &&
- pdmap.get(ParallelType::TIDx)->as<kir::NamedScalar>()->name() ==
- "blockDim.x");
+ // pass with dynamic shape
+ auto t1 = at::randn({16, 16, 8}, options);
+ TORCH_CHECK(complyWith(t1, tensor_type));
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input1 = at::randn({13}, options);
- at::Tensor input2 = at::randn({15, 13}, options);
+ // rank failure
+ auto t5 = at::randn({16, 8, 8, 8}, options);
+ TORCH_CHECK(!complyWith(t5, tensor_type));
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion({input1, input2});
+ // broadcasting semantic change failure
+ auto t2 = at::randn({16, 1, 8}, options);
+ TORCH_CHECK(!complyWith(t2, tensor_type));
- auto ref = (input1 + 1).unsqueeze(0) + input2;
+ // contiguity failure via slicing
+ auto t3 = t0.slice(1, 0, 8, 2);
+ TORCH_CHECK(!complyWith(t3, tensor_type));
- testValidate(&fusion, outputs, {input1, input2}, {ref}, __LINE__, __FILE__);
+ // contiguity failure via slicing
+ auto t4 = t0.slice(2, 0, 8, 2);
+ TORCH_CHECK(!complyWith(t4, tensor_type));
}
-TEST(NVFuserTest, FusionParallelDimensionMap5_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
+TEST(NVFuserTest, FusionGroupGuardBroadcastTensor_CUDA) {
+ std::vector<int64_t> sizes_vec({16, 1, 8});
+ std::vector<int64_t> strides_vec({8, 8, 1});
+ auto tensor_type = TensorType::create(
+ at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- auto tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
- auto tv3 = broadcast(tv0, {false, true});
- auto tv4 = add(tv3, tv1);
- fusion.addOutput(tv4);
+ // broadcasting semantic change
+ auto t0 = at::randn({16, 8, 8}, options);
+ TORCH_CHECK(!complyWith(t0, tensor_type));
- tv4->split(1, 4);
- tv0->computeAt(tv4, -1);
- tv1->computeAt(tv4, -1);
+ // dtype failure
+ auto t1 = at::randn({16, 1, 8}, options.dtype(at::kHalf));
+ TORCH_CHECK(!complyWith(t1, tensor_type));
- tv4->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv4->axis(-2)->parallelize(ParallelType::TIDy);
- tv3->axis(-2)->parallelize(ParallelType::TIDy);
+ // dtype failure
+ auto t2 = at::randn({16, 1, 8}, options);
+ TORCH_CHECK(complyWith(t2, tensor_type));
- GpuLower gpulw(&fusion);
- const auto& pdmap = gpulw.parallelDimensionMap();
- TORCH_CHECK(pdmap.isExact(ParallelType::TIDx));
- TORCH_CHECK(pdmap.isExact(ParallelType::TIDy));
- TORCH_CHECK(
- pdmap.get(ParallelType::TIDx)->isConst() &&
- pdmap.get(ParallelType::TIDx)->as<kir::Int>()->value().value() == 4);
- TORCH_CHECK(
- pdmap.get(ParallelType::TIDy)->isA<kir::NamedScalar>() &&
- pdmap.get(ParallelType::TIDy)->as<kir::NamedScalar>()->name() ==
- "blockDim.y");
+ // device inconsistency shouldn't fail
+ auto t3 = at::randn({16, 1, 8}, options.device(at::kCPU, 0));
+ TORCH_CHECK(complyWith(t3, tensor_type));
+}
+TEST(NVFuserTest, FusionGroupGuardPermutedTensor_CUDA) {
+ std::vector<int64_t> sizes_vec({16, 8, 8});
+ std::vector<int64_t> strides_vec({64, 1, 8});
+ auto tensor_type = TensorType::create(
+ at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor input1 = at::randn({13}, options);
- at::Tensor input2 = at::randn({13, 15}, options);
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion({input1, input2});
+ // failing permutation
+ auto t0 = at::randn({16, 8, 8}, options);
+ TORCH_CHECK(!complyWith(t0, tensor_type));
+
+ // passing with dynamic shape
+ auto t1 = t0.permute({0, 2, 1});
+ TORCH_CHECK(complyWith(t1, tensor_type));
+}
+
+TEST(NVFuserTest, FusionGroupGuardRelaxedCheck_CUDA) {
+ std::vector<int64_t> sizes_vec({16, 8, 8});
+ std::vector<int64_t> strides_vec({128, 16, 1});
+ auto tensor_type = TensorType::create(
+ at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt);
+ auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- auto ref = (input1).unsqueeze(-1) + input2;
+ // contiguity check passes although it differs
+ auto t0 = at::randn({16, 16, 8}, options);
+ TORCH_CHECK(complyWith(t0, tensor_type));
- testValidate(&fusion, outputs, {input1, input2}, {ref}, __LINE__, __FILE__);
+ // passing with dynamic shape
+ auto t1 = t0.slice(1, 0, 16, 2);
+ TORCH_CHECK(complyWith(t1, tensor_type));
}
} // namespace jit
} // namespace torch
+
#endif // #if defined(USE_CUDA)
+++ /dev/null
-#if defined(USE_CUDA)
-#include <gtest/gtest.h>
-
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/codegen.h>
-#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
-#include <torch/csrc/jit/codegen/cuda/interface.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_graphviz.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/mutator.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
-
-// fuser and IR parser
-#include "test_gpu_validator.h"
-
-#include <ATen/cuda/Exceptions.h>
-#include <c10/cuda/CUDAStream.h>
-
-#include <algorithm>
-#include <iostream>
-
-// Tests go in torch::jit
-namespace torch {
-namespace jit {
-
-using namespace torch::jit::fuser::cuda;
-using namespace at::indexing;
-
-namespace {
-
-// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
-// but unknown sizes
-TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) {
- return TensorViewBuilder()
- .ndims(ndims)
- .dtype(dtype)
- .contiguity(std::vector<bool>(ndims, true))
- .build();
-}
-
-// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
-// but unknown sizes
-TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
- return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
-}
-
-// Make a non-contiguous tensor of compile-time known sizes
-TensorView* makeConcreteTensor(
- std::vector<int64_t> shape,
- DataType dtype = DataType::Float) {
- return TensorViewBuilder().shape(shape).dtype(dtype).build();
-}
-
-void checkIntValue(
- ExpressionEvaluator& evaluator,
- Val* val,
- Int::ScalarType expected_value) {
- TORCH_CHECK(val->isAnInt());
- const auto actual_value = evaluator.evaluate(val);
- TORCH_CHECK(actual_value.has_value());
- TORCH_CHECK(actual_value.value() == expected_value);
-}
-
-void checkIntValue(
- kir::ExpressionEvaluator& evaluator,
- const kir::Val* val,
- kir::Int::ScalarType expected_value) {
- const auto actual_value = evaluator.evaluate(val);
- TORCH_CHECK(actual_value.has_value());
- TORCH_CHECK(actual_value.value() == expected_value);
-}
-
-// ATen version of tensor shifting
-auto shift(at::Tensor tensor, const std::vector<int>& offsets) {
- TORCH_INTERNAL_ASSERT(tensor.ndimension() == offsets.size());
- at::Tensor t = tensor;
- for (size_t i = 0; i < offsets.size(); ++i) {
- const auto offset = offsets[i];
- if (offset == 0) {
- continue;
- }
- t = t.roll(offsets[i], i);
- std::vector<at::indexing::TensorIndex> indices(
- tensor.ndimension(), at::indexing::Slice(0, at::indexing::None));
- if (offset > 0) {
- indices[i] = at::indexing::Slice(0, offset);
- } else {
- indices[i] = at::indexing::Slice(offset, at::indexing::None);
- }
- t.index(indices) = 0;
- }
- return t;
-}
-
-// ATen version of tensor shifting
-auto gather(
- at::Tensor tensor,
- const std::vector<int>& window_shape,
- const std::vector<std::vector<int>>& pad_width) {
- TORCH_CHECK(
- tensor.ndimension() == window_shape.size(),
- "Invalid window shape: ",
- window_shape,
- ". Size of the window shape is different from the tensor dimension.");
- TORCH_CHECK(
- tensor.ndimension() == pad_width.size(),
- "Invalid pad width: ",
- pad_width,
- ". Size of the pad width is different from the tensor dimension.");
- at::Tensor t = tensor;
- for (size_t i = 0; i < window_shape.size(); ++i) {
- const auto w_size = window_shape[i];
- TORCH_CHECK(w_size != 0);
- const auto& pad = pad_width[i];
- TORCH_CHECK(pad.size() == 2);
- at::Tensor concat_tensor;
- for (int w = 0; w < w_size; ++w) {
- std::vector<int> shift_offsets(t.ndimension(), 0);
- shift_offsets[i] = pad[0] - w;
- auto shifted = shift(t, shift_offsets);
- shifted = shifted.unsqueeze(-1);
- if (w == 0) {
- concat_tensor = shifted;
- } else {
- concat_tensor = at::cat({concat_tensor, shifted}, -1);
- }
- }
- t = concat_tensor;
- }
- return t;
-}
-
-} // namespace
-
-// Shift an input tensor
-TEST(NVFuserTest, FusionShift1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- auto tv1 = shift(tv0, {-1, 0});
- fusion.addOutput(tv1);
-
- auto tv2 = shift(tv0, {0, 1});
- fusion.addOutput(tv2);
-
- auto tv3 = shift(tv0, {2, 2});
- fusion.addOutput(tv3);
-
- auto tv4 = shift(tv0, {-2, -2});
- fusion.addOutput(tv4);
-
- int numel_x = 9;
- int numel_y = 11;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = shift(t0, {-1, 0});
- TORCH_CHECK(t1.equal(outputs[0]));
-
- auto t2 = shift(t0, {0, 1});
- TORCH_CHECK(t2.equal(outputs[1]));
-
- auto t3 = shift(t0, {2, 2});
- TORCH_CHECK(t3.equal(outputs[2]));
-
- auto t4 = shift(t0, {-2, -2});
- TORCH_CHECK(t4.equal(outputs[3]));
-}
-
-// Shifts an intermediate tensor
-TEST(NVFuserTest, FusionShift2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = shift(tv1, {-1, 0});
- fusion.addOutput(tv2);
-
- // make it a little more complex
- auto tv3 = add(tv0, new Double(3));
- auto tv4 = add(tv3, new Double(4));
- auto tv5 = shift(tv4, {-1, 0});
- auto tv6 = shift(tv4, {0, -1});
- auto tv7 = shift(tv4, {1, 0});
- auto tv8 = shift(tv4, {0, 0});
- auto tv9 = add(tv5, tv6);
- auto tv10 = add(tv9, tv7);
- auto tv11 = add(tv10, tv8);
- fusion.addOutput(tv11);
-
- for (auto tv : {tv1, tv2, tv3, tv4, tv5, tv6, tv7, tv8, tv9, tv10, tv11}) {
- tv->setMemoryType(MemoryType::Global);
- }
-
- // t1 allocation: (t1.size[0] + 1) * (t1.size[1])
- // t3 allocation: (t3.size[0] + 2) * (t3.size[1] + 1)
- // t4 allocation: (t3.size[0] + 2) * (t3.size[1] + 1)
- GpuLower gpulw(&fusion);
-
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == 1 || tensor_name == 3 || tensor_name == 4) {
- TORCH_CHECK(alloc->shape().size() == 2);
- for (int i = 0; i < 2; ++i) {
- if (tensor_name == 1 && i == 1) {
- TORCH_CHECK(alloc->shape().at(i)->isA<kir::NamedScalar>());
- continue;
- }
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
- TORCH_CHECK(def != nullptr && def->operation() == BinaryOpType::Add);
- TORCH_CHECK(def->as<kir::BinaryOp>()->lhs()->isA<kir::NamedScalar>());
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- if (tensor_name == 1) {
- TORCH_CHECK(i == 0);
- TORCH_CHECK(rhs_value == 1);
- } else {
- if (i == 0) {
- TORCH_CHECK(rhs_value == 2);
- } else {
- TORCH_CHECK(rhs_value == 1);
- }
- }
- }
- }
- }
- }
-
- int numel_x = 9;
- int numel_y = 11;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = shift(t1, {-1, 0});
-
- auto t3 = t0 + 3;
- auto t4 = t3 + 4;
- auto t5 = shift(t4, {-1, 0});
- auto t6 = shift(t4, {0, -1});
- auto t7 = shift(t4, {1, 0});
- auto t8 = shift(t4, {0, 0});
- auto t9 = t5 + t6;
- auto t10 = t9 + t7;
- auto t11 = t10 + t8;
-
- testValidate(&fusion, outputs, inputs, {t2, t11}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftRightOfCA_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = shift(tv1, {0, 1});
- fusion.addOutput(tv2);
-
- tv0->computeAt(tv2, -2);
-
- tv1->setMemoryType(MemoryType::Global);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 100;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = shift(t1, {0, 1});
-
- TORCH_CHECK(t2.allclose(outputs[0]));
-}
-
-TEST(NVFuserTest, FusionShiftLeftOfCA_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(1));
- auto tv3 = shift(tv2, {-1, 0});
- auto tv4 = add(tv3, new Double(1));
- fusion.addOutput(tv4);
-
- tv0->computeAt(tv4, -1);
-
- // Lowering should trigger an assertion failure as a shifted axis is
- // found inside an allocation position.
- ASSERT_ANY_THROW(fusion.printKernel());
-}
-
-TEST(NVFuserTest, FusionShiftSplit1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = shift(tv1, {0, 1});
- auto tv3 = shift(tv1, {0, -2});
- fusion.addOutput(tv2);
- fusion.addOutput(tv3);
-
- int split_factor = 4;
- tv2->split(-1, split_factor);
- tv3->split(-1, split_factor);
-
- tv0->computeAt(tv2, -2);
- tv0->computeAt(tv3, -2);
-
- // t1 allocation: (4 + 3)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == 1) {
- TORCH_CHECK(alloc->shape().size() == 1);
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor && rhs_value == 3);
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 9;
- int numel_y = 11;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = shift(t1, {0, 1});
- auto t3 = shift(t1, {0, -2});
-
- testValidate(&fusion, outputs, inputs, {t2, t3}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftSplit2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(1));
- auto tv3 = shift(tv2, {0, -1});
- auto tv4 = shift(tv2, {0, 1});
- auto tv5 = add(tv3, tv4);
- fusion.addOutput(tv5);
-
- auto tv6 = add(tv0, new Double(1));
- auto tv7 = shift(tv6, {0, 0});
- auto tv8 = add(tv7, new Double(1));
- fusion.addOutput(tv8);
-
- int split_factor = 4;
-
- tv5->split(-1, split_factor);
- tv8->split(-1, split_factor);
-
- tv0->computeAt(tv5, -2);
- tv0->computeAt(tv8, -2);
-
- // t1 and t2 allocation: (4 + 2)
- // t4 allocation: (4)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == 1 || tensor_name == 2) {
- TORCH_CHECK(alloc->shape().size() == 1);
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor && rhs_value == 2);
- } else if (tensor_name == 4) {
- TORCH_CHECK(alloc->shape().size() == 1);
- auto size = dynamic_cast<kir::Int*>(alloc->shape().at(0));
- TORCH_CHECK(size != nullptr && size->isConst());
- int size_value = *size->value();
- TORCH_CHECK(size_value == split_factor);
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 9;
- int numel_y = 11;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 2;
- auto t3 = shift(t1, {0, -1});
- auto t4 = shift(t1, {0, 1});
- auto t5 = t3 + t4;
-
- auto t6 = t0 + 1;
- auto t7 = t6;
- auto t8 = t7 + 1;
-
- testValidate(&fusion, outputs, inputs, {t5, t8}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftDoubleSplit_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(2));
- auto tv3 = shift(tv2, {0, 1});
- fusion.addOutput(tv3);
-
- int split_factor1 = 8;
- int split_factor2 = 4;
-
- tv3->split(-1, split_factor1);
-
- tv0->computeAt(tv3, -2);
-
- tv1->split(-1, split_factor2);
-
- // t1: [i1, i2/8, 8/4, 4]
- // t2: [i1, i2/8, 8]
- // t3: [i1, i2/8, 8]
-
- // t1 and t2 allocation: (split_factor1 + 1)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == 1 || tensor_name == 2) {
- TORCH_CHECK(alloc->shape().size() == 1);
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1);
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 3;
- auto ref = shift(t1, {0, 1});
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift3ptStencil_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // 3-pt stencil
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
-
- std::vector<std::vector<int>> offsets = {{-1}, {1}};
-
- std::vector<TensorView*> tvs;
- for (const auto& offset : offsets) {
- tvs.push_back(shift(tv0, offset));
- }
-
- auto tv_out = tv0;
-
- for (auto tv : tvs) {
- tv_out = add(tv_out, tv);
- }
-
- tv_out = div(tv_out, new Double(tvs.size() + 1));
-
- fusion.addOutput(tv_out);
-
- int split_factor = 4;
-
- tv_out->split(0, split_factor);
-
- // This seems fine but not verified yet
- // tv_out->axis(-1)->parallelize(ParallelType::Unswitch);
-
- auto cache = tv0->cache_after();
-
- tv0->computeAt(tv_out, 1);
-
- // Inline completely except for the cache
- for (auto tv : tvs) {
- tv->computeAt(tv_out, -1);
- }
-
- // cache allocation: (split_factor + 2)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == cache->name()) {
- TORCH_CHECK(alloc->shape().size() == 1);
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor && rhs_value == 2);
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto ref = (t0 + shift(t0, {-1}) + shift(t0, {1})) / 3;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift5ptStencil_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // 5-pt stencil
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};
-
- std::vector<TensorView*> tvs;
- for (const auto& offset : offsets) {
- tvs.push_back(shift(tv0, offset));
- }
-
- auto tv_out = tv0;
-
- for (auto tv : tvs) {
- tv_out = add(tv_out, tv);
- }
-
- tv_out = div(tv_out, new Double(tvs.size() + 1));
-
- fusion.addOutput(tv_out);
-
- std::vector<int> split_factor({4, 8});
-
- tv_out->split(-1, split_factor[1]);
- tv_out->split(0, split_factor[0]);
- tv_out->reorder({{1, 2}, {2, 1}});
-
- auto cache = tv0->cache_after();
-
- tv0->computeAt(tv_out, 2);
-
- // Inline completely except for the cache
- for (auto tv : tvs) {
- tv->computeAt(tv_out, -1);
- }
-
- // cache allocation: (split_factor + 2) * (split_factor + 2)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == cache->name()) {
- TORCH_CHECK(alloc->shape().size() == 2);
- for (int i = 0; i < 2; ++i) {
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2);
- }
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto ref = t0;
- for (const auto& offset : offsets) {
- ref = ref + shift(t0, offset);
- }
- ref = ref / int(offsets.size() + 1);
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift9ptStencil_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // 9-pt stencil
- std::vector<std::vector<int>> offsets;
- for (int i = -1; i < 2; ++i) {
- for (int j = -1; j < 2; ++j) {
- if (i == 0 && j == 0) {
- continue;
- }
- offsets.push_back({i, j});
- }
- }
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- std::vector<TensorView*> tvs;
- for (const auto& offset : offsets) {
- tvs.push_back(shift(tv0, offset));
- }
-
- auto tv_out = tv0;
-
- for (auto tv : tvs) {
- tv_out = add(tv_out, tv);
- }
-
- tv_out = div(tv_out, new Double(tvs.size() + 1));
-
- fusion.addOutput(tv_out);
-
- std::vector<int> split_factor({4, 8});
- tv_out->split(-1, split_factor[1]);
- tv_out->split(0, split_factor[0]);
- tv_out->reorder({{1, 2}, {2, 1}});
-
- auto cache = tv0->cache_after();
-
- tv0->computeAt(tv_out, 2);
-
- // Inline completely except for the cache
- for (auto tv : tvs) {
- tv->computeAt(tv_out, -1);
- }
-
- // This seems fine but not yet verified
- // tv_out->axis(-1)->parallelize(ParallelType::Unswitch);
-
- // cache allocation: (split_factor + 2) * (split_factor + 2)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == cache->name()) {
- TORCH_CHECK(alloc->shape().size() == 2);
- for (int i = 0; i < 2; ++i) {
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2);
- }
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto ref = t0;
- for (const auto& offset : offsets) {
- ref = ref + shift(t0, offset);
- }
- ref = ref / int(offsets.size() + 1);
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftSmemBlocking_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = shift(tv1, {0, 1});
- fusion.addOutput(tv2);
-
- int smem_block_factor = 32;
-
- tv2->split(-1, smem_block_factor);
-
- tv0->computeAt(tv2, -2);
-
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
- tv1->setMemoryType(MemoryType::Shared);
-
- // tv1 allocation: (split_factor + 1)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == tv1->name()) {
- TORCH_CHECK(alloc->shape().size() == 1);
- for (int i = 0; i < 1; ++i) {
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == smem_block_factor && rhs_value == 1);
- }
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 100;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = shift(t1, {0, 1});
- auto ref = t2;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift3ptStencilParallel_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // 3-pt stencil
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- std::vector<TensorView*> tvs;
- tvs.push_back(shift(tv0, {-1}));
- tvs.push_back(shift(tv0, {1}));
-
- auto tv_out = tv0;
-
- for (auto tv : tvs) {
- tv_out = add(tv_out, tv);
- }
-
- tv_out = div(tv_out, new Double(tvs.size() + 1));
-
- fusion.addOutput(tv_out);
-
- int smem_block_factor = 32;
-
- tv_out->split(0, smem_block_factor);
- // tv_out->axis(-1)->parallelize(ParallelType::Unswitch);
-
- auto tv0_cache = tv0->cache_after();
-
- tv0->computeAt(tv_out, 1);
-
- for (auto tv : tvs) {
- tv->computeAt(tv_out, -1);
- }
-
- tv0_cache->setMemoryType(MemoryType::Shared);
- tv_out->axis(-1)->parallelize(ParallelType::TIDx);
- tv0_cache->axis(-1)->parallelize(ParallelType::TIDx);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto ref = (t0 + shift(t0, {-1}) + shift(t0, {1})) / 3;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift5ptStencilParallel_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // 5-pt stencil
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};
-
- std::vector<TensorView*> tvs;
- for (const auto& offset : offsets) {
- tvs.push_back(shift(tv0, offset));
- }
-
- auto tv_out = tv0;
-
- for (auto tv : tvs) {
- tv_out = add(tv_out, tv);
- }
-
- tv_out = div(tv_out, new Double(tvs.size() + 1));
-
- fusion.addOutput(tv_out);
-
- int smem_block_factor = 32;
-
- tv_out->split(-1, smem_block_factor);
- tv_out->split(0, smem_block_factor);
-
- tv_out->reorder({{1, 2}, {2, 1}});
-
- auto tv0_cache = tv0->cache_after();
-
- tv0->computeAt(tv_out, 2);
-
- for (auto tv : tvs) {
- tv->computeAt(tv_out, -1);
- }
-
- tv_out->axis(-1)->parallelize(ParallelType::TIDx);
- tv_out->axis(-2)->parallelize(ParallelType::TIDy);
- tv_out->axis(-3)->parallelize(ParallelType::BIDx);
- tv_out->axis(-4)->parallelize(ParallelType::BIDy);
-
- tv0_cache->setMemoryType(MemoryType::Shared);
- tv0_cache->axis(-1)->parallelize(ParallelType::TIDx);
- tv0_cache->axis(-2)->parallelize(ParallelType::TIDy);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto ref = t0;
- for (const auto& offset : offsets) {
- ref = ref + shift(t0, offset);
- }
- ref = ref / int(offsets.size() + 1);
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftMerge1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = shift(tv1, {-1, 1});
- fusion.addOutput(tv2);
-
- int split_factor = 4;
-
- tv2->split(-1, split_factor);
- tv2->split(0, split_factor);
- tv2->reorder({{1, 2}, {2, 1}});
- tv2->merge(2, 3);
-
- tv0->computeAt(tv2, 2);
-
- // t1 allocation: (split_factor + 1) * (split_factor + 1)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == 1) {
- TORCH_CHECK(alloc->shape().size() == 2);
- for (int i = 0; i < 2; ++i) {
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor && rhs_value == 1);
- }
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = shift(t1, {-1, 1});
- auto ref = t2;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftMerge2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = shift(tv1, {1, -1});
- auto tv3 = shift(tv1, {-1, 1});
- auto tv4 = add(tv2, tv3);
- fusion.addOutput(tv4);
-
- int split_factor = 4;
-
- tv4->split(-1, split_factor);
- tv4->split(0, split_factor);
- tv4->reorder({{1, 2}, {2, 1}});
- tv4->merge(2, 3);
-
- tv0->computeAt(tv4, -2);
-
- // t1 allocation: (split_factor + 2) * (split_factor + 2)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == 1) {
- TORCH_CHECK(alloc->shape().size() == 2);
- for (int i = 0; i < 2; ++i) {
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor && rhs_value == 2);
- }
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = shift(t1, {1, -1});
- auto t3 = shift(t1, {-1, 1});
- auto t4 = t2 + t3;
-
- TORCH_CHECK(t4.allclose(outputs[0]));
-}
-
-TEST(NVFuserTest, FusionShiftGlobal_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = shift(tv1, {0, 1});
- auto tv3 = shift(tv1, {-1, 0});
- auto tv4 = add(tv2, tv3);
- fusion.addOutput(tv4);
-
- tv1->split(-1, 4);
- tv2->split(-1, 8);
- tv3->split(-1, 2);
- tv4->split(-1, 3);
-
- tv1->merge(-2, -1);
-
- tv1->setMemoryType(MemoryType::Global);
- tv2->setMemoryType(MemoryType::Global);
- tv3->setMemoryType(MemoryType::Global);
-
- // t1 allocation: (t1.size[0] + 1) * (t1.size[1] + 1)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == 1) {
- TORCH_CHECK(alloc->shape().size() == 2);
- for (int i = 0; i < 2; ++i) {
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
- TORCH_CHECK(def != nullptr && def->operation() == BinaryOpType::Add);
- TORCH_CHECK(def->as<kir::BinaryOp>()->lhs()->isA<kir::NamedScalar>());
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(rhs_value == 1);
- }
- }
- }
- }
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = shift(t1, {0, 1});
- auto t3 = shift(t1, {-1, 0});
- auto t4 = t2 + t3;
- auto ref = t4;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftDoubleSplitMerge1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(2));
- auto tv3 = shift(tv2, {0, 1});
- fusion.addOutput(tv3);
-
- int split_factor1 = 8;
- int split_factor2 = 4;
-
- tv3->split(-1, split_factor1);
-
- tv0->computeAt(tv3, -2);
-
- tv1->split(-1, split_factor2);
- tv1->merge(-2, -1);
-
- // t1 and t2 allocation: (split_factor1 + 1)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == 1 || tensor_name == 2) {
- TORCH_CHECK(alloc->shape().size() == 1);
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(0)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1);
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 3;
- auto ref = shift(t1, {0, 1});
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftDoubleSplitMerge2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(2));
- auto tv3 = shift(tv2, {1, 1});
- fusion.addOutput(tv3);
-
- auto out = tv3;
-
- int split_factor1 = 32;
- int split_factor2 = 4;
-
- out->split(-1, split_factor1);
- out->split(-1, split_factor2);
- out->split(0, split_factor1);
- out->split(1, split_factor2);
- out->reorder({{3, 1}, {1, 2}, {4, 3}, {2, 4}});
- out->merge(2, 3);
- out->merge(2, 3);
- out->merge(2, 3);
- out->merge(0, 1);
-
- TransformPropagator::from(out);
-
- tv0->computeAt(out, 1);
-
- out->axis(0)->parallelize(ParallelType::BIDx);
- out->axis(1)->parallelize(ParallelType::TIDx);
-
- scheduler_utils::parallelizeAllLike(out, {tv1, tv2});
-
- for (auto tv : {tv1, tv2}) {
- tv->setMemoryType(MemoryType::Shared);
- }
-
- // t1 and t2 allocation: (split_factor1 + 1) * (split_factor1 + 1)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == 1 || tensor_name == 2) {
- TORCH_CHECK(alloc->shape().size() == 2);
- for (int i = 0; i < 2; ++i) {
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor1 && rhs_value == 1);
- }
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto ref = shift(t0 + 1 + 2, {1, 1});
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift5ptStencilParallel1DThreadBlock_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // 5-pt stencil
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};
-
- std::vector<TensorView*> tvs;
- for (const auto& offset : offsets) {
- tvs.push_back(shift(tv0, offset));
- }
-
- auto tv_out = tv0;
-
- for (auto tv : tvs) {
- tv_out = add(tv_out, tv);
- }
-
- tv_out = div(tv_out, new Double(tvs.size() + 1));
-
- fusion.addOutput(tv_out);
-
- std::vector<int> split_factor({4, 32});
-
- tv_out->split(-1, split_factor[1]);
- tv_out->split(0, split_factor[0]);
- tv_out->reorder({{1, 2}, {2, 1}});
-
- auto tv0_cache = tv0->cache_after();
-
- // Merge the inner-most two axes and create
- // a 1D thread block of split_factor1*split_factor2 threads
- tv_out->merge(-2, -1);
-
- tv0->computeAt(tv_out, 2);
-
- // Inline completely except for the cache
- for (auto tv : tvs) {
- tv->computeAt(tv_out, -1);
- }
-
- tv0_cache->merge(-2, -1);
-
- tv_out->axis(-1)->parallelize(ParallelType::TIDx);
- tv_out->axis(1)->parallelize(ParallelType::BIDx);
- tv_out->axis(0)->parallelize(ParallelType::BIDy);
-
- tv0_cache->setMemoryType(MemoryType::Shared);
- tv0_cache->axis(-1)->parallelize(ParallelType::TIDx);
-
- // cache allocation: (split_factor1 + 2) * (split_factor2 + 2)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == tv0_cache->name()) {
- TORCH_CHECK(alloc->shape().size() == 2);
- for (int i = 0; i < 2; ++i) {
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor[i] && rhs_value == 2);
- }
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto ref = t0;
- for (const auto& offset : offsets) {
- ref = ref + shift(t0, offset);
- }
- ref = ref / int(offsets.size() + 1);
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftChain1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = shift(tv0, {0, 1});
- auto tv2 = shift(tv1, {0, 1});
- fusion.addOutput(tv2);
-
- int split_factor = 4;
- tv2->split(-1, split_factor);
-
- tv0->computeAt(tv2, -2);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto ref = shift(shift(t0, {0, 1}), {0, 1});
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftChain2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = shift(tv0, {0, 1});
- auto tv2 = shift(tv1, {0, -1});
- fusion.addOutput(tv2);
-
- tv2->split(-1, 4);
-
- tv0->computeAt(tv2, -2);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto ref = shift(shift(t0, {0, 1}), {0, -1});
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftChain3_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = shift(tv1, {0, 1});
- auto tv3 = shift(tv2, {0, 1});
- fusion.addOutput(tv3);
-
- int split_factor = 4;
- tv3->split(-1, split_factor);
-
- tv0->computeAt(tv3, -2);
-
- // Halo size of tv1 is 2 as it needs to account for both of the two
- // shift operations , while that of tv2 is still just 1
-
- // tv1: (split_factor + 2)
- // tv2: (split_factor + 1)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == 1 || tensor_name == 2) {
- TORCH_CHECK(alloc->shape().size() == 1);
- for (int i = 0; i < 1; ++i) {
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor);
- if (tensor_name == 1) {
- TORCH_CHECK(rhs_value == 2);
- } else if (tensor_name == 2) {
- TORCH_CHECK(rhs_value == 1);
- }
- }
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = shift(t1, {0, 1});
- auto t3 = shift(t2, {0, 1});
- auto ref = t3;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftChain4_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = shift(tv0, {1, -1});
- auto tv2 = shift(tv1, {2, -2});
- auto tv3 = shift(tv2, {3, -3});
- auto tv4 = shift(tv3, {4, -4});
- auto tv_out = tv4;
-
- fusion.addOutput(tv_out);
-
- int split_factor = 4;
-
- tv_out->split(-1, split_factor);
- tv_out->split(0, split_factor);
- tv_out->reorder({{1, 2}, {2, 1}});
-
- tv0->computeAt(tv_out, 2);
-
- tv1->merge(-2, -1);
- tv2->merge(-2, -1);
- tv3->merge(-2, -1);
-
- // tv1: (split_factor + 9) * (split_factor + 9)
- // tv2: (split_factor + 7) * (split_factor + 7)
- // tv3: (split_factor + 4) * (split_factor + 4)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == 1 || tensor_name == 2) {
- TORCH_CHECK(alloc->shape().size() == 2);
- for (int i = 0; i < 2; ++i) {
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor);
- if (tensor_name == 1) {
- TORCH_CHECK(rhs_value == 9);
- } else if (tensor_name == 2) {
- TORCH_CHECK(rhs_value == 7);
- } else if (tensor_name == 3) {
- TORCH_CHECK(rhs_value == 4);
- }
- }
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = shift(t0, {1, -1});
- auto t2 = shift(t1, {2, -2});
- auto t3 = shift(t2, {3, -3});
- auto t4 = shift(t3, {4, -4});
- auto ref = t4;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShift5ptStencilChain_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- std::vector<std::vector<int>> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};
-
- // First stencil: 5pt stencil
- // stencil1 = (tv0 + tv0[+1][0] + tv0[-1][0] + tv0[0][+1] + tv0[0][-1]) / 5
- std::vector<TensorView*> tv_stencil1_shifts;
- for (const auto& offset : offsets) {
- tv_stencil1_shifts.push_back(shift(tv0, offset));
- }
-
- auto tv_stencil1 = tv0;
- for (auto tv : tv_stencil1_shifts) {
- tv_stencil1 = add(tv_stencil1, tv);
- }
-
- tv_stencil1 = div(tv_stencil1, new Double(tv_stencil1_shifts.size() + 1));
-
- // Second stencil: Same 5pt stencil
- std::vector<TensorView*> tv_stencil2_shifts;
- for (const auto& offset : offsets) {
- tv_stencil2_shifts.push_back(shift(tv_stencil1, offset));
- }
-
- auto tv_stencil2 = tv_stencil1;
- for (auto tv : tv_stencil2_shifts) {
- tv_stencil2 = add(tv_stencil2, tv);
- }
-
- tv_stencil2 = div(tv_stencil2, new Double(tv_stencil2_shifts.size() + 1));
-
- auto tv_out = tv_stencil2;
-
- fusion.addOutput(tv_out);
-
- auto tv0_cache = tv0->cache_after();
-
- std::vector<int> split_factor({16, 16});
-
- tv_out->split(-1, split_factor[1]);
- tv_out->split(0, split_factor[0]);
- tv_out->reorder({{1, 2}, {2, 1}});
-
- tv0->computeAt(tv_out, 2);
-
- // Inline completely all inputs to the first stencil output, except for the
- // tv0 cache
- for (auto tv : tv_stencil1_shifts) {
- tv->computeAt(tv_stencil1, -1);
- }
-
- // Inline completely all inputs to the second stencil output, except
- // for the first stencil output
- for (auto tv : tv_stencil2_shifts) {
- tv->computeAt(tv_stencil2, -1);
- }
-
- tv_out->axis(1)->parallelize(ParallelType::BIDx);
- tv_out->axis(0)->parallelize(ParallelType::BIDy);
-
- auto all_values = DependencyCheck::getAllValsBetween(
- {fusion.inputs().begin(), fusion.inputs().end()}, fusion.outputs());
- for (auto tv : ir_utils::filterByType<TensorView>(all_values)) {
- tv->axis(-1)->parallelize(ParallelType::TIDx);
- tv->axis(-2)->parallelize(ParallelType::TIDy);
- }
-
- tv0_cache->setMemoryType(MemoryType::Shared);
- tv_stencil1->setMemoryType(MemoryType::Shared);
-
- // tv0_cache: (split_factor + 4) * (split_factor + 4)
- // tv_stencil1: (split_factor + 2) * (split_factor + 2)
- GpuLower gpulw(&fusion);
- for (const auto& kir_node : gpulw.kernel()->irNodes()) {
- if (auto alloc = dynamic_cast<kir::Allocate*>(kir_node.get())) {
- auto tensor_name = alloc->buffer()->name();
- if (tensor_name == tv0_cache->name() ||
- tensor_name == tv_stencil1->name()) {
- TORCH_CHECK(alloc->shape().size() == 2);
- for (int i = 0; i < 2; ++i) {
- auto def =
- dynamic_cast<kir::BinaryOp*>(alloc->shape().at(i)->definition());
- auto lhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->lhs());
- TORCH_CHECK(lhs != nullptr && lhs->isConst());
- int lhs_value = *lhs->value();
- auto rhs = dynamic_cast<kir::Int*>(def->as<kir::BinaryOp>()->rhs());
- TORCH_CHECK(rhs != nullptr && rhs->isConst());
- int rhs_value = *rhs->value();
- TORCH_CHECK(lhs_value == split_factor[i]);
- if (tensor_name == tv0_cache->name()) {
- TORCH_CHECK(rhs_value == 4);
- } else if (tensor_name == tv_stencil1->name()) {
- TORCH_CHECK(rhs_value == 2);
- }
- }
- }
- }
- }
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto stencil1 = t0;
- for (const auto& offset : offsets) {
- stencil1 = stencil1 + shift(t0, offset);
- }
- stencil1 = stencil1 / int(offsets.size() + 1);
- auto stencil2 = stencil1;
- for (const auto& offset : offsets) {
- stencil2 = stencil2 + shift(stencil1, offset);
- }
- stencil2 = stencil2 / int(offsets.size() + 1);
- auto ref = stencil2;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-// Shift a reduced tensor
-TEST(NVFuserTest, FusionShiftReduction1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = sum(tv1, {1});
- auto tv3 = shift(tv2, {1});
- fusion.addOutput(tv3);
-
- tv3->split(0, 4);
- tv0->computeAt(tv3, 1);
- tv0->computeAt(tv2, -1);
-
- const int numel_x = 9;
- const int numel_y = 11;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = sum(t1, {1});
- auto t3 = shift(t2, {1});
- auto ref = t3;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-// Parallelized version of FusionShiftReduction1
-TEST(NVFuserTest, FusionShiftReduction2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = sum(tv1, {1});
- auto tv3 = shift(tv2, {1});
- fusion.addOutput(tv3);
-
- tv3->split(0, 4);
- tv0->computeAt(tv3, 1);
-
- tv2->split(-1, 32);
- tv0->computeAt(tv2, -1);
-
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
- tv2->setMemoryType(MemoryType::Shared);
-
- const int numel_x = 201;
- const int numel_y = 301;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = sum(t1, {1});
- auto t3 = shift(t2, {1});
- auto ref = t3;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftRfactor1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = sum(tv1, {1});
- auto tv3 = shift(tv2, {1});
- fusion.addOutput(tv3);
-
- tv3->split(0, 4);
- tv0->computeAt(tv3, 1);
-
- tv2->split(-1, 32);
- auto rf = tv2->rFactor({-2});
- tv0->computeAt(tv2, -1);
- tv0->computeAt(rf, -1);
-
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
-
- tv2->setMemoryType(MemoryType::Shared);
-
- const int numel_x = 201;
- const int numel_y = 301;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = sum(t1, {1});
- auto t3 = shift(t2, {1});
- auto ref = t3;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftBcast1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- auto tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
- auto tv2 = broadcast(tv0, {false, true});
- auto tv3 = shift(tv2, {0, 1});
- auto tv4 = add(tv3, tv1);
- fusion.addOutput(tv4);
-
- tv0->computeAt(tv4, -1);
- tv1->computeAt(tv4, -1);
-
- const int numel_x = 9;
- const int numel_y = 11;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x}, options);
- at::Tensor t1 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0, t1};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(inputs);
-
- auto t4 = t0.unsqueeze(-1).expand({numel_x, numel_y}) + t1;
- auto ref = t4;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftBcast2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- auto tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
- auto tv2 = broadcast(tv0, {false, true});
- auto tv3 = shift(tv2, {1, 0});
- auto tv4 = add(tv3, tv1);
- fusion.addOutput(tv4);
-
- tv4->split(0, 4);
- tv0->computeAt(tv4, 1);
-
- const int numel_x = 9;
- const int numel_y = 11;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x}, options);
- at::Tensor t1 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0, t1};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(inputs);
-
- auto t2 = t0.unsqueeze(-1).expand({numel_x, numel_y});
- auto t3 = shift(t2, {1, 0});
- auto ref = t3 + t1;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-// Combine ShiftBcast1 and ShiftBcast2 with parallelization
-TEST(NVFuserTest, FusionShiftBcast3_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- auto tv1 = makeSymbolicTensor(2);
- fusion.addInput(tv1);
- auto tv2 = broadcast(tv0, {false, true});
- auto tv3 = shift(tv2, {1, 0});
- auto tv4 = shift(tv2, {0, 1});
- auto tv5 = shift(tv2, {-1, -1});
- auto tv6 = add(tv3, tv4);
- auto tv7 = add(tv6, tv5);
- auto tv8 = add(tv7, tv1);
- fusion.addOutput(tv8);
-
- tv8->split(0, 4);
- tv8->split(-1, 4);
- tv0->computeAt(tv8, 1);
-
- tv8->axis(-1)->parallelize(ParallelType::TIDx);
- for (auto tv : {tv8, tv7, tv6, tv5, tv4, tv3, tv2}) {
- tv->axis(1)->parallelize(ParallelType::TIDy);
- }
-
- tv2->setMemoryType(MemoryType::Shared);
-
- const int numel_x = 101;
- const int numel_y = 201;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x}, options);
- at::Tensor t1 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0, t1};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(inputs);
-
- auto t2 = t0.unsqueeze(-1).expand({numel_x, numel_y});
- auto t3 = shift(t2, {1, 0});
- auto t4 = t2;
- auto t5 = shift(t2, {-1, 0});
- auto ref = t3 + t4 + t5 + t1;
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-// See issue #893
-TEST(NVFuserTest, FusionShiftSyncPlacement1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv0, new Double(2));
- auto tv3 = add(tv1, tv2);
- auto tv4 = shift(tv3, {0, 1});
- fusion.addOutput(tv4);
-
- tv4->split(1, 8);
- tv0->computeAt(tv4, 2);
-
- tv2->computeAt(tv3, -1);
-
- tv1->setMemoryType(MemoryType::Shared);
- tv3->setMemoryType(MemoryType::Shared);
-
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv4->axis(-1)->parallelize(ParallelType::TIDx);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
- int numel_y = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x, numel_y}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = t0 + 2;
- auto t3 = add(t1, t2);
- auto t4 = shift(t3, {0, 1});
-
- testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__);
-}
-
-// See issue #893. Top-level placement.
-TEST(NVFuserTest, FusionShiftSyncPlacement2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv0, new Double(2));
- auto tv3 = add(tv1, tv2);
- auto tv4 = shift(tv3, {1});
- fusion.addOutput(tv4);
-
- tv2->computeAt(tv3, -1);
-
- tv1->setMemoryType(MemoryType::Shared);
- tv3->setMemoryType(MemoryType::Shared);
-
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
- tv4->axis(-1)->parallelize(ParallelType::TIDx);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 99;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({numel_x}, options);
- std::vector<IValue> inputs = {t0};
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = t0 + 2;
- auto t3 = add(t1, t2);
- auto t4 = shift(t3, {1});
-
- testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionShiftSyncPlacement3_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(1);
- fusion.addInput(tv0);
- auto tv1 = add(tv0, new Double(1));
- auto tv2 = add(tv1, new Double(2));
- auto tv3 = shift(tv2, {1});
- fusion.addOutput(tv3);
-
- // This doesn't work. syncthreads is needed between tv1 and tv2, but
- // both the loop extent of both tv1 and tv2 has halo, so the loop is
- // not eliminated even though it is parallelized. Moving syncthreads
- // out of the loop would make it placed before tv1, which would make
- // it meaningless.
- // Ideally, an exception should be thrown at this computeAt, but at
- // this point, the fusion is not yet parallelized, nor memory type
- // is set, so this computeAt itself is not an error yet.
- tv1->computeAt(tv2, -1);
-
- tv1->setMemoryType(MemoryType::Shared);
- tv2->setMemoryType(MemoryType::Shared);
-
- tv1->axis(-1)->parallelize(ParallelType::TIDx);
- tv2->axis(-1)->parallelize(ParallelType::TIDx);
- tv3->axis(-1)->parallelize(ParallelType::TIDx);
-
- // The error should be detected when the fusion is lowered.
- ASSERT_ANY_THROW(fusion.printKernel());
-}
-
-// Based on original CUDA provided by Vishal Mehta.
-// Major differences with the original version:
-// - Boundary processing. We always pad by zero. The original version
-// is only defined for the interior domain.
-// - The original version uses additional 2 warps to load the halos
-// along the Y dimension. The other 10 warps are used to load a 32x10
-// tile, and all warps will do coalesced loads. No such optimization
-// is done in the fuser version.
-TEST(NVFuserTest, FusionHorizontalDiffusion_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto inp = makeSymbolicTensor(3);
- fusion.addInput(inp);
- auto coeff = makeSymbolicTensor(3);
- fusion.addInput(coeff);
-
- std::vector<std::vector<int>> offsets{
- {0, 1, 0}, {0, -1, 0}, {0, 0, 1}, {0, 0, -1}};
-
- // T2, T3, T4, T5
- std::vector<TensorView*> inp_neighbors;
- for (const auto& offset : offsets) {
- inp_neighbors.push_back(shift(inp, offset));
- }
-
- // T8
- TensorView* sum_of_neighbors = nullptr;
- for (auto inp_neighbor : inp_neighbors) {
- if (sum_of_neighbors == nullptr) {
- sum_of_neighbors = inp_neighbor;
- } else {
- sum_of_neighbors = add(sum_of_neighbors, inp_neighbor);
- }
- }
-
- // T9 = T0 * 4
- // T10 = T9 - T8
- auto lap = sub(mul(inp, new Double(4)), sum_of_neighbors);
-
- // T11 = shift(T10)
- // T12 = T11 - T10
- auto flx = sub(shift(lap, {0, 0, -1}), lap);
- // T14 = T13 - T0
- // T15 = T12 * T14
- // T16 = T15 > 0
- // T17 = T16 ? 0 : T12
- auto flx_cond = gt(mul(flx, sub(shift(inp, {0, 0, -1}), inp)), new Double(0));
- auto flx0 = where(flx_cond, new Double(0), flx);
-
- // T18 = shift(T10)
- // T19 = T18 - T10
- auto fly = sub(shift(lap, {0, -1, 0}), lap);
- // T20 = shift(T0)
- // T21 = T20 - T0
- // T22 = T19 * T21
- // T23 = T22 > 0
- auto fly_cond = gt(mul(fly, sub(shift(inp, {0, -1, 0}), inp)), new Double(0));
- // T24 = T23 ? 0 : T19
- auto fly0 = where(fly_cond, new Double(0), fly);
-
- // T25 = shift(flx0)
- // T26 = T17 - T25
- // T27 = shift(fly0)
- // T28 = T24 - T27
- // T29 = T26 + T28
- // T30 = T1 * T29
- // T31 = T0 - T30
- auto out =
- sub(inp,
- mul(coeff,
- add(sub(flx0, shift(flx0, {0, 0, 1})),
- sub(fly0, shift(fly0, {0, 1, 0})))));
-
- fusion.addOutput(out);
-
- /////////////////////////////////
- // Scheduling
- /////////////////////////////////
-
- // Step 1: 2D Tiling
-
- const int tile_x = 32;
- const int tile_y = 8;
-
- out->split(-1, tile_x);
- out->split(-3, tile_y);
- out->reorder({{-2, -3}});
- inp->computeAt(out, -3);
- coeff->computeAt(out, -3);
-
- // Step 2: Inlining
-
- // Inline inputs to lap
- auto lap_vals = DependencyCheck::getAllValsBetween({inp}, {lap});
- for (auto val : ir_utils::filterByType<TensorView>(lap_vals)) {
- if (val != lap && val != inp) {
- val->computeAt(lap, -1);
- }
- }
-
- // Inline inputs to flx0
- auto flx0_vals = DependencyCheck::getAllValsBetween({lap, inp}, {flx0});
- for (auto val : ir_utils::filterByType<TensorView>(flx0_vals)) {
- if (val != lap && val != flx0 && val != inp) {
- val->computeAt(flx0, -1);
- }
- }
-
- // Inline inputs to fly0
- auto flxy_vals = DependencyCheck::getAllValsBetween({lap, inp}, {fly0});
- for (auto val : ir_utils::filterByType<TensorView>(flxy_vals)) {
- if (val != lap && val != fly0 && val != inp) {
- val->computeAt(fly0, -1);
- }
- }
-
- // Inline inputs to out
- auto out_vals = DependencyCheck::getAllValsBetween({flx0, fly0}, {out});
- for (auto val : ir_utils::filterByType<TensorView>(out_vals)) {
- if (val != flx0 && val != fly0 && val != out) {
- val->computeAt(out, -1);
- }
- }
-
- // Step 3: Parallelization
-
- // Block parallelization
- out->axis(0)->parallelize(ParallelType::BIDz);
- out->axis(1)->parallelize(ParallelType::BIDy);
- out->axis(2)->parallelize(ParallelType::BIDx);
-
- // Thread parallelization
- for (auto tv : {out, flx0, fly0, lap}) {
- tv->axis(3)->parallelize(ParallelType::TIDy);
- tv->axis(4)->parallelize(ParallelType::TIDx);
- if (tv != out) {
- tv->setMemoryType(MemoryType::Shared);
- }
- }
-
- /////////////////////////////////
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- int numel_x = 101;
- int numel_y = 99;
- int numel_z = 10;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor inp_at = at::randn({numel_z, numel_y, numel_x}, options);
- at::Tensor coeff_at = at::randn({numel_z, numel_y, numel_x}, options);
- std::vector<IValue> inputs = {inp_at, coeff_at};
- auto outputs = fe.runFusion(inputs);
-
- {
- at::Tensor zeros = at::zeros({numel_z, numel_y, numel_x}, options);
- auto lap = inp_at * 4 -
- (shift(inp_at, {0, 1, 0}) + shift(inp_at, {0, -1, 0}) +
- shift(inp_at, {0, 0, 1}) + shift(inp_at, {0, 0, -1}));
- auto flx = shift(lap, {0, 0, -1}) - lap;
- auto flx_cond = (flx * (shift(inp_at, {0, 0, -1}) - inp_at)) > 0;
- auto flx0 = at::where(flx_cond, zeros, flx);
- auto fly = shift(lap, {0, -1, 0}) - lap;
- auto fly_cond = (fly * (shift(inp_at, {0, -1, 0}) - inp_at)) > 0;
- auto fly0 = at::where(fly_cond, zeros, fly);
-
- auto ref = inp_at -
- coeff_at *
- ((flx0 - shift(flx0, {0, 0, 1})) + (fly0 - shift(fly0, {0, 1, 0})));
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
- }
-}
-
-// 3x3 max pooling
-TEST(NVFuserTest, FusionMaxPooling_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Format: CHW
- auto inp = makeSymbolicTensor(3);
- fusion.addInput(inp);
-
- // 3x3 pooling of the HW spatial domain
- std::vector<std::vector<int>> offsets;
- for (int i = -1; i <= 1; ++i) {
- for (int j = -1; j <= 1; ++j) {
- if (i == 0 && j == 0) {
- continue;
- }
- offsets.push_back({i, j});
- }
- }
-
- std::vector<TensorView*> inp_tile({inp});
- for (auto offset : offsets) {
- offset.insert(offset.begin(), 0);
- inp_tile.push_back(shift(inp, offset));
- }
-
- TensorView* max_tensor = nullptr;
- for (auto tv : inp_tile) {
- if (max_tensor == nullptr) {
- max_tensor = tv;
- } else {
- max_tensor = binaryOp(BinaryOpType::Max, max_tensor, tv);
- }
- }
-
- fusion.addOutput(max_tensor);
-
- ////////////////////////////////////
-
- // Cache the input and weight tensors
- auto inp_cache = inp->cache_after();
-
- // Tiling the spatial domain
- const int tile_x = 32;
- const int tile_y = 8;
-
- max_tensor->split(-2, tile_y);
- max_tensor->axis(-2)->parallelize(ParallelType::TIDy);
- max_tensor->split(-1, tile_x);
- max_tensor->axis(-1)->parallelize(ParallelType::TIDx);
- max_tensor->reorder({{-3, -2}});
-
- inp_cache->computeAt(max_tensor, 3);
- inp_cache->axis(-2)->parallelize(ParallelType::TIDy);
- inp_cache->axis(-1)->parallelize(ParallelType::TIDx);
- inp_cache->setMemoryType(MemoryType::Shared);
-
- auto max_tensor_dep =
- DependencyCheck::getAllValsBetween({inp_cache}, {max_tensor});
- for (auto tv : ir_utils::filterByType<TensorView>(max_tensor_dep)) {
- if (tv == inp_cache || tv == max_tensor) {
- continue;
- }
- tv->computeAt(max_tensor, -1);
- }
-
- max_tensor->axis(0)->parallelize(ParallelType::BIDx);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- const int hw = 50;
- const int num_channels = 20;
- const int pooling_window = 3;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor aten_inp = at::randn({num_channels, hw, hw}, options);
- // shift always pads by zero, so if all surrounding values are
- // negative, max pooling would pick a padded value, which isn't the
- // correct behavior. We need to be able to choose the value of
- // padding. In this case, padding by the minimum value would not
- // have this problem. For now, avoid the problem by making sure all
- // values are not negative.
- aten_inp = at::abs(aten_inp);
- std::vector<IValue> inputs = {aten_inp};
-
- auto outputs = fe.runFusion(inputs);
-
- auto ref = at::max_pool2d(
- aten_inp, {pooling_window, pooling_window}, {1, 1}, {1, 1});
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionGatherPadding1_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- const std::vector<int> window_shape = {1, 3};
- const std::vector<std::vector<int>> padding_width = {{0, 0}, {1, 1}};
-
- auto tv1 = gather(tv0, window_shape, padding_width);
-
- fusion.addOutput(tv1);
-
- const int s1 = 11;
- const int s2 = 13;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({s1, s2}, options);
-
- auto ref = gather(t0, window_shape, padding_width);
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion({t0});
-
- TORCH_CHECK(ref.equal(outputs[0]));
-}
-
-TEST(NVFuserTest, FusionGatherPadding2_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- const std::vector<int> window_shape = {1, 3};
- const std::vector<std::vector<int>> padding_width = {{0, 0}, {1, 1}};
-
- auto tv0 = makeSymbolicTensor(2);
- fusion.addInput(tv0);
-
- auto tv1 = add(tv0, new Double(1));
-
- auto tv2 = gather(tv1, window_shape, padding_width);
-
- auto tv3 = sum(tv2, {-1});
-
- fusion.addOutput(tv3);
-
- tv3->split(1, 32);
- tv0->computeAt(tv3, 2);
- tv2->computeAt(tv3, -1);
-
- tv3->axis(0)->parallelize(ParallelType::BIDy);
- tv3->axis(1)->parallelize(ParallelType::BIDx);
- tv3->axis(2)->parallelize(ParallelType::TIDx);
- tv1->axis(2)->parallelize(ParallelType::TIDx);
-
- tv1->setMemoryType(MemoryType::Shared);
-
- const int s1 = 99;
- const int s2 = 101;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::Tensor t0 = at::randn({s1, s2}, options);
- std::vector<IValue> inputs = {t0};
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
- auto outputs = fe.runFusion(inputs);
-
- auto t1 = t0 + 1;
- auto t2 = gather(t1, window_shape, padding_width);
- auto ref = sum(t2, {-1});
-
- testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionConv2DStatic_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Input: [C, H, W]
- auto inp = makeSymbolicTensor(3);
- fusion.addInput(inp);
-
- // Weights: [K, C, 3, 3]
- auto w = makeSymbolicTensor(4);
- fusion.addInput(w);
-
- // Gather a neighbor tile of [3, 3] with padding size of 1 for each
- // side of the spatial dimensions
- auto inp_tile = gather(inp, {1, 3, 3}, {{0, 0}, {1, 1}, {1, 1}});
- // inp_tile: [C, H, W, 1, 3, 3]
-
- auto inp_bc =
- broadcast(inp_tile, {true, false, false, false, false, false, false});
- auto w_bc = broadcast(w, {false, false, true, true, true, false, false});
-
- auto inp_times_w = mul(inp_bc, w_bc);
-
- // Reduce the channel and neighbor tile dimensions
- auto out = sum(inp_times_w, {1, 4, 5, 6});
-
- fusion.addOutput(out);
-
- ////////////////////////////////////
-
- // Cache the input and weight tensors
- auto inp_cache = inp->cache_after();
-
- // Blocking the spatial dimensions
- const int block_w = 16;
- const int block_h = 4;
- // Blocking the channel dimension
- const int block_c = 8;
-
- out->split(2, block_h);
- out->split(4, block_w);
- out->reorder({{3, 4}});
- // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3]
-
- out->split(1, block_c);
- // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3]
-
- auto out_rf = out->rFactor({1, -3, -2, -1});
- // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3]
- // out_rf: [K, Ci, Ho, Wo, Hi, Wi]
-
- // Create a [block_x, block_y] tile on smem
- inp_cache->computeAt(out, 4);
- // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi]
- inp_cache->setMemoryType(MemoryType::Shared);
-
- // Move Ci forward
- out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}});
- inp_cache->computeAt(out_rf, 5);
-
- inp_tile->computeAt(out_rf, -1);
- w->computeAt(out_rf, -1);
-
- out->axis(0)->parallelize(ParallelType::BIDx);
- out->axis(1)->parallelize(ParallelType::TIDz);
- out->axis(4)->parallelize(ParallelType::TIDy);
- out->axis(5)->parallelize(ParallelType::TIDx);
-
- scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf});
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- const int dim_h = 99;
- const int dim_w = 101;
- const int dim_c = 10;
- const int dim_f = 20;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options);
- at::Tensor at_w = at::randn({dim_f, dim_c, 3, 3}, options);
- std::vector<IValue> inputs = {at_inp, at_w};
-
- auto cg_outputs = fe.runFusion(inputs);
-
- at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis
- auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1);
- at_out = at_out.squeeze(0); // drop the N axis
-
- testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__);
-}
-
-// Mostly the same as the static conv test, but the shape of the weights,
-// 3x3 in this case, is given dynamically
-TEST(NVFuserTest, FusionConv2DDynamic_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Input: [C, H, W]
- auto inp = makeSymbolicTensor(3);
- fusion.addInput(inp);
-
- // Weights: [K, C, S, T]
- auto w = makeSymbolicTensor(4);
- fusion.addInput(w);
-
- auto w_h = new Int();
- fusion.addInput(w_h);
- auto w_w = new Int();
- fusion.addInput(w_w);
-
- auto pad_h = new Int();
- fusion.addInput(pad_h);
- auto pad_w = new Int();
- fusion.addInput(pad_w);
-
- // Gather a neighbor tile of [w_dim_h, w_dim_w] with padding
- auto inp_tile = gather(
- inp,
- {new Int(1), w_h, w_w},
- {{new Int(0), new Int(0)}, {pad_h, pad_h}, {pad_w, pad_w}});
- // inp_tile: [C, 1, H - w_h + 1, W - w_w + 1, w_h, w_w]
-
- auto inp_bc =
- broadcast(inp_tile, {true, false, false, false, false, false, false});
- auto w_bc = broadcast(w, {false, false, true, true, true, false, false});
-
- auto inp_times_w = mul(inp_bc, w_bc);
-
- // Reduce the channel and neighbor tile dimensions
- auto out = sum(inp_times_w, {1, 4, 5, 6});
-
- fusion.addOutput(out);
-
- ////////////////////////////////////
- // Cache the input and weight tensors
- auto inp_cache = inp->cache_after();
-
- // Blocking the spatial dimensions
- const int block_w = 16;
- const int block_h = 4;
- // Blocking the channel dimension
- const int block_c = 8;
-
- out->split(2, block_h);
- out->split(4, block_w);
- out->reorder({{3, 4}});
- // out: [K, C, Ho, Wo, Hi, Wi, 1, 3, 3]
-
- out->split(1, block_c);
- // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 3, 3]
-
- auto out_rf = out->rFactor({1, -3, -2, -1});
- // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 3, 3]
- // out_rf: [K, Ci, Ho, Wo, Hi, Wi]
-
- // Create a [block_x, block_y] tile on smem
- inp_cache->computeAt(out, 4);
- // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi]
- inp_cache->setMemoryType(MemoryType::Shared);
-
- // Move Ci forward
- out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}});
- inp_cache->computeAt(out_rf, 5);
-
- inp_tile->computeAt(out_rf, -1);
- w->computeAt(out_rf, -1);
-
- out->axis(0)->parallelize(ParallelType::BIDx);
- out->axis(1)->parallelize(ParallelType::TIDz);
- out->axis(4)->parallelize(ParallelType::TIDy);
- out->axis(5)->parallelize(ParallelType::TIDx);
-
- scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf});
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- const int dim_h = 99;
- const int dim_w = 101;
- const int dim_c = 10;
- const int dim_f = 20;
- const int dim_w_h = 3;
- const int dim_w_w = 3;
- const int dim_pad_h = (dim_w_h - 1) / 2;
- const int dim_pad_w = (dim_w_w - 1) / 2;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options);
- at::Tensor at_w = at::randn({dim_f, dim_c, dim_w_h, dim_w_w}, options);
- std::vector<IValue> inputs = {
- at_inp, at_w, dim_w_h, dim_w_w, dim_pad_h, dim_pad_w};
-
- auto cg_outputs = fe.runFusion(inputs);
-
- at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis
- auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1);
- at_out = at_out.squeeze(0); // drop the N axis
-
- testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__);
-}
-
-// 5x5 followed by 3x3
-TEST(NVFuserTest, FusionConv2DDynamicChain_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Input: [K1, H, W]
- auto inp = makeSymbolicTensor(3);
- fusion.addInput(inp);
-
- // Weights: [K2, K1, S1, T1]
- auto w1 = makeSymbolicTensor(4);
- fusion.addInput(w1);
-
- // Weights: [K3, K2, S2, T2]
- auto w2 = makeSymbolicTensor(4);
- fusion.addInput(w2);
-
- auto w1_h = new Int();
- fusion.addInput(w1_h);
- auto w1_w = new Int();
- fusion.addInput(w1_w);
-
- auto w2_h = new Int();
- fusion.addInput(w2_h);
- auto w2_w = new Int();
- fusion.addInput(w2_w);
-
- auto pad_h1 = new Int();
- fusion.addInput(pad_h1);
- auto pad_w1 = new Int();
- fusion.addInput(pad_w1);
-
- auto pad_h2 = new Int();
- fusion.addInput(pad_h2);
- auto pad_w2 = new Int();
- fusion.addInput(pad_w2);
-
- // Gather a neighbor tile of [w1_h, w1_w] with padding
- auto inp_tile = gather(
- inp,
- {new Int(1), w1_h, w1_w},
- {{new Int(0), new Int(0)}, {pad_h1, pad_h1}, {pad_w1, pad_w1}});
- // inp_tile: [C, 1, H - w1_h + 1, W - w1_w + 1, w1_h, w1_w]
-
- auto inp_bc =
- broadcast(inp_tile, {true, false, false, false, false, false, false});
- auto w1_bc = broadcast(w1, {false, false, true, true, true, false, false});
-
- auto inp_times_w1 = mul(inp_bc, w1_bc);
-
- // Reduce the channel and neighbor tile dimensions
- auto out1 = sum(inp_times_w1, {1, 4, 5, 6});
-
- // Second conv
- auto out1_tile = gather(
- out1,
- {new Int(1), w2_h, w2_w},
- {{new Int(0), new Int(0)}, {pad_h2, pad_h2}, {pad_w2, pad_w2}});
-
- auto out1_bc =
- broadcast(out1_tile, {true, false, false, false, false, false, false});
- auto w2_bc = broadcast(w2, {false, false, true, true, true, false, false});
-
- auto out1_times_w2 = mul(out1_bc, w2_bc);
-
- auto out2 = sum(out1_times_w2, {1, 4, 5, 6});
-
- fusion.addOutput(out2);
-
- ////////////////////////////////////
- // Cache the input and weight tensors
- auto inp_cache = inp->cache_after();
-
- // Blocking the spatial dimensions
- const int block_w = 16;
- const int block_h = 4;
-
- out2->split(2, block_h);
- out2->split(4, block_w);
- out2->reorder({{3, 4}});
- // out2: [K3, K2, Ho, Wo, Hi, Wi, 1, 3, 3]
-
- // Create a [block_x, block_y] tile on smem
- inp_cache->computeAt(out2, 4);
- // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi]
- inp_cache->setMemoryType(MemoryType::Shared);
-
- // Move Ci forward
- out1->reorder({{5, 3}, {3, 4}, {4, 5}});
- out1->setMemoryType(MemoryType::Shared);
-
- inp_cache->computeAt(out1, 4);
-
- inp_tile->computeAt(out1, -1);
- w1->computeAt(out1, -1);
-
- out1_tile->computeAt(out2, -1);
- w2->computeAt(out2, -1);
-
- out2->axis(0)->parallelize(ParallelType::BIDx);
- out2->axis(4)->parallelize(ParallelType::TIDy);
- out2->axis(5)->parallelize(ParallelType::TIDx);
-
- scheduler_utils::parallelizeAllLike(out2, {inp_cache, out1});
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- const int dim_h = 99;
- const int dim_w = 101;
- const int dim_k1 = 3;
- const int dim_k2 = 5;
- const int dim_k3 = 7;
- const int dim_w1_h = 5;
- const int dim_w1_w = 5;
- const int dim_pad1_h = (dim_w1_h - 1) / 2;
- const int dim_pad1_w = (dim_w1_w - 1) / 2;
- const int dim_w2_h = 3;
- const int dim_w2_w = 3;
- const int dim_pad2_h = (dim_w2_h - 1) / 2;
- const int dim_pad2_w = (dim_w2_w - 1) / 2;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor at_inp = at::randn({dim_k1, dim_h, dim_w}, options);
- at::Tensor at_w1 = at::randn({dim_k2, dim_k1, dim_w1_h, dim_w1_w}, options);
- at::Tensor at_w2 = at::randn({dim_k3, dim_k2, dim_w2_h, dim_w2_w}, options);
- std::vector<IValue> inputs = {
- at_inp,
- at_w1,
- at_w2,
- dim_w1_h,
- dim_w1_w,
- dim_w2_h,
- dim_w2_w,
- dim_pad1_h,
- dim_pad1_w,
- dim_pad2_h,
- dim_pad2_w};
-
- auto cg_outputs = fe.runFusion(inputs);
-
- at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis
- auto at_out1 = at::conv2d(at_inp, at_w1, {}, 1, 2);
- auto at_out2 = at::conv2d(at_out1, at_w2, {}, 1, 1);
- at_out2 = at_out2.squeeze(0); // drop the N axis
-
- testValidate(&fusion, cg_outputs, inputs, {at_out2}, __LINE__, __FILE__);
-}
-
-TEST(NVFuserTest, FusionConv2DStaticEvenSizedWindow_CUDA) {
- Fusion fusion;
- FusionGuard fg(&fusion);
-
- // Input: [C, H, W]
- auto inp = makeSymbolicTensor(3);
- fusion.addInput(inp);
-
- // Weights: [K, C, 2, 2]
- auto w = makeSymbolicTensor(4);
- fusion.addInput(w);
-
- // Gather a neighbor tile of [2, 2] with padding size of 1 only for
- // the right side of the spatial dimensions. The left padding is
- // zero so that the output axis stays the same.
- auto inp_tile = gather(inp, {1, 2, 2}, {{0, 0}, {0, 1}, {0, 1}});
- // inp_tile: [C, H, W, 1, 2, 2]
-
- auto inp_bc =
- broadcast(inp_tile, {true, false, false, false, false, false, false});
- auto w_bc = broadcast(w, {false, false, true, true, true, false, false});
-
- auto inp_times_w = mul(inp_bc, w_bc);
-
- // Reduce the channel and neighbor tile dimensions
- auto out = sum(inp_times_w, {1, 4, 5, 6});
-
- fusion.addOutput(out);
-
- ////////////////////////////////////
-
- // Cache the input and weight tensors
- auto inp_cache = inp->cache_after();
-
- // Blocking the spatial dimensions
- const int block_w = 16;
- const int block_h = 4;
- // Blocking the channel dimension
- const int block_c = 8;
-
- out->split(2, block_h);
- out->split(4, block_w);
- out->reorder({{3, 4}});
- // out: [K, C, Ho, Wo, Hi, Wi, 1, 2, 2]
-
- out->split(1, block_c);
- // out: [K, Co, Ci, Ho, Wo, Hi, Wi, 1, 2, 2]
-
- auto out_rf = out->rFactor({1, -3, -2, -1});
- // out_rf: [K, rCo, Ci, Ho, Wo, Hi, Wi, 1, 2, 2]
- // out_rf: [K, Ci, Ho, Wo, Hi, Wi]
-
- // Create a [block_x, block_y] tile on smem
- inp_cache->computeAt(out, 4);
- // inp_cache: [Co, Ho, Wo, Ci, Hi, Wi]
- inp_cache->setMemoryType(MemoryType::Shared);
-
- // Move Ci forward
- out_rf->reorder({{-4, -6}, {-5, -4}, {-6, -5}});
- inp_cache->computeAt(out_rf, 5);
-
- inp_tile->computeAt(out_rf, -1);
- w->computeAt(out_rf, -1);
-
- out->axis(0)->parallelize(ParallelType::BIDx);
- out->axis(1)->parallelize(ParallelType::TIDz);
- out->axis(4)->parallelize(ParallelType::TIDy);
- out->axis(5)->parallelize(ParallelType::TIDx);
-
- scheduler_utils::parallelizeAllLike(out, {inp_cache, out_rf});
-
- FusionExecutor fe;
- fe.compileFusion(&fusion);
-
- const int dim_h = 99;
- const int dim_w = 101;
- const int dim_c = 10;
- const int dim_f = 20;
-
- auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
- at::manual_seed(0);
- at::Tensor at_inp = at::randn({dim_c, dim_h, dim_w}, options);
- at::Tensor at_w = at::randn({dim_f, dim_c, 2, 2}, options);
- std::vector<IValue> inputs = {at_inp, at_w};
-
- auto cg_outputs = fe.runFusion(inputs);
-
- at_inp = at_inp.unsqueeze(0); // at::conv2d needs the N axis
- auto at_out = at::conv2d(at_inp, at_w, {}, 1, 1);
- at_out = at_out.squeeze(0); // drop the N axis
- // The shape of the spatial domain is (dim_h+1)x(dim_w+1), whereas
- // the fuser output has dim_h*dim_w. Drop the first elements to make
- // it match with the fuser output.
- std::vector<at::indexing::TensorIndex> indices{
- at::indexing::Slice(0, at::indexing::None),
- at::indexing::Slice(1, at::indexing::None),
- at::indexing::Slice(1, at::indexing::None)};
- ;
- at_out = at_out.index(indices);
-
- testValidate(&fusion, cg_outputs, inputs, {at_out}, __LINE__, __FILE__);
-}
-
-} // namespace jit
-} // namespace torch
-#endif // #if defined(USE_CUDA)
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-
-#include <unordered_map>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-struct ValidationConstants {
- // Tolerances generated from randn + add + sum fusion
- // compared against double precision
- std::array<std::array<double, 2>, 20> sum_tolerances_float = {
- {{4, 1.51992e-06}, {8, 2.23704e-06}, {16, 2.95788e-06},
- {32, 4.4778e-06}, {64, 6.75395e-06}, {128, 8.57934e-06},
- {256, 1.30594e-05}, {512, 2.19122e-05}, {1024, 3.3451e-05},
- {2048, 5.78476e-05}, {4096, 0.000108292}, {8192, 0.00012207},
- {16384, 0.000136882}, {32768, 0.000248561}, {65536, 0.000407594},
- {131072, 0.000500901}, {262144, 0.000923019}, {524288, 0.00156909},
- {1048576, 0.00223107}, {2097152, 0.00343043}}};
-
- // Tolerances generated from randn + add + sum fusion
- // compared against double precision
- std::array<std::array<double, 2>, 20> sum_tolerances_half = {
- {{4, 0.00390625}, {8, 0.0078125}, {16, 0.0078125},
- {32, 0.0155334}, {64, 0.0156269}, {128, 0.0312042},
- {256, 0.0312548}, {512, 0.0619979}, {1024, 0.0625103},
- {2048, 0.124686}, {4096, 0.12501}, {8192, 0.24945},
- {16384, 0.250049}, {32768, 0.498946}, {65536, 0.500071},
- {131072, 0.985087}, {262144, 1.00006}, {524288, 1.99234},
- {1048576, 2.00032}, {2097152, 3.99073}}};
-
- double base_half_abs_tol = -1;
- double base_half_rel_tol = -1;
- double base_float_abs_tol = -1;
- double base_float_rel_tol = -1;
-};
-
-namespace {
-
-// Returns abs and relative values to use for validation
-std::pair<double, double> getTolerance(
- DataType dtype,
- int64_t reduction_size,
- const ValidationConstants& tolerances) {
- switch (dtype) {
- case DataType::Float:
- // TODO: Pull new tolerances for Double, for now we will just use float
- // tolerances as it should be no worse.
- case DataType::Double: {
- const auto& sum_tolerance_entry = tolerances.sum_tolerances_float;
- const auto& base_abs = tolerances.base_float_abs_tol;
- const auto& base_rel = tolerances.base_float_rel_tol;
-
- if (reduction_size <= 1) {
- // No reduction case
- if (base_abs == -1 || base_rel == -1) {
- return {sum_tolerance_entry[0][1], sum_tolerance_entry[1][1]};
- } else {
- return {base_abs, base_rel};
- }
- } else {
- // Reduction case
- size_t entry = 0;
- while (sum_tolerance_entry[entry][0] < reduction_size &&
- entry < sum_tolerance_entry.size()) {
- entry++;
- }
- double abs_tol = 0.0;
- if (entry + 1 < sum_tolerance_entry.size()) {
- // Grab the next entry up so we have some margin
- abs_tol = sum_tolerance_entry[entry + 1][1];
- } else {
- // If we hit the end of the list, return twice the max error we
- // measured
- abs_tol = sum_tolerance_entry[sum_tolerance_entry.size() - 1][1] * 2.;
- }
- // Relative tol we're going to set to 1% of abs tol just for
- // a small margin of rel error.
- return {abs_tol, abs_tol * 0.01};
- }
- }
- case DataType::Half: {
- // Copied from float case
- const auto& sum_tolerance_entry = tolerances.sum_tolerances_half;
- const auto& base_abs = tolerances.base_half_abs_tol;
- const auto& base_rel = tolerances.base_half_rel_tol;
-
- if (reduction_size <= 1) {
- // No reduction case
- if (base_abs == -1 || base_rel == -1) {
- return {sum_tolerance_entry[0][1], sum_tolerance_entry[1][1]};
- } else {
- return {base_abs, base_rel};
- }
- } else {
- // Reduction case
- size_t entry = 0;
- while (sum_tolerance_entry[entry][0] < reduction_size &&
- entry < sum_tolerance_entry.size()) {
- entry++;
- }
- double abs_tol = 0.0;
- if (entry + 1 < sum_tolerance_entry.size()) {
- // Grab the next entry up so we have some margin
- abs_tol = sum_tolerance_entry[entry + 1][1];
- } else {
- // If we hit the end of the list, return twice the max error we
- // measured
- abs_tol = sum_tolerance_entry[sum_tolerance_entry.size() - 1][1] * 2.;
- }
- // Relative tol we're going to set to 1% of abs tol just for
- // a small margin of rel error.
- return {abs_tol, abs_tol * 0.01};
- }
- }
- case DataType::Int:
- return {0.0, 0.0};
- case DataType::Int32:
- return {0.0, 0.0};
- case DataType::Bool:
- return {0.0, 0.0};
- default:
- TORCH_INTERNAL_ASSERT(
- false, "Do not have tolerance computation for type ", dtype, ".");
- }
-}
-
-class ReductionSizeMapper : private IterVisitor {
- public:
- //! Runs through the fusion and determines how many reductions were performed
- //! to compute each tensorview.
- static std::unordered_map<TensorView*, int64_t> computeReductionSizes(
- Fusion* fusion,
- ExpressionEvaluator& expr_eval) {
- ReductionSizeMapper mapper(fusion, expr_eval);
- return mapper.reduction_map;
- }
-
- private:
- ReductionSizeMapper(Fusion* fusion, ExpressionEvaluator& expr_eval)
- : expr_eval_(expr_eval) {
- // Initialize input values
- for (auto inp : fusion->inputs()) {
- if (inp->isA<TensorView>()) {
- auto tv = inp->as<TensorView>();
- // Shouldn't have any reductions, but run it through analysis anyways.
- reduction_map[tv] = getReductionSize(tv);
- }
- }
-
- IterVisitor::traverse(fusion);
-
- // catch up with dangling outputs;
- for (auto out : fusion->outputs()) {
- if (out->isA<TensorView>()) {
- auto tv = out->as<TensorView>();
- // possible that we have a dangling output that's not generated by any
- // expression. e.g. 0 workspace or null tensor
- if (reduction_map.count(tv) == 0) {
- // Shouldn't have any reductions, but run it through analysis anyways.
- reduction_map[tv] = getReductionSize(tv);
- }
- }
- }
- }
-
- int64_t getReductionSize(const TensorView* tv) {
- int64_t reduction_elements = 1;
- for (auto id : tv->getMaybeRFactorDomain()) {
- if (id->isReduction()) {
- auto inferred_extent = expr_eval_.evaluate(id->extent());
- TORCH_INTERNAL_ASSERT(
- inferred_extent.has_value(),
- "Couldn't figure out what the dimensions of a tensorview is in evaluation for validation. ",
- id,
- " in ",
- tv);
- reduction_elements = reduction_elements * inferred_extent.value();
- }
- }
- return reduction_elements;
- }
-
- void handle(Expr* expr) override {
- if (!ir_utils::isTVOp(expr)) {
- return;
- }
-
- int64_t inp_reduction_elements = 1;
- for (auto inp : expr->inputs()) {
- if (inp->isA<TensorView>()) {
- if (auto tv = inp->as<TensorView>()) {
- inp_reduction_elements =
- std::max(inp_reduction_elements, reduction_map.at(tv));
- }
- }
- }
-
- for (auto out : expr->outputs()) {
- if (out->isA<TensorView>()) {
- auto tv = out->as<TensorView>();
- reduction_map[tv] = getReductionSize(tv) * inp_reduction_elements;
- }
- }
- }
-
- private:
- using IterVisitor::handle;
-
- std::unordered_map<TensorView*, int64_t> reduction_map;
- ExpressionEvaluator& expr_eval_;
-};
-
-ExpressionEvaluator bindInputsAndLaunchParams(
- Fusion* fusion,
- const at::ArrayRef<IValue>& aten_inputs,
- const LaunchParams& launch_constraints) {
- auto expr_eval = executor_utils::bindFusionInputs(aten_inputs, fusion);
- for (auto val : fusion->vals()) {
- if (!val->isA<TensorView>()) {
- continue;
- }
-
- // Roughly taken from executor.cpp/computeLaunchParams
- auto tv = val->as<TensorView>();
- for (auto id : tv->domain()->domain()) {
- if (!(id->isThread() && id->extent()->definition() == nullptr)) {
- continue;
- }
-
- if (id->isBroadcast()) {
- continue;
- }
-
- auto extent = id->extent();
- auto inferred_extent = expr_eval.evaluate(extent);
- auto p_type = id->getParallelType();
-
- if (inferred_extent.has_value()) {
- // This value could have been inferred, make sure it was set right.
- TORCH_CHECK(
- inferred_extent.value() == launch_constraints.getDim(p_type) ||
- launch_constraints.getRawVal(p_type) == -1,
- "inferred that ",
- p_type,
- " should be set to ",
- inferred_extent.value(),
- " but launch constraints specified ",
- launch_constraints.getRawVal(p_type));
- } else {
- // Bind the launch constraint into our evaluation context
- if (launch_constraints.hasDim(id->getParallelType())) {
- expr_eval.bind(extent, launch_constraints.getDim(p_type));
- }
- }
- }
- }
- return expr_eval;
-}
-
-} // namespace
-
-// Validation will look through the fusion and figure out how many elements were
-// reduced to create each output. It will then compute a tolernace to use for
-// allclose based on experimental results. The experimental results were based
-// on adding two tensors then summing them. This of course has an assumption
-// that we're always summing values between -2 and 2. If we start summing values
-// larger than that this approach might not hold.
-inline void testValidate(
- Fusion* fusion,
- const std::vector<at::Tensor>& fusion_outputs,
- const at::ArrayRef<IValue>& aten_inputs,
- const std::vector<at::Tensor>& aten_outputs,
- int line_number,
- const char* file_name,
- std::string err_msg = "",
- const LaunchParams& lparams = LaunchParams(),
- const ValidationConstants& tolerances = ValidationConstants()) {
- FusionGuard fg(fusion);
-
- auto expr_eval = bindInputsAndLaunchParams(fusion, aten_inputs, lparams);
-
- auto reduction_sizes =
- ReductionSizeMapper::computeReductionSizes(fusion, expr_eval);
-
- TORCH_INTERNAL_ASSERT(
- fusion_outputs.size() == aten_outputs.size() &&
- aten_outputs.size() == fusion->outputs().size(),
- "Number of outputs don't match.");
-
- TORCH_INTERNAL_ASSERT(
- fusion->inputs().size() == aten_inputs.size(),
- "Number of inputs don't match.");
-
- for (size_t i = 0; i < fusion->inputs().size(); i++) {
- if (fusion->inputs()[i]->isA<TensorView>()) {
- TORCH_INTERNAL_ASSERT(
- aten_inputs[i].isTensor(), "Mismatch of tensor inputs.");
-
- auto fusion_input_tv = fusion->inputs()[i]->as<TensorView>();
- auto at_tensor = aten_inputs[i].toTensor();
-
- TORCH_INTERNAL_ASSERT(
- at_tensor.dim() ==
- TensorDomain::noReductions(
- fusion_input_tv->getMaybeRFactorDomain())
- .size(),
- "Dimensionality mismatch in inputs.");
- }
- }
-
- for (size_t i = 0; i < fusion->outputs().size(); i++) {
- TORCH_INTERNAL_ASSERT(
- fusion->outputs()[i]->isA<TensorView>(), "Mismatch of tensor outputs.");
-
- auto fusion_output_tensor = fusion_outputs[i];
- auto fusion_output_tv = fusion->outputs()[i]->as<TensorView>();
- auto aten_output_tensor = aten_outputs[i];
-
- TORCH_INTERNAL_ASSERT(
- reduction_sizes.count(fusion_output_tv),
- "Missed reduction size count on fusion output at index: ",
- i);
-
- int64_t reduction_size = reduction_sizes.at(fusion_output_tv);
-
- TORCH_INTERNAL_ASSERT(
- aten_output_tensor.dim() == fusion_output_tensor.dim() &&
- fusion_outputs[i].dim() ==
- TensorDomain::noReductions(
- fusion_output_tv->getMaybeRFactorDomain())
- .size(),
- "Dimensionality mismatch in inputs.");
-
- auto tolerance_values = getTolerance(
- fusion_output_tv->getDataType().value(), reduction_size, tolerances);
-
- if (aten_output_tensor.is_floating_point()) {
- TORCH_INTERNAL_ASSERT(
- aten_output_tensor.allclose(
- fusion_output_tensor.to(aten_output_tensor.dtype()),
- tolerance_values.second,
- tolerance_values.first),
- "\n",
- err_msg,
- "\nValidation error in output ",
- i,
- " on line ",
- line_number,
- " in file ",
- file_name,
- ".\n Detected abs error of: ",
- aten_output_tensor.sub(fusion_output_tensor)
- .abs()
- .max()
- .item()
- .to<double>(),
- "\n absolute tolerance was set to ",
- tolerance_values.first,
- "\n and relative tolerance set to ",
- tolerance_values.second);
- } else {
- TORCH_INTERNAL_ASSERT(
- aten_output_tensor.equal(
- fusion_output_tensor.to(aten_output_tensor.dtype())),
- "\n",
- err_msg,
- ".\n Validation error in output ",
- i,
- " on line ",
- line_number,
- " in file ",
- file_name,
- ".\n Values are not equal and are not a floating type.");
- }
- }
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
import unittest
import os
-import random
import torch
-from torch.nn import functional
-from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR # TEST_WITH_ROCM
-from torch.testing._internal.common_cuda import TEST_MULTIGPU
+from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR
from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed
-from torch.testing import FileCheck
from test_jit import JitTestCase, RUN_CUDA
import itertools
import numpy as np
-import math
-from typing import List
-
-CUDA_MAJOR, CUDA_MINOR = (int(x) for x in torch.version.cuda.split('.'))
-
-os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1'
-os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1'
-os.environ['PYTORCH_NVFUSER_DISABLE_FASTMATH'] = '1'
-os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0'
-os.environ['PYTORCH_NVFUSER_DISABLE_RNG_UNROLL'] = '1'
+os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1'
+os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1'
+os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0'
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
torch._C._jit_set_texpr_fuser_enabled(False)
FUSION_GROUP = 'prim::CudaFusionGroup'
FUSION_GUARD = 'prim::CudaFusionGuard'
-def is_pre_volta():
- prop = torch.cuda.get_device_properties(torch.cuda.current_device())
- return prop.major < 7
-
class TestCudaFuser(JitTestCase):
- special_values = torch.tensor(
- [float("-inf"), -10, -math.pi,
- -1, -0.5, 0, 1, 0.5,
- math.pi, 10, float("inf"),
- float("nan")], dtype=torch.float, device='cuda')
-
- int_types = [
- torch.int8,
- torch.uint8,
- torch.int16,
- torch.int32,
- torch.int64
- ]
-
- support_tensor_dtypes = [
- torch.int32,
- torch.int64,
- torch.float16,
- torch.float32,
- torch.float64,
- torch.bool
- ]
-
def _getSubgraphInFusion(self, graph):
num_node = 0
subgraph = None
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
self.old_guard = torch._C._jit_set_nvfuser_guard_mode(False)
- torch._C._debug_set_autodiff_subgraph_inlining(False)
if(RUN_CUDA):
self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True)
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse)
torch._C._jit_set_nvfuser_guard_mode(self.old_guard)
- torch._C._debug_set_autodiff_subgraph_inlining(True)
super(TestCudaFuser, self).tearDown()
def _run_helper(self, jit_op, op, *args):
torch.cuda.manual_seed_all(123)
o = op(*args)
self.assertEqual(o, jit_o)
- self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, 1, consider_subgraphs=True)
+ self.assertGraphContains(jit_op.graph_for(*args), FUSION_GUARD)
def _run_training_helper(self, jit_op, op, grads, *args):
torch.cuda.manual_seed_all(123)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z, q), FUSION_GUARD)
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_reduction_dtypes(self):
-
- for op in [torch.sum, torch.mean]:
- for dtype in [torch.float16, torch.float32, torch.double]:
- def make_func(op):
- def func(x: torch.Tensor):
- o = torch.mul(x, 1.0)
- o = op(o, dim=[2])
- return o
- return func
-
- x = torch.randn(8, 4, 16, dtype=dtype, device="cuda")
- t = make_func(op)
- t_jit = torch.jit.trace(t, x)
- jit_o = t_jit(x)
- jit_o = t_jit(x)
- o = t(x)
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
- self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
-
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
o = t(x, y, z)
self.assertEqual(o, jit_o)
subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z))
- self.assertGraphContainsExactly(subgraph, 'aten::add', 4, consider_subgraphs=False)
+ self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False)
@unittest.skipIf(True, "Broadcast with different output not supported yet")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
# Currently cannot fuse this
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
+ def _binary_test_helper(self, operation):
+ def t(x: torch.Tensor, y: torch.Tensor, z: float):
+ o = x + z
+ o = operation(o, y)
+ return o
+ t_jit = torch.jit.script(t)
+ x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
+ y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
+ jit_o = t_jit(x, y, 2.0)
+ jit_o = t_jit(x, y, 2.0)
+ o = t(x, y, 2.0)
+ self.assertEqual(o, jit_o)
+ self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD)
+
def _unary_test_helper(self, operation):
def t(x: torch.Tensor, z: float):
o = x + z
torch.relu,
torch.sigmoid,
torch.tanh,
- torch.nn.functional.silu]
+ torch.nn.functional.gelu]
for op in operations:
self._unary_test_helper(op)
- def _unary_type_test_helper(self, operation, dtype, random_data=True):
- shape = (4, 8, 32, 32)
-
- # need additional def of t for boolean ops
- def t(x: torch.Tensor, y: torch.Tensor):
- o = x * y
- o = operation(o)
- return o
-
- y = torch.tensor([1], device="cuda").to(dtype)
-
- if random_data:
- x = torch.randn(shape, dtype=torch.float32, device="cuda")
- if dtype in self.int_types:
- # prefer a larger variance for integer types
- x *= 5
- x = x.to(dtype=dtype)
- else:
- x = self.special_values.to(dtype=dtype)
- try:
- ref = t(x, y)
- except Exception:
- # same way as TE checker, if eager mode throws, ignore this test
- return
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x, y)
- jit_o = t_jit(x, y)
- if dtype in self.support_tensor_dtypes:
- self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
- o = t(x, y)
- self.assertEqual(o, jit_o, msg=f"""
- failing case:
- {dtype} {operation} {x}
- """)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_data_compatibility(self):
- dtypes = [
- *self.int_types,
- torch.float16,
- torch.float32,
- torch.float64
- ]
- operations = [torch.neg,
- torch.abs,
- torch.log,
- torch.log10,
- torch.log1p,
- torch.log2,
- torch.lgamma,
- torch.exp,
- torch.expm1,
- torch.erf,
- torch.erfc,
- torch.cos,
- torch.acos,
- torch.cosh,
- torch.sin,
- torch.asin,
- torch.tan,
- torch.atan,
- torch.sqrt,
- torch.rsqrt,
- torch.ceil,
- torch.floor,
- torch.round,
- torch.trunc,
- torch.frac,
- torch.reciprocal,
- torch.relu,
- torch.sigmoid,
- torch.tanh,
- torch.nn.functional.silu]
- prev_fallback = os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK']
- os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '0'
- for op, dtype in itertools.product(operations, dtypes):
- self._unary_type_test_helper(op, dtype, False) # test special numbers
- self._unary_type_test_helper(op, dtype) # test random data
- os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = prev_fallback
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_category_rule(self):
- def run_tensor(x, z):
- def t(x: torch.Tensor, z: torch.Tensor):
- o = x + z
- o = torch.abs(o)
- return o
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x, z)
- jit_o = t_jit(x, z)
- o = t(x, z)
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o, jit_o)
- self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD)
-
- def run_scalar(x, z):
- def t(x: torch.Tensor, z: float):
- o = x + z
- o = torch.abs(o)
- return o
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x, z)
- jit_o = t_jit(x, z)
- o = t(x, z)
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o, jit_o)
- self.assertGraphContains(t_jit.graph_for(x, z), FUSION_GUARD)
-
- # n-dim with 0-dim (no type-promote)
- x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
- z = torch.tensor(2.0, dtype=torch.double, device="cuda")
- run_tensor(x, z)
-
- # n-dim with 0-dim (type-promote)
- x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long)
- z = torch.tensor(2.0, dtype=torch.double, device="cuda")
- run_tensor(x, z)
-
- # n-dim with n-dim (type-promote)
- x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
- z = torch.randn(4, 8, 32, 32, dtype=torch.double, device="cuda")
- run_tensor(x, z)
-
- # n-dim with scalar (no type-promote)
- x = torch.randn(4, 8, 32, 32, dtype=torch.float16, device="cuda")
- z = torch.tensor(3., dtype=torch.double)
- run_scalar(x, z)
-
- # n-dim with scalar (type-promote)
- x = torch.randn(4, 8, 32, 32, device="cuda").to(dtype=torch.long)
- z = torch.tensor(3., dtype=torch.double)
- run_scalar(x, z)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_unary_bitwise(self):
- def bit_not(x: torch.Tensor):
- return ~(x + 0)
-
- jitted = torch.jit.script(bit_not)
- x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(torch.long)
- jit_o = bit_not(x)
- jit_o = bit_not(x)
- o = bit_not(x)
- self.assertEqual(o, jit_o)
- jitted.graph_for(x) # Shows up in second instance, not first
- self.assertGraphContains(jitted.graph_for(x), FUSION_GUARD)
-
- def bool_not(x: torch.Tensor, y: torch.Tensor):
- return ~(x & y)
-
- jitted = torch.jit.script(bool_not)
- x = torch.rand(4, 8, 32, 32, dtype=torch.float, device="cuda").round().to(torch.bool)
- y = torch.rand(4, 8, 32, 32, dtype=torch.float, device="cuda").round().to(torch.bool)
- jit_o = bool_not(x, y)
- jit_o = bool_not(x, y)
- o = bool_not(x, y)
- self.assertEqual(o, jit_o)
- jitted.graph_for(x, y) # Shows up in second instance, not first
- self.assertGraphContains(jitted.graph_for(x, y), FUSION_GUARD)
-
- def _binary_test_helper(self, operation, dtype):
- def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
- o = x + z
- o = operation(o, y)
- return o
- x = (torch.randn(4, 32, 32, dtype=torch.float, device="cuda") * 5).to(dtype)
- y = (torch.randn(4, 32, 32, dtype=torch.float, device="cuda") * 5).to(dtype)
- # Avoid division by zero for integer tensors
- div_like = [torch.div, torch.fmod, torch.remainder]
- if operation in div_like and (dtype == torch.int32 or dtype == torch.int64):
- y[y == 0] = 1
- z = torch.tensor([2], device="cuda").to(dtype)
- o = t(x, y, z)
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x, y, z)
- jit_o = t_jit(x, y, z)
-
- self.assertEqual(o, jit_o)
- self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
-
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_binary_ops(self):
- data_types = [
- torch.float32,
- torch.float64,
- torch.int32,
- torch.int64
- ]
- # need some extra support
- # to handle below with integer inputs, and they
- # don't look like popular integer ops in models
- # , TODO: insert assertions in cpp
- # if decide not to fuse these on int
- skip_for_integer = [
- torch.atan2,
- torch.fmod,
- torch.pow,
- torch.div
- ]
operations = [torch.div,
torch.mul,
torch.atan2,
torch.gt,
torch.le,
torch.lt]
- for op, dtype in itertools.product(operations, data_types):
- if (dtype not in self.int_types) or (op not in skip_for_integer):
- self._binary_test_helper(op, dtype)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_binary_bitwise(self):
- def jit_or(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
- return (x & y) | z
-
- def jit_xor(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
- return (x & y) ^ z
-
- def jit_lshift(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
- return (x & y) << z
-
- def jit_rshift(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
- return (x & y) >> z
-
- for jit_func in [jit_or, jit_xor, jit_lshift, jit_rshift]:
- x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(torch.long)
- y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(5).to(torch.long)
- z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda").mul(2).to(torch.long)
-
- jitted = torch.jit.script(jit_func)
- jit_o = jitted(x, y, z)
- jit_o = jitted(x, y, z)
- o = jit_func(x, y, z)
- self.assertEqual(o, jit_o)
- self.assertGraphContains(jitted.graph_for(x, y, z), FUSION_GUARD)
-
- # We shouldn't need this redefinition of the function, but otherwise it won't recompile for a new type
- def jit_or(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
- return (x & y) | z
-
- def jit_xor(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
- return (x & y) ^ z
-
- for jit_func in [jit_or, jit_xor]:
- x = torch.rand(4, 2, dtype=torch.float, device="cuda").round().to(torch.bool)
- y = torch.rand(4, 2, dtype=torch.float, device="cuda").round().to(torch.bool)
- z = torch.rand(4, 2, dtype=torch.float, device="cuda").round().to(torch.bool)
-
- jitted = torch.jit.script(jit_func)
- jit_o = jitted(x, y, z)
- jit_o = jitted(x, y, z)
- o = jit_func(x, y, z)
- self.assertEqual(o, jit_o)
- self.assertGraphContains(jitted.graph_for(x, y, z), FUSION_GUARD)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_type_as_op(self):
- def t(x: torch.Tensor, y: torch.Tensor, z: float):
- o = torch.lt(x, z)
- o = o.type_as(y)
- return o
- t_jit = torch.jit.script(t)
- x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
- y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
- jit_o = t_jit(x, y, 0.5)
- jit_o = t_jit(x, y, 0.5)
- o = t(x, y, 0.5)
- self.assertEqual(o, jit_o)
- self.assertGraphContains(t_jit.graph_for(x, y, 0.5), FUSION_GUARD)
+ for op in operations:
+ self._binary_test_helper(op)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
# legacy fuser does not work for rand_like, see issue #34361
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_random_topo(self):
- os.environ["PYTORCH_NVFUSER_DISABLE_FALLBACK"] = "1"
+ os.environ["PYTORCH_CUDA_FUSER_DISABLE_FALLBACK"] = "1"
self.assertTrue(runDefaultTestWithSeed(28449))
def _compare(self, desc, inp1, inp2, error):
o = torch.relu(o)
return o
- x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute(
- [perm0.index(i) for i in range(len(sizes))])
+ x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute([perm0.index(i) for i in range(len(sizes))])
if broadcast_axis >= 0:
sizes[broadcast_axis] = 1
- y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute(
- [perm1.index(i) for i in range(len(sizes))])
+ y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute([perm1.index(i) for i in range(len(sizes))])
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
x = [7, 8, 12]
self._permutation_helper(x, b_axis, torch.float32, "cuda", perm0, perm1)
- def _reduction_helper(self, sizes, reduction_axis, dtype, device, perm0, perm1, keepdim=False):
+ def _reduction_helper(self, sizes, reduction_axis, dtype, device, perm0, perm1):
class MyReduction(torch.nn.Module):
- __constants__ = ['reduction_axis', 'keepdim']
+ __constants__ = ['reduction_axis']
def __init__(self):
super(MyReduction, self).__init__()
self.reduction_axis = reduction_axis
- self.keepdim = keepdim
def forward(self, x: torch.Tensor, y: torch.Tensor):
o = torch.add(x, y)
- o = torch.sum(o, dim=self.reduction_axis, keepdim=self.keepdim)
+ o = torch.sum(o, dim=self.reduction_axis)
return o
t = MyReduction()
- x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute(
- [perm0.index(i) for i in range(len(sizes))])
- y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute(
- [perm1.index(i) for i in range(len(sizes))])
+ x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute([perm0.index(i) for i in range(len(sizes))])
+ y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute([perm1.index(i) for i in range(len(sizes))])
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
# to single element (codegen limitation at this moment)
for num_reduce_dim in range(1, len(x)):
for axes in itertools.combinations(range(len(x)), num_reduce_dim):
- for keepdim in (True, False):
- perm0 = range(len(x))
- perm1 = range(len(x))
- self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1, keepdim)
-
- def _layer_norm_autodiff_helper(self, model, grad, shapes, args):
- jit_model = torch.jit.script(model)
-
- eps = np.random.random() * 1e-4
- use_cudnn = bool(np.random.randint(0, 2))
-
- # profile/optimization runs
- for i in range(3):
- jit_o = jit_model(shapes, *args, eps, use_cudnn)
- jit_o.backward(grad)
-
- ref_args = [t.detach().clone().requires_grad_() for t in args]
- [t.grad.zero_() for t in args]
- jit_o = jit_model(shapes, *args, eps, use_cudnn)
- jit_o.backward(grad)
-
- o = model(shapes, *ref_args, eps, use_cudnn)
- o.backward(grad)
- self.assertEqual(jit_o, o)
- for arg, ref_arg in zip(args, ref_args):
- self.assertEqual(arg.grad, ref_arg.grad)
-
- # check fusion in fw & bw
- g = jit_model.graph_for(shapes, *args, eps, use_cudnn)
- for node in g.nodes():
- n = node
- dbg_state = jit_model.get_debug_state()
- for val in dbg_state.execution_plans.values():
- v = val
- state2 = v.code.grad_executor_states()
- for val in state2[0].execution_plans.values():
- v2 = val
- FileCheck().check(FUSION_GUARD).run(g)
- FileCheck().check(FUSION_GUARD).run(v2.graph)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_layer_norm_autodiff(self):
- def t_wb(shapes: List[int], x, w, b, eps: float, cudnn: bool):
- o = torch.layer_norm(x, shapes, w, b, eps, cudnn)
- o = torch.relu(o)
- return o
-
- def t_w(shapes: List[int], x, w, eps: float, cudnn: bool):
- o = torch.layer_norm(x, shapes, w, None, eps, cudnn)
- o = torch.relu(o)
- return o
-
- def t_b(shapes: List[int], x, b, eps: float, cudnn: bool):
- o = torch.layer_norm(x, shapes, None, b, eps, cudnn)
- o = torch.relu(o)
- return o
-
- def t(shapes: List[int], x, eps: float, cudnn: bool):
- o = torch.layer_norm(x, shapes, None, None, eps, cudnn)
- o = torch.relu(o)
- return o
-
- model = {3: t_wb, 2: t_w, 1: t_b, 0: t}
+ perm0 = range(len(x))
+ perm1 = range(len(x))
+ self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1)
- for w, b in itertools.product([True, False], repeat=2):
- batch = [4]
- shapes = [2, 3, 4]
- m = model[w * 2 + b]
-
- grad = torch.randn(batch + shapes, dtype=torch.float32, device="cuda")
- args = [torch.randn(batch + shapes, dtype=torch.float32, device="cuda").requires_grad_()]
- if w:
- args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
- if b:
- args.append(torch.randn(shapes, dtype=torch.float32, device="cuda").requires_grad_())
- self._layer_norm_autodiff_helper(m, grad, shapes, args)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_layer_norm_parser(self):
- dtype = torch.float32
- device = "cuda"
- x = torch.randn([4, 4, 2], dtype=dtype, device=device)
- w = torch.randn([4, 2], dtype=dtype, device=device)
- b = torch.randn([4, 2], dtype=dtype, device=device)
-
- def t(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor):
- o = torch.relu(x)
- o = torch.layer_norm(o, [4, 2], w, b, 1e-5)
- return o
-
- o = t(x, w, b)
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x, w, b)
- jit_o = t_jit(x, w, b)
- o = t(x, w, b)
- self.assertGraphContains(t_jit.graph_for(x, w, b), FUSION_GUARD)
-
- def _native_layer_norm_helper(self, shape, norm_shape, dtype, device, error, affine=True):
- class MyLayerNorm(torch.nn.Module):
- __constants__ = ['norm_shape']
-
- def __init__(self, elementwise_affine=True):
- super(MyLayerNorm, self).__init__()
- self.norm_shape = norm_shape
- if elementwise_affine:
- self.weight = torch.randn(norm_shape, dtype=dtype, device=device)
- self.bias = torch.randn(norm_shape, dtype=dtype, device=device)
- with torch.no_grad():
- self.weight.fill_(1)
- self.bias.fill_(0)
- else:
- self.weight = None
- self.bias = None
-
- def forward(self, x: torch.Tensor):
- o = torch.relu(x)
- o = torch.native_layer_norm(o, self.norm_shape, self.weight, self.bias, 1e-5)
- return o
-
- t = MyLayerNorm(affine)
-
- x = torch.randn(shape, dtype=dtype, device=device)
- t_jit = torch.jit.script(t)
- jit_o, jit_mean, jit_rstd = t_jit(x)
- jit_o, jit_mean, jit_rstd = t_jit(x)
- o, mean, rstd = t(x)
- self.assertEqual(o.dtype, jit_o.dtype)
- # numerical issues here due to our scheduling.
- # can't use `self.assertEqual(o, jit_o)`
- self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
- self.assertTrue(self._compare("comparing mean failed", mean, jit_mean, error))
- self.assertTrue(self._compare("comparing rstd failed", rstd, jit_rstd, error))
- self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_native_layer_norm(self):
- dims = 4
- rnds = 3
- for idx in range(rnds):
- for offset in range(1, dims):
- for affine in (True, False):
- input_shape = [random.randint(10, 30) for idx in range(dims)]
- norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)]
- self._native_layer_norm_helper(input_shape, norm_shape, torch.float32, "cuda", 1e-4, affine)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_native_layer_norm_half(self):
- dims = 4
- rnds = 3
- for idx in range(rnds):
- for offset in range(1, dims):
- input_shape = [random.randint(10, 30) for idx in range(dims)]
- norm_shape = [input_shape[idx] for idx in range(dims - offset, dims)]
- self._native_layer_norm_helper(input_shape, norm_shape, torch.float16, "cuda", 5e-3)
-
- def _norm_helper(self, shape, dtype, device, error, is_batch_norm_else_instance_norm):
- class MyBatchNorm(torch.nn.Module):
- def __init__(self):
- super(MyBatchNorm, self).__init__()
-
- def forward(self, x: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
- o = torch.nn.functional.batch_norm(x, r_mean, r_var, training=True)
- o = torch.relu(o)
- return o
-
- class MyInstanceNorm(torch.nn.Module):
- def __init__(self):
- super(MyInstanceNorm, self).__init__()
-
- def forward(self, x: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
- o = torch.nn.functional.instance_norm(x, r_mean, r_var, use_input_stats=True)
- o = torch.relu(o)
- return o
-
- t = MyBatchNorm() if is_batch_norm_else_instance_norm else MyInstanceNorm()
-
- x = torch.randn(shape, dtype=dtype, device=device)
- running_mean = torch.zeros(shape[1], dtype=torch.float32, device=device)
- running_var = torch.ones(shape[1], dtype=torch.float32, device=device)
- t_jit = torch.jit.script(t)
-
- eager_running_mean = running_mean.clone()
- eager_running_var = running_var.clone()
- jit_running_mean = running_mean.clone()
- jit_running_var = running_var.clone()
-
- jit_o = t_jit(x, running_mean.clone(), running_var.clone())
-
- self.assertTrue(self._compare("prerun comparing running_mean failed", eager_running_mean, jit_running_mean, error))
- self.assertTrue(self._compare("prerun comparing running_var failed", eager_running_var, jit_running_var, error))
-
- jit_o = t_jit(x, jit_running_mean, jit_running_var)
- o = t(x, eager_running_mean, eager_running_var)
- self.assertEqual(o.dtype, jit_o.dtype)
- # numerical issues here due to our scheduling.
- # can't use `self.assertEqual(o, jit_o)`
- self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
- self.assertTrue(self._compare("comparing running_mean failed", eager_running_mean, jit_running_mean, error))
- self.assertTrue(self._compare("comparing running_var failed", eager_running_var, jit_running_var, error))
- self.assertGraphContains(t_jit.graph_for(x, running_mean, running_var), FUSION_GUARD)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_norm(self):
- output_elements = 10000
- channel_sizes = [67, 457, 1024, 4096]
-
- with torch.backends.cudnn.flags(enabled=False):
- for is_batch_norm_else_instance_norm in [False, True]:
- for dims in range(3, 6):
- output_size = int(pow(output_elements, 1. / (dims - 1)))
- for C in channel_sizes:
- x = [output_size for idx in range(dims)]
- x[1] = C
- self._norm_helper(x, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_norm_large(self):
- output_elements = 262144
- channel_sizes = 67, 457, 1024
-
- for is_batch_norm_else_instance_norm in [True, False]:
- for dims in range(3, 6):
- output_size = int(pow(output_elements, 1. / (dims - 1)))
- for C in channel_sizes:
- x = [output_size for idx in range(dims)]
- x[1] = C
- self._norm_helper(x, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_norm_half(self):
- output_elements = 10000
- channel_sizes = [67, 457, 1024, 4096]
-
- with torch.backends.cudnn.flags(enabled=False):
- for is_batch_norm_else_instance_norm in [False, True]:
- for dims in range(3, 6):
- output_size = int(pow(output_elements, 1. / (dims - 1)))
- for C in channel_sizes:
- x = [output_size for idx in range(dims)]
- x[1] = C
- self._norm_helper(x, torch.float16, "cuda", 5e-3, is_batch_norm_else_instance_norm)
-
- def _softmax_helper(self, shape, reduction_axis, dtype, device, error):
- class MySoftmax(torch.nn.Module):
- __constants__ = ['reduction_axis']
-
- def __init__(self):
- super(MySoftmax, self).__init__()
- self.reduction_axis = reduction_axis
-
- def forward(self, x: torch.Tensor, y: torch.Tensor):
- o = torch.add(x, y)
- o = torch.nn.functional.softmax(o, dim=self.reduction_axis)
- return o
-
- t = MySoftmax()
-
- x = torch.randn(shape, dtype=dtype, device=device)
- y = torch.randn(shape, dtype=dtype, device=device)
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x, y)
- jit_o = t_jit(x, y)
- o = t(x, y)
- self.assertEqual(o.dtype, jit_o.dtype)
- # numerical issues here due to our scheduling.
- # can't use `self.assertEqual(o, jit_o)`
- self.assertTrue(self._compare("comparing output failed", o, jit_o, error))
- self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_softmax(self):
- output_size = 10000
- dims = 4
- output_size = int(pow(output_size, 1. / dims))
- reduction_sizes = [67, 256, 1024, 4096]
-
- for reduction_dim in range(dims):
- for reduction_size in reduction_sizes:
- x = [output_size for idx in range(dims)]
- x[reduction_dim] = reduction_size
- self._softmax_helper(x, reduction_dim, torch.float32, "cuda", 1e-4)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_softmax_half(self):
- output_size = 10000
- dims = 4
- output_size = int(pow(output_size, 1. / dims))
- reduction_sizes = [67, 256, 1024, 4096]
-
- for reduction_dim in range(dims):
- for reduction_size in reduction_sizes:
- x = [output_size for idx in range(dims)]
- x[reduction_dim] = reduction_size
- self._softmax_helper(x, reduction_dim, torch.float16, "cuda", 5e-3)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
for perm1 in itertools.permutations(range(len(x))):
self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1)
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
- def test_channels_last_with_broadcast(self):
- # setting this true forces a new graph to be generated with a new
- # input a different broadcast shape
- torch._C._jit_set_nvfuser_guard_mode(True)
-
- def t(x: torch.Tensor, y: torch.Tensor):
- o = torch.mul(x, y)
- o = o + 2.0
+ def test_reduction_dtype(self):
+ def t(x: torch.Tensor):
+ o = torch.mul(x, 1.0)
+ o = torch.sum(o, dim=[2], dtype=torch.float32)
return o
t_jit = torch.jit.script(t)
- # Single Channel broadcasts
- # Test 1
- x = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda")
- x = x.to(memory_format=torch.channels_last)
-
- y = torch.randn(8, 4, 10, 1, dtype=torch.float, device="cuda")
- y = y.to(memory_format=torch.channels_last)
-
- jit_o = t_jit(x, y)
- jit_o = t_jit(x, y)
- o = t(x, y)
-
+ x = torch.randn(8, 4, 16, dtype=torch.float, device="cuda")
+ jit_o = t_jit(x)
+ jit_o = t_jit(x)
+ o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
- jit_o.is_contiguous(memory_format=torch.channels_last))
- self.assertEqual(o, jit_o)
-
- # Test 2
- y = torch.randn(8, 4, 1, 16, dtype=torch.float, device="cuda")
- y = y.to(memory_format=torch.channels_last)
+ self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
+ self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
- jit_o = t_jit(x, y)
- jit_o = t_jit(x, y)
- o = t(x, y)
+ @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+ @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
+ "Requires fusion optimization pass to be effective")
+ def test_reduction_half(self):
+ def t(x: torch.Tensor):
+ o = torch.mul(x, 1.0)
+ o = torch.sum(o, dim=[2])
+ return o
+ t_jit = torch.jit.script(t)
+ x = torch.randn(8, 4, 16, dtype=torch.float16, device="cuda")
+ jit_o = t_jit(x)
+ jit_o = t_jit(x)
+ o = t(x)
self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
- jit_o.is_contiguous(memory_format=torch.channels_last))
- self.assertEqual(o, jit_o)
+ self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4))
+ self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
- # Test 3
- y = torch.randn(8, 1, 10, 16, dtype=torch.float, device="cuda")
- y = y.to(memory_format=torch.channels_last)
-
- jit_o = t_jit(x, y)
- jit_o = t_jit(x, y)
- o = t(x, y)
-
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
- jit_o.is_contiguous(memory_format=torch.channels_last))
- self.assertEqual(o, jit_o)
-
- # Test 3
- y = torch.randn(1, 4, 10, 16, dtype=torch.float, device="cuda")
- y = y.to(memory_format=torch.channels_last)
-
- jit_o = t_jit(x, y)
- jit_o = t_jit(x, y)
- o = t(x, y)
-
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
- jit_o.is_contiguous(memory_format=torch.channels_last))
- self.assertEqual(o, jit_o)
-
- '''
- Currently, the JIT doesn't have tensor merge logic to handle adding
- a broadcast tensor with more than one broadcast into a non-broadcast
- tensor. Therefore, either of these tests can fail depending on the
- sort implementation. The second test is known to fail.
-
- # Two Channel broadcasts
- # Test 1
- y = torch.randn(8, 4, 1, 1, dtype=torch.float, device="cuda")
- y = y.to(memory_format=torch.channels_last)
-
- jit_o = t_jit(x, y)
- jit_o = t_jit(x, y)
- o = t(x, y)
-
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
- jit_o.is_contiguous(memory_format=torch.channels_last))
- self.assertEqual(o, jit_o)
-
- # Test 2
- y = torch.randn(8, 4, 1, 1, dtype=torch.float, device="cuda")
- y = y.to(memory_format=torch.channels_last).transpose(2,3)
- x = x.transpose(2,3)
-
- jit_o = t_jit(x, y)
- jit_o = t_jit(x, y)
- o = t(x, y)
-
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o.is_contiguous(memory_format=torch.channels_last),
- jit_o.is_contiguous(memory_format=torch.channels_last))
- self.assertEqual(o, jit_o)
- '''
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_pw_single_reduction_partition(self):
- sizes = [2, 2, 2]
- dtype = torch.float
- device = "cuda"
- x = torch.randn(sizes, dtype=dtype, device=device)
- y = torch.randn(sizes, dtype=dtype, device=device)
- z = torch.randn(sizes, dtype=dtype, device=device)
+ @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+ @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
+ "Requires fusion optimization pass to be effective")
+ def test_pw_single_reduction_partition(self):
+ sizes = [8, 8, 8]
+ dtype = torch.float
+ device = "cuda"
+ x = torch.randn(sizes, dtype=dtype, device=device)
+ y = torch.randn(sizes, dtype=dtype, device=device)
+ z = torch.randn(sizes, dtype=dtype, device=device)
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = torch.add(x, y)
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_permutation_preservation(self):
- sizes = [2, 2, 2, 2]
- dtype = torch.float
- device = "cuda"
- x = torch.randn(sizes, dtype=dtype, device=device).to(memory_format=torch.channels_last)
-
- def t(x: torch.Tensor):
- o = torch.relu(x)
- o = torch.sum(o, dim=[0])
- return o
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x)
- jit_o = t_jit(x)
- o = t(x)
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o, jit_o)
- self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
- # we should preserve permutation to inputs
- self.assertEqual(jit_o.stride(), (1, 4, 2))
-
- def t(x: torch.Tensor):
- o = torch.relu(x)
- o = torch.add(o, 1.0)
- return o
-
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x)
- jit_o = t_jit(x)
- o = t(x)
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o, jit_o)
- self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
- self.assertTrue(jit_o.is_contiguous(memory_format=torch.channels_last))
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_normalization_partition(self):
- sizes = [8, 8, 8]
- dtype = torch.float
- device = "cuda"
- x = torch.randn(sizes, dtype=dtype, device=device)
- y = torch.randn(sizes, dtype=dtype, device=device)
- z = torch.randn(sizes, dtype=dtype, device=device)
- r_m = torch.randn(8, dtype=dtype, device=device)
- r_v = torch.randn(8, dtype=dtype, device=device)
-
- def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, r_mean: torch.Tensor, r_var: torch.Tensor):
- o = torch.add(x, y)
- o = torch.nn.functional.softmax(o, dim=0)
- o = torch.add(o, z)
- o = torch.nn.functional.batch_norm(o, r_mean, r_var, training=True)
- return o
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x, y, z, r_m, r_v)
- jit_o = t_jit(x, y, z, r_m, r_v)
- o = t(x, y, z, r_m, r_v)
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o, jit_o)
- self.assertGraphContains(t_jit.graph_for(x, y, z, r_m, r_v), FUSION_GUARD)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_sum_to_one(self):
- dtype = torch.float
- device = "cuda"
- x = torch.randn([4, 5, 6], dtype=dtype, device=device)
-
- def t(x: torch.Tensor):
- o = torch.add(x, 0)
- o = torch.sum(o, dim=[0, 1, 2])
- return o
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x)
- jit_o = t_jit(x)
- o = t(x)
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o, jit_o)
- self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
self.assertEqual(o, jit_o)
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD)
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_trivial_reduction(self):
- dtype = torch.float
- device = "cuda"
- x = torch.randn([1, 4, 8], dtype=dtype, device=device)
-
- def t(x: torch.Tensor):
- o = torch.add(x, 0)
- o = torch.sum(o, dim=[0])
- o = torch.sum(o, dim=[0])
- return o
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x)
- jit_o = t_jit(x)
- o = t(x)
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o, jit_o)
- self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD)
-
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
repro_jit = torch.jit.script(repro)
self._run_helper(repro_jit, repro, x, 0.6)
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
# have been optimized away
self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 0)
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_profile_ivalue(self):
- dtype = torch.float
- device = "cuda"
- x = torch.randn([7, 4, 7], dtype=dtype, device=device)
- y = torch.randn([7, 4, 7], dtype=dtype, device=device)
-
- def t(x: torch.Tensor, y: torch.Tensor, dim: List[int], keepdim: bool):
- o = torch.add(x, y)
- o = o.sum(dim, keepdim=keepdim)
- return o
-
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x, y, (0, 1), False)
- jit_o = t_jit(x, y, (0, 1), False)
- o = t(x, y, (0, 1), False)
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o, jit_o)
- self.assertGraphContains(t_jit.graph_for(x, y, (0, 1), False), FUSION_GUARD)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_sum_to_size(self):
- dtype = torch.float
- device = "cuda"
- x = torch.randn([2, 4, 4], dtype=dtype, device=device)
- y = torch.randn([2, 4, 4], dtype=dtype, device=device)
-
- def t(x: torch.Tensor, y: torch.Tensor, new_size: List[int]):
- o = torch.add(x, y)
- o = o.sum_to_size(new_size)
- return o
-
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x, y, (4, 1))
- jit_o = t_jit(x, y, (4, 1))
- o = t(x, y, (4, 1))
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o, jit_o)
- self.assertGraphContains(t_jit.graph_for(x, y, (4, 1)), FUSION_GUARD)
-
- # update shape: old kernel should handle dynamic shape well without
- # recompilation
- x = torch.randn([2, 5, 8], dtype=dtype, device=device)
- y = torch.randn([2, 5, 8], dtype=dtype, device=device)
- # (TODO) check executed kernel, should extend autograd.profiler to fused
- # kernels
- jit_o = t_jit(x, y, (5, 1))
- o = t(x, y, (5, 1))
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o, jit_o)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_grad_sum_to_size(self):
- dtype = torch.float
- device = "cuda"
- x = torch.randn([2, 4, 4], dtype=dtype, device=device).requires_grad_()
- y = torch.randn([4], dtype=dtype, device=device).requires_grad_()
- grad = torch.randn([2, 4, 4], dtype=dtype, device=device)
-
- ref_x = x.detach().clone().requires_grad_()
- ref_y = y.detach().clone().requires_grad_()
-
- def t(x: torch.Tensor, y: torch.Tensor):
- o = torch.add(x, y)
- o = torch.relu(o)
- return o
-
- # profiling runs for forward & backward
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x, y)
- jit_o.backward(grad)
- jit_o = t_jit(x, y)
- jit_o.backward(grad)
-
- x.grad = None
- y.grad = None
- jit_o = t_jit(x, y)
- jit_o.backward(grad)
- o = t(ref_x, ref_y)
- o.backward(grad)
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o, jit_o)
- self.assertEqual(x.grad, ref_x.grad)
- self.assertEqual(y.grad, ref_y.grad)
- bwd_graph = list(
- list(t_jit.get_debug_state().execution_plans.values())[
- 0].code.grad_executor_states()[0].execution_plans.values()
- )[0].graph
- FileCheck().check(FUSION_GUARD).run(bwd_graph)
-
- # update shape: old kernel should handle dynamic shape well without
- # recompilation
- x = torch.randn([2, 5, 8], dtype=dtype, device=device).requires_grad_()
- y = torch.randn([8], dtype=dtype, device=device).requires_grad_()
- ref_x = x.detach().clone().requires_grad_()
- ref_y = y.detach().clone().requires_grad_()
- grad = torch.randn([2, 5, 8], dtype=dtype, device=device)
- jit_o = t_jit(x, y)
- # (TODO) check executed kernel, should extend autograd.profiler to fused
- # kernels
- jit_o.backward(grad)
- o = t(ref_x, ref_y)
- o.backward(grad)
- self.assertEqual(o.dtype, jit_o.dtype)
- self.assertEqual(o, jit_o)
- self.assertEqual(x.grad, ref_x.grad)
- self.assertEqual(y.grad, ref_y.grad)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_add_backward_with_alpha(self):
- x = torch.randn(4, 2, dtype=torch.float32, device='cuda', requires_grad=True)
- y = torch.randn(4, 2, dtype=torch.float32, device='cuda', requires_grad=True)
- grad = torch.randn(4, 2, dtype=torch.float32, device='cuda')
-
- # Test that a mul is not generated when not needed
- # Alpha=1.0 or is not used
- def test1(x: torch.Tensor, y: torch.Tensor):
- o = torch.add(x, y, alpha=1.0)
- o = o + 1.0
- return o
-
- test1_jit = torch.jit.script(test1)
- for i in range(3):
- jit_o = test1_jit(x, y)
- jit_o.backward(grad)
-
- bwd1_graph = list(
- list(test1_jit.get_debug_state().execution_plans.values())[
- 0].code.grad_executor_states()[0].execution_plans.values()
- )[0].graph
- FileCheck().check_not("aten::mul_").run(bwd1_graph)
-
- # Alpha is set to something other than 1.0
- def test2(x: torch.Tensor, y: torch.Tensor):
- o = torch.add(x, y, alpha=2.0)
- o = o + 1.0
- return o
-
- test2_jit = torch.jit.script(test2)
- for i in range(3):
- jit_o = test2_jit(x, y)
- jit_o.backward(grad)
-
- bwd2_graph = list(
- list(test2_jit.get_debug_state().execution_plans.values())[
- 0].code.grad_executor_states()[0].execution_plans.values()
- )[0].graph
- FileCheck().check("aten::mul_").run(bwd2_graph)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_dropout_inference_fusion(self):
- dtype = torch.float
- device = "cuda"
- x = torch.randn([10, 4, 8], dtype=dtype, device=device)
-
- def t(x: torch.Tensor, p: float, train: bool):
- o = torch.nn.functional.dropout(x, p, training=train)
- o = o + 1.0
- return o
-
- t_jit = torch.jit.script(t)
-
- self._run_helper(t_jit, t, x, 0.15, False)
-
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
- def test_dropout_train_nograd_fusion(self):
+ def test_gelu_fusion(self):
dtype = torch.float
device = "cuda"
- x = torch.randn([10, 4, 8], dtype=dtype, device=device)
-
- def t(x: torch.Tensor, p: float, train: bool):
- o = torch.nn.functional.dropout(x, p, training=train)
- o = o + 1.0
- return o
-
- t_jit = torch.jit.script(t)
-
- self._run_helper(t_jit, t, x, 0.0, True)
+ x = torch.randn([64, 128, 1024], dtype=dtype, device=device, requires_grad=True)
+ grads = torch.randn([64, 128, 1024], dtype=dtype, device=device)
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_dropout_train_nograd_prob_check(self):
- dtype = torch.float
- device = "cuda"
- x = torch.randn([1024, 1024], dtype=dtype, device=device)
-
- def t(x: torch.Tensor, p: float, train: bool):
- o = torch.nn.functional.dropout(x, p, training=train)
- o = o + 0.0
- return o
-
- t_jit = torch.jit.script(t)
-
- for prob in [0.0, 0.15, 0.5, 0.85, 1.]:
- torch.cuda.manual_seed_all(123)
- jit_o = t_jit(x, prob, True)
- torch.cuda.manual_seed_all(123)
- jit_o = t_jit(x, prob, True)
-
- self.assertTrue(jit_o.detach().isfinite().all().item())
-
- num_elems = x.numel()
- num_zeros = num_elems - jit_o.detach().count_nonzero().item()
- percent_zeros = num_zeros / num_elems
-
- self.assertTrue((percent_zeros >= (prob - 0.01)) and (percent_zeros <= (prob + 0.01)))
- self.assertGraphContainsExactly(t_jit.graph_for(x, prob, True), FUSION_GUARD, 1, consider_subgraphs=True)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_dropout_training_fusion(self):
- dtype = torch.float
- device = "cuda"
- x = torch.randn([10, 4, 8], dtype=dtype, device=device, requires_grad=True)
- grads = torch.randn([10, 4, 8], dtype=dtype, device=device)
-
- def t(x: torch.Tensor, p: float, train: bool):
- o = torch.nn.functional.dropout(x, p, training=train)
- o = o * 1.0
- return o
-
- t_jit = torch.jit.script(t)
-
- # The drop probability needs to be set to zero given that the order of picking random
- # numbers between eager mode and the jit is different
- self._run_training_helper(t_jit, t, grads, x, 0.0, True)
-
- def t2(x: torch.Tensor, p: float, train: bool):
- o = torch.nn.functional.softmax(x, dim=-1)
- o = torch.nn.functional.dropout(o, p, training=train)
- return o
-
- t2_jit = torch.jit.script(t2)
-
- # The drop probability needs to be set to zero given that the order of picking random
- # numbers between eager mode and the jit is different
- self._run_training_helper(t2_jit, t2, grads, x, 0.0, True)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_gelu(self):
- dtype = torch.float
- device = "cuda"
- x = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=True)
- grads = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=False)
-
- def t(x: torch.Tensor, fast : bool):
- o = torch.nn.functional.gelu(x, fast)
- o = o * 1.0
- return o
-
- t_jit = torch.jit.script(t)
-
- for approximate in [False, True]:
- self._run_training_helper(t_jit, t, grads, x, approximate)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_dropout_training_prob_check(self):
- dtype = torch.float
- device = "cuda"
- x = torch.randn([1024, 1024], dtype=dtype, device=device, requires_grad=True)
- x_nograd = torch.randn([1024, 1024], dtype=dtype, device=device)
-
- def t(x: torch.Tensor, p: float, train: bool):
- o = torch.nn.functional.dropout(x, p, training=train)
- o = o + 0.0
- return o
-
- t_jit = torch.jit.script(t)
-
- for prob in [0.0, 0.15, 0.5, 0.85, 1.]:
- torch.cuda.manual_seed_all(123)
- jit_o = t_jit(x, prob, True)
- torch.cuda.manual_seed_all(123)
- jit_o = t_jit(x, prob, True)
- torch.cuda.manual_seed_all(123)
- jit_o = t_jit(x, prob, True)
-
- self.assertTrue(jit_o.detach().isfinite().all().item())
-
- num_elems = x.numel()
- num_zeros = num_elems - jit_o.detach().count_nonzero().item()
- percent_zeros = num_zeros / num_elems
-
- self.assertTrue((percent_zeros >= (prob - 0.01)) and (percent_zeros <= (prob + 0.01)))
- self.assertGraphContainsExactly(t_jit.graph_for(x, prob, True), FUSION_GUARD, 1, consider_subgraphs=True)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_linear(self):
- in_feature = 2
- out_feature = 8
- x = torch.randn(4, in_feature, dtype=torch.float32, device='cuda')
- weight = torch.randn(out_feature, in_feature, dtype=torch.float32, device='cuda')
- bias = torch.randn(out_feature, dtype=torch.float32, device='cuda')
-
- def t(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
- o = torch.nn.functional.linear(x, weight, bias)
- o = torch.relu(o)
- return o
-
- # bias set to true.
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x, weight, bias)
- jit_o = t_jit(x, weight, bias)
- o = t(x, weight, bias)
- self.assertEqual(o, jit_o)
- # since the output value is not used at all, the fusion operator should
- # have been optimized away
- self.assertGraphContainsExactly(t_jit.graph_for(x, weight, bias), FUSION_GUARD, 1)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_backward_type(self):
- # not super useful to check gradient of integer/bool, so skipping here
- type_pairs = [
- (torch.float, torch.half),
- (torch.double, torch.half),
- (torch.float, torch.double),
- ]
- for x_type, y_type in type_pairs:
- x = torch.randn(4, 2, dtype=x_type, device='cuda', requires_grad=True)
- y = torch.randn(4, 2, dtype=y_type, device='cuda', requires_grad=True)
- grad = torch.randn(4, 2, dtype=torch.float, device='cuda')
-
- def test1(x: torch.Tensor, y: torch.Tensor):
- o = torch.add(x, y)
- o = torch.add(o, y)
- o = torch.add(o, y)
- o = torch.add(o, y)
- o = o + 1.0
- return o
-
- test1_jit = torch.jit.script(test1)
- for i in range(3):
- jit_o = test1_jit(x, y)
- jit_o.backward(grad)
-
- bwd_graph = list(
- list(test1_jit.get_debug_state().execution_plans.values())[
- 0].code.grad_executor_states()[0].execution_plans.values()
- )[0].graph
-
- FileCheck().check(FUSION_GROUP).run(bwd_graph)
- self.assertEqual(x.grad.dtype, x.dtype)
- self.assertEqual(y.grad.dtype, y.dtype)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_autocast_1(self):
- def t(x: torch.Tensor, y: torch.Tensor):
- o = x * 2.0
- o = torch.softmax(o, dim=-1)
- o = o * 3.0
- o = torch.matmul(o, y)
- return o
-
- x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True)
- y = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
- grad = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=False)
- t_jit = torch.jit.script(t)
-
- for i in range(3):
- with torch.cuda.amp.autocast():
- jit_o = t_jit(x, y)
- if i == 2 :
- fwd_graph = t_jit.graph_for(x, y)
- jit_o.backward(grad)
-
- self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
-
- with torch.cuda.amp.autocast():
- bwd_graph = list(
- list(t_jit.get_debug_state().execution_plans.values())[
- 0].code.grad_executor_states()[0].execution_plans.values()
- )[0].graph
- FileCheck().check(FUSION_GROUP).run(bwd_graph)
-
- self.assertEqual(jit_o.dtype, torch.half)
- self.assertEqual(x.grad.dtype, x.dtype)
- self.assertEqual(y.grad.dtype, y.dtype)
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_autocast_2(self):
def t(x: torch.Tensor):
- o = x * 2.0
- o = torch.softmax(o, dim=-1)
- o = o * 3.0
- o = torch.softmax(o, dim=-1)
- o = o * 4.0
- return o
-
- x = torch.randn(8, 4, dtype=torch.half, device='cuda', requires_grad=True)
- grad = torch.randn(8, 4, dtype=torch.float, device='cuda', requires_grad=False)
- t_jit = torch.jit.script(t)
-
- for i in range(3):
- with torch.cuda.amp.autocast() :
- jit_o = t_jit(x)
- if i == 2 :
- fwd_graph = t_jit.graph_for(x)
- jit_o.backward(grad)
-
- self.assertGraphContainsExactly(fwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
-
- with torch.cuda.amp.autocast():
- bwd_graph = list(
- list(t_jit.get_debug_state().execution_plans.values())[
- 0].code.grad_executor_states()[0].execution_plans.values()
- )[0].graph
- FileCheck().check(FUSION_GROUP).run(bwd_graph)
-
- self.assertEqual(jit_o.dtype, torch.float)
- self.assertEqual(x.grad.dtype, x.dtype)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_to_dtype_fp32_to_fp16(self):
- def t(x: torch.Tensor):
- o = x * 2.0
- o = o.to(dtype=torch.half)
- o = o * 3.0
- return o
-
- x = torch.randn(8, 4, dtype=torch.float, device='cuda')
- t_jit = torch.jit.script(t)
-
- for i in range(3):
- jit_o = t_jit(x)
-
- self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
- self.assertEqual(jit_o.dtype, torch.half)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_to_dtype_fp16_to_fp32(self):
- def t(x: torch.Tensor):
- o = x * 2.0
- o = o.to(dtype=torch.float)
- o = o * 3.0
- return o
-
- x = torch.randn(8, 4, dtype=torch.half, device='cuda')
- t_jit = torch.jit.script(t)
-
- for i in range(3):
- jit_o = t_jit(x)
-
- self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
- self.assertEqual(jit_o.dtype, torch.float)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_to_dtype_fp16_to_fp16(self):
- def t(x: torch.Tensor):
- o = x * 2.0
- o = o.to(dtype=torch.half)
- o = o * 3.0
- return o
-
- x = torch.randn(8, 4, dtype=torch.half, device='cuda')
- t_jit = torch.jit.script(t)
-
- for i in range(3):
- jit_o = t_jit(x)
-
- self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
- self.assertEqual(jit_o.dtype, torch.half)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(not TEST_MULTIGPU, "requires multiple CUDA device")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_multiple_device_pw(self):
-
- def t(x):
- o = x + 1.0
- o = torch.relu(o)
- return o
-
- x = torch.randn(2, dtype=torch.float32, device="cuda")
- t_jit = torch.jit.script(t)
-
- for i in range(3):
- jit_o = t_jit(x)
-
- self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
- torch.cuda.device(1)
- x = x.to("cuda:1")
- jit_o = t_jit(x)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_graph_for_with_missing_optimized_engine(self):
- x = torch.randn(8, 4, 2, dtype=torch.float, device="cuda").requires_grad_()
-
- def t(x: torch.Tensor, flag: bool):
- x = x + 1.0
- x = torch.relu(x)
- if flag:
- o = x + 1.0
- o = torch.relu(o)
- else:
- o = x + 2.0
- o = torch.relu(o)
- return o
-
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x, False)
- jit_o = t_jit(x, False)
- jit_o = t_jit(x, True)
- o = t(x, True)
- self.assertEqual(o, jit_o)
- # since the output value is not used at all, the fusion operator should
- # have been optimized away
- self.assertGraphContainsExactly(t_jit.graph_for(x, True), FUSION_GUARD, 1, True)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_branches(self):
- in_feature = 2
- out_feature = 4
- x = torch.randn(4, in_feature, dtype=torch.float32, device='cuda')
- weight = torch.randn(out_feature, in_feature, dtype=torch.float32, device='cuda')
- bias = torch.randn(out_feature, dtype=torch.float32, device='cuda')
-
- def t(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, flag: bool):
- if flag:
- o = torch.nn.functional.linear(x, weight, bias)
- o = o + 1.0
- o = torch.relu(o)
- else:
- o = x.sum()
- o = o + 2.0
- o = torch.relu(o)
- return o
-
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x, weight, bias, True)
- jit_o = t_jit(x, weight, bias, True)
- o = t(x, weight, bias, True)
- self.assertEqual(o, jit_o)
- # since the output value is not used at all, the fusion operator should
- # have been optimized away
- self.assertGraphContainsExactly(t_jit.graph_for(x, weight, bias, True), FUSION_GUARD, 1)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_scalar_tensor(self):
- x = torch.empty([], device="cuda", dtype=torch.float32)
-
- def t(x: torch.Tensor):
- o = x + 1.0
- o = torch.nn.functional.relu(o)
- return o
-
- # bias set to true.
- t_jit = torch.jit.script(t)
- jit_o = t_jit(x)
- jit_o = t_jit(x)
- o = t(x)
- self.assertEqual(o, jit_o)
- # since the output value is not used at all, the fusion operator should
- # have been optimized away
- self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1)
-
- @unittest.skipIf(os.environ.get('PYTORCH_NO_CUDA_MEMORY_CACHING') is not None,
- "skipping graph_rng when caching allocator is disabled")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(CUDA_MAJOR < 11, "requires CUDA11 or above")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_graph_rng(self):
- self.assertTrue(torch._C._jit_nvfuser_enabled())
- size = 10000
- a = torch.randn((size,), device="cuda", dtype=torch.float)
-
- def t(x):
- o = x + 1.0
- o = torch.nn.functional.dropout(o, p=0.1)
- o = o + 1.0
- o = torch.nn.functional.dropout(o, p=0.1)
+ o = torch.nn.functional.gelu(x)
+ o = o * 1.0
return o
t_jit = torch.jit.script(t)
- for _ in range(3):
- t_jit(a)
-
- self.assertGraphContainsExactly(t_jit.graph_for(a), FUSION_GUARD, 1)
-
- # Control (jitted, ungraphed)
- torch.cuda.manual_seed(5)
- eager_out = a.clone()
- for _ in range(3):
- eager_out = t_jit(eager_out)
-
- graph_in = a.clone()
- g = torch.cuda._Graph()
- s = torch.cuda.Stream()
- s.wait_stream(torch.cuda.current_stream())
- with torch.cuda.stream(s):
- torch.cuda.manual_seed(5)
- g.capture_begin()
- graph_out = t_jit(graph_in)
- g.capture_end()
- torch.cuda.current_stream().wait_stream(s)
- # g is now a jitted, graphed version of t.
-
- # Runs a (jitted, graphed) -> (jitted, ungraphed) -> (jitted, graphed) sequence.
- # The ops in the overall sequence should be the same as Control.
- g.replay()
- # graph_out is now filled with g's result. Use it as ungraphed input.
- out = t_jit(graph_out)
- graph_in.copy_(out)
- g.replay()
-
- # If replay() updated RNG state correctly, graph_out should now equal eager_out
- self.assertEqual(graph_out, eager_out)
-
- def _test_batch_norm_impl_index_helper(self, batch, c, hw, affine=True, track_running_stats=True, train=True):
- # enabling inlining to avoid counter increment in BN forward
- torch._C._debug_set_autodiff_subgraph_inlining(True)
- dtype = torch.float32
-
- class MyModule(torch.nn.Module):
- def __init__(self, num_features=10, affine=True, track_running_stats=True):
- super(MyModule, self).__init__()
- self.bn = torch.nn.BatchNorm2d(num_features,
- 1e-5,
- affine=affine,
- track_running_stats=track_running_stats).to(dtype=dtype)
-
- def forward(self, x):
- o = x * 1.0
- o = self.bn(o)
- return o
-
- x = torch.randn(batch, c, hw, hw, dtype=torch.float, device="cuda").to(dtype=dtype).requires_grad_()
- grad = torch.randint(-20, 20, (batch, c, hw, hw), device="cuda").to(dtype=dtype).div(-10)
-
- my_module = MyModule(c, affine, track_running_stats).cuda()
- ref_module = MyModule(c, affine, track_running_stats).cuda()
-
- if not train:
- my_module.eval()
- ref_module.eval()
-
- t_jit = torch.jit.script(my_module)
- ref_module.load_state_dict(my_module.state_dict())
-
- ref_x = x.detach().requires_grad_()
-
- for i in range(0, 3):
- jit_o = t_jit(x)
- jit_o.backward(grad)
-
- # TODO: remove this run?
- o = ref_module(ref_x)
- o.backward(grad)
-
- has_affine = ref_module.bn.weight is not None
- has_running_stats = ref_module.bn.running_mean is not None
-
- if has_running_stats:
- my_module.bn.running_mean.zero_()
- my_module.bn.running_var.fill_(1.0)
- ref_module.bn.running_mean.zero_()
- ref_module.bn.running_var.fill_(1.0)
-
- # Verify that when train is False, we don't have grad for weight/bias.
- if has_affine and train:
- my_module.bn.weight.grad.zero_()
- my_module.bn.bias.grad.zero_()
- ref_module.bn.weight.grad.zero_()
- ref_module.bn.bias.grad.zero_()
-
- x.grad.zero_()
- ref_x.grad.zero_()
-
- # real runs
- jit_o = t_jit(x)
- jit_o.backward(grad)
-
- o = ref_module(ref_x)
- o.backward(grad)
-
- # assert forward graph fusion
- self.assertGraphContainsExactly(t_jit.graph_for(x), FUSION_GUARD, 1, consider_subgraphs=True)
- # assert backward graph fusion
- bwd_graph = list(
- list(t_jit.get_debug_state().execution_plans.values())[0].code.grad_executor_states()[0]
- .execution_plans.values())[0].graph
- self.assertGraphContainsExactly(bwd_graph, FUSION_GUARD, 1, consider_subgraphs=True)
-
- self.assertTrue(self._compare("comparing output failed", jit_o, o, 1e-5))
- self.assertTrue(self._compare("comparing input grad failed", x.grad, ref_x.grad, 1e-4))
- # TODO: switch to welford and reduce this to 1e-5
- # The 1e-3 looks bad, but we don't have welford in codegen, so numeric
- # is very different between reference and codegen.
- if has_affine and train:
- self.assertTrue(self._compare("comparing weight grad failed",
- my_module.bn.weight.grad,
- ref_module.bn.weight.grad,
- 1e-3))
- self.assertTrue(self._compare("comparing bias grad failed",
- my_module.bn.bias.grad,
- ref_module.bn.bias.grad,
- 1e-4))
- if has_running_stats:
- self.assertTrue(self._compare("comparing running_mean failed",
- my_module.bn.running_mean,
- ref_module.bn.running_mean,
- 1e-5))
- self.assertTrue(self._compare("comparing running_var failed",
- my_module.bn.running_var,
- ref_module.bn.running_var,
- 1e-5))
-
- @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_batch_norm_impl_index_correctness(self):
- with torch.backends.cudnn.flags(enabled=True):
- batch = [2, 7, 16]
- channels = [4, 89, 19, 32]
- hw = [1, 8, 17, 32]
-
- # avoid tolerance failure in CI
- torch.cuda.manual_seed_all(211)
-
- # failing sizes (2, 1, 1, 1)
- # failing sizes (2, 89, 8, 8) training False, track True, affine: False
- for b, c, hw in itertools.product(batch, channels, hw):
- setups = [
- [True, True],
- [False, False],
- [True, False],
- [False, True]]
- for training_and_track, affine in itertools.product(setups, [True, False]):
- training, track_running_stats = training_and_track
- self._test_batch_norm_impl_index_helper(b, c, hw, affine, track_running_stats, training)
-
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
- "Requires fusion optimization pass to be effective")
- def test_softplus_fuser(self):
- def shifted_softplus(x: torch.Tensor, shift: float):
- return functional.softplus(x) - shift
-
- jitted = torch.jit.script(shifted_softplus)
- inp = torch.randn(4, 2, dtype=torch.float32, device="cuda").requires_grad_()
- inp_ref = inp.detach().clone().requires_grad_()
- grad = torch.randn(4, 2, dtype=torch.float32, device="cuda")
-
- aten_o = shifted_softplus(inp_ref, 0.693147)
- aten_o.backward(grad)
- aten_grad = inp_ref.grad
-
- for i in range(3):
- jit_o = jitted(inp, 0.693147)
- inp.grad = None # avoid accumulation on grad
- jit_o.backward(grad)
- jit_grad = inp.grad
-
- assert torch.allclose(jit_o, aten_o)
- assert torch.allclose(jit_grad, aten_grad)
- self.assertGraphContains(jitted.graph_for(inp, 0.693147), FUSION_GROUP, True)
+ self._run_training_helper(t_jit, t, grads, x)
class TestPassManagerCudaFuser(JitTestCase):
# NVFuser runtime library
libtorch_nvfuser_runtime_sources = [
"torch/csrc/jit/codegen/cuda/runtime/block_reduction.cu",
- "torch/csrc/jit/codegen/cuda/runtime/block_sync_atomic.cu",
- "torch/csrc/jit/codegen/cuda/runtime/block_sync_default.cu",
"torch/csrc/jit/codegen/cuda/runtime/broadcast.cu",
"torch/csrc/jit/codegen/cuda/runtime/fp16_support.cu",
"torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu",
"torch/csrc/jit/codegen/cuda/runtime/helpers.cu",
"torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu",
"torch/csrc/jit/codegen/cuda/runtime/tensor.cu",
- "torch/csrc/jit/codegen/cuda/runtime/welford.cu",
"aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh",
"aten/src/ATen/cuda/detail/UnpackRaw.cuh",
]
"torch/csrc/autograd/functions/comm.cpp",
"torch/csrc/jit/codegen/cuda/arith.cpp",
"torch/csrc/jit/codegen/cuda/compute_at.cpp",
- "torch/csrc/jit/codegen/cuda/compute_at_map.cpp",
"torch/csrc/jit/codegen/cuda/codegen.cpp",
"torch/csrc/jit/codegen/cuda/dispatch.cpp",
"torch/csrc/jit/codegen/cuda/expr_evaluator.cpp",
"torch/csrc/jit/codegen/cuda/fusion.cpp",
"torch/csrc/jit/codegen/cuda/graph_fuser.cpp",
"torch/csrc/jit/codegen/cuda/index_compute.cpp",
- "torch/csrc/jit/codegen/cuda/index_reference_replay.cpp",
"torch/csrc/jit/codegen/cuda/instrumentation.cpp",
"torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp",
"torch/csrc/jit/codegen/cuda/ir_cloner.cpp",
"torch/csrc/jit/codegen/cuda/ir_graphviz.cpp",
"torch/csrc/jit/codegen/cuda/ir_nodes.cpp",
"torch/csrc/jit/codegen/cuda/ir_iostream.cpp",
- "torch/csrc/jit/codegen/cuda/ir_utils.cpp",
"torch/csrc/jit/codegen/cuda/iter_visitor.cpp",
"torch/csrc/jit/codegen/cuda/kernel.cpp",
"torch/csrc/jit/codegen/cuda/kernel_cache.cpp",
- "torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp",
"torch/csrc/jit/codegen/cuda/kernel_ir.cpp",
"torch/csrc/jit/codegen/cuda/kernel_ir_builder.cpp",
"torch/csrc/jit/codegen/cuda/kernel_ir_printer.cpp",
- "torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp",
- "torch/csrc/jit/codegen/cuda/lower_allocation.cpp",
- "torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp",
"torch/csrc/jit/codegen/cuda/lower_index.cpp",
- "torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp",
"torch/csrc/jit/codegen/cuda/lower_loops.cpp",
- "torch/csrc/jit/codegen/cuda/lower_magic_zero.cpp",
- "torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp",
- "torch/csrc/jit/codegen/cuda/lower_predicate.cpp",
- "torch/csrc/jit/codegen/cuda/lower_shift.cpp",
- "torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp",
- "torch/csrc/jit/codegen/cuda/lower_trivial_reductions.cpp",
+ "torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp",
+ "torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp",
"torch/csrc/jit/codegen/cuda/lower_unroll.cpp",
+ "torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp",
"torch/csrc/jit/codegen/cuda/lower_utils.cpp",
"torch/csrc/jit/codegen/cuda/lower_validation.cpp",
"torch/csrc/jit/codegen/cuda/lower2device.cpp",
"torch/csrc/jit/codegen/cuda/manager.cpp",
"torch/csrc/jit/codegen/cuda/mutator.cpp",
- "torch/csrc/jit/codegen/cuda/ops/composite.cpp",
- "torch/csrc/jit/codegen/cuda/ops/normalization.cpp",
- "torch/csrc/jit/codegen/cuda/parallel_dimension_map.cpp",
- "torch/csrc/jit/codegen/cuda/parallel_type_bitmap.cpp",
"torch/csrc/jit/codegen/cuda/parser.cpp",
"torch/csrc/jit/codegen/cuda/partition.cpp",
"torch/csrc/jit/codegen/cuda/predicate_compute.cpp",
"torch/csrc/jit/codegen/cuda/register_interface.cpp",
- "torch/csrc/jit/codegen/cuda/root_domain_map.cpp",
- "torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp",
- "torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp",
- "torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp",
- "torch/csrc/jit/codegen/cuda/scheduler/registry.cpp",
- "torch/csrc/jit/codegen/cuda/scheduler/utils.cpp",
+ "torch/csrc/jit/codegen/cuda/scheduler.cpp",
"torch/csrc/jit/codegen/cuda/shape_inference.cpp",
- "torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp",
"torch/csrc/jit/codegen/cuda/tensor_view.cpp",
"torch/csrc/jit/codegen/cuda/transform_iter.cpp",
"torch/csrc/jit/codegen/cuda/transform_replay.cpp",
"torch/csrc/jit/codegen/cuda/transform_rfactor.cpp",
"torch/csrc/jit/codegen/cuda/type.cpp",
- "torch/csrc/jit/codegen/cuda/utils.cpp",
"torch/csrc/jit/tensorexpr/cuda_codegen.cpp",
"torch/csrc/jit/runtime/register_cuda_ops.cpp",
]
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/type.h>
-#include <cfloat>
namespace torch {
namespace jit {
switch (dtype) {
case DataType::Bool:
return new Bool();
- case DataType::Double:
case DataType::Float:
+ return new Float();
case DataType::Half:
- return new Double();
+ return new Half();
case DataType::Int:
return new Int();
default:
TORCH_CHECK(
false,
- "Cannot handle ValType: ",
+ "Was expecting a scalar type, but received ValType: ",
vtype,
" with DataType:",
- dtype,
- " in newScalar.");
+ dtype);
}
TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype) {
continue;
if (dom[i]->isBroadcast())
continue;
- out_domain[i] = dom[i]->clone();
+ out_domain[i] = new IterDomain(dom[i]->start(), dom[i]->extent());
}
}
for (const auto dim_i : c10::irange(out_domain.size())) {
return out_vals;
}
+Val* newOutputVal(const std::vector<Val*>& vals) {
+ ValType out_vtype = vals[0]->getValType().value();
+ DataType out_dtype = vals[0]->getDataType().value();
+
+ for (auto val : vals) {
+ TORCH_CHECK(val->isVal(), "Invalid statement found during promotion.");
+ TORCH_CHECK(
+ val->getDataType().value() != DataType::Null,
+ "Invalid datatype found during prmotion.");
+ out_vtype = promote_type(out_vtype, val->getValType().value());
+ out_dtype = promote_type(out_dtype, val->getDataType().value());
+ }
+
+ if (out_vtype == ValType::TensorView)
+ return newOutputTV(vals, out_dtype);
+
+ return newScalar(out_vtype, out_dtype);
+}
+
Val* newValLike(Val* val, DataType dtype) {
+ TORCH_CHECK(val->isVal(), "Invalid statement provided to create new value.");
TORCH_CHECK(
dtype != DataType::Null, "Invalid datatype provided for new value.");
- const ValType vtype = val->getValType().value();
+ ValType vtype = val->getValType().value();
if (vtype == ValType::TensorView)
return newOutputTV({val}, dtype);
// UNARY OPERATIONS
Val* unaryOp(UnaryOpType type, Val* v1) {
- TORCH_INTERNAL_ASSERT(
- type != UnaryOpType::Address,
- "The reference operator & is not accessible in the Fusion IR");
- Val* out = newValLike(v1, v1->getDataType().value());
- // TODO: We should add the following, but we need to go through shchedulers
- // and make sure all calls to "fusion->inputs" includes the output of RandLike
- //
- // If rand like, there isn't a real dependency on the input value, so map it
- // to a dummy scalar. if
- //
- // (type == UnaryOpType::RandLike) {
- // v1 = new NamedScalar("__rnd", v1->getDataType().value());
- // }
-
+ Val* out = newOutputVal({v1});
new UnaryOp(type, out, v1);
return out;
}
Val* neg(Val* v) {
return unaryOp(UnaryOpType::Neg, v);
}
-
TensorView* neg(TensorView* v) {
return unaryOp(UnaryOpType::Neg, v);
}
return func(v1->template as<Val>(), v2->template as<Val>())
->template as<TensorView>();
}
-
template <typename T1, typename T2>
TensorView* arithOpOverloads(BinaryOpType type, T1* v1, T2* v2) {
return binaryOp(type, v1->template as<Val>(), v2->template as<Val>())
->template as<TensorView>();
}
-
template <typename T1, typename T2, typename T3>
TensorView* arithOpOverloads(
Val* (*func)(Val*, Val*, Val*),
vals[2]->template as<Val>())
->template as<TensorView>();
}
-
template <typename T1, typename T2, typename T3, typename T4>
TensorView* arithOpOverloads(
Val* (*func)(Val*, Val*, Val*, Val*),
vals[3]->template as<Val>())
->template as<TensorView>();
}
-
-namespace {
-enum class Category { Scalar, ZeroDimTensor, DimTensor };
-
-inline Category getCategory(const Val* v) {
- if (v->isA<TensorView>()) {
- if (v->as<TensorView>()->nDims() > 0) {
- return Category::DimTensor;
- } else {
- return Category::ZeroDimTensor;
- }
- } else {
- return Category::Scalar;
- }
-}
-
-// replicated logic from Aten/native/TypeProperties.cpp, minus complex support
-DataType getCommonType(DataType higher, DataType lower) {
- if (isFloatingPointType(higher)) {
- return higher;
- }
- if (higher == DataType::Bool || isFloatingPointType(lower)) {
- return promote_type(higher, lower);
- }
- if (higher != DataType::Null) {
- return higher;
- }
- return lower;
-}
-} // namespace
-
-// Type promotion logic for binary operators
-DataType getOutputType(BinaryOpType op_type, Val* v1, Val* v2) {
- DataType v1_dtype = v1->getDataType().value();
- DataType v2_dtype = v2->getDataType().value();
-
- const bool floating_input =
- isFloatingPointType(v1_dtype) || isFloatingPointType(v2_dtype);
-
- const bool integer_input =
- isIntegralType(v1_dtype) || isIntegralType(v2_dtype);
-
- const bool all_integer_input =
- isIntegralType(v1_dtype) && isIntegralType(v2_dtype);
-
- if (all_integer_input) {
- TORCH_INTERNAL_ASSERT(
- !(noFullIntegerSupport(op_type)) || (v1->isScalar() && v2->isScalar()),
- "unsupported op with all integer tensor inputs");
- }
-
- // Combine categories
- const auto v1_cat = getCategory(v1);
- const auto v2_cat = getCategory(v2);
- if (v1_cat != v2_cat) {
- const DataType higher = v1_cat > v2_cat ? v1_dtype : v2_dtype;
- const DataType lower = v1_cat > v2_cat ? v2_dtype : v1_dtype;
- const DataType common_type = getCommonType(higher, lower);
- v1_dtype = common_type;
- v2_dtype = common_type;
- }
-
- if (isIntegerOp(op_type) || (alsoBooleanOperator(op_type) && integer_input)) {
- // If integer op or maybe bool op with integer inputs meaning binary op
- if (integer_input && all_integer_input) {
- return promote_type(v1_dtype, v2_dtype);
- } else if (integer_input && !all_integer_input) {
- TORCH_CHECK(
- !floating_input,
- "Operator ",
- op_type,
- " not supported with floating point inputs.");
- return isIntegralType(v1_dtype) ? v1_dtype : v2_dtype;
- } else {
- TORCH_INTERNAL_ASSERT(
- false,
- "Currently no support for float inputs to int operations. ",
- "Inputs should be manually casted first.");
- }
- } else if (isLogicalOp(op_type)) {
- return DataType::Bool;
- } else if (alsoBooleanOperator(op_type)) {
- // If boolean op that can't have floating inputs (& or |)
- TORCH_CHECK(
- !floating_input,
- "Operator ",
- op_type,
- " not supported with floating point inputs.");
- return DataType::Bool;
- } else {
- // Otherwise do normal type promotion
- return promote_type(v1_dtype, v2_dtype);
- }
-}
-
} // namespace
TORCH_CUDA_CU_API Val* binaryOp(BinaryOpType type, Val* v1, Val* v2) {
- const auto out_dtype = getOutputType(type, v1, v2);
- const auto out_vtype =
- promote_type(v1->getValType().value(), v2->getValType().value());
auto vals = maybeBroadcast({v1, v2});
- Val* out = nullptr;
- if (out_vtype == ValType::TensorView) {
- out = newOutputTV(vals, out_dtype);
- } else {
- out = newScalar(out_vtype, out_dtype);
+ Val* out = newOutputVal({vals[0], vals[1]});
+ if (is_logical_op(type)) {
+ if (out->getDataType().value() != DataType::Bool)
+ out = newValLike(out, DataType::Bool);
+ } else if (type >= BinaryOpType::Mod) {
+ if (out->getDataType().value() != DataType::Int)
+ out = newValLike(out, DataType::Int);
}
+
new BinaryOp(type, out, vals[0], vals[1]);
return out;
}
-
TensorView* binaryOp(BinaryOpType type, TensorView* v1, Val* v2) {
return arithOpOverloads(type, v1, v2);
}
-
TensorView* binaryOp(BinaryOpType type, Val* v1, TensorView* v2) {
return arithOpOverloads(type, v1, v2);
}
-
TensorView* binaryOp(BinaryOpType type, TensorView* v1, TensorView* v2) {
return arithOpOverloads(type, v1, v2);
}
TensorView* add(TensorView* v1, TensorView* v2) {
return arithOpOverloads(add, v1, v2);
}
-
// sub
Val* sub(Val* v1, Val* v2) {
return binaryOp(BinaryOpType::Sub, v1, v2);
TensorView* sub(TensorView* v1, TensorView* v2) {
return arithOpOverloads(sub, v1, v2);
}
-
// mul
Val* mul(Val* v1, Val* v2) {
return binaryOp(BinaryOpType::Mul, v1, v2);
TensorView* mul(TensorView* v1, TensorView* v2) {
return arithOpOverloads(mul, v1, v2);
}
-
// div
Val* div(Val* v1, Val* v2) {
return binaryOp(BinaryOpType::Div, v1, v2);
TensorView* div(TensorView* v1, TensorView* v2) {
return arithOpOverloads(div, v1, v2);
}
-
// mod
Val* mod(Val* v1, Val* v2) {
return binaryOp(BinaryOpType::Mod, v1, v2);
TensorView* mod(TensorView* v1, TensorView* v2) {
return arithOpOverloads(mod, v1, v2);
}
-
// lt
Val* lt(Val* v1, Val* v2) {
return binaryOp(BinaryOpType::LT, v1, v2);
TensorView* lt(TensorView* v1, TensorView* v2) {
return arithOpOverloads(lt, v1, v2);
}
-
-// gt
-Val* gt(Val* v1, Val* v2) {
- return binaryOp(BinaryOpType::GT, v1, v2);
-}
-TensorView* gt(TensorView* v1, Val* v2) {
- return arithOpOverloads(gt, v1, v2);
-}
-TensorView* gt(Val* v1, TensorView* v2) {
- return arithOpOverloads(gt, v1, v2);
-}
-TensorView* gt(TensorView* v1, TensorView* v2) {
- return arithOpOverloads(gt, v1, v2);
-}
// eq
Val* eq(Val* v1, Val* v2) {
return binaryOp(BinaryOpType::Eq, v1, v2);
TensorView* eq(TensorView* v1, TensorView* v2) {
return arithOpOverloads(eq, v1, v2);
}
-
// ceilDiv
Val* ceilDiv(Val* v1, Val* v2) {
return binaryOp(BinaryOpType::CeilDiv, v1, v2);
TensorView* ceilDiv(TensorView* v1, TensorView* v2) {
return arithOpOverloads(ceilDiv, v1, v2);
}
-
// andOp
Val* andOp(Val* v1, Val* v2) {
TORCH_CHECK(
- !isFloatingPointType(v1->getDataType().value()),
- "Input1 should not be a floating point type, but received: ",
+ v1->getDataType().value() == DataType::Bool,
+ "Input1 should be of type bool, not ",
v1->getDataType().value());
TORCH_CHECK(
- !isFloatingPointType(v2->getDataType().value()),
- "Input2 should not be a floating point type, but received: ",
+ v2->getDataType().value() == DataType::Bool,
+ "Input2 should be of type bool, not ",
v2->getDataType().value());
return binaryOp(BinaryOpType::And, v1, v2);
}
// TODO: How do we adjust this so we can reduce to a single scalar value?
static TensorView* newForReduction(
TensorView* tv,
- const std::vector<unsigned int>& axes,
- DataType data_type = DataType::Null) {
+ const std::vector<unsigned int>& axes) {
auto orig_domain = TensorDomain::noReductions(tv->getRootDomain());
std::set<unsigned int> axes_set(axes.begin(), axes.end());
const IterDomain* id = orig_domain[dim];
TORCH_CHECK(
- !(isReduction && id->isBroadcast() && !id->isImplicitBroadcast()),
+ !(isReduction && id->isBroadcast()),
"Cannot reduce an axis that is marked as broadcasted as it has an undetermined size. Tried to reduce ID = ",
id,
" of tensor ",
TensorDomain* td =
new TensorDomain(new_domain, std::vector<bool>(new_domain.size(), true));
-
- data_type =
- data_type == DataType::Null ? tv->getDataType().value() : data_type;
- return new TensorView(td, data_type);
+ return new TensorView(td, tv->getDataType().value());
}
TensorView* reductionOp(
BinaryOpType reduction_op_type,
const std::vector<int>& axes,
Val* init,
- TensorView* tv,
- bool keep_dim /*=false*/) {
+ TensorView* tv) {
TORCH_CHECK(
init->isConstScalar(),
"Cannot create a reduction operation where the initial value is not a const scalar.");
}
TensorView* out = newForReduction(tv, uint_axes);
- const auto out_type = out->getDataType().value();
- const auto init_type = init->getDataType().value();
- TORCH_CHECK(
- (isFloatingPointType(out_type) && isFloatingPointType(init_type)) ||
- (isIntegralType(out_type) && isIntegralType(init_type)) ||
- (out_type == DataType::Bool && init_type == DataType::Bool),
- "Types should match for reduction ops but received: ",
- out_type,
- " and ",
- init_type);
+ if (init->getDataType().value() != tv->getDataType().value())
+ init = castOp(tv->getDataType().value(), init);
new ReductionOp(reduction_op_type, init, out, tv);
-
- if (keep_dim) {
- auto tv_root = TensorDomain::noReductions(tv->getRootDomain());
- std::vector<bool> is_broadcast(tv_root.size(), false);
- for (int axis : axes) {
- is_broadcast[axis] = true;
- }
-
- out = broadcast(out, is_broadcast);
- }
return out;
}
-TensorView* sum(
- TensorView* v1,
- const std::vector<int>& axes,
- bool keep_dim /*=false*/) {
- Val* init = nullptr;
- auto dtype = v1->getDataType().value();
- if (isFloatingPointType(dtype)) {
- init = new Double(0.0);
- } else if (isIntegralType(dtype)) {
- init = new Int(0);
- } else {
- TORCH_CHECK(
- false,
- "Could not generate a sum op for tensor with type: ",
- v1->getDataType().value());
- }
-
- return reductionOp(BinaryOpType::Add, axes, init, v1, keep_dim);
-}
-
-TensorView* max(
- TensorView* v1,
- const std::vector<int>& axes,
- bool keep_dim /*=false*/) {
- Val* init = nullptr;
- switch (v1->getDataType().value()) {
- case (DataType::Double):
- init = new Double(DBL_MIN);
- break;
- case (DataType::Float):
- init = new Double(FLT_MIN);
- break;
- case (DataType::Int):
- init = new Int(INT_MIN);
- break;
- default:
- TORCH_CHECK(
- false,
- "Could not generate a max op for tensor with type: ",
- v1->getDataType().value());
- }
-
- return reductionOp(BinaryOpType::Max, axes, init, v1, keep_dim);
-}
-
-TensorView* min(
- TensorView* v1,
- const std::vector<int>& axes,
- bool keep_dim /*=false*/) {
- Val* init = nullptr;
+TensorView* sum(TensorView* v1, const std::vector<int>& axes) {
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ Val* init;
switch (v1->getDataType().value()) {
- case (DataType::Double):
- init = new Double(DBL_MAX);
- break;
case (DataType::Float):
- init = new Double(FLT_MAX);
+ init = new Float(0.0);
break;
case (DataType::Int):
- init = new Int(INT_MAX);
+ init = new Int(0);
break;
default:
TORCH_CHECK(
false,
- "Could not generate a min op for tensor with type: ",
+ "Could not generate a sum op for tensor with type: ",
v1->getDataType().value());
}
- return reductionOp(BinaryOpType::Min, axes, init, v1, keep_dim);
+ return reductionOp(BinaryOpType::Add, axes, init, v1);
}
TensorView* broadcast(
}
std::vector<IterDomain*> out_domain;
- // Don't propagate reduction IDs through arith ops.
auto inp_domain = TensorDomain::noReductions(inp->getRootDomain());
size_t iinp = 0, ibdim = 0;
while (ibdim < is_broadcast_dim.size()) {
ParallelType::Serial,
IterType::BroadcastWithoutStride));
} else {
- out_domain.push_back(inp_domain[iinp]->clone());
+ // Don't propagate reduction IDs through arith ops.
+ out_domain.push_back(inp_domain[iinp]);
iinp++;
}
ibdim++;
TensorView* out_tensor = new TensorView(
new TensorDomain(out_domain, std::vector<bool>(out_domain.size(), true)),
inp->getDataType().value());
- new BroadcastOp(out_tensor, inp, is_broadcast_dim);
- return out_tensor;
-}
-
-WelfordResult Welford(
- TensorView* tv,
- const std::vector<int>& axes,
- TensorView* init_avg,
- TensorView* init_var,
- Int* init_N) {
- TORCH_CHECK(
- TensorDomain::sameAs(tv->getRootDomain(), tv->domain()->domain()),
- "Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt.");
-
- TORCH_CHECK(tv->nDims() > 0, "Tried to reduce a 0-dim tensor");
- TORCH_CHECK(axes.size() > 0, "No reduction axis specified");
-
- // Initial values for welford op are tensors, so their dims have to match the
- // output dim,
- // i.e. original_dims - dims_to_be_reduced
- Val* init_avg_val = nullptr;
- Val* init_var_val = nullptr;
- if (!init_N->isZeroInt()) {
- TORCH_CHECK(
- init_avg != nullptr && init_var != nullptr && init_N != nullptr,
- "welford op: all init values need to be provided");
- TORCH_CHECK(
- (axes.size() + init_avg->getRootDomain().size()) ==
- tv->getRootDomain().size(),
- "welford op: initial tensor mismatch");
- TORCH_CHECK(
- (axes.size() + init_var->getRootDomain().size()) ==
- tv->getRootDomain().size(),
- "welford op: initial tensor mismatch");
- init_avg_val = init_avg;
- init_var_val = init_var;
- } else {
- init_avg_val = new Double(0);
- init_var_val = new Double(0);
- }
-
- // Check and collect reduction axes
- std::vector<unsigned int> uint_axes;
- for (int axis : axes) {
- if (axis < 0)
- axis += int(tv->nDims());
-
- TORCH_CHECK(
- axis >= 0 && (unsigned int)axis < tv->nDims(),
- "Reduction on invalid axis, recieved: ",
- axis,
- " however tensor view only has ",
- tv->nDims(),
- " dims.");
-
- uint_axes.push_back((unsigned int)axis);
- }
-
- // Create tensor outputs
- TensorView* out_avg = newForReduction(tv, uint_axes);
- TensorView* out_var = newForReduction(tv, uint_axes);
- TensorView* out_N = newForReduction(tv, uint_axes, DataType::Int);
-
- new WelfordOp(
- out_avg,
- out_var,
- out_N, /*out var/avg/count */
- init_avg_val,
- init_var_val,
- init_N, /*init var/avg/count */
- tv,
- nullptr,
- new Int(1)); /*in var/avg/count */
-
- return WelfordResult(out_avg, out_var, out_N);
-}
-
-WelfordResult::WelfordResult(
- TensorView* in_avg,
- TensorView* in_var_sum,
- TensorView* in_n)
- : avg(in_avg), var_sum(in_var_sum), n(in_n) {
- TORCH_INTERNAL_ASSERT(avg->definition()->sameAs(var_sum->definition()));
- TORCH_INTERNAL_ASSERT(avg->definition()->sameAs(n->definition()));
-}
-
-WelfordResult WelfordResult::rFactor(const std::vector<int>& axes) {
- auto o_tv = avg->definition()->as<WelfordOp>()->out()->as<TensorView>();
- return o_tv->rFactor(axes, avg, var_sum, n);
-}
-
-TensorView* transpose(
- TensorView* inp,
- const std::unordered_map<int, int>& old2new) {
- auto inp_domain = TensorDomain::noReductions(inp->getRootDomain());
- std::vector<IterDomain*> out_domain(inp_domain.size());
-
- auto new2old = ir_utils::normalizeOld2New(old2new, inp_domain.size());
-
- for (size_t i = 0; i < out_domain.size(); ++i) {
- auto in_id = inp_domain[new2old[i]];
- out_domain[i] = in_id->clone();
- }
-
- TensorView* out_tensor = new TensorView(
- new TensorDomain(out_domain, std::vector<bool>(out_domain.size(), true)),
- inp->getDataType().value());
- new TransposeOp(out_tensor, inp, new2old);
+ new BroadcastOp(out_tensor, inp);
return out_tensor;
}
}
// TERNARY OPERATIONS
-// where (c ? v1 : v2)
+// where
Val* where(Val* c, Val* v1, Val* v2) {
TORCH_CHECK(
c->getDataType().value() == DataType::Bool,
"Condition should be of DataType Bool, not ",
c->getDataType().value());
- // Not actually an add, but need to send a binary op to get output type
- auto out_dtype = getOutputType(BinaryOpType::Add, v1, v2);
- auto out_vtype =
- promote_type(v1->getValType().value(), v2->getValType().value());
auto vals = maybeBroadcast({c, v1, v2});
- Val* out = nullptr;
- if (out_vtype == ValType::TensorView) {
- out = newOutputTV(vals, out_dtype);
- } else {
- out = newScalar(out_vtype, out_dtype);
- }
+ Val* out = newOutputVal({vals[1], vals[2]});
new TernaryOp(TernaryOpType::Where, out, vals[0], vals[1], vals[2]);
return out;
}
-
TensorView* where(TensorView* v1, Val* v2, Val* v3) {
return arithOpOverloads(where, v1, v2, v3);
}
// TERNARY OPERATIONS
Val* threshold(Val* in, Val* thresh, Val* value) {
- const auto in_type = in->getDataType().value();
- const auto thresh_type = thresh->getDataType().value();
- const auto value_type = value->getDataType().value();
- if (isFloatingPointType(in_type)) {
- TORCH_CHECK(
- isFloatingPointType(thresh_type) && isFloatingPointType(value_type),
- "All input DataType values should match the input type ",
- in_type,
- " vs ",
- thresh_type,
- " and ",
- value_type);
- } else if (isIntegralType(in_type)) {
- TORCH_CHECK(
- isIntegralType(thresh_type) && isIntegralType(value_type),
- "All input DataType values should match the input ",
- in_type,
- " vs ",
- thresh_type,
- " and ",
- value_type);
- }
TORCH_CHECK(
- (thresh->getValType().value() == ValType::Scalar ||
- thresh->getValType().value() == ValType::NamedScalar) &&
- (value->getValType().value() == ValType::Scalar ||
- value->getValType().value() == ValType::NamedScalar),
- "For Threshold operation: Thresh and Value values should be Scalars.");
+ in->getDataType().value() == thresh->getDataType().value() &&
+ in->getDataType().value() == value->getDataType().value(),
+ "All input DataType values should match the input ",
+ in->getDataType().value());
+ TORCH_CHECK(
+ thresh->getValType().value() == ValType::Scalar &&
+ value->getValType().value() == ValType::Scalar,
+ "Thresh and Value values should be Scalars");
- Val* out = newValLike(in, in_type);
+ Val* out = newOutputVal({in});
new TernaryOp(TernaryOpType::Threshold, out, in, thresh, value);
return out;
}
Val* clamp(Val* in, Val* min_val, Val* max_val) {
- const auto in_type = in->getDataType().value();
- const auto min_type = min_val->getDataType().value();
- const auto max_type = max_val->getDataType().value();
- if (isFloatingPointType(in_type)) {
- TORCH_CHECK(
- isFloatingPointType(min_type) && isFloatingPointType(max_type),
- "All input DataType values should match the input type ",
- in_type,
- " vs ",
- min_type,
- " and ",
- max_type);
- } else if (isIntegralType(in_type)) {
- TORCH_CHECK(
- isIntegralType(min_type) && isIntegralType(max_type),
- "All input DataType values should match the input ",
- in_type,
- " vs ",
- min_type,
- " and ",
- max_type);
- }
TORCH_CHECK(
- (min_val->getValType().value() == ValType::Scalar ||
- min_val->getValType().value() == ValType::NamedScalar) &&
- (max_val->getValType().value() == ValType::Scalar ||
- max_val->getValType().value() == ValType::NamedScalar),
- "For Threshold operation: Thresh and Value values should be Scalars.");
+ in->getDataType().value() == min_val->getDataType().value() &&
+ in->getDataType().value() == max_val->getDataType().value(),
+ "All input DataType values should match the input ",
+ in->getDataType().value());
+ TORCH_CHECK(
+ min_val->getValType().value() == ValType::Scalar &&
+ max_val->getValType().value() == ValType::Scalar,
+ "Min and Max values should be Scalars");
- Val* out = newValLike(in, in_type);
+ Val* out = newOutputVal({in});
new TernaryOp(TernaryOpType::Clamp, out, in, min_val, max_val);
return out;
return clamp(in->as<Val>(), min_val, max_val)->as<TensorView>();
}
-// sum_to operator
-
-TensorView* sum_to(TensorView* in, const std::vector<Int*>& sum_to_size) {
- const auto& root = TensorDomain::noReductions(in->getRootDomain());
-
- TORCH_CHECK(
- root.size() >= sum_to_size.size(),
- "sum_to: Error trying to reduce",
- in,
- "into a shape of size",
- sum_to_size.size());
-
- // If no reduction is needed sum_to returns the input tv
- TensorView* out = in;
-
- const int64_t leading_dims = root.size() - sum_to_size.size();
-
- // Generate reduction axes for leading dims
- std::vector<int> reduce_dims(leading_dims);
- std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
-
- // Generate reduction axes for dims within sum_to_size
- std::vector<bool> inner_red_dims(sum_to_size.size(), false);
- bool reduction_within_shape = false;
-
- // Reduce rest of the dims with keep_dim
- for (int i = leading_dims; i < int(root.size()); i++) {
- if (sum_to_size[i - leading_dims]->isOneInt() &&
- !root[i]->extent()->isOneInt()) {
- inner_red_dims[i - leading_dims] = true;
- reduce_dims.push_back(i);
- reduction_within_shape = true;
- }
- }
-
- // Reduction step
- if (!reduce_dims.empty()) {
- out = sum(in, reduce_dims);
- }
-
- // Broadcast back reduced dims within shape
- if (reduction_within_shape) {
- out = broadcast(out, inner_red_dims);
- }
-
- return out;
-}
-
-TensorView* sum_to(TensorView* in, const std::vector<int64_t>& sum_to_size) {
- const auto& root = TensorDomain::noReductions(in->getRootDomain());
-
- TORCH_CHECK(
- root.size() >= sum_to_size.size(),
- "sum_to: Error trying to reduce",
- in,
- "into a shape of size",
- sum_to_size.size());
-
- // If no reduction is needed sum_to returns the input tv
- TensorView* out = in;
-
- const int64_t leading_dims = root.size() - sum_to_size.size();
-
- // Generate reduction axes for leading dims
- std::vector<int> reduce_dims(leading_dims);
- std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
-
- // Generate reduction axes for dims within sum_to_size
- std::vector<bool> inner_red_dims(sum_to_size.size(), false);
- bool reduction_within_shape = false;
-
- // Reduce rest of the dims with keep_dim
- for (int i = leading_dims; i < int(root.size()); i++) {
- if (sum_to_size[i - leading_dims] == 1 && !root[i]->extent()->isOneInt()) {
- inner_red_dims[i - leading_dims] = true;
- reduce_dims.push_back(i);
- reduction_within_shape = true;
- }
- }
-
- // Reduction step
- if (!reduce_dims.empty()) {
- out = sum(in, reduce_dims);
- }
-
- // Broadcast back reduced dims within shape
- if (reduction_within_shape) {
- out = broadcast(out, inner_red_dims);
- }
-
- return out;
-}
-
-TensorView* shift(TensorView* inp, const std::vector<int>& offsets) {
- TORCH_CHECK(
- TensorDomain::noReductions(inp->getRootDomain()).size() == offsets.size(),
- "Invalid shift offsets, number of entries in offsets expected to be ",
- TensorDomain::noReductions(inp->getRootDomain()).size(),
- " but received ",
- offsets.size());
-
- auto out = newValLike(inp, inp->getDataType().value())->as<TensorView>();
- new ShiftOp(out, inp, offsets);
- return out;
-}
-
-namespace {
-std::vector<Int*> convertToIntVector(const std::vector<int>& x) {
- std::vector<Int*> converted;
- std::transform(x.begin(), x.end(), std::back_inserter(converted), [](int x) {
- return new Int(x);
- });
- return converted;
-}
-} // namespace
-
-TensorView* gather(
- TensorView* inp,
- const std::vector<int>& window_shape,
- const std::vector<std::vector<int>>& pad_width) {
- std::vector<Int*> window_shape_int = convertToIntVector(window_shape);
- std::vector<std::vector<Int*>> pad_width_int;
- std::transform(
- pad_width.begin(),
- pad_width.end(),
- std::back_inserter(pad_width_int),
- [](const std::vector<int>& x) { return convertToIntVector(x); });
- return gather(inp, window_shape_int, pad_width_int);
-}
-
-TensorView* gather(
- TensorView* inp,
- const std::vector<Int*>& window_shape,
- const std::vector<std::vector<Int*>>& pad_width) {
- auto inp_dom = TensorDomain::noReductions(inp->getRootDomain());
- const auto ndims = inp_dom.size();
-
- TORCH_CHECK(
- ndims == window_shape.size(),
- "Invalid window shape: number of entries expected to be ",
- ndims,
- " but received ",
- window_shape.size());
-
- TORCH_CHECK(
- ndims == pad_width.size(),
- "Invalid pad width: number of entries expected to be ",
- ndims,
- " but received ",
- pad_width.size());
-
- std::for_each(pad_width.begin(), pad_width.end(), [](const auto& p) {
- TORCH_CHECK(
- p.size() == 2,
- "Each entry of pad_width must have two non-negative integers.");
- });
-
- std::vector<IterDomain*> out_dom;
- std::vector<IterDomain*> out_gather_dom;
-
- for (size_t i = 0; i < ndims; ++i) {
- const auto inp_axis = inp_dom[i];
- const auto window_dim = window_shape[i];
- const auto pad_left = pad_width[i][0];
- const auto pad_right = pad_width[i][1];
- TORCH_INTERNAL_ASSERT(inp_axis->start()->isZeroInt());
- Val* out_axis_dim = nullptr;
- if (window_dim->isConst() && pad_left->isConst() && pad_right->isConst()) {
- const int64_t extent_adjustment =
- -(-window_dim->value().value() + 1 + pad_left->value().value() +
- pad_right->value().value());
- out_axis_dim = extent_adjustment == 0
- ? inp_axis->extent()
- : sub(inp_axis->extent(), new Int(extent_adjustment));
- } else {
- out_axis_dim =
- add(add(sub(inp_axis->extent(), window_dim), new Int(1)),
- add(pad_left, pad_right));
- }
- out_dom.push_back(new IterDomain(
- new Int(0),
- out_axis_dim,
- ParallelType::Serial,
- inp_axis->getIterType()));
- // create a new axis for the gathered domain
- out_gather_dom.push_back(new IterDomain(
- new Int(0), window_dim, ParallelType::Serial, IterType::Gather));
- }
-
- out_dom.insert(out_dom.end(), out_gather_dom.begin(), out_gather_dom.end());
-
- auto out = new TensorView(
- new TensorDomain(out_dom, std::vector<bool>(out_dom.size(), true)),
- inp->getDataType().value());
-
- new GatherOp(out, inp, window_shape, pad_width);
- return out;
-}
-
} // namespace cuda
} // namespace fuser
} // namespace jit
BinaryOpType reduction_op_type,
const std::vector<int>& axes,
Val* init,
- TensorView* v1,
- bool keep_dim = false);
-
-//! Auxiliary Struct holding result of
-//! a single welford op in ternsorview
-class TORCH_CUDA_CU_API WelfordResult {
- public:
- TensorView* avg;
- TensorView* var_sum;
- TensorView* n;
-
- explicit WelfordResult(
- TensorView* in_avg,
- TensorView* in_var_sum,
- TensorView* in_n);
-
- WelfordResult rFactor(const std::vector<int>& axes);
-};
-
-//! Welford operator on specified axes. This is currently the only scan op with
-//! multiple outputs that is supported. May consider generalization if more scan
-//! ops are added.
-TORCH_CUDA_CU_API WelfordResult Welford(
- TensorView* tv,
- const std::vector<int>& axes,
- TensorView* init_avg = nullptr,
- TensorView* init_var = nullptr,
- Int* init_N = new Int(0));
+ TensorView* v1);
// UNARY OPERATIONS
TORCH_CUDA_CU_API Val* neg(Val* v);
TensorView* inp,
const std::vector<bool>& is_broadcast_dim);
-//! Transpose a tensor as specified by axis mappings.
-//!
-//! The transposition mapping is specified with a list of pairs from
-//! old to new positions. Positions are relative to the noReduction
-//! domain.
-//!
-//! \param inp Tensor to transpose
-//! \param old2new Pairs of mapping from old to new positions.
-TORCH_CUDA_CU_API TensorView* transpose(
- TensorView* inp,
- const std::unordered_map<int, int>& old2new);
-
// BINARY OPERATIONS
// add
TORCH_CUDA_CU_API Val* add(Val* v1, Val* v2);
TORCH_CUDA_CU_API TensorView* lt(TensorView* v1, Val* v2);
TORCH_CUDA_CU_API TensorView* lt(Val* v1, TensorView* v2);
TORCH_CUDA_CU_API TensorView* lt(TensorView* v1, TensorView* v2);
-// gt
-TORCH_CUDA_CU_API Val* gt(Val* v1, Val* v2);
-TORCH_CUDA_CU_API TensorView* gt(TensorView* v1, Val* v2);
-TORCH_CUDA_CU_API TensorView* gt(Val* v1, TensorView* v2);
-TORCH_CUDA_CU_API TensorView* gt(TensorView* v1, TensorView* v2);
// eq
TORCH_CUDA_CU_API Val* eq(Val* v1, Val* v2);
TORCH_CUDA_CU_API TensorView* eq(TensorView* v1, Val* v2);
// REDUCTION OPERATIONS
TORCH_CUDA_CU_API TensorView* sum(
TensorView* v1,
- const std::vector<int>& reduction_axes,
- bool keep_dim = false);
-
-TORCH_CUDA_CU_API TensorView* max(
- TensorView* v1,
- const std::vector<int>& reduction_axes,
- bool keep_dim = false);
-
-TORCH_CUDA_CU_API TensorView* min(
- TensorView* v1,
- const std::vector<int>& reduction_axes,
- bool keep_dim = false);
+ const std::vector<int>& reduction_axes);
// COMPOUND OPERATIONS
// add_alpha
TORCH_CUDA_CU_API Val* clamp(Val* in, Val* min_val, Val* max_val);
TORCH_CUDA_CU_API TensorView* clamp(TensorView* in, Val* min_val, Val* max_val);
-//! Internal operator for supporting backward graphs
-//!
-//! example:
-//! v1 = T1 [I0(10),I1(20),I2(30),I3(40)]
-//! v2 = sum_to(v1,{30,1}) ------> v2 = T2[I2,R3 (keep_dim)]
-//!
-//! This operator will return v1* directly if sizes of v1 root domain
-//! is already the same as shape.
-//!
-//! Name of sum_to is different from NV fuser naming,
-//! this is to align with the operator name of at::sum_to.
-
-TORCH_CUDA_CU_API TensorView* sum_to(
- TensorView* v1,
- const std::vector<Int*>& sum_to_size);
-
-TORCH_CUDA_CU_API TensorView* sum_to(
- TensorView* v1,
- const std::vector<int64_t>& sum_to_size);
-
-//! Shift a tensor to a direction specified by offsets.
-//!
-//! Example:
-//! t0: 2D tensor of size N by M
-//! t1 = shift(t0, {1, -1});
-//!
-//! then:
-//! t1[i, j] = t0[i-1, j+1] for 1 <= i < N and 0 <= j < M-1.
-//! t1[i, j] = 0, otherwise
-TORCH_CUDA_CU_API TensorView* shift(
- TensorView* inp,
- const std::vector<int>& offsets);
-
-//! Gather a window of nearby elements for each element.
-//!
-//! Each window of size window_shape is stored as a additional
-//! innermost domain, meaning that the number of dimensions of the
-//! output tensor doubles. The pad_width parameter specifies the
-//! padding width of each side of each axis.
-//!
-//! Example:
-//! t0: 2D tensor of [N, M]
-//! t1 = gather(t0, {1, 3}, {{0, 0}, {1, 1}});
-//!
-//! then:
-//! t1: [N, M, 1, 3]
-//! t1[i, j, k, l] = The value at the window position of [k, l]
-//! for t0[i, j]
-TORCH_CUDA_CU_API TensorView* gather(
- TensorView* inp,
- const std::vector<int>& window_shape,
- const std::vector<std::vector<int>>& pad_width);
-
-//! Gather a window of nearby elements for each element.
-//!
-//! Same as the another gather interface but with Int* parameters.
-TORCH_CUDA_CU_API TensorView* gather(
- TensorView* inp,
- const std::vector<Int*>& window_shape,
- const std::vector<std::vector<Int*>>& pad_width);
-
} // namespace cuda
} // namespace fuser
} // namespace jit
#include <c10/util/irange.h>
#include <torch/csrc/jit/codegen/cuda/codegen.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/type.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>
-#include <array>
#include <sstream>
#include <vector>
namespace {
-class CudaKernelGenerator : private kir::IrVisitor {
- static constexpr const char* kTab = " ";
+class CudaKernelGenerator : private OptInConstDispatch {
+ static constexpr char* kTab = " ";
public:
static std::string generateKernelDefinition(
- const kir::Kernel* kernel,
+ const Kernel* kernel,
const std::string& kernel_name) {
CudaKernelGenerator codegen(kernel);
codegen.genDeclaration(kernel_name);
}
private:
- explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) {}
+ explicit CudaKernelGenerator(const Kernel* kernel) : kernel_(kernel) {}
// Generates the kernel function declaration
void genDeclaration(const std::string& kernel_name) {
code_ << "__global__ void " << kernel_name << "(";
- std::vector<kir::Val*> params;
+ std::vector<Val*> params;
- // Inputs & Outputs
+ // Inputs
for (auto val : kernel_->inputs()) {
params.push_back(val);
}
+
+ // Outputs
for (auto val : kernel_->outputs()) {
params.push_back(val);
}
+ // Global buffers
+ for (auto allocate : kernel_summary.global_allocations) {
+ params.push_back(allocate->buffer());
+ }
+
// Generate parameter declarations
- for (kir::Val* val : params) {
- if (const auto tv = dynamic_cast<kir::TensorView*>(val)) {
- code_ << "Tensor<" << val->dtype() << ", "
- << TensorDomain::noReductions(
- tv->fuserTv()->getMaybeRFactorDomain())
- .size()
- << "> " << varName(tv);
- } else {
- TORCH_INTERNAL_ASSERT(val->isScalar()); // NOLINT (LLVM bug 48525)
- TORCH_INTERNAL_ASSERT(val->definition() == nullptr);
- code_ << val->dtype() << " " << gen(val);
+ for (Val* val : params) {
+ switch (val->getValType().value()) {
+ case ValType::KirTensorView: {
+ // TODO(kir): review this
+ const auto tv = val->as<kir::TensorView>();
+ code_ << "Tensor<" << val->getDataType().value() << ", "
+ << TensorDomain::noReductions(
+ tv->fuserTv()->getMaybeRFactorDomain())
+ .size()
+ << "> " << gen(tv);
+ break;
+ }
+ case ValType::KirScalar:
+ code_ << val->getDataType().value() << " " << gen(val);
+ break;
+ default:
+ TORCH_CHECK(!"Unexpected parameter type");
}
if (val != params.back()) {
}
}
- // Global buffers
- for (auto allocate : kernel_summary.global_allocations) {
- TORCH_INTERNAL_ASSERT(allocate->buffer()->isA<kir::TensorView>());
- const auto tv = allocate->buffer()->as<kir::TensorView>();
- const auto& maybe_rfactor_domain = tv->domain()->hasRFactor()
- ? tv->domain()->rfactorDomain()
- : tv->domain()->rootDomain();
- const auto nDims = std::count_if(
- maybe_rfactor_domain.begin(),
- maybe_rfactor_domain.end(),
- [](const kir::IterDomain* id) {
- return !id->isReduction() &&
- id->iterType() != IterType::BroadcastWithoutStride;
- });
- code_ << ", Tensor<" << tv->dtype() << ", " << nDims << "> "
- << varName(tv);
- }
-
// Kernels generating random numbers take extra (seed, offset) arguments
if (kernel_summary.is_stochastic) {
- code_ << ", at::PhiloxCudaState philox_args";
+ code_ << ", unsigned long long seed, unsigned long long offset";
}
code_ << ") ";
// Random number generator (optional)
if (kernel_summary.is_stochastic) {
indent() << "const int idx = blockIdx.x*blockDim.x + threadIdx.x;\n";
- indent() << "auto offset = philox_args.captured_ ?\n";
- indent()
- << " static_cast<uint64_t>(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :\n";
- indent() << " philox_args.offset_.val;\n";
- indent() << "Philox rnd(philox_args.seed_, idx, offset);\n";
+ indent() << "Philox rnd(seed, idx, offset);\n";
}
// Do we have any dynamic shared memory buffers?
// Do we have any reductions?
const bool has_reductions = kernel_summary.has_block_reductions ||
- kernel_summary.number_of_grid_reductions > 0;
-
- const bool has_parallel_welford =
- kernel_summary.has_block_welford || kernel_summary.has_grid_welford;
+ kernel_summary.has_grid_reductions;
// Shared memory
- if (has_dynamic_smem || has_reductions || has_parallel_welford) {
+ if (has_dynamic_smem || has_reductions) {
indent() << "alignas("
#ifndef __HIP_PLATFORM_HCC__
<< dataTypeSize(kernel_summary.largest_smem_data_type)
indent() << "unsigned offset = 0;\n";
}
- if (has_reductions || has_parallel_welford) {
+ if (has_reductions) {
indent() << "void* shared_mem = array;\n";
if (has_dynamic_smem) {
- if (has_parallel_welford) {
- indent() << "offset += "
- << "((blockDim.x * blockDim.y * blockDim.z) * 3 * sizeof("
- << kernel_summary.largest_smem_data_type << "));\n";
- } else {
- indent() << "offset += "
- << "((blockDim.x * blockDim.y * blockDim.z) * sizeof("
- << kernel_summary.largest_smem_data_type << "));\n";
- }
- }
-
- if (has_parallel_welford) {
- // Unpack shared mem pointer
- auto space_type = kernel_summary.largest_smem_data_type;
- indent()
- << "nvfuser_index_t block_size = blockDim.x*blockDim.y*blockDim.z;\n";
- indent() << space_type << " *shared_mem_var = "
- << "static_cast<" << space_type << "*>("
- << "shared_mem);\n";
- indent() << space_type
- << " *shared_mem_avg = shared_mem_var + block_size;\n";
- indent() << space_type
- << " *shared_mem_n = shared_mem_avg + block_size;\n";
+ indent() << "offset += "
+ << "((blockDim.x * blockDim.y * blockDim.z) * sizeof("
+ << kernel_summary.largest_smem_data_type << "));\n";
}
}
}
-
- // Call the initialization function if using a custom block sync
- if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
- indent() << "block_sync::init();\n";
- }
}
void genBody() {
for (auto expr : kernel_->topLevelExprs()) {
- expr->accept(this);
+ OptInConstDispatch::handle(expr);
}
}
return code_;
}
- std::string gen(const kir::Node* node) {
+ std::string gen(const Statement* stmt) {
std::stringstream tmp_code;
std::swap(tmp_code, code_);
- auto replacement = replacement_map_.find(node);
- if (replacement != replacement_map_.end()) {
- node = replacement->second;
- }
- node->accept(this);
+ handle(stmt);
std::swap(tmp_code, code_);
return tmp_code.str();
}
- // TODO(kir): consider automatic var naming
- std::string varName(const kir::Val* val) {
- std::string prefix = "";
- if (val->isA<kir::TensorView>()) {
- prefix = "T";
- } else {
- prefix = typePrefix(val->dtype());
- }
-
- std::stringstream value_name;
- if (val->name() != kInvalidStmName) {
- value_name << prefix << val->name();
- } else {
- value_name << "k" << prefix << val->id();
- }
- return value_name.str();
+ std::string gen(const kir::TensorView* tv) {
+ std::stringstream tv_name;
+ tv_name << "T" << tv->name();
+ return tv_name.str();
}
- std::string genInline(const kir::Node* node) {
+ std::string genInline(const Statement* stmt) {
const bool saved_inline = print_inline_;
print_inline_ = true;
- auto result = gen(node);
+ const auto result = gen(stmt);
print_inline_ = saved_inline;
// NOLINTNEXTLINE(performance-no-automatic-move)
return result;
}
- void visit(const kir::Predicate* node) final {
- TORCH_INTERNAL_ASSERT(node->hasValue());
- code_ << gen(node->value());
+ void handle(const Statement* node) final {
+ OptInConstDispatch::handle(node);
+ }
+
+ void handle(const Expr* node) final {
+ OptInConstDispatch::handle(node);
}
- void visit(const kir::Bool* node) final {
- const auto def = node->definition();
+ void handle(const Val* node) final {
+ OptInConstDispatch::handle(node);
+ }
+
+ void handle(const kir::Bool* node) final {
+ const auto def = node->getOrigin();
if (print_inline_ && def != nullptr) {
code_ << "(" << gen(def) << ")";
- } else if (node->isConst()) {
- code_ << (*node->value() ? "true" : "false");
+ } else if (node->isSymbolic()) {
+ code_ << "b" << node->name();
} else {
- code_ << varName(node);
+ code_ << *node->value();
}
}
- void visit(const kir::Double* node) final {
- const auto def = node->definition();
+ void handle(const kir::Float* node) final {
+ const auto def = node->getOrigin();
if (print_inline_ && def != nullptr) {
code_ << "(" << gen(def) << ")";
- } else if (node->isConst()) {
- const int digits = std::numeric_limits<Double::ScalarType>::max_digits10;
- code_ << std::setprecision(digits) << *node->value();
+ } else if (node->isSymbolic()) {
+ code_ << "f" << node->name();
} else {
- code_ << varName(node);
+ const int digits = std::numeric_limits<Float::ScalarType>::max_digits10;
+ code_ << "float(" << std::setprecision(digits) << *node->value() << ")";
}
}
- void visit(const kir::Int* node) final {
- const auto def = node->definition();
+ void handle(const kir::Half* node) final {
+ const auto def = node->getOrigin();
if (print_inline_ && def != nullptr) {
code_ << "(" << gen(def) << ")";
- } else if (node->isConst()) {
- code_ << *node->value();
+ } else if (node->isSymbolic()) {
+ code_ << "h" << node->name();
} else {
- code_ << varName(node);
+ code_ << "__float2half(" << *node->value() << ")";
}
}
- void visit(const kir::NamedScalar* node) final {
- // dim3 components are unsigned int. Cast to signed integer to
- // support negative indexing
- if (node->getParallelIndex().has_value() ||
- node->getParallelDim().has_value()) {
- code_ << "((nvfuser_index_t)" << node->name() << ")";
+ void handle(const kir::Int* node) final {
+ const auto def = node->getOrigin();
+ if (print_inline_ && def != nullptr) {
+ code_ << "(" << gen(def) << ")";
+ } else if (node->isSymbolic()) {
+ code_ << "i" << node->name();
} else {
- code_ << node->name();
+ code_ << *node->value();
}
}
- void visit(const kir::TensorIndex* node) final {
- code_ << varName(node->view()) << "[";
+ void handle(const kir::NamedScalar* node) final {
+ code_ << node->name();
+ }
+
+ void handle(const kir::TensorIndex* node) final {
+ code_ << gen(node->view()) << "[";
bool first = true;
for (auto* ind : node->indices()) {
code_ << "]";
}
- void visit(const kir::IterDomain* node) final {
+ void handle(const kir::IterDomain* node) final {
TORCH_INTERNAL_ASSERT(!"Unreachable");
}
- void visit(const kir::TensorDomain* node) final {
+ void handle(const kir::TensorDomain* node) final {
TORCH_INTERNAL_ASSERT(!"Unreachable");
}
- void visit(const kir::TensorView* tv) final {
+ void handle(const kir::TensorView* node) final {
TORCH_INTERNAL_ASSERT(!"Unreachable");
}
- void visit(const kir::UnaryOp* node) final {
- bool is_vector_op = false;
- size_t vector_word_size = 1;
-
- if (vectorize_scope_ && node->out()->isA<kir::TensorIndex>()) {
- auto ti = node->out()->as<kir::TensorIndex>();
-
- bool vectorize_op = false;
- bool misaligned_op = false;
-
- for (auto id : ti->view()->fuserTv()->domain()->domain()) {
- if (!isParallelTypeVectorize(id->getParallelType())) {
- continue;
- }
-
- ExpressionEvaluator expr_eval(id->fusion());
- auto vector_size_optional = expr_eval.evaluate(id->extent());
-
- TORCH_INTERNAL_ASSERT(
- vector_size_optional.has_value(),
- "Could not evaluate constant value bound to vectorized dim.");
-
- vector_word_size = vector_size_optional.value();
-
- vectorize_op = id->getParallelType() == ParallelType::Vectorize;
- misaligned_op =
- id->getParallelType() == ParallelType::MisalignedVectorize;
- break;
- }
-
- if (vectorize_op) {
- TORCH_INTERNAL_ASSERT(
- node->operation() == UnaryOpType::Set,
- "Cannot vectorize operations that are not sets. ",
- "Use cache_before and cache_after to store/load with vectorized reads into buffers.");
- is_vector_op = true;
- }
-
- if (misaligned_op) {
- is_vector_op = (node->operation() == UnaryOpType::Set);
- }
-
- if (is_vector_op && !node->in()->isScalar()) {
- TORCH_INTERNAL_ASSERT(
- node->out()->dtype() == node->in()->dtype(),
- "Vectorized store/load requires input and output datatypes match.");
- }
- }
-
- if (is_vector_op) {
- if (node->in()->isScalar()) {
- indent() << "reinterpret_cast<"
- << "Array<" << node->out()->dtype() << ", " << vector_word_size
- << ">*>"
- << "(&" << gen(node->out()) << ")->set(" << gen(node->in())
- << ");\n";
- } else {
- indent() << "*reinterpret_cast<"
- << "Array<" << node->out()->dtype() << ", " << vector_word_size
- << ">*>"
- << "(&" << gen(node->out()) << ")"
- << " = *reinterpret_cast<"
- << "Array<" << node->in()->dtype() << ", " << vector_word_size
- << ">*>"
- << "(&" << gen(node->in()) << ");\n";
- }
- return;
- }
-
- if (node->out()->isA<kir::NamedScalar>()) {
- const auto op_type = node->operation();
- if (auto op = inline_op_str(op_type)) {
- indent() << gen(node->out()) << " = " << *op << genInline(node->in())
- << ";\n";
- }
- return;
- }
-
+ void handle(const kir::UnaryOp* node) final {
if (!print_inline_) {
indent() << gen(node->out());
if (!node->out()->isScalar() && !node->in()->isScalar()) {
code_ << " = ";
}
- const auto op_type = node->operation();
- if (auto op = inline_op_str(op_type)) {
- if (alsoBooleanOperator(op_type) &&
- node->out()->dtype() == DataType::Bool) {
- code_ << stringifyBooleanOp(op_type) << gen(node->in());
- } else {
- code_ << *op << gen(node->in());
- }
+ if (auto op = inline_op_str(node->getUnaryOpType())) {
+ code_ << *op << gen(node->in());
} else {
- if (op_type == UnaryOpType::Cast) {
- const auto cast_str =
- cast_func_str({node->in()->dtype(), node->out()->dtype()});
- TORCH_INTERNAL_ASSERT(
- cast_str.has_value(),
- "Invalid cast. Input type: ",
- node->in()->dtype(),
- ", output type: ",
- node->out()->dtype());
+ if (node->getUnaryOpType() == UnaryOpType::Cast) {
+ const auto cast_str = cast_func_str(
+ {node->in()->getDataType().value(),
+ node->out()->getDataType().value()});
code_ << cast_str.value();
} else {
- code_ << op_type;
- if (needFloatSuffix(op_type) &&
- node->out()->dtype() == DataType::Float) {
- code_ << "f";
- }
+ code_ << node->getUnaryOpType();
}
code_ << "(";
- if (op_type == UnaryOpType::RandLike) {
+ if (node->getUnaryOpType() == UnaryOpType::RandLike) {
code_ << "rnd";
} else {
code_ << gen(node->in());
std::string genBinaryOp(
BinaryOpType op_type,
- kir::Val* out,
const std::string& lhs,
const std::string& rhs) {
std::stringstream expr;
if (auto op = inline_op_str(op_type)) {
- expr << lhs << " ";
- if (alsoBooleanOperator(op_type) && out->dtype() == DataType::Bool) {
- expr << stringifyBooleanOp(op_type);
- } else {
- expr << *op;
- }
- expr << " " << rhs;
+ expr << lhs << " " << *op << " " << rhs;
} else {
- expr << op_type;
- if (needFloatSuffix(op_type) && out->dtype() == DataType::Float) {
- expr << "f";
- }
- expr << "(" << lhs << ", " << rhs << ")";
+ expr << op_type << "(" << lhs << ", " << rhs << ")";
}
return expr.str();
}
- // If one argument is a tensorview and the other is a scalar, make sure we
- // cast the scalar to the tensorview type
- std::string scalarCast(kir::Val* lhs, kir::Val* rhs) {
- // If neither are scalars return
- if (!((lhs->isScalar() || rhs->isScalar()) &&
- (lhs->isA<kir::TensorIndex>() || rhs->isA<kir::TensorIndex>()))) {
- return "";
- }
-
- // Looking for mixed tensorview scalar options where types don't match
- // but are either both floating or both int types. We should cast
- // scalar to tensorview type in these instances.
- auto lhs_t = lhs->dtype();
- auto rhs_t = rhs->dtype();
-
- // If same type, don't cast anything
- if (lhs_t == rhs_t) {
- return "";
- }
-
- // Don't do anything when dealing with bools
- if (lhs_t == DataType::Bool || rhs_t == DataType::Bool) {
- return "";
- }
-
- // Mixing floating and int combination
- if ((isFloatingPointType(lhs_t) != isFloatingPointType(rhs_t)) ||
- (isIntegralType(lhs_t) != isIntegralType(rhs_t))) {
- return "";
- }
-
- std::stringstream cast;
- cast << "(" << (lhs->isA<kir::TensorIndex>() ? lhs_t : rhs_t) << ") ";
- return cast.str();
- }
-
- void visit(const kir::BinaryOp* node) final {
- const auto op_type = node->operation();
+ void handle(const kir::BinaryOp* node) final {
+ const auto op_type = node->getBinaryOpType();
if (print_inline_) {
// Inline expression: `lhs op rhs`
- code_ << genBinaryOp(
- op_type, node->out(), gen(node->lhs()), gen(node->rhs()));
+ code_ << genBinaryOp(op_type, gen(node->lhs()), gen(node->rhs()));
} else {
indent() << gen(node->out());
if (node->out()->isScalar()) {
// Single line: `out = lhs op rhs;`
code_ << " = "
- << genBinaryOp(
- op_type, node->out(), gen(node->lhs()), gen(node->rhs()));
+ << genBinaryOp(op_type, gen(node->lhs()), gen(node->rhs()));
} else {
// Split TensorView expressions across multiple lines:
//
// = lhs
// op rhs;
//
-
- auto cast = scalarCast(node->lhs(), node->rhs());
if (auto op = inline_op_str(op_type)) {
code_ << "\n";
- indent() << kTab << "= " << (node->lhs()->isScalar() ? cast : "")
- << gen(node->lhs()) << "\n";
- indent() << kTab;
- if (alsoBooleanOperator(op_type) &&
- node->out()->dtype() == DataType::Bool) {
- code_ << stringifyBooleanOp(op_type);
- } else {
- code_ << *op;
- }
- code_ << " " << (node->rhs()->isScalar() ? cast : "")
- << gen(node->rhs());
+ indent() << kTab << "= " << gen(node->lhs()) << "\n";
+ indent() << kTab << *op << " " << gen(node->rhs());
} else {
- if (integer_op_str(op_type) && isIntegralType(node->out()->dtype())) {
- auto int_op = integer_op_str(op_type);
- code_ << " = " << *int_op << "(\n";
- } else {
- code_ << " = " << op_type << "(\n";
- }
- indent() << kTab << (node->lhs()->isScalar() ? cast : "")
- << gen(node->lhs()) << ",\n";
- indent() << kTab << (node->rhs()->isScalar() ? cast : "")
- << gen(node->rhs()) << ")";
+ code_ << " = " << op_type << "(\n";
+ indent() << kTab << gen(node->lhs()) << ",\n";
+ indent() << kTab << gen(node->rhs()) << ")";
}
}
code_ << ";\n";
}
}
- void visit(const kir::TernaryOp* node) final {
+ void handle(const kir::TernaryOp* node) final {
if (!print_inline_) {
indent() << gen(node->out());
if (!node->out()->isScalar()) {
code_ << " = ";
}
- code_ << node->operation() << "(" << gen(node->in1()) << ", ";
-
- // Make sure the two operands of where has the same
- // type. Note that compiling "where(0.0f, 0.0)" fails because of
- // the overloading ambiguity.
- if (node->operation() == TernaryOpType::Where) {
- auto cast = scalarCast(node->in2(), node->in3());
- code_ << (node->in2()->isScalar() ? cast : "") << gen(node->in2()) << ", "
- << (node->in3()->isScalar() ? cast : "") << gen(node->in3()) << ")";
- } else {
- code_ << gen(node->in2()) << ", " << gen(node->in3()) << ")";
- }
+ code_ << node->getTernaryOpType() << "(" << gen(node->in1()) << ", "
+ << gen(node->in2()) << ", " << gen(node->in3()) << ")";
if (!print_inline_) {
code_ << ";\n";
}
}
- std::string genReductionOp(BinaryOpType op_type, kir::Val* out) {
+ std::string genReductionOp(BinaryOpType op_type, DataType data_type) {
std::stringstream lambda;
- DataType data_type = out->dtype();
lambda << "[](" << data_type << " &a, " << data_type << " b) "
- << "{ a = " << genBinaryOp(op_type, out, "a", "b") << "; }";
+ << "{ a = " << genBinaryOp(op_type, "a", "b") << "; }";
return lambda.str();
}
- void visit(const kir::BroadcastOp* node) final {
- TORCH_INTERNAL_ASSERT(node->out()->isA<kir::TensorIndex>());
- const auto tensor_index = node->out()->as<kir::TensorIndex>();
-
- const ParallelTypeBitmap domains =
- kernel_->predicateMap().getParallelBroadcastDomains(
- tensor_index->view()->fuserTv());
+ void handle(const kir::BroadcastOp* node) final {
+ const ir_utils::ParallelTypeBitmap domains =
+ ir_utils::getParallelBroadcastDomains(
+ node->out(), kernel_->predicateMap());
const bool thread_x = domains.get(ParallelType::TIDx);
const bool thread_y = domains.get(ParallelType::TIDy);
"Parallel broadcast across blocks not supported");
if (block_broadcast_needed) {
- const auto data_type = node->out()->dtype();
+ const auto data_type = node->out()->getDataType().value();
indent() << "broadcast::blockBroadcast<" << (thread_x ? "true" : "false")
<< ", " << (thread_y ? "true" : "false") << ", "
<< (thread_z ? "true" : "false") << ">(\n";
indent() << kTab << gen(node->out()) << ",\n";
indent() << kTab << gen(node->in()) << ",\n";
- indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
- TORCH_INTERNAL_ASSERT(
- node->predicate() != nullptr && node->predicate()->hasValue());
- indent() << kTab << genInline(node->predicate()) << ");\n";
+ indent() << kTab << "static_cast<" << data_type << "*>(shared_mem));\n";
} else {
indent() << gen(node->out()) << "\n";
indent() << kTab << " = " << gen(node->in()) << ";\n";
}
}
- void visit(const kir::ReductionOp* node) final {
- TORCH_INTERNAL_ASSERT(node->out()->isA<kir::TensorIndex>());
+ void handle(const kir::ReductionOp* node) final {
+ TORCH_CHECK(node->out()->getValType() == ValType::TensorIndex);
const auto out = node->out()->as<kir::TensorIndex>();
const auto domain = out->view()->domain();
if (!has_block_reduce && !has_grid_reduce) {
const auto gen_out = gen(out);
- const auto op_type = node->operation();
+ const auto op_type = node->getReductionOpType();
indent() << gen_out << " = "
- << genBinaryOp(op_type, out, gen_out, gen(node->in())) << ";\n";
+ << genBinaryOp(op_type, gen_out, gen(node->in())) << ";\n";
return;
}
const bool tidy = par_domains.find(ParallelType::TIDy) != par_domains.end();
const bool tidz = par_domains.find(ParallelType::TIDz) != par_domains.end();
- const auto data_type = node->out()->dtype();
- const auto op_type = node->operation();
+ const auto data_type = node->out()->getDataType().value();
+ const auto op_type = node->getReductionOpType();
if (has_block_reduce) {
if (has_grid_reduce) {
indent() << data_type << " "
- << "block_result_" << block_reduce_name_ << "="
- << gen(node->init()) << ";\n";
+ << "block_result"
+ << ";\n";
}
indent() << "blockReduce<" << (tidx ? "true" : "false") << ", "
<< (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false")
<< ">(\n";
if (has_grid_reduce) {
- indent() << kTab << "block_result_" << block_reduce_name_ << ",\n";
+ indent() << kTab << "block_result"
+ << ",\n";
} else {
indent() << kTab << gen(node->out()) << ",\n";
}
indent() << kTab << gen(node->in()) << ",\n";
- indent() << kTab << genReductionOp(op_type, node->out()) << ",\n";
+ indent() << kTab << genReductionOp(op_type, data_type) << ",\n";
indent() << kTab << "threadIdx,\n";
indent() << kTab << "blockDim,\n";
indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
- TORCH_INTERNAL_ASSERT(
- node->predicate() != nullptr && node->predicate()->hasValue());
- auto read_pred = genInline(node->predicate());
- indent() << kTab << read_pred << ",\n";
- // Pass the write predicate if available and different from the
- // default predicate. The blockReduce runtime function uses the
- // default predicate for both read and write when only the
- // default one is given.
- if (node->writePredicate() != nullptr) {
- TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue());
- auto write_pred = genInline(node->writePredicate());
- indent() << kTab << write_pred << ",\n";
- }
- indent() << kTab << data_type << "(" << genInline(node->init())
- << "));\n";
- }
- }
-
- void visit(const kir::WelfordOp* node) final {
- TORCH_INTERNAL_ASSERT(node->out()->isA<kir::TensorIndex>());
-
- const auto out = node->out()->as<kir::TensorIndex>();
- const auto domain = out->view()->domain();
-
- const auto out_var = node->outVar();
- const auto out_avg = node->outAvg();
- const auto out_N = node->outN();
-
- const auto in_var = node->inVar();
- const auto in_avg = node->inAvg();
- const auto in_N = node->inN();
-
- const bool has_block_reduce = domain->hasBlockReduction();
- const bool has_grid_reduce = domain->hasGridReduction();
-
- // Serial WelfordOp generation
- if (!has_block_reduce && !has_grid_reduce) {
- indent() << "welfordCombine ("
- << "\n";
- indent() << " " << gen(out_avg) << ",\n";
- indent() << " " << gen(out_var) << ",\n";
- indent() << " " << gen(out_N) << ",\n";
- indent() << " " << gen(in_avg) << ",\n";
- if (in_var) {
- indent() << " " << gen(in_var) << ",\n";
- } else {
- indent() << " (" << in_avg->dtype() << ") 0"
- << ",\n";
- }
- indent() << " (" << out_N->dtype() << ")" << gen(in_N) << ");\n";
- return;
- }
-
- const auto par_domains = node->getParallelReductionDomains();
- const bool tidx = par_domains.find(ParallelType::TIDx) != par_domains.end();
- const bool tidy = par_domains.find(ParallelType::TIDy) != par_domains.end();
- const bool tidz = par_domains.find(ParallelType::TIDz) != par_domains.end();
-
- const auto data_type = node->out()->dtype();
-
- if (has_block_reduce) {
- if (has_grid_reduce) {
- // allocate block result
- indent() << data_type << " "
- << "block_result_avg_" << block_reduce_name_ << " = "
- << gen(node->initAvg()) << ";\n";
- indent() << data_type << " "
- << "block_result_var_" << block_reduce_name_ << " = "
- << gen(node->initVar()) << ";\n";
- indent() << DataType::Int << " "
- << "block_result_n_" << block_reduce_name_ << " = "
- << gen(node->initN()) << ";\n";
- }
- indent() << "blockWelford<" << (tidx ? "true" : "false") << ", "
- << (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false")
- << ">(\n";
- if (has_grid_reduce) {
- indent() << kTab << "block_result_avg_" << block_reduce_name_ << ",\n"
- << kTab << "block_result_var_" << block_reduce_name_ << ",\n"
- << kTab << "block_result_n_" << block_reduce_name_ << ",\n";
+ if (node->pred() == nullptr) {
+ indent() << kTab << "true,\n";
} else {
- indent() << kTab << gen(node->outAvg()) << ",\n";
- indent() << kTab << gen(node->outVar()) << ",\n";
- indent() << kTab << gen(node->outN()) << ",\n";
- }
- indent() << " " << gen(in_avg) << ",\n";
- if (in_var) {
- indent() << " " << gen(in_var) << ",\n";
- } else {
- indent() << " (" << in_avg->dtype() << ") 0"
- << ",\n";
+ indent() << kTab << genInline(node->pred()) << ",\n";
}
- indent() << out_N->dtype() << "(" << gen(in_N) << "),\n";
- indent() << kTab << "threadIdx,\n";
- indent() << kTab << "blockDim,\n";
- indent() << kTab << "reinterpret_cast<" << data_type
- << "*>(shared_mem_avg),\n";
- indent() << kTab << "reinterpret_cast<" << data_type
- << "*>(shared_mem_var),\n";
- indent() << kTab << "reinterpret_cast<" << DataType::Int
- << "*>(shared_mem_n),\n";
- TORCH_INTERNAL_ASSERT(node->predicate() != nullptr);
- TORCH_INTERNAL_ASSERT(
- node->predicate() != nullptr && node->predicate()->hasValue());
- auto read_pred = genInline(node->predicate());
- indent() << kTab << read_pred << ",\n";
- if (node->writePredicate() != nullptr) {
- TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue());
- auto write_pred = genInline(node->writePredicate());
- indent() << kTab << write_pred << ",\n";
- }
- indent() << kTab << data_type << "(0));\n";
+ indent() << kTab << genInline(node->init()) << ");\n";
}
}
- // Support ReductionOp and WelfordOp
- template <typename REDUCTION_OP>
- std::string generateGridReduceTemplateFlags(
- const REDUCTION_OP* rop,
- const ParallelTypeBitmap& thread_pred) {
- const auto par_domains = rop->getParallelReductionDomains();
- const std::array<ParallelType, 6> ptypes{
- ParallelType::BIDx,
- ParallelType::BIDy,
- ParallelType::BIDz,
- ParallelType::TIDx,
- ParallelType::TIDy,
- ParallelType::TIDz};
- std::stringstream flags;
- for (const ParallelType pt : ptypes) {
- const bool parallel_reduction = par_domains.find(pt) != par_domains.end();
- const bool pred = thread_pred.get(pt);
- TORCH_INTERNAL_ASSERT(
- !(parallel_reduction && pred), "Cannot reduce predicated axis: ", pt);
- bool flag = false;
- // Currently assumed that no dimensions parallelized with blocks
- // are predicated. This assumption may be lifted, but
- // gridReduction would need some changes.
- if (isParallelTypeBlockDim(pt)) {
- TORCH_INTERNAL_ASSERT(
- !pred, "Predication on block dimensions not allowed: ", pt);
- flag = parallel_reduction;
- } else {
- flag = !pred && !parallel_reduction;
- }
- if (pt != ptypes[0]) {
- flags << ", ";
- }
- flags << (flag ? "true" : "false");
- }
- return flags.str();
- }
-
- void visit(const kir::GridReduction* node) final {
+ void handle(const kir::GridReduction* node) final {
const auto rop = node->reduction_op();
- TORCH_INTERNAL_ASSERT(rop->out()->isA<kir::TensorIndex>());
+ TORCH_INTERNAL_ASSERT(rop->out()->getValType() == ValType::TensorIndex);
const auto out = rop->out()->as<kir::TensorIndex>();
const auto domain = out->view()->domain();
TORCH_INTERNAL_ASSERT(domain->hasGridReduction());
- const auto data_type = rop->out()->dtype();
- const auto op_type = rop->operation();
+ const auto par_domains = rop->getParallelReductionDomains();
+ const bool tidx = par_domains.find(ParallelType::TIDx) != par_domains.end();
+ const bool tidy = par_domains.find(ParallelType::TIDy) != par_domains.end();
+ const bool tidz = par_domains.find(ParallelType::TIDz) != par_domains.end();
+ const bool bidx = par_domains.find(ParallelType::BIDx) != par_domains.end();
+ const bool bidy = par_domains.find(ParallelType::BIDy) != par_domains.end();
+ const bool bidz = par_domains.find(ParallelType::BIDz) != par_domains.end();
+
+ const auto data_type = rop->out()->getDataType().value();
+ const auto op_type = rop->getReductionOpType();
TORCH_INTERNAL_ASSERT(
- node->reduction_buffer()->buffer()->isA<kir::TensorView>());
+ node->reduction_buffer()->buffer()->getValType().value() ==
+ ValType::KirTensorView);
TORCH_INTERNAL_ASSERT(
- node->sync_buffer()->buffer()->isA<kir::TensorView>());
+ node->sync_buffer()->buffer()->getValType().value() ==
+ ValType::KirTensorView);
const auto work_buffer =
node->reduction_buffer()->buffer()->as<kir::TensorView>();
const auto sync_buffer =
node->sync_buffer()->buffer()->as<kir::TensorView>();
- const std::string flags_str =
- generateGridReduceTemplateFlags(rop, node->threadPredicate());
-
// Since block-level reduction is already done, those dimensions
// with tidx/y/z being true do not participate in the grid reduction.
indent() << kir::GridReduction::getPredicateFlagName(out->view()) << " = "
- << "reduction::gridReduce<" << flags_str << ">(\n";
+ << "reduction::gridReduce<" << (bidx ? "true" : "false") << ", "
+ << (bidy ? "true" : "false") << ", " << (bidz ? "true" : "false")
+ << ", " << (!tidx ? "true" : "false") << ", "
+ << (!tidy ? "true" : "false") << ", " << (!tidz ? "true" : "false")
+ << ">(\n";
indent() << kTab << gen(rop->out()) << ",\n";
if (domain->hasBlockReduction()) {
- indent() << kTab << "block_result_" << block_reduce_name_ << ",\n";
- block_reduce_name_++;
+ indent() << kTab << "block_result"
+ << ",\n";
} else {
indent() << kTab << gen(rop->in()) << ",\n";
}
- indent() << kTab << genReductionOp(op_type, out) << ",\n";
- indent() << kTab << "&" << varName(work_buffer) << "[0],\n";
- indent() << kTab << varName(sync_buffer) << ",\n";
+ indent() << kTab << genReductionOp(op_type, data_type) << ",\n";
+ indent() << kTab << "&" << gen(work_buffer) << "[0],\n";
+ indent() << kTab << gen(sync_buffer) << ",\n";
indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
- TORCH_INTERNAL_ASSERT(
- node->predicate() != nullptr && node->predicate()->hasValue());
- auto read_pred = genInline(node->predicate());
- indent() << kTab << read_pred << ",\n";
- if (node->writePredicate() != nullptr) {
- TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue());
- auto write_pred = genInline(node->writePredicate());
- indent() << kTab << write_pred << ",\n";
+ if (node->pred() == nullptr) {
+ indent() << kTab << "true,\n";
} else {
- indent() << kTab << read_pred << ",\n";
+ indent() << kTab << genInline(node->pred()) << ",\n";
}
- indent() << kTab << data_type << "("
- << genInline(node->reduction_op()->init()) << "));\n";
+ indent() << kTab << genInline(node->reduction_op()->init()) << ");\n";
}
- void visit(const kir::GridWelford* node) final {
- const auto wop = node->welford_op();
- TORCH_INTERNAL_ASSERT(wop->outAvg()->isA<kir::TensorIndex>());
-
- const auto out = wop->out()->as<kir::TensorIndex>();
- const auto domain = out->view()->domain();
- TORCH_INTERNAL_ASSERT(domain->hasGridReduction());
-
- const auto data_type = out->dtype();
-
- TORCH_INTERNAL_ASSERT(node->var_buffer()->buffer()->isA<kir::TensorView>());
- TORCH_INTERNAL_ASSERT(
- node->sync_buffer()->buffer()->isA<kir::TensorView>());
-
- const auto avg_buffer = node->avg_buffer()->buffer()->as<kir::TensorView>();
- const auto var_buffer = node->var_buffer()->buffer()->as<kir::TensorView>();
- const auto n_buffer = node->N_buffer()->buffer()->as<kir::TensorView>();
- const auto sync_buffer =
- node->sync_buffer()->buffer()->as<kir::TensorView>();
-
- const std::string flags_str =
- generateGridReduceTemplateFlags(wop, node->threadPredicate());
-
- // Since block-level reduction is already done, those dimensions
- // with tidx/y/z being true do not participate in the grid reduction.
- indent() << kir::GridWelford::getPredicateFlagName(out->view()) << " = "
- << "welford::gridWelford<" << flags_str << ">(\n";
- indent() << kTab << gen(wop->outAvg()) << ",\n"
- << kTab << gen(wop->outVar()) << ",\n"
- << kTab << gen(wop->outN()) << ",\n";
- if (domain->hasBlockReduction()) {
- indent() << kTab << "block_result_avg_" << block_reduce_name_ << ",\n"
- << kTab << "block_result_var_" << block_reduce_name_ << ",\n"
- << kTab << "block_result_n_" << block_reduce_name_ << ",\n";
- block_reduce_name_++;
- } else {
- indent() << kTab << gen(wop->inAvg()) << ",\n";
- if (wop->inVar() == nullptr) {
- indent() << kTab << "(" << data_type << ") 0,\n";
- } else {
- indent() << kTab << gen(wop->inVar()) << ",\n";
- }
- indent() << kTab << "(" << wop->outN()->dtype() << ")" << gen(wop->inN())
- << ",\n";
- }
- indent() << kTab << "&" << varName(avg_buffer) << "[0],\n";
- indent() << kTab << "&" << varName(var_buffer) << "[0],\n";
- indent() << kTab << "&" << varName(n_buffer) << "[0],\n";
- indent() << kTab << varName(sync_buffer) << ",\n";
- indent() << kTab << "reinterpret_cast<" << data_type
- << "*>(shared_mem_avg),\n";
- indent() << kTab << "reinterpret_cast<" << data_type
- << "*>(shared_mem_var),\n";
- indent() << kTab << "reinterpret_cast<" << wop->outN()->dtype()
- << "*>(shared_mem_n),\n";
- TORCH_INTERNAL_ASSERT(
- node->predicate() != nullptr && node->predicate()->hasValue());
- auto read_pred = genInline(node->predicate());
- indent() << kTab << read_pred << ",\n";
- if (node->writePredicate() != nullptr) {
- TORCH_INTERNAL_ASSERT(node->writePredicate()->hasValue());
- auto write_pred = genInline(node->writePredicate());
- indent() << kTab << write_pred << ",\n";
- } else {
- indent() << kTab << read_pred << ",\n";
- }
- // TODO : init value support or remove.
- indent() << kTab << data_type << "(0));\n";
- }
-
- void handleScope(const kir::Scope& scope) {
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Woverloaded-virtual"
+ // TODO(Kir): fix me
+ void handle(const kir::Scope& scope) {
for (auto expr : scope.exprs()) {
- expr->accept(this);
+ handle(expr);
}
}
+#pragma clang diagnostic pop
- void visit(const kir::ForLoop* node) final {
+ void handle(const kir::ForLoop* node) final {
// TODO(kir): handle this during lowering
- if (node->iter_domain()->isBroadcast()) {
- handleScope(node->body());
- return;
- } else if (node->vectorize()) {
- vectorize_scope_ = node->vectorize();
- handleScope(node->body());
- vectorize_scope_ = false;
- return;
- }
-
- // By default, a parallelized loop would look like:
- //
- // for (int x = threadIdx.x; x < stop; x += blockDim.x) {
- // do_some_comp(x);
- // }
- //
- // When stop is guaranteed to be smaller or equal to the number of
- // threads, the for-loop is not necessary. In the above case, we
- // would just generate the loop body without the for clause but
- // references to the loop index replaced by the loop start value.
- //
- // When the loop end is the same as the IterDomain extent, the
- // assumption can be safely made. This is more conservative than
- // necessary since the loop stop value just needs to be <= the
- // IterDomain extent. However, at this point, this conservative
- // analysis seems sufficient.
- if (node->stop() == node->iter_domain()->extent() &&
- node->iter_domain()->isThread()) {
- // Register a replacement of references to the loop index with
- // the loop start value.
- replacement_map_.insert({node->index(), node->start()});
- handleScope(node->body());
- replacement_map_.erase(node->index());
- return;
- }
-
- if (node->start()->isZeroInt() && node->stop()->isOneInt()) {
- indent() << "constexpr "
- << "nvfuser_index_t"
- << " " << gen(node->index()) << " = 0;\n";
- handleScope(node->body());
+ if (node->iter_domain()->isThread() || node->iter_domain()->isBroadcast()) {
+ handle(node->body());
return;
}
const auto gen_index = gen(node->index());
- const auto gen_start = genInline(node->start());
- const auto gen_stop = genInline(node->stop());
- const auto gen_step = genInline(node->step());
+ const auto gen_start = genInline(node->iter_domain()->start());
+ const auto gen_extent = genInline(node->iter_domain()->extent());
+ indent() << "for(size_t " << gen_index << " = " << gen_start << "; "
+ << gen_index << " < " << gen_extent << "; ++" << gen_index << ") ";
- std::stringstream step_code;
- if (node->step()->isOneInt()) {
- step_code << "++" << gen_index;
- } else {
- step_code << gen_index << " += " << gen_step;
- }
- if (node->isUnrollable()) {
- indent() << "#pragma unroll\n";
- } else {
- indent() << "#pragma unroll 1\n";
- }
- indent() << "for(nvfuser_index_t " << gen_index << " = " << gen_start
- << "; " << gen_index << " < " << gen_stop << "; "
- << step_code.str() << ") ";
startBlock(true);
- handleScope(node->body());
+ handle(node->body());
endBlock();
}
- void visit(const kir::IfThenElse* node) final {
- auto conditional = node->predicate()->value();
- if (conditional->isConst()) {
- // If the conditional is a constant, then the IfThenElse is not required
- if (conditional->value().value()) {
- handleScope(node->thenBody());
- } else {
- handleScope(node->elseBody());
- }
- return;
- }
-
- indent() << "if (" << genInline(conditional) << ") ";
+ void handle(const kir::IfThenElse* node) final {
+ indent() << "if (" << genInline(node->cond()) << ") ";
// "then" block
startBlock(true);
- handleScope(node->thenBody());
+ handle(node->thenBody());
// "else" block (optional)
if (node->hasElse()) {
endBlock(" else ");
startBlock(true);
- handleScope(node->elseBody());
+ handle(node->elseBody());
}
endBlock();
}
// TODO(kir): fold initialization into Allocate
- void visit(const kir::Allocate* node) final {
- const auto buffer_dtype = node->buffer()->dtype();
-
- if (!node->buffer()->isA<kir::TensorView>()) {
- indent() << buffer_dtype << " " << gen(node->buffer()) << ";\n";
+ void handle(const kir::Allocate* node) final {
+ if (node->buffer()->getValType().value() != ValType::KirTensorView) {
+ indent() << node->buffer_type() << " " << gen(node->buffer()) << ";\n";
return;
}
const auto tv = node->buffer()->as<kir::TensorView>();
-
- const auto size = node->size();
- TORCH_INTERNAL_ASSERT(size != nullptr);
+ TORCH_INTERNAL_ASSERT(tv->domain()->nDims() > 0);
+ TORCH_INTERNAL_ASSERT(node->size() != nullptr);
if (node->alias() != nullptr) {
// Allocate alias another Allocate node
const auto alias_tv = node->alias()->buffer()->as<kir::TensorView>();
- indent() << "// Alias Allocation - " << node->memoryType() << "\n";
- indent() << buffer_dtype << "* " << varName(tv) << " = "
- << varName(alias_tv) << ";\n";
+ indent() << "// Alias Allocation - " << node->getMemoryType() << "\n";
+ indent() << node->buffer_type() << "* " << gen(tv) << " = "
+ << gen(alias_tv) << ";\n";
} else {
// Standard Memory Allocation
switch (tv->memoryType()) {
case MemoryType::Global:
- indent() << "// Allocate global tensor " << varName(tv) << "\n";
+ indent() << "// Allocate global tensor " << gen(tv) << "\n";
break;
case MemoryType::Shared:
- if (kir::ExpressionEvaluator::isConst(size)) {
+ if (node->size()->isConstScalar()) {
// Static shared memory
- indent() << "__shared__ " << buffer_dtype << " " << varName(tv)
- << "[" << genInline(size) << "];\n";
+ indent() << "__shared__ " << node->buffer_type() << " " << gen(tv)
+ << "[" << genInline(node->size()) << "];\n";
} else {
// Align Offset Position
indent() << "offset = alignBufferSize(offset,"
- << dataTypeSize(buffer_dtype) << ");\n";
+ << dataTypeSize(node->buffer_type()) << ");\n";
// Shared Memory Pointer
- indent() << buffer_dtype << "* " << varName(tv)
- << " = reinterpret_cast<" << buffer_dtype << "*>"
+ indent() << node->buffer_type() << "* " << gen(tv)
+ << " = reinterpret_cast<" << node->buffer_type() << "*>"
<< "(array + offset);\n";
// Increment Offset Position
- indent() << "offset += (" << genInline(size) << " * sizeof("
- << buffer_dtype << "));\n";
+ indent() << "offset += (" << genInline(node->size()) << " * sizeof("
+ << node->buffer_type() << "));\n";
}
break;
case MemoryType::Local:
- indent() << buffer_dtype << " " << varName(tv) << "["
- << genInline(size) << "];\n";
+ indent() << node->buffer_type() << " " << gen(tv) << "["
+ << genInline(node->size()) << "];\n";
break;
default:
TORCH_INTERNAL_ASSERT(false, "Unexpected memory type");
}
}
- void visit(const kir::Sync* node) final {
- // Use a custom synchronization method if enabled
- if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
- indent() << "block_sync::sync();\n";
- } else {
- indent() << "__barrier_sync(0);\n";
- }
- }
-
- void visit(const kir::InitMagicZero* node) final {
- indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n";
- }
-
- void visit(const kir::UpdateMagicZero* node) final {
- indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
+ void handle(const kir::Sync* node) final {
+ indent() << "__syncthreads();\n";
}
private:
std::stringstream code_;
- const kir::Kernel* kernel_;
+ const Kernel* kernel_;
int block_nest_level_ = 0;
- int block_reduce_name_ = 0;
// TODO(kir): replace with explicit assignment statements
bool print_inline_ = false;
-
- // Mark when we are inside of a vectorized for-loop
- bool vectorize_scope_ = false;
-
- //! Holds active replacement mappings during codegen
- std::unordered_map<const kir::Node*, const kir::Node*> replacement_map_;
};
} // namespace
std::string generateCudaKernel(
- const kir::Kernel* kernel,
+ const Kernel* kernel,
const std::string& kernel_name) {
FUSER_PERF_SCOPE("generateCudaKernel");
return CudaKernelGenerator::generateKernelDefinition(kernel, kernel_name);
//! Generates a CUDA kernel definition for the given kernel
TORCH_CUDA_CU_API std::string generateCudaKernel(
- const kir::Kernel* kernel,
+ const Kernel* kernel,
const std::string& kernel_name = "CUDAGeneratedKernel");
} // namespace codegen
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
namespace fuser {
namespace cuda {
+ComputeAtData::ComputeAtData(TensorView* tv)
+ : tv_ref_(tv),
+ original_has_compute_at_(tv->hasComputeAt()),
+ original_compute_at_position(tv->getThisComputeAtAxis()),
+ original_domain_(tv->domain()),
+ new_compute_at_domain_(tv->domain()) {}
+
+// Clear pass based data
+void ComputeAtData::clearPass() {
+ // If the last pass set a position, update the new_compute_at_position if
+ // latest position would be greater than previously set.
+ if (current_traversal_position_set &&
+ current_traversal_position > new_compute_at_position) {
+ new_compute_at_position = current_traversal_position;
+ }
+
+ current_traversal_position_set = false;
+ current_traversal_position = 0;
+}
+
+void ComputeAtData::setPassPosition(unsigned int pos) {
+ if (current_traversal_position_set) {
+ // A single traversal cannot try to enforce more than one position on a
+ // TensorView as it would produce in incorrect code. If this is hit, then
+ // the given tensor and its production should be duplicated.
+ TORCH_CHECK(
+ pos == current_traversal_position,
+ "Error during computeAt. ComputeAt pass wanted to set position of ",
+ tv_ref_,
+ " at position ",
+ pos,
+ " but was already set to position ",
+ current_traversal_position,
+ ". This tensor would have to be recomputed to satsify the selected computeAt position.");
+ }
+
+ current_traversal_position = pos;
+ touched_ = true;
+ current_traversal_position_set = true;
+}
+
+unsigned int ComputeAtData::getNewPosition() const {
+ // If the last pass set a position, return the latest position if
+ // it would be greater than previously set.
+ if (current_traversal_position_set &&
+ current_traversal_position > new_compute_at_position) {
+ return current_traversal_position;
+ } else {
+ return new_compute_at_position;
+ }
+}
+
+void ComputeAtData::validateNewComputeAt() const {
+ FUSER_PERF_SCOPE("validateNewComputeAt");
+
+ TORCH_INTERNAL_ASSERT(
+ getNewPosition() >= original_compute_at_position,
+ "Invalid computeAt detected. This computeAt would invalidate the set computeAt on ",
+ tv_ref_,
+ " as the new computeAt position was found to be ",
+ getNewPosition(),
+ ".");
+ auto mismatch = BestEffortReplay::findFirstMismatchedID(
+ tv_ref_->domain(), original_domain_);
+ TORCH_CHECK(
+ mismatch >= (int)original_compute_at_position,
+ "Invalid computeAt detected. This computeAt call would invalidate the set computeAt on ",
+ tv_ref_,
+ " as the previous set computeAt was on the domain ",
+ original_domain_,
+ " with a computeAt position of ",
+ original_compute_at_position,
+ ".");
+}
+
+void ComputeAtData::setComputeAtDomain(TensorDomain* td) {
+ if (new_compute_at_domain_ != original_domain_) {
+ TORCH_INTERNAL_ASSERT(
+ *new_compute_at_domain_ == *td,
+ "TensorDomain, ",
+ td,
+ ", does not match with the previously set domain of ",
+ tv_ref_,
+ ", which is ",
+ new_compute_at_domain_);
+ }
+ new_compute_at_domain_ = td;
+}
+
namespace {
// Wrapper around set_intersection
return intersection;
}
+// convert an iterable of Val* to be an iterable of TensorView*
+template <typename T1, typename T2>
+T1 tvIterable(const T2& val_iterable) {
+ T1 tv_iterable = T1();
+ std::transform(
+ val_iterable.begin(),
+ val_iterable.end(),
+ std::back_inserter(tv_iterable),
+ [](Val* v) {
+ TORCH_INTERNAL_ASSERT(
+ v->getValType().value() == ValType::TensorView,
+ "When following the computeAt dependency chain, a non TensorView value was found.");
+ return v->as<TensorView>();
+ });
+ return tv_iterable;
+}
+
std::deque<std::deque<TensorView*>> tvChains(
std::deque<std::deque<Val*>> val_chains) {
std::deque<std::deque<TensorView*>> tv_chains(val_chains.size());
for (const auto i : c10::irange(val_chains.size())) {
- auto tv_iterable = ir_utils::filterByType<TensorView>(val_chains[i]);
- tv_chains[i] =
- std::deque<TensorView*>(tv_iterable.begin(), tv_iterable.end());
+ tv_chains[i] = tvIterable<std::deque<TensorView*>>(val_chains[i]);
}
return tv_chains;
}
-bool validateDomain(TensorView* tv, TensorDomain* new_td) {
- auto first_mismatch =
- BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td);
- return first_mismatch >= (int)tv->getMaxProducerPosition() &&
- first_mismatch >= (int)tv->getComputeAtPosition();
-}
-
-// Return the max position in consumer that producer can be inlined to
-// Cannot inline:
-// Reduction dimensions in producer
-// Block broadcast dimensions in producer
-// Vectorized dimensions in producer or consumer
-// Unrolled dimensions in producer or consumer
-// Dimensions derived from root dimensions that exist in both but are
-// unmappable
-unsigned int getReplayablePosPasC(
- TensorView* producer,
- TensorView* consumer,
- const ComputeAtRootDomainMap& root_map_,
- ComputeAtMode mode) {
- // Grab dimensions in producer and consumer that are mappable to eachother
- // based on the computeAtRootDomainMap. This will tell us which dimensions
- // can be inlined based on avoiding trying to inline reduction structures.
- auto mappable_roots =
- root_map_.getMappableDims(producer->domain(), consumer->domain());
-
- // Check if any consumer dimensions are marked as vectorize as producer can
- // not be inlined to vectorized dimensions in consumer.
- auto c_dom = consumer->domain()->domain();
- auto vector_dim_it =
- std::find_if(c_dom.begin(), c_dom.end(), [&mode](IterDomain* id) {
- return isParallelTypeVectorize(id->getParallelType()) ||
- ((mode == ComputeAtMode::BestEffort ||
- mode == ComputeAtMode::MostInlined) &&
- id->getParallelType() == ParallelType::Unroll);
- });
-
- // Limit max position based on vectorized dims in consumer.
- auto max_consumer_pos = std::distance(c_dom.begin(), vector_dim_it);
-
- auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer);
- auto c2p_root_map =
- PairwiseRootDomainMap(producer, consumer)
- .mapConsumerToProducer(consumer->domain(), producer->domain());
-
- auto replay_PasC =
- BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_root_map);
-
- // Look for id's that map to a consumer id that's vectorized
- auto c2p_replay_map = replay_PasC.getReplay();
-
- for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0;
- consumer_pos--) {
- auto map_it = c2p_replay_map.find(consumer->axis((int)consumer_pos - 1));
- if (map_it != c2p_replay_map.end()) {
- auto p_id = map_it->second;
- // If we find a consumer dim that maps to a producer dim that's
- // vectorized or unrolled limit max compute at by it.
- if (isParallelTypeVectorize(p_id->getParallelType()) ||
- ((mode == ComputeAtMode::BestEffort ||
- mode == ComputeAtMode::MostInlined) &&
- p_id->getParallelType() == ParallelType::Unroll)) {
- max_consumer_pos = consumer_pos - 1;
- }
- }
- }
-
- // Start at max position and work backwards, try to find a location where
- // producer can be inlined.
- for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0;
- consumer_pos--) {
- // Grab all root dimensions of consumer as roots must be used to understand
- // inlining potential.
- auto consumer_root_dim_vals =
- IterVisitor::getInputsTo({c_dom.begin(), c_dom.begin() + consumer_pos});
- // convert to iter domains
- auto consumer_root_dim_ids =
- ir_utils::filterByType<IterDomain>(consumer_root_dim_vals);
- // If any root dimensions cannot be mapped to producer we can't inline. If
- // any root dimension
- if (std::any_of(
- consumer_root_dim_ids.begin(),
- consumer_root_dim_ids.end(),
- [&mappable_roots, &c2p_root_map](IterDomain* root_id) {
- return mappable_roots.find(root_id) == mappable_roots.end() &&
- c2p_root_map.find(root_id) != c2p_root_map.end();
- })) {
- continue;
- }
- return consumer_pos;
- }
-
- return 0;
-}
-
-// Return the max position in producer that can be inlined to consumer
-// Cannot inline:
-// Reduction dimensions in producer
-// Vectorized dimensions in producer or consumer
-// Unrolled dimensions in producer or consumer
-// Dimensions derived from root dimensions that exist in both but are
-// unmappable
-unsigned int getReplayablePosCasP(
- TensorView* consumer,
- TensorView* producer,
- const ComputeAtRootDomainMap& root_map_,
- ComputeAtMode mode) {
- // Grab dimensions in producer and consumer that are mappable to eachother
- // based on the computeAtRootDomainMap. This will tell us which dimensions
- // can be inlined based on avoiding trying to inline reduction structures.
- auto mappable_roots =
- root_map_.getMappableDims(producer->domain(), consumer->domain());
-
- auto p_dom = producer->domain()->domain();
- auto first_reduction =
- std::find_if(p_dom.begin(), p_dom.end(), [](IterDomain* id) {
- return id->isReduction();
- });
-
- auto first_vectorized_axis =
- std::find_if(p_dom.begin(), first_reduction, [&mode](IterDomain* id) {
- return isParallelTypeVectorize(id->getParallelType()) ||
- ((mode == ComputeAtMode::BestEffort ||
- mode == ComputeAtMode::MostInlined) &&
- id->getParallelType() == ParallelType::Unroll);
- });
-
- auto max_producer_pos = std::distance(p_dom.begin(), first_vectorized_axis);
-
- auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer);
- auto p2c_root_map = pairwise_root_map.mapProducerToConsumer(
- producer->domain(), consumer->domain());
-
- auto replay_CasP =
- BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map);
-
- // Look for id's that map to a consumer id that's vectorized
- auto p2c_replay_map = replay_CasP.getReplay();
-
- for (size_t producer_pos = max_producer_pos; producer_pos > 0;
- producer_pos--) {
- auto map_it = p2c_replay_map.find(producer->axis((int)producer_pos - 1));
- if (map_it != p2c_replay_map.end()) {
- auto c_id = map_it->second;
- // If we find a producer dim that maps to a consumer vectorized or
- // unrolled dim, limit max compute at by it
- if (isParallelTypeVectorize(c_id->getParallelType()) ||
- ((mode == ComputeAtMode::BestEffort ||
- mode == ComputeAtMode::MostInlined) &&
- c_id->getParallelType() == ParallelType::Unroll)) {
- max_producer_pos = producer_pos - 1;
- }
- }
- }
-
- for (size_t producer_pos = max_producer_pos; producer_pos > 0;
- producer_pos--) {
- auto all_vals = DependencyCheck::getAllValsBetween(
- {producer->getMaybeRFactorDomain().begin(),
- producer->getMaybeRFactorDomain().end()},
- {p_dom.begin(), p_dom.begin() + producer_pos});
-
- // If any root dims could have mapped to consumer, but don't, then we can't
- // compute at this point
- if (std::any_of(
- producer->getMaybeRFactorDomain().begin(),
- producer->getMaybeRFactorDomain().end(),
- [&mappable_roots, &all_vals](IterDomain* root_id) {
- return std::find(all_vals.begin(), all_vals.end(), root_id) !=
- all_vals.end() &&
- mappable_roots.find(root_id) == mappable_roots.end();
- })) {
- continue;
- }
-
- return producer_pos;
- }
- return 0;
-}
-
-unsigned int getInnermostNonBroadcastIdFrom(TensorView* tv) {
- unsigned int ret = tv->getComputeAtPosition();
-
- // Still assuming we only have block broadcast for now.
- // This part may change
- while (ret > 0 && tv->axis((int)ret - 1)->isBroadcast()) {
- ret--;
- }
-
- return ret;
-}
-
-// Try to find the aligned position on consumer's domain corresponding to the
-// compute at position of producer domain. Used in computeAt pass only. No
-// checking on actual producer-consumer relationship.
-unsigned int getConsumerPosAlignedToProducerCA(
- TensorView* consumer,
- TensorView* producer) {
- // Locate consumer's position that aligns with
- // the producer's new compute at axis. We need broadcast axes forwarded so we
- // need to replay PasC as CasP will not forward braodcast dims. For example
- // if we have:
- // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 )
- // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will
- // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
- // NVFuserTest.FusionComplexBCast1_CUDA
-
- auto c2p_map =
- BestEffortReplay::replayPasC(
- producer,
- consumer,
- -1,
- // Compute at root domain may not be valid here, as all
- // producers don't have to be able to map into consumer at
- // max producer position. Since computeAt should be valid
- // and this mechanism is only intended to lower produce
- // position of consumer, we can simply use the pairwise map.
- PairwiseRootDomainMap(producer, consumer))
- .getReplay();
-
- // Find the innermost position of consumer that has
- // been mapped within the producer ca axis.
- unsigned int consumer_pos = consumer->nDims();
- while (consumer_pos > 0) {
- auto consumer_id = consumer->axis((int)consumer_pos - 1);
- auto p_dom = producer->domain()->domain();
- if (std::any_of(
- p_dom.begin(),
- p_dom.begin() + producer->getComputeAtPosition(),
- [&consumer_id, &c2p_map](IterDomain* p_id) {
- auto c_id_it = c2p_map.find(consumer_id);
- if (c_id_it != c2p_map.end()) {
- return c_id_it->second == p_id;
- }
- return false;
- })) {
- break;
- }
- consumer_pos--;
- }
-
- return consumer_pos;
-}
-
} // namespace
-void ComputeAt::runAt(
+void ComputeAt::run(
TensorView* producer,
TensorView* consumer,
- unsigned int consumer_position,
- ComputeAtMode mode) {
+ unsigned int consumer_position) {
FUSER_PERF_SCOPE("ComputeAt::run");
// Make sure the correct fusion is setup between this and consumer.
// Make sure Fusion Guard is set appropriately
FusionGuard fg(producer->fusion());
- TORCH_CHECK(
- DependencyCheck::isDependencyOf(producer, consumer),
- "Compute At expects ",
- producer->name(),
- " is a dependency of ",
- consumer->name(),
- ", however it is not.");
+ std::vector<TensorView*> producers;
- // Run computeAt on our potentially modified producer(s)
- ComputeAt ca(producer, consumer, consumer, consumer_position, mode);
- ca.runPass();
-}
+ // It doesn't make sense to set computeAt on an input as it's not generated,
+ // it's provided. If this was called, move the computeAt to users of the
+ // producer that are in a dependency between prodcer and consumer.
+ if (producer->fusion()->hasInput(producer)) {
+ auto all_chains =
+ tvChains(DependencyCheck::getAllDependencyChains(producer, consumer));
-void ComputeAt::runWith(
- TensorView* producer,
- TensorView* consumer,
- unsigned int producer_position,
- ComputeAtMode mode) {
- FUSER_PERF_SCOPE("ComputeAt::runWith");
-
- // Make sure the correct fusion is setup between this and consumer.
- TORCH_CHECK(
- producer->fusion() == consumer->fusion(),
- producer,
- " and ",
- consumer,
- " are not in the same fusion.");
-
- TORCH_CHECK(
- DependencyCheck::isDependencyOf(producer, consumer),
- "Compute At expects ",
- producer->name(),
- " is a dependency of ",
- consumer->name(),
- ", however it is not.");
-
- // Make sure Fusion Guard is set appropriately
- FusionGuard fg(producer->fusion());
+ TORCH_CHECK(
+ !all_chains.empty(),
+ "Compute At expects ",
+ producer,
+ " is a dependency of ",
+ consumer,
+ ", however it is not.");
+
+ std::unordered_set<TensorView*> added_producers;
+
+ // Check all dependency chains, select the next TV after producer towards
+ // consumer. These are the TVs we're going to actually call computeAt on.
+ for (const auto& tv_chain : all_chains) {
+ // When a chain only has two tensors, they must be the producer,
+ // which is an input, and the consumer. There is nothing we need
+ // to do for such chains.
+ if (tv_chain.size() > 2) {
+ // Make sure we only add once, but we want to add in a determinsitic
+ // order
+ if (added_producers.find(tv_chain[1]) == added_producers.end()) {
+ producers.push_back(tv_chain[1]);
+ added_producers.emplace(tv_chain[1]);
+ }
+ }
+ }
+ } else {
+ // If producer is not an input, it's the only one.
+ producers.push_back(producer);
+ }
- ComputeAt ca(producer, consumer, producer, producer_position, mode);
- ca.runPass();
+ // Run computeAt on our potentially modified producer(s)
+ if (!producers.empty()) {
+ for (auto producer_to_run : producers) {
+ ComputeAt ca(producer_to_run, consumer, consumer_position);
+ ca.runPass();
+ }
+ }
}
// Actually applies transformation
unsigned int ComputeAt::backwardComputeAt_impl(
TensorView* producer,
TensorView* consumer,
- unsigned int consumer_compute_at_pos) {
+ unsigned int consumer_compute_at_axis) {
FUSER_PERF_SCOPE("backwardComputeAt_impl");
- auto max_consumer_compute_at_pos =
- getReplayablePosPasC(producer, consumer, root_map_, mode_);
- if (mode_ == ComputeAtMode::BestEffort) {
- consumer_compute_at_pos =
- std::min(consumer_compute_at_pos, max_consumer_compute_at_pos);
- } else if (mode_ == ComputeAtMode::MostInlined) {
- consumer_compute_at_pos = max_consumer_compute_at_pos;
- } else {
- TORCH_INTERNAL_ASSERT(
- consumer_compute_at_pos <= max_consumer_compute_at_pos,
- "Invalid compute at position detected in compute at when trying to replay producer: ",
- producer,
- " as consumer: ",
- consumer,
- " tried to do this at position: ",
- consumer_compute_at_pos,
- " but max position that's allowed is ",
- max_consumer_compute_at_pos);
- }
-
- auto replay_producer_pair = TransformReplay::replayPasC(
- producer, consumer, (int)consumer_compute_at_pos, root_map_);
-
- if (replay_producer_pair.second == 0) {
- return 0;
- }
+ auto& producer_entry = tv_data.at(producer);
- if (replay_producer_pair.second >= producer->getComputeAtPosition()) {
- const TensorDomain* current_domain = producer->domain();
- TensorDomain* new_domain = replay_producer_pair.first;
+ // Use TensorDomain interface so it doesn't set computeAt automatically
+ auto replay = TransformReplay::replayPasC(
+ producer, consumer, (int)consumer_compute_at_axis);
- TORCH_INTERNAL_ASSERT(
- validateDomain(producer, new_domain),
- "Tried to set the domain of ",
- producer,
- " to ",
- new_domain,
- " but that would invalidate previously compute at position or max producer position.");
+ producer_entry.setPassPosition(replay.second);
- producer->setDomain(new_domain);
- if (!producer->isFusionInput()) {
- producer->setComputeAt(replay_producer_pair.second);
- }
-
- consumer->setMaxProducer(consumer_compute_at_pos);
- for (auto other_consumer : ir_utils::consumerTvsOf(producer)) {
- if (other_consumer != consumer) {
- auto max_consumer_pos =
- getConsumerPosAlignedToProducerCA(other_consumer, producer);
- other_consumer->setMaxProducer(max_consumer_pos);
- }
- }
- root_map_.setAlias(current_domain, new_domain);
+ if (producer_entry.shouldSetComputeAt(replay.second)) {
+ producer->setComputeAt(consumer, (int)consumer_compute_at_axis);
+ producer_entry.setComputeAtDomain(producer->domain());
}
- return replay_producer_pair.second;
+ return replay.second;
}
-// Actually applies transformation, replay consumer based on producer, set
-// compute at of producer, set pass position of consumer, return position
-// relative to consumer
+// Actually applies transformation
unsigned int ComputeAt::forwardComputeAt_impl(
TensorView* producer,
TensorView* consumer,
- unsigned int producer_compute_at_pos) {
+ unsigned int producer_compute_at_axis) {
FUSER_PERF_SCOPE("forwardComputeAt_impl");
- auto max_producer_compute_at_pos =
- getReplayablePosCasP(consumer, producer, root_map_, mode_);
-
- if (mode_ == ComputeAtMode::BestEffort) {
- producer_compute_at_pos =
- std::min(producer_compute_at_pos, max_producer_compute_at_pos);
- } else if (mode_ == ComputeAtMode::MostInlined) {
- producer_compute_at_pos = max_producer_compute_at_pos;
- } else {
- TORCH_INTERNAL_ASSERT(
- producer_compute_at_pos <= max_producer_compute_at_pos,
- "Invalid compute at position detected in compute at when trying to replay consumer: ",
- consumer,
- " as producer: ",
- producer,
- " tried to do this at position: ",
- producer_compute_at_pos,
- " but max position that's allowed is ",
- max_producer_compute_at_pos);
- }
+ auto& consumer_entry = tv_data.at(consumer);
+ const auto& producer_entry = tv_data.at(producer);
- auto replay_consumer_pair = TransformReplay::replayCasP(
- consumer, producer, (int)producer_compute_at_pos, root_map_);
+ auto replay = TransformReplay::replayCasP(
+ consumer, producer, (int)producer_compute_at_axis);
- if (producer_compute_at_pos > producer->getComputeAtPosition()) {
- if (!producer->isFusionInput()) {
- producer->setComputeAt((int)producer_compute_at_pos);
- }
+ if (producer_entry.shouldSetComputeAt(producer_compute_at_axis)) {
+ producer->setComputeAt(consumer, replay.second);
}
- if (replay_consumer_pair.second > consumer->getMaxProducerPosition()) {
- const TensorDomain* current_domain = consumer->domain();
- TensorDomain* new_domain = replay_consumer_pair.first;
-
- TORCH_INTERNAL_ASSERT(
- validateDomain(consumer, new_domain),
- "Tried to set the domain of ",
- consumer,
- " to ",
- new_domain,
- " but that would invalidate previously compute at position or max producer position.");
-
- consumer->setDomain(new_domain);
- consumer->setMaxProducer(replay_consumer_pair.second);
- for (auto other_consumer : ir_utils::consumerTvsOf(producer)) {
- if (other_consumer != consumer) {
- auto max_consumer_pos =
- getConsumerPosAlignedToProducerCA(other_consumer, producer);
- other_consumer->setMaxProducer(max_consumer_pos);
- }
- }
- root_map_.setAlias(current_domain, new_domain);
+ consumer_entry.setPassPosition(replay.second);
+ if (consumer_entry.shouldSetComputeAt(replay.second) &&
+ consumer != consumer_) {
+ consumer_entry.setComputeAtDomain(consumer->domain());
}
- return replay_consumer_pair.second;
+ return replay.second;
}
void ComputeAt::setCommonConsumer() {
TORCH_CHECK(
!all_chains.empty(),
"Compute At expects ",
- producer_->name(),
+ producer_,
" is a dependency of ",
- consumer_->name(),
+ consumer_,
", however it is not.");
// Remove all TVs from producer to consumer as common consumer must be at or
// computeAt if it will increase computeAt positions.
void ComputeAt::traverseBackward() {
FUSER_PERF_SCOPE("ComputeAt::traverseBackward");
- if (reference_ == producer_) {
- // Forward compute at don't need to run backward traversal
- producer_position_ = reference_position_;
- return;
- }
// propagate *backward* through all *producer* use_chains or from *producer*
// to common_consumer if common_consumer exists. Only apply transform if
for (auto tv_chain : chains) {
TensorView* running_producer = tv_chain.back();
TensorView* running_consumer = nullptr;
- unsigned int running_consumer_pos = reference_position_;
+ unsigned int running_consumer_pos = consumer_position_;
tv_chain.pop_back();
TORCH_INTERNAL_ASSERT(running_producer == consumer_);
running_consumer_pos = backwardComputeAt_impl(
running_producer, running_consumer, running_consumer_pos);
}
-
- TORCH_INTERNAL_ASSERT(
- running_producer == producer_,
- "Compute at backward traversal ended up on something other than the producer.");
- producer_position_ = running_consumer_pos;
}
}
DependencyCheck::getAllDependencyChains(producer_, common_consumer_));
}
+ unsigned int producer_pos = tv_data.at(producer_).getNewPosition();
+
// propagate forward through all chains
for (auto tv_dep_chain : chains) {
TensorView* running_producer = nullptr;
TensorView* running_consumer = tv_dep_chain.front();
tv_dep_chain.pop_front();
- unsigned int running_producer_pos = producer_position_;
+ unsigned int running_producer_pos = producer_pos;
TORCH_INTERNAL_ASSERT(running_consumer == producer_);
running_producer = running_consumer;
running_consumer = tv_dep_chain.front();
tv_dep_chain.pop_front();
+
running_producer_pos = forwardComputeAt_impl(
running_producer, running_consumer, running_producer_pos);
}
}
}
-void ComputeAt::resetMaxProducerPos(TensorView* consumer_tv) {
- if (consumer_tv->definition() == nullptr) {
- consumer_tv->setMaxProducer(0, true);
+void ComputeAt::runPass() {
+ FUSER_PERF_SCOPE("ComputeAt::runPass");
+
+ // Initialize tv_data for all TensorViews we may modify
+ auto chains = producer_use_chains_;
+ if (common_consumer_ != nullptr) {
+ chains = tvChains(
+ DependencyCheck::getAllDependencyChains(producer_, common_consumer_));
}
- unsigned int new_consummer_pa_pos = 0;
-
- // Re-compute the max producer position as one or more
- // of the producers of this consumer have updated their
- // compute at position.
- for (auto inp : ir_utils::producerTvsOf(consumer_tv)) {
- if (!inp->isFusionInput()) {
- // Locate consumer's position that aligns with
- // the producer's new compute at axis.
- unsigned int inp_ca_pos_to_consumer =
- getConsumerPosAlignedToProducerCA(consumer_tv, inp);
-
- // Populate the max consumer position required by
- // producer compute at.
- new_consummer_pa_pos =
- std::max(new_consummer_pa_pos, inp_ca_pos_to_consumer);
+ for (const auto& tv_chain : chains) {
+ for (auto tv : tv_chain) {
+ if (tv_data.find(tv) == tv_data.end()) {
+ tv_data[tv] = ComputeAtData(tv);
+ }
}
}
- consumer_tv->setMaxProducer(new_consummer_pa_pos, true);
-}
+ // Traverse backward through all dep chains from producer to consumer
+ traverseBackward();
-void ComputeAt::hoistInnermostBroadcast() {
- auto fusion = producer_->fusion();
+ // Clear data from backward traversal:
+ for (auto& entry : tv_data) {
+ entry.second.clearPass();
+ }
- std::unordered_set<TensorView*> consumers_to_update;
+ // Start at producer and traverse forward through all chains
+ traverseForward();
- auto all_vals = fusion->usedMathVals();
- auto all_tvs = ir_utils::filterByType<TensorView>(all_vals);
+ setupOutputs();
- for (auto running_producer : all_tvs) {
- if (!running_producer->isFusionInput()) {
- auto producer_ca_pos = running_producer->getComputeAtPosition();
- // Find the innermost iterdomain that is not a broadcast
- auto new_ca_pos = getInnermostNonBroadcastIdFrom(running_producer);
- // Update the compute at pos of this producer if the original
- // compute at is within inner most broadcast axes
- if (new_ca_pos < producer_ca_pos) {
- running_producer->setComputeAt(new_ca_pos, true);
- }
- // Mark all consumers of this producer for later produce
- // position update.
- // This is safe with segmented fusion. TV uses will reset
- // when FusionSegmentGuard try to change the IO.
- auto tv_consumers = ir_utils::consumerTvsOf(running_producer);
- consumers_to_update.insert(tv_consumers.begin(), tv_consumers.end());
- }
+ for (const auto& entry : tv_data) {
+ entry.first->setDomain(entry.second.getComputeAtDomain());
+ entry.second.validateNewComputeAt();
}
- // Update the produce positions of all affected consumers
- for (auto running_consumer : consumers_to_update) {
- TORCH_INTERNAL_ASSERT(running_consumer->definition() != nullptr);
- resetMaxProducerPos(running_consumer);
- }
+ TORCH_INTERNAL_ASSERT(
+ BestEffortReplay::findFirstMismatchedID(
+ consumer_->domain(), tv_data.at(consumer_).getOriginalDomain()) ==
+ (int)consumer_->domain()->nDims(),
+ "ComputeAt logic changed the consumer domain which should not happen. Domain was ",
+ tv_data.at(consumer_).getOriginalDomain(),
+ " but is now: ",
+ consumer_->domain());
}
-void ComputeAt::updateSiblings() {
- // Track which consumers may have a wrong produce at position to update
- // later
- auto updateSiblingsOfTv = [&](TensorView* tv) {
- if (tv->definition() == nullptr) {
- return;
- }
-
- std::unordered_set<TensorView*> consumers_to_update;
+void ComputeAt::setupOutputs() {
+ FUSER_PERF_SCOPE("ComputeAt::setupOutputs");
- if (tv->definition()->outputs().size() > 1) {
- auto outs = tv->definition()->outputs();
- auto out_tvs = ir_utils::filterByType<TensorView>(outs);
- for (auto sibling_tv : out_tvs) {
- if (sibling_tv == tv) {
- continue;
- }
+ if (common_consumer_ != nullptr)
+ return;
- std::unordered_map<IterDomain*, IterDomain*> tv_to_sibling_map;
- TORCH_INTERNAL_ASSERT(
- tv->getRootDomain().size() == sibling_tv->getRootDomain().size(),
- "Error replaying multiple output expressions in computeAt.");
-
- // Propagate any root parallelization as fullSelfReplay expects it.
- for (size_t i = 0; i < sibling_tv->getRootDomain().size(); i++) {
- auto id = tv->getRootDomain()[i];
- auto sibling_id = sibling_tv->getRootDomain()[i];
- if (id->getParallelType() != ParallelType::Serial &&
- sibling_id->getParallelType() == ParallelType::Serial) {
- sibling_id->parallelize(id->getParallelType());
- } else if (
- id->getParallelType() == ParallelType::Serial &&
- sibling_id->getParallelType() != ParallelType::Serial) {
- id->parallelize(sibling_id->getParallelType());
- }
- }
- if (tv->getComputeAtPosition() > sibling_tv->getComputeAtPosition()) {
- auto sibling_domain = TransformReplay::fullSelfReplay(
- sibling_tv->domain(), tv->domain());
- validateDomain(sibling_tv, sibling_domain);
- sibling_tv->setDomain(sibling_domain);
- sibling_tv->setComputeAt(tv->getComputeAtPosition());
- sibling_tv->setMaxProducer(tv->getMaxProducerPosition());
- auto consumer_tvs = ir_utils::consumerTvsOf(sibling_tv);
- consumers_to_update.insert(consumer_tvs.begin(), consumer_tvs.end());
+ std::vector<TensorView*> touched_output_order;
+ const auto& terminating_outputs =
+ FusionGuard::getCurFusion()->getTerminatingOutputs();
+
+ for (auto out : ir_utils::filterByType<TensorView>(
+ FusionGuard::getCurFusion()->outputs())) {
+ if (tv_data.find(out) != tv_data.end()) {
+ if (tv_data[out].touched()) {
+ // No need to adjust computeAt when an output is not
+ // a terminating output.
+ if (std::find(
+ terminating_outputs.begin(), terminating_outputs.end(), out) !=
+ terminating_outputs.end()) {
+ touched_output_order.push_back(out);
}
}
}
- for (auto consumer : consumers_to_update) {
- this->resetMaxProducerPos(consumer);
- }
- };
-
- // Find all tensor views that may have been modified
- auto chains = producer_use_chains_;
- if (common_consumer_ != nullptr) {
- chains = tvChains(
- DependencyCheck::getAllDependencyChains(producer_, common_consumer_));
}
- std::unordered_set<TensorView*> participating_tvs;
- for (auto chain : chains) {
- participating_tvs.insert(chain.begin(), chain.end());
- }
-
- for (auto tv : participating_tvs) {
- updateSiblingsOfTv(tv);
- }
-}
-
-void ComputeAt::updateInputProduceAts() {
- std::unordered_set<TensorView*> consumers_to_check;
-
- // Find all tensor views that may have been modified
- auto chains = producer_use_chains_;
- if (common_consumer_ != nullptr) {
- chains = tvChains(
- DependencyCheck::getAllDependencyChains(producer_, common_consumer_));
- }
-
- for (auto chain : chains) {
- if (chain.size() > 1 && chain[0]->isFusionInput()) {
- consumers_to_check.emplace(chain[1]);
+ if (touched_output_order.size() > 0) {
+ for (size_t i = 0; i < touched_output_order.size() - 1; i++) {
+ touched_output_order[i]->setComputeAt(
+ touched_output_order[i + 1],
+ (int)tv_data.at(touched_output_order[i]).getNewPosition(),
+ (int)tv_data.at(touched_output_order[i + 1]).getNewPosition());
}
}
-
- for (auto tv : consumers_to_check) {
- resetMaxProducerPos(tv);
- }
-}
-
-void ComputeAt::runPass() {
- FUSER_PERF_SCOPE("ComputeAt::runPass");
-
- // Traverse backward through all dep chains from producer to consumer
- traverseBackward();
-
- // Start at producer and traverse forward through all chains
- traverseForward();
-
- // Back off on inlining the inner broadcast axes
- hoistInnermostBroadcast();
-
- // Clear max producer position of consumers from fusion inputs.
- updateInputProduceAts();
-
- // Update siblings of multi output expressions
- updateSiblings();
}
ComputeAt::ComputeAt(
TensorView* _producer,
TensorView* _consumer,
- TensorView* _reference,
- unsigned int _reference_position,
- ComputeAtMode _mode)
+ unsigned int _consumer_position)
: producer_(_producer),
consumer_(_consumer),
- reference_(_reference),
- reference_position_(_reference_position),
- mode_(_mode) {
+ consumer_position_(_consumer_position) {
TORCH_INTERNAL_ASSERT(
- reference_ == producer_ || reference_ == consumer_,
- "For compute at reference must be producer or consumer, it's neither.",
- " reference: ",
- reference_,
- " consumer: ",
- consumer_,
- " producer: ",
- producer_);
- TORCH_INTERNAL_ASSERT(
- reference_position_ >= 0 && reference_position_ <= reference_->nDims(),
+ consumer_position_ >= 0 && consumer_position_ <= consumer_->nDims(),
"Invalid computeAt axis, received ",
- reference_position_,
+ _consumer_position,
" but should be > -",
- reference_->nDims(),
+ consumer_->nDims(),
" and <= ",
- reference_->nDims(),
+ consumer_->nDims(),
".");
producer_use_chains_ = tvChains(DependencyCheck::getAllUseChains(producer_));
// consumer for all chains at or after the consumer specified in the computeAt
// call.
setCommonConsumer();
-
- root_map_.build();
}
} // namespace cuda
#pragma once
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-
#include <c10/util/Exception.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
class TensorDomain;
class TensorView;
-class ComputeAt {
+// We're going to keep data related to the computeAt pass for each TensorView in
+// this structure, this will allow us to keep a single entry in a map from a
+// TensorView to this one.
+// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+class ComputeAtData {
public:
- // Runs the compute at pass making producer look like consumer, computing
- // producer relative to consumer
- static void runAt(
- TensorView* producer,
- TensorView* consumer,
- unsigned int consumer_position,
- ComputeAtMode mode = ComputeAtMode::Standard);
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+ ComputeAtData() = default;
+ ComputeAtData(TensorView* tv);
+
+ // Clear after a given traversal. There will be more than one.
+ void clearPass();
+
+ // Makes sure value matches current_traversal_position if
+ // current_traversal_position_set is true. If this is not the case we're in
+ // an invalid compute_at that would require tensor replication.
+ void setPassPosition(unsigned int pos);
+
+ // Returns if new postion is greater or equal to previous seen, if
+ bool shouldSetComputeAt(unsigned int pos) const {
+ return pos > original_compute_at_position &&
+ pos > new_compute_at_position && pos >= current_traversal_position;
+ }
+
+ // Will return new_compute_at_position, after making sure we cleared out the
+ // last pass
+ unsigned int getNewPosition() const;
+
+ // Will make sure we haven't invalidated previous computeAt calls by
+ // checking that any axes previously in computeAt are still there.
+ void validateNewComputeAt() const;
+
+ // Did we ever compute a value for this TV?
+ bool touched() const {
+ return touched_;
+ }
+
+ TensorDomain* getOriginalDomain() const {
+ return original_domain_;
+ }
+
+ // If we set computeAt, save the domain so we can reset it after traversal.
+ // Traversal state can deviate from the domain we will want to save after the
+ // entire computeAt pass.
+ void setComputeAtDomain(TensorDomain* td);
+
+ // Return domain set in setComputeAtDomain
+ TensorDomain* getComputeAtDomain() const {
+ return new_compute_at_domain_;
+ }
- // Runs the compute with pass making consumer look like producer, computing
- // producer relative to consumer
- static void runWith(
- TensorView* producer,
- TensorView* consumer,
- unsigned int producer_position,
- ComputeAtMode mode = ComputeAtMode::Standard);
+ private:
+ // Was the position ever modified?
+ bool touched_ = false;
+
+ // Hold onto the provided TensorView
+ TensorView* tv_ref_ = nullptr;
+
+ // Did this tv have computeAt set before calling this computeAt pass?
+ bool original_has_compute_at_ = false;
+
+ // What was the computeAt position before the computeAt pass started
+ unsigned int original_compute_at_position = 0;
+
+ // and what was the previous domain that position was set relative to.
+ TensorDomain* original_domain_ = nullptr;
+
+ // Position we can update during a traversal
+ unsigned int current_traversal_position = 0;
+
+ // Did this traversal set a position or not yet
+ bool current_traversal_position_set = false;
+
+ // Position to update after a traversal
+ unsigned int new_compute_at_position = 0;
+
+ // Domain when we actually set computeAt, will set back to this after the
+ // pass.
+ TensorDomain* new_compute_at_domain_;
+};
+
+class ComputeAt {
+ public:
+ static void run(
+ TensorView* _producer,
+ TensorView* _consumer,
+ unsigned int _consumer_position);
ComputeAt() = delete;
ComputeAt(ComputeAt&) = delete;
private:
TensorView* producer_;
TensorView* consumer_;
- TensorView* reference_;
- unsigned int reference_position_;
- ComputeAtMode mode_ = ComputeAtMode::Standard;
-
- unsigned int producer_position_ = 0;
- ComputeAtRootDomainMap root_map_;
+ unsigned int consumer_position_;
// Runs replayPasC and sets producer computeAt settings. Returns
- // producer_compute_at_pos.
+ // producer_compute_at_axis.
unsigned int backwardComputeAt_impl(
TensorView* producer,
TensorView* consumer,
- unsigned int consumer_compute_at_pos);
+ unsigned int consumer_compute_at_axis);
// Runs replayCasP and sets producer computeAt settings. Returns
- // consumer_compute_at_pos.
+ // consumer_compute_at_axis.
unsigned int forwardComputeAt_impl(
TensorView* producer,
TensorView* consumer,
- unsigned int producer_compute_at_pos);
+ unsigned int producer_compute_at_axis);
// Look through all the use chains of producer. Check if there's a single
// consumer for all chains at or after the consumer specified in the computeAt
// of producer
void traverseForward();
- // Looks at producer tensor views of consumer_tv, recomputes its max
- // producer position, and sets max producer position. This function can
- // only potentially lower the max producer position of consumer_tv.
- void resetMaxProducerPos(TensorView* consumer_tv);
-
- // Undo the inlining of block broadcast at the innermost positions
- // to avoid generating repeated block broadcasts
- void hoistInnermostBroadcast();
-
- // Update multi-output expressions. If one output is modified, all outputs
- // should be modified as well. Propagate transformations, compute at, and
- // produce at from tv to siblings. Run as final pass as it will invalidate the
- // computeAt map originally computed.
- void updateSiblings();
-
- // Compute at pass requires tracking "maxProducerPosition" even if set simply
- // from input tensor views. However, when lowering, we need a valid produce at
- // position of all tensors, so inputs should never actually set their
- // consumers maxProduceAt position.
- void updateInputProduceAts();
-
// Run the computeAt pass
void runPass();
+ // Set outputs relative to eachother if there is not a common consumer
+ void setupOutputs();
+
// Common consumer if it exists
TensorView* common_consumer_ = nullptr;
// Producer use chains set in, used in a few spots.
std::deque<std::deque<TensorView*>> producer_use_chains_;
+ // All we need to know and keep track of for each TensorView in this pass.
+ std::unordered_map<TensorView*, ComputeAtData> tv_data;
+
ComputeAt(
TensorView* _producer,
TensorView* _consumer,
- TensorView* _reference,
- unsigned int _reference_position,
- ComputeAtMode _mode);
+ unsigned int _consumer_position);
~ComputeAt() = default;
};
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-namespace {
-
-//! Class to figure out how many non-broadcast axes and how many broadcast axes
-//! were used to produce an iter domain. This is important for figuring out what
-//! the correct broadcasted extent is of an iteration domain.
-//!
-//! When GpuLower is available, trivial reductions are not counted as
-//! concrete domains so that they should not be used to generate
-//! for-loops.
-class InputDomainCounter : public IterVisitor {
- public:
- // Returns number of {non-braodcast non-reduction iteration domains, broadcast
- // and trivial reduction domains} used to generate the iteration domains in
- // provided target domain.
- static std::unordered_map<IterDomain*, std::pair<int, int>> produceCounts(
- const std::vector<IterDomain*>& domain,
- GpuLower* gpu_lower) {
- if (domain.empty()) {
- return std::unordered_map<IterDomain*, std::pair<int, int>>();
- }
-
- InputDomainCounter counter(domain);
-
- std::unordered_map<IterDomain*, std::pair<int, int>> count_map;
- for (auto entry : counter.domain_set_) {
- auto id = entry.first;
- auto input_id_set = entry.second;
- int concrete_counts = 0;
- int broadcast_counts = 0;
- for (auto input_id : input_id_set) {
- if (input_id->isBroadcast() ||
- (gpu_lower &&
- gpu_lower->trivialReductionInfo().isDerived(input_id))) {
- broadcast_counts++;
- } else {
- concrete_counts++;
- }
- }
- count_map[id] = {concrete_counts, broadcast_counts};
- }
-
- // Inputs may be root domains which wouldn't have any entries if no exprs
- // were traversed, so manually insert their count
- for (auto id : domain) {
- if (count_map.find(id) == count_map.end()) {
- count_map[id] =
- (id->isBroadcast() ||
- (gpu_lower && gpu_lower->trivialReductionInfo().isDerived(id)))
- ? std::make_pair(0, 1)
- : std::make_pair(1, 0);
- }
- }
- return count_map;
- }
-
- private:
- InputDomainCounter(const std::vector<IterDomain*>& domain_) {
- traverseFrom(
- domain_[0]->fusion(),
- std::vector<Val*>(domain_.begin(), domain_.end()));
- }
-
- private:
- std::unordered_set<IterDomain*>& getEntry(IterDomain* id) {
- auto domain_set_it = domain_set_.find(id);
- if (domain_set_it == domain_set_.end()) {
- domain_set_it =
- domain_set_
- .emplace(std::make_pair(id, std::unordered_set<IterDomain*>()))
- .first;
- domain_set_it->second.emplace(id);
- }
-
- return domain_set_it->second;
- }
-
- void handle(Expr* expr) override {
- // If we end up moving swizzle to an Expr it would be identity here, instead
- // of outputs being a function of all inputs
- switch (expr->getExprType().value()) {
- case (ExprType::Split):
- case (ExprType::Merge):
- break;
- default:
- TORCH_INTERNAL_ASSERT(
- false, "Invalid expr type found in transform traversal.");
- }
-
- // Gather all non-broadcast input domains
- std::unordered_set<IterDomain*> resulting_set;
- for (auto input_id : ir_utils::filterByType<IterDomain>(expr->inputs())) {
- auto input_entry = getEntry(input_id);
- resulting_set.insert(input_entry.begin(), input_entry.end());
- }
- for (auto output_id : ir_utils::filterByType<IterDomain>(expr->outputs())) {
- domain_set_.emplace(std::make_pair(output_id, resulting_set));
- }
- }
-
- std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>> domain_set_;
-};
-
-// Only used once, consider removing.
-template <class T>
-std::deque<T*> deduplicateDeque(const std::deque<T*>& deque) {
- std::unordered_set<T*> used;
- std::deque<T*> deduped;
- for (auto entry : deque) {
- if (used.find(entry) == used.end()) {
- deduped.push_back(entry);
- used.emplace(entry);
- }
- }
- return deduped;
-}
-
-void assertLowered(bool lowered) {
- TORCH_INTERNAL_ASSERT(
- lowered,
- "Tried to accessed lowered values of compute at map,",
- " however a valid lowering was not set when compute at map was created.");
-}
-
-} // namespace
-
-void ComputeAtMap::mapIds(IterDomain* id0, IterDomain* id1) {
- auto set_it_0 = disjoint_iter_set_maps_.find(id0);
- auto set_it_1 = disjoint_iter_set_maps_.find(id1);
- if (set_it_0 == disjoint_iter_set_maps_.end() &&
- set_it_1 == disjoint_iter_set_maps_.end()) {
- // Neither iter domain has been mapped, so make a new disjoint set
- auto new_set = std::make_shared<std::deque<IterDomain*>>();
- new_set.get()->push_back(id0);
- new_set.get()->push_back(id1);
- disjoint_iter_set_maps_.emplace(std::make_pair(id0, new_set));
- disjoint_iter_set_maps_.emplace(std::make_pair(id1, new_set));
- disjoint_iter_sets_.push_back(new_set);
-
- // Update parallel type map
- if (mapping_mode_ == MappingMode::PARALLEL) {
- if (id0->isParallelized() && id1->isParallelized()) {
- // Both are parallelized, make sure they're the same, set entry for
- // parallel map
- TORCH_INTERNAL_ASSERT(id0->getParallelType() == id1->getParallelType());
- parallel_type_map_[new_set] = id0->getParallelType();
- } else if (id0->isParallelized() || id1->isParallelized()) {
- // Only one is parallelized, set entry for parallel map
- parallel_type_map_[new_set] = id0->isParallelized()
- ? id0->getParallelType()
- : id1->getParallelType();
- }
- }
-
- } else if (
- set_it_0 != disjoint_iter_set_maps_.end() &&
- set_it_1 != disjoint_iter_set_maps_.end()) {
- // Both iter domains have been mapped, so join their sets together
- auto set0_ptr = set_it_0->second;
- auto set1_ptr = set_it_1->second;
-
- // If the sets are already the same, do nothing
- if (set0_ptr == set1_ptr) {
- return;
- }
-
- // Place everything in set1 into set0 and remap all ID's in set1 to set0
- auto& set1 = *set1_ptr;
- for (auto id : set1) {
- set0_ptr->push_back(id);
- disjoint_iter_set_maps_[id] = set0_ptr;
- }
-
- // set1 no longer needed as its IDs are copied into set0
- disjoint_iter_sets_.erase(std::find(
- disjoint_iter_sets_.begin(), disjoint_iter_sets_.end(), set1_ptr));
-
- // Update parallel type map
- if (mapping_mode_ == MappingMode::PARALLEL) {
- auto parallel_type_0_it = parallel_type_map_.find(set0_ptr);
- auto parallel_type_1_it = parallel_type_map_.find(set1_ptr);
- if (parallel_type_0_it != parallel_type_map_.end() &&
- parallel_type_1_it != parallel_type_map_.end()) {
- // If both sets had a parallel type associated with them, make sure they
- // are the same
- TORCH_INTERNAL_ASSERT(
- parallel_type_0_it->second == parallel_type_1_it->second);
- } else if (parallel_type_1_it != parallel_type_map_.end()) {
- // Set 1 has a parallel type, set 0 does not, set parallel entry
- parallel_type_map_[set0_ptr] = parallel_type_1_it->second;
- }
- // Else set 0 already has the right parallel type set in the map, if at
- // all
-
- // Remove set1 from the parallel type map as it shouldn't exist anymore
- parallel_type_map_.erase(set1_ptr);
- }
-
- } else {
- auto existing_set = set_it_0 != disjoint_iter_set_maps_.end()
- ? set_it_0->second
- : set_it_1->second;
- auto missing_id = set_it_0 != disjoint_iter_set_maps_.end() ? id1 : id0;
- existing_set->push_back(missing_id);
- disjoint_iter_set_maps_[missing_id] = existing_set;
-
- // Update parallel type map
- if (mapping_mode_ == MappingMode::PARALLEL) {
- auto parallel_type_it = parallel_type_map_.find(existing_set);
- if (parallel_type_it != parallel_type_map_.end() &&
- missing_id->isParallelized()) {
- // existing_set has a parallel type already and missing_id has a
- // parallel type, make sure they match. No need to update map
- TORCH_INTERNAL_ASSERT(
- parallel_type_it->second == missing_id->getParallelType());
- } else if (
- parallel_type_it == parallel_type_map_.end() &&
- id1->isParallelized()) {
- // Set parallel type of existing_set as the newly added missing_id is
- // parallel
- parallel_type_map_[existing_set] = missing_id->getParallelType();
- }
- }
- }
-}
-
-void ComputeAtMap::build(Fusion* fusion, GpuLower* gpu_lower) {
- // Consumers can only show up once in an expression, keep track of all of them
- std::vector<TensorView*> consumer_tvs;
-
- for (auto expr : fusion->exprs()) {
- if (!expr->outputs()[0]->isA<TensorView>()) {
- continue;
- }
-
- auto tv_outputs = ir_utils::filterByType<TensorView>(expr->outputs());
- TensorView* first_output_tv = nullptr;
- for (auto c_tv : tv_outputs) {
- consumer_tvs.push_back(c_tv);
-
- if (first_output_tv == nullptr) {
- first_output_tv = c_tv;
- } else {
- // Map multi outputs of an expression to eachother. c is current output,
- // and f as first output. Keep consistent with the later section of
- // producer and consumers. Which here producer is now "first output",
- // and consumer is still consumer.
-
- TORCH_INTERNAL_ASSERT(
- c_tv->getRootDomain().size() ==
- first_output_tv->getRootDomain().size(),
- "Multiple outputs with mismatched dimensions is not supported. ",
- "Only supported case is welford op where all outputs tvs have idential domains.");
- // p->f, c->c
- std::unordered_map<IterDomain*, IterDomain*> c2f_root_map;
- for (size_t i = 0; i < first_output_tv->getRootDomain().size(); i++) {
- c2f_root_map.insert(std::make_pair(
- c_tv->getRootDomain()[i], first_output_tv->getRootDomain()[i]));
- }
-
- // Multi output mapping
- auto replay_FasC = BestEffortReplay(
- first_output_tv->domain()->domain(),
- c_tv->domain()->domain(),
- c2f_root_map);
-
- auto c2f_map = replay_FasC.getReplay();
-
- // If we're creating parallel map, only map the leaf
- // axes. Also, the producer axis must be left of the CA
- // point.
- // Otherwise, map the entire replay map.
- if (mapping_mode_ == MappingMode::PARALLEL) {
- // Mark axes left of compute at point for parallel type tracking
- std::unordered_set<IterDomain*> producer_axes_to_map(
- first_output_tv->domain()->domain().begin(),
- first_output_tv->domain()->domain().begin() +
- first_output_tv->getComputeAtPosition());
-
- for (auto c_id : c_tv->domain()->domain()) {
- auto it = c2f_map.find(c_id);
- if (it == c2f_map.end()) {
- continue;
- }
- auto f_id = it->second;
- if (producer_axes_to_map.find(f_id) == producer_axes_to_map.end()) {
- continue;
- }
- mapIds(f_id, c_id);
- }
- } else {
- for (auto entry : c2f_map) {
- auto c_id = entry.first;
- auto f_id = entry.second;
- // Map the id's together
- mapIds(f_id, c_id);
- }
- }
- }
-
- auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs());
-
- for (auto p_tv : tv_inputs) {
- // If outside computeAt axis, we don't want to directly map
- // consumer/producer as their thread mappings could change as long as
- // it's across shared/global memory.
- auto pairwise_map = PairwiseRootDomainMap(p_tv, c_tv);
- auto c2p_root_map =
- pairwise_map.mapConsumerToProducer(c_tv->domain(), p_tv->domain());
-
- // Look for matching ID transformations in producer and consumer, replay
- // producer as consumer. We want to replay producer as consumer instead
- // of the other way around since consumer may have some broadcasted axes
- // producer doesn't have merged into loops producer may use. If we did
- // consumer as producer we wouldn't have this information in the
- // mapping. If we're using this map for indexing, we do not want to
- // propagate broadcast mismatches. If we're using it to identify loop
- // nests, we do want to propagate mismatches.
- auto replay_PasC = mapping_mode_ == MappingMode::LOOP ||
- mapping_mode_ == MappingMode::PARALLEL
- ? BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map)
- : BestEffortReplay(
- p_tv->domain()->domain(),
- c_tv->domain()->domain(),
- c2p_root_map);
-
- auto c2p_map = replay_PasC.getReplay();
-
- // If we're creating parallel map, only map the leaf
- // axes. Also, the producer axis must be left of the CA
- // point.
- // Otherwise, map the entire replay map.
- if (mapping_mode_ == MappingMode::PARALLEL) {
- // Mark axes left of compute at point for parallel type tracking
- std::unordered_set<IterDomain*> producer_axes_to_map(
- p_tv->domain()->domain().begin(),
- p_tv->domain()->domain().begin() + p_tv->getComputeAtPosition());
-
- for (auto c_id : c_tv->domain()->domain()) {
- auto it = c2p_map.find(c_id);
- if (it == c2p_map.end()) {
- continue;
- }
- auto p_id = it->second;
- if (producer_axes_to_map.find(p_id) == producer_axes_to_map.end()) {
- continue;
- }
- mapIds(p_id, c_id);
- }
- } else {
- for (auto entry : c2p_map) {
- auto c_id = entry.first;
- auto p_id = entry.second;
- // Map the id's together
- mapIds(p_id, c_id);
- }
- }
- }
- }
- }
-
- // deduplicate iter domain entries in each set
- for (const auto& iter_set : disjoint_iter_sets_) {
- *iter_set = deduplicateDeque(*iter_set);
- }
-
- // For each IterDomain set we will track how many concrete root domains were
- // used to generate the IterDomain. Used to populate conrete_id_map. Concrete
- // ID has maximum of concrete ids, ties are decided based on n_broadcast_ids.
- // Refer to AdvancedLowering5 for why we need to split ties with broadcast
- // dims.
- std::unordered_map<IterDomain*, int> n_concrete_ids_;
- std::unordered_map<IterDomain*, int> n_broadcast_ids_;
-
- for (auto c_tv : consumer_tvs) {
- auto counts =
- InputDomainCounter::produceCounts(c_tv->domain()->domain(), gpu_lower);
- std::transform(
- counts.begin(),
- counts.end(),
- std::inserter(n_concrete_ids_, n_concrete_ids_.end()),
- [](auto counts_entry) {
- return std::make_pair(counts_entry.first, counts_entry.second.first);
- });
- std::transform(
- counts.begin(),
- counts.end(),
- std::inserter(n_broadcast_ids_, n_broadcast_ids_.end()),
- [](auto counts_entry) {
- return std::make_pair(counts_entry.first, counts_entry.second.second);
- });
- }
-
- for (auto inp_tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
- auto counts = InputDomainCounter::produceCounts(
- inp_tv->domain()->domain(), gpu_lower);
- std::transform(
- counts.begin(),
- counts.end(),
- std::inserter(n_concrete_ids_, n_concrete_ids_.end()),
- [](auto counts_entry) {
- return std::make_pair(counts_entry.first, counts_entry.second.first);
- });
- std::transform(
- counts.begin(),
- counts.end(),
- std::inserter(n_broadcast_ids_, n_broadcast_ids_.end()),
- [](auto counts_entry) {
- return std::make_pair(counts_entry.first, counts_entry.second.second);
- });
- }
-
- // Populate concrete id map
- for (const auto& set : disjoint_iter_sets_) {
- int max_concrete_count = -1;
- int max_broadcast_count = -1;
- IterDomain* concrete_id = nullptr;
- for (auto id : *set) {
- int concrete_count = n_concrete_ids_.at(id);
- if (concrete_count >= max_concrete_count) {
- int broadcast_count = n_broadcast_ids_.at(id);
- if (concrete_count > max_concrete_count ||
- broadcast_count > max_broadcast_count) {
- max_concrete_count = concrete_count;
- max_broadcast_count = broadcast_count;
- concrete_id = id;
- }
- }
- }
-
- TORCH_INTERNAL_ASSERT(
- concrete_id != nullptr, "Could not concretize an IterDomain set.");
-
- for (auto id : *set) {
- concrete_id_map_[id] = concrete_id;
- if (mapping_mode_ == MappingMode::PARALLEL) {
- auto parallel_map_it = parallel_type_map_.find(set);
- // Parallelize all IterDomains to simplify lowering and codegen
- if (parallel_map_it != parallel_type_map_.end()) {
- // Don't propogate vectorize like other parallel types
- if (parallel_map_it->second != ParallelType::Vectorize) {
- id->parallelize(parallel_map_it->second);
- }
- }
- }
- }
- }
-
- if (gpu_lower != nullptr) {
- convertToKir(fusion, gpu_lower);
- }
-}
-
-void ComputeAtMap::convertToKir(Fusion* fusion, GpuLower* gpu_lower) {
- TORCH_INTERNAL_ASSERT(fusion != nullptr);
- TORCH_INTERNAL_ASSERT(gpu_lower != nullptr);
-
- has_lowered_kir_ = true;
-
- std::unordered_map<
- std::shared_ptr<std::deque<IterDomain*>>,
- std::shared_ptr<std::deque<kir::IterDomain*>>>
- disjoint_set_2_kir;
-
- for (const auto& disjoint_iter_set : disjoint_iter_set_maps_) {
- auto fusion_set = disjoint_iter_set.second;
- auto kir_set_it = disjoint_set_2_kir.find(fusion_set);
- std::shared_ptr<std::deque<kir::IterDomain*>> kir_set;
- if (kir_set_it == disjoint_set_2_kir.end()) {
- kir_set = std::make_shared<std::deque<kir::IterDomain*>>();
- std::transform(
- fusion_set->begin(),
- fusion_set->end(),
- std::inserter(*kir_set, kir_set->begin()),
- [&gpu_lower](IterDomain* id) {
- return gpu_lower->lowerValue(id)->as<kir::IterDomain>();
- });
- disjoint_set_2_kir.emplace(std::make_pair(fusion_set, kir_set));
- } else {
- kir_set = kir_set_it->second;
- }
- kir_disjoint_iter_set_maps_.emplace(std::make_pair(
- gpu_lower->lowerValue(disjoint_iter_set.first)->as<kir::IterDomain>(),
- kir_set));
- }
-
- for (auto entry : concrete_id_map_) {
- kir_concrete_id_map_.emplace(std::make_pair(
- gpu_lower->lowerValue(entry.first)->as<kir::IterDomain>(),
- gpu_lower->lowerValue(entry.second)->as<kir::IterDomain>()));
- }
-
- for (const auto& entry : disjoint_iter_set_maps_) {
- kir_2_fusion_[gpu_lower->lowerValue(entry.first)->as<kir::IterDomain>()] =
- entry.first;
- }
-
- // Make sure we have all IterDomains that could be used to generate a ForLoop
- for (auto expr : fusion->exprs()) {
- if (!expr->outputs()[0]->isA<TensorView>()) {
- continue;
- }
-
- auto tv_outputs = ir_utils::filterByType<TensorView>(expr->outputs());
-
- for (auto out : tv_outputs) {
- for (auto entry : out->domain()->domain()) {
- kir_2_fusion_[gpu_lower->lowerValue(entry)->as<kir::IterDomain>()] =
- entry;
- }
- }
- }
-}
-
-bool ComputeAtMap::areMapped(IterDomain* id0, IterDomain* id1) const {
- if (id0 == id1) {
- return true;
- }
- auto set0_it = disjoint_iter_set_maps_.find(id0);
- auto set1_it = disjoint_iter_set_maps_.find(id1);
- if (set0_it == disjoint_iter_set_maps_.end() ||
- set1_it == disjoint_iter_set_maps_.end()) {
- return false;
- }
- return (set0_it->second.get() == set1_it->second.get());
-}
-
-bool ComputeAtMap::areMapped(kir::IterDomain* id0, kir::IterDomain* id1) const {
- assertLowered(has_lowered_kir_);
- if (id0 == id1) {
- return true;
- }
- auto set0_it = kir_disjoint_iter_set_maps_.find(id0);
- auto set1_it = kir_disjoint_iter_set_maps_.find(id1);
- if (set0_it == kir_disjoint_iter_set_maps_.end() ||
- set1_it == kir_disjoint_iter_set_maps_.end()) {
- return false;
- }
- return (set0_it->second.get() == set1_it->second.get());
-}
-
-IterDomain* ComputeAtMap::getConcreteMappedID(IterDomain* id) const {
- auto it = concrete_id_map_.find(id);
- if (it != concrete_id_map_.end()) {
- return it->second;
- }
- return id;
-}
-
-kir::IterDomain* ComputeAtMap::getConcreteMappedID(kir::IterDomain* id) const {
- assertLowered(has_lowered_kir_);
- auto it = kir_concrete_id_map_.find(id);
- if (it != kir_concrete_id_map_.end()) {
- return it->second;
- }
- return id;
-}
-
-IterDomain* ComputeAtMap::toFusion(kir::IterDomain* kir) const {
- assertLowered(has_lowered_kir_);
- auto kir_2_fusion_it = kir_2_fusion_.find(kir);
- TORCH_INTERNAL_ASSERT(
- kir_2_fusion_it != kir_2_fusion_.end(),
- "Kernel ir is not guarneteed to be reversible into fusion ir, could not find fusion entry. ",
- kir::toString(kir, false));
- return kir_2_fusion_it->second;
-}
-
-std::string ComputeAtMap::toString() const {
- std::stringstream ss;
-
- // We may not have cleaned up non active sets as this is intended for debug,
- // so first grab unique entries and iterate over them.
- std::unordered_set<std::shared_ptr<std::deque<IterDomain*>>> disjoint_sets;
-
- for (const auto& entry : disjoint_iter_set_maps_) {
- disjoint_sets.emplace(entry.second);
- }
-
- for (const auto& disjoint_set : disjoint_sets) {
- ss << " disjoint_set{ ";
- TORCH_INTERNAL_ASSERT(disjoint_set->size() > 0);
- auto concrete_id = concrete_id_map_.at(disjoint_set->front());
- for (auto it = disjoint_set->begin(); it != disjoint_set->end(); it++) {
- if (it != disjoint_set->begin()) {
- ss << ", ";
- }
- ss << (*it);
- if (*it == concrete_id) {
- ss << "*";
- }
- }
- ss << " }";
- if (mapping_mode_ == MappingMode::PARALLEL) {
- if (parallel_type_map_.find(disjoint_set) != parallel_type_map_.end()) {
- ss << " -> " << parallel_type_map_.at(disjoint_set);
- } else {
- ss << " -> " << ParallelType::Serial;
- }
- }
- ss << "\n";
- }
- return ss.str();
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <deque>
-#include <unordered_map>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class GpuLower;
-
-class TORCH_CUDA_CU_API ComputeAtMap {
- public:
- // There's three modes of these iter domain mappings. For indexing, for loop
- // nest mapping/generation, and to figure out the parallelization strategy.
- //
- // For index/loop mode consider:
- //
- // consumer[i0, b1] = producer[i0]
- // consumer->merge(0) (consumer will now be [i0 * b1])
- // When producer is replayed as consumer (the direction we use for mapping)
- // with BestEffortReplay forward_bcast_mismatch = True the producer to
- // consumer map will have both a mapping of consumer(i0) to producer(i0) as
- // well as consumer(i0*b1) to producer(i0). This latter mapping is important
- // for loop nest mappings as the consumer will generate a loop based on i0*b1
- // and the producer may be computeAt inside this loop nest. However, for
- // indexing we do not want these two maps as producer may be indexed as i0*i1
- // depending on the loop nest structure and how it was built. Therefore we
- // really need to carry two sets of maps around for lowering.
- //
- // Parallel mode is important if we have something like:
- // consumer[i0o, threadIdx.x{i0i}] = producer[i0o, threadIdx.y{i0i}](computeAt
- // = 1) which can easily happen when using shared memory. We want to make sure
- // that the iteration domain used for loop construction (concreteId) has the
- // proper parallelization strategy. In parallel mode we do typical iteration
- // domain mapping, however we remove from it any iteration domains outside the
- // computeAt of producer when mapping. This guarentees we won't map
- // IterDomains that could have different parallelization strategies. We also
- // propagate the parallel strategy in parallel mode so all mapped IDs that
- // must have the same parallel type, do.
- enum class MappingMode { PARALLEL, LOOP, INDEX };
-
- ComputeAtMap() = default;
- ComputeAtMap(MappingMode mapping_mode) : mapping_mode_(mapping_mode) {}
-
- //! Builds all valid mappings. When gpu_lower is not nullptr,
- //! equivalent mappings for KIR are also created.
- void build(Fusion* fusion, GpuLower* gpu_lower = nullptr);
-
- //! Returns if id0 and id1 are mapped to eachother, meaning they represent the
- //! same loop nest in the lowered code
- bool areMapped(IterDomain* id0, IterDomain* id1) const;
-
- bool areMapped(kir::IterDomain* id0, kir::IterDomain* id1) const;
-
- //! Returns an iter domain that is the maximum expanded size of all iter
- //! domains the one provided maps to. Useful for opening loops to the correct
- //! iteration size. Not guarenteed to return the same ID every call, but is
- //! guarenteed to return iter domains in the same disjoint set.
- IterDomain* getConcreteMappedID(IterDomain* id) const;
-
- kir::IterDomain* getConcreteMappedID(kir::IterDomain* id) const;
-
- // TODO: Would be great if we didn't need this, but we have nice functionality
- // in iter_visitor that isn't moved over. Use of this is limited to indexing
- // and this should definitely be removed by building out kernel ir to have
- // better parity with fusion ir.
- IterDomain* toFusion(kir::IterDomain* kir) const;
-
- // Prints mapping information via Fusion IR
- std::string toString() const;
-
- private:
- bool has_lowered_kir_ = false;
-
- void mapIds(IterDomain* id0, IterDomain* id1);
-
- //! Convert everything to lowered structures (kernel ir), as we will use
- //! this class frequently during lowering.
- void convertToKir(Fusion* fusion, GpuLower* gpu_lower);
-
- private:
- MappingMode mapping_mode_ = MappingMode::LOOP;
-
- // This is actually only used when mapping mode == LOOP. Only used in expr
- // sorting, it's actually maximum position where a loop is shared across any
- // neighbor.
- std::unordered_map<TensorView*, unsigned int> produce_at_map_;
-
- // Disjoint sets of iter domains, only defined if iter domain is within
- // compute at of a tensor view. Maps these iter domains to a set containing
- // all other iter domains in the fusion that map to the same loop nest.
- std::unordered_map<IterDomain*, std::shared_ptr<std::deque<IterDomain*>>>
- disjoint_iter_set_maps_;
-
- std::unordered_map<
- kir::IterDomain*,
- std::shared_ptr<std::deque<kir::IterDomain*>>>
- kir_disjoint_iter_set_maps_;
-
- // Keep a list of disjoint_iter_sets that's deterministic to iterate over
- std::deque<std::shared_ptr<std::deque<IterDomain*>>> disjoint_iter_sets_;
-
- // Tracks if there's a parallel iter domain associated a disjoint iter domain
- // set
- std::unordered_map<std::shared_ptr<std::deque<IterDomain*>>, ParallelType>
- parallel_type_map_;
-
- // For each IterDomain set we will track how many concrete root domains were
- // used to generate the IterDomain
- std::unordered_map<IterDomain*, IterDomain*> concrete_id_map_;
-
- std::unordered_map<kir::IterDomain*, kir::IterDomain*> kir_concrete_id_map_;
-
- // Map kir::IterDomain* back to the fusion IR IterDomain*.
- // TODO: Would be great if we didn't need this.
- std::unordered_map<kir::IterDomain*, IterDomain*> kir_2_fusion_;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <c10/util/Exception.h>
-
-#include <algorithm>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Container class DisjointSet models equivalence relationships
-//!
-//! Each instance of this class keeps a set of equivalent classes
-//! DisjointSet::join(a,b) makes the full class of a and b equivalent
-//! DisjointSet::areEqual(a,b) checks if a and b belong same class
-template <typename T, typename Hash = std::hash<T>>
-class DisjointSet {
- public:
- DisjointSet() = default;
-
- //! Joins the equivalent class that a and b belong to
- //! areEqual(a',b') will be true for each a'=a and b'=b
- //!
- //! \param a An element from a equivalent class
- //! will create a new equivalent class if a does
- //! not belong to any
- //! \param b An element from another equivalent class
- //! will create a new equivalent class if b does
- //! not belong to any
- void join(T a, T b) {
- // cases where either of the quiv class doesn't exist
- if (!entry_map.count(a) && !entry_map.count(b)) {
- createPoint(a);
- entry_map[b] = fixedPoint(a);
- } else if (!entry_map.count(a)) {
- entry_map[a] = fixedPoint(b);
- } else if (!entry_map.count(b)) {
- entry_map[b] = fixedPoint(a);
- } else {
- // case where both equiv classes exist and need to join
- const int i0 = fixedPoint(a);
- const int i1 = fixedPoint(b);
- int new_parent = 0;
- int new_child = 0;
-
- // Either order here is correct but joining larger class to smaller class
- // tend to be faster
- std::tie(new_parent, new_child) = (weights[i0] < weights[i1])
- ? std::make_pair(i0, i1)
- : std::make_pair(i1, i0);
- weights[new_parent] += weights[new_child];
- set_map[new_child] = new_parent;
- }
- }
-
- //! Checks if a and b belong to the same equivalent class
- //!
- //! \param a An element from a equivalent class
- //! \param b An element from another equivalent class
- //! \returns Boolean value representing if a and b are
- //! recorded to be in the same equivalent class
- //! will return false if any of a or b doesn't
- //! have an equivalent class recorded
- bool areEquivalent(T a, T b) const {
- if (!entry_map.count(a) || !entry_map.count(b)) {
- return false;
- }
- return fixedPoint(a) == fixedPoint(b);
- }
-
- //! Queries if an element exists in this set
- bool contains(T a) const {
- return entry_map.count(a) > 0;
- }
-
- //! Returns all elements added to this set
- std::vector<T> getAllElements() const {
- std::vector<T> elms(entry_map.size());
- std::transform(
- entry_map.begin(),
- entry_map.end(),
- elms.begin(),
- [](const auto& entry_map_kv) { return entry_map_kv.first; });
- return elms;
- }
-
- //! Clears the equivalence relationships
- void clear() {
- set_map.clear();
- weights.clear();
- entry_map.clear();
- next_index_ = 0;
- }
-
- //! Dumps the equivalent relationships
- std::ostream& print(std::ostream& os) const {
- std::unordered_map<int, std::unordered_set<T, Hash>> fixedPointMap;
- for (const auto& kv : entry_map) {
- int fixed_point = fixedPoint(kv.first);
- auto it = fixedPointMap.find(fixed_point);
- if (it == fixedPointMap.end()) {
- it = fixedPointMap.insert({fixed_point, {}}).first;
- }
- it->second.insert(kv.first);
- }
- os << "{\n";
- for (const auto& kv : fixedPointMap) {
- os << "\t{ ";
- for (const auto& val : kv.second) {
- os << toString(val) << " ";
- }
- os << "}\n";
- }
- os << "}\n";
- return os;
- }
-
- private:
- // Internal fixed point implementation:
- // Returns the equivalent class that e belongs to
- int getFixedPointForClass(int e) const {
- TORCH_INTERNAL_ASSERT(static_cast<int>(set_map.size()) > e);
- while (set_map[e] != e) {
- // Chasing to fixed point
- e = set_map[e];
- }
- return e;
- }
-
- //! Utility to check the class e belongs to:
- //!
- //! \param e element e to find the equiv class for
- //! \returns the equivalent class that e belongs to
- //!
- int fixedPoint(T e) const {
- // Handles case when i doesn't have an equivalence class
- TORCH_INTERNAL_ASSERT(entry_map.count(e));
-
- // Use fixed point as a representation for the equiv class
- return getFixedPointForClass(entry_map.at(e));
- }
-
- //! Utility to create a new equiv class for i
- //
- //! \param i Element i to create the equiv class for
- void createPoint(T i) {
- entry_map[i] = next_index_;
- set_map.push_back(next_index_++);
- weights.push_back(1);
- }
-
- private:
- // Internal representation of the equivalence class as integers
- // set_map implements the "parent" relationship
- std::vector<int> set_map;
- // Weights is used for preliminary perf optimization
- std::vector<int> weights;
-
- // Map the input of type T to its equivalence class
- std::unordered_map<T, int, Hash> entry_map;
-
- // Running counter for generating new index when
- // Creating new equiv classes
- int next_index_ = 0;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
case DataType::Bool:
ptr(handler)->handle(val->as<Bool>());
return;
- case DataType::Double:
- ptr(handler)->handle(val->as<Double>());
+ case DataType::Float:
+ ptr(handler)->handle(val->as<Float>());
+ return;
+ case DataType::Half:
+ ptr(handler)->handle(val->as<Half>());
return;
case DataType::Int:
ptr(handler)->handle(val->as<Int>());
case ValType::NamedScalar:
ptr(handler)->handle(val->as<NamedScalar>());
return;
+
+ // TODO: remove once the Kernel IR has its own visitor
+ case ValType::TensorIndex:
+ ptr(handler)->handle(val->as<kir::TensorIndex>());
+ return;
+ case ValType::KirScalar:
+ switch (*(val->getDataType())) {
+ case DataType::Bool:
+ ptr(handler)->handle(val->as<kir::Bool>());
+ return;
+ case DataType::Float:
+ ptr(handler)->handle(val->as<kir::Float>());
+ return;
+ case DataType::Half:
+ ptr(handler)->handle(val->as<kir::Half>());
+ return;
+ case DataType::Int:
+ ptr(handler)->handle(val->as<kir::Int>());
+ return;
+ default:
+ break;
+ }
+ break;
+ case ValType::KirNamedScalar:
+ ptr(handler)->handle(val->as<kir::NamedScalar>());
+ return;
+ case ValType::KirIterDomain:
+ ptr(handler)->handle(val->as<kir::IterDomain>());
+ return;
+ case ValType::KirTensorDomain:
+ ptr(handler)->handle(val->as<kir::TensorDomain>());
+ return;
+ case ValType::KirTensorView:
+ ptr(handler)->handle(val->as<kir::TensorView>());
+ return;
+
default:
break;
}
case ExprType::ReductionOp:
ptr(handler)->handle(expr->as<ReductionOp>());
return;
- case ExprType::WelfordOp:
- ptr(handler)->handle(expr->as<WelfordOp>());
- return;
case ExprType::BroadcastOp:
ptr(handler)->handle(expr->as<BroadcastOp>());
return;
- case ExprType::TransposeOp:
- ptr(handler)->handle(expr->as<TransposeOp>());
+
+ case ExprType::KirUnaryOp:
+ ptr(handler)->handle(expr->as<kir::UnaryOp>());
+ return;
+ case ExprType::KirBinaryOp:
+ ptr(handler)->handle(expr->as<kir::BinaryOp>());
+ return;
+ case ExprType::KirTernaryOp:
+ ptr(handler)->handle(expr->as<kir::TernaryOp>());
return;
- case ExprType::ShiftOp:
- ptr(handler)->handle(expr->as<ShiftOp>());
+ case ExprType::KirReductionOp:
+ ptr(handler)->handle(expr->as<kir::ReductionOp>());
return;
- case ExprType::GatherOp:
- ptr(handler)->handle(expr->as<GatherOp>());
+ case ExprType::KirBroadcastOp:
+ ptr(handler)->handle(expr->as<kir::BroadcastOp>());
return;
+
+ case ExprType::GridReduction:
+ ptr(handler)->handle(expr->as<kir::GridReduction>());
+ return;
+ case ExprType::ForLoop:
+ ptr(handler)->handle(expr->as<kir::ForLoop>());
+ return;
+ case ExprType::IfThenElse:
+ ptr(handler)->handle(expr->as<kir::IfThenElse>());
+ return;
+ case ExprType::Allocate:
+ ptr(handler)->handle(expr->as<kir::Allocate>());
+ return;
+ case ExprType::Sync:
+ ptr(handler)->handle(expr->as<kir::Sync>());
+ return;
+
default:
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
}
case DataType::Bool:
ptr(handler)->handle(val->as<Bool>());
return;
- case DataType::Double:
- ptr(handler)->handle(val->as<Double>());
+ case DataType::Float:
+ ptr(handler)->handle(val->as<Float>());
+ return;
+ case DataType::Half:
+ ptr(handler)->handle(val->as<Half>());
return;
case DataType::Int:
ptr(handler)->handle(val->as<Int>());
case ValType::NamedScalar:
ptr(handler)->handle(val->as<NamedScalar>());
return;
+
+ // TODO: remove once the Kernel IR has its own visitor
+ case ValType::TensorIndex:
+ ptr(handler)->handle(val->as<kir::TensorIndex>());
+ return;
+ case ValType::KirScalar:
+ switch (*(val->getDataType())) {
+ case DataType::Bool:
+ ptr(handler)->handle(val->as<kir::Bool>());
+ return;
+ case DataType::Float:
+ ptr(handler)->handle(val->as<kir::Float>());
+ return;
+ case DataType::Half:
+ ptr(handler)->handle(val->as<kir::Half>());
+ return;
+ case DataType::Int:
+ ptr(handler)->handle(val->as<kir::Int>());
+ return;
+ default:
+ break;
+ }
+ break;
+ case ValType::KirNamedScalar:
+ ptr(handler)->handle(val->as<kir::NamedScalar>());
+ return;
+ case ValType::KirIterDomain:
+ ptr(handler)->handle(val->as<kir::IterDomain>());
+ return;
+ case ValType::KirTensorDomain:
+ ptr(handler)->handle(val->as<kir::TensorDomain>());
+ return;
+ case ValType::KirTensorView:
+ ptr(handler)->handle(val->as<kir::TensorView>());
+ return;
+
default:
break;
}
case ExprType::ReductionOp:
ptr(handler)->handle(expr->as<ReductionOp>());
return;
- case ExprType::WelfordOp:
- ptr(handler)->handle(expr->as<WelfordOp>());
- return;
case ExprType::BroadcastOp:
ptr(handler)->handle(expr->as<BroadcastOp>());
return;
- case ExprType::TransposeOp:
- ptr(handler)->handle(expr->as<TransposeOp>());
+
+ case ExprType::KirUnaryOp:
+ ptr(handler)->handle(expr->as<kir::UnaryOp>());
+ return;
+ case ExprType::KirBinaryOp:
+ ptr(handler)->handle(expr->as<kir::BinaryOp>());
+ return;
+ case ExprType::KirTernaryOp:
+ ptr(handler)->handle(expr->as<kir::TernaryOp>());
return;
- case ExprType::ShiftOp:
- ptr(handler)->handle(expr->as<ShiftOp>());
+ case ExprType::KirReductionOp:
+ ptr(handler)->handle(expr->as<kir::ReductionOp>());
return;
- case ExprType::GatherOp:
- ptr(handler)->handle(expr->as<GatherOp>());
+ case ExprType::KirBroadcastOp:
+ ptr(handler)->handle(expr->as<kir::BroadcastOp>());
return;
+
+ case ExprType::GridReduction:
+ ptr(handler)->handle(expr->as<kir::GridReduction>());
+ return;
+ case ExprType::ForLoop:
+ ptr(handler)->handle(expr->as<kir::ForLoop>());
+ return;
+ case ExprType::IfThenElse:
+ ptr(handler)->handle(expr->as<kir::IfThenElse>());
+ return;
+ case ExprType::Allocate:
+ ptr(handler)->handle(expr->as<kir::Allocate>());
+ return;
+ case ExprType::Sync:
+ ptr(handler)->handle(expr->as<kir::Sync>());
+ return;
+
default:
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
}
switch (*(val->getDataType())) {
case DataType::Bool:
return ptr(mutator)->mutate(val->as<Bool>());
- case DataType::Double:
- return ptr(mutator)->mutate(val->as<Double>());
+ case DataType::Float:
+ return ptr(mutator)->mutate(val->as<Float>());
+ case DataType::Half:
+ return ptr(mutator)->mutate(val->as<Half>());
case DataType::Int:
return ptr(mutator)->mutate(val->as<Int>());
default:
return ptr(mutator)->mutate(val->as<TensorDomain>());
case ValType::TensorView:
return ptr(mutator)->mutate(val->as<TensorView>());
+ case ValType::TensorIndex:
+ return ptr(mutator)->mutate(val->as<kir::TensorIndex>());
case ValType::NamedScalar:
return ptr(mutator)->mutate(val->as<NamedScalar>());
default:
return ptr(mutator)->mutate(expr->as<TernaryOp>());
case ExprType::ReductionOp:
return ptr(mutator)->mutate(expr->as<ReductionOp>());
- case ExprType::WelfordOp:
- return ptr(mutator)->mutate(expr->as<WelfordOp>());
+ case ExprType::GridReduction:
+ return ptr(mutator)->mutate(expr->as<kir::GridReduction>());
case ExprType::BroadcastOp:
return ptr(mutator)->mutate(expr->as<BroadcastOp>());
- case ExprType::TransposeOp:
- return ptr(mutator)->mutate(expr->as<TransposeOp>());
- case ExprType::ShiftOp:
- return ptr(mutator)->mutate(expr->as<ShiftOp>());
- case ExprType::GatherOp:
- return ptr(mutator)->mutate(expr->as<GatherOp>());
+ case ExprType::ForLoop:
+ return ptr(mutator)->mutate(expr->as<kir::ForLoop>());
+ case ExprType::IfThenElse:
+ return ptr(mutator)->mutate(expr->as<kir::IfThenElse>());
+ case ExprType::Allocate:
+ return ptr(mutator)->mutate(expr->as<kir::Allocate>());
+ case ExprType::Sync:
+ return ptr(mutator)->mutate(expr->as<kir::Sync>());
default:
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
}
#pragma once
-#include <torch/csrc/jit/codegen/cuda/utils.h>
-
#include <c10/util/Exception.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <unordered_map>
-// dispatch.h prevents the need from adding manual dispatch in every class that
-// wants to define how to process a series of nodes. dispatch.h provides 4
-// classes that can be inherited providing a means to override functions on a
-// per-node basis. There are currently 4 provided dispatch mechanisms:
-//
-// OptOutDispatch:
-//
-// provides the functions:
-// virtual void handle(ValType* irnode){}
-//
-// This provides a mechanisms to override this handle for particular node
-// types. For example if we only wanted to actually run a function on
-// BinaryOps, we could inherit OptOutDispatch and simply override: void
-// handle(BinaryOp*) { doSomething; } Then we could run through all our
-// Statement* and call OptOutDispatch::handle(statement). When a BinaryOp is
-// encountered our override function will be called. For every other node,
-// nothing will be done.
-//
-// OptInDispatch:
-//
-// This class is similar to OptOutDispatch, however if we encounter a node
-// that we haven't specified an override for in the derived class, an error
-// will be thrown. This is useful if we create a class that is expected to
-// handle any type of node it encounters.
-//
-// OptOutMutator:
-//
-// This class is similar to OptOutDispatch except the functions provided are of
-// type: virtual Statement* mutate(Statement*) this is useful for when we want
-// to have an IR node result from our overloaded functions.
-//
-// OptInMutator:
-//
-// This class is similar to OptInDispatch except the functions provided are of
-// type: virtual Statement* mutate(Statement*) this is useful for when we want
-// to have an IR node result from our overloaded functions.
+/*
+ * dispatch.h prevents the need from adding manual dispatch in every class that
+ * wants to define how to process a series of nodes. dispatch.h provides 4
+ * classes that can be inherited providing a means to override functions on a
+ * per-node basis. There are currently 4 provided dispatch mechanisms:
+ *
+ * OptOutDispatch:
+ *
+ * provides the functions:
+ * virtual void handle(ValType* irnode){}
+ *
+ * This provides a mechanisms to override this handle for particular node
+ * types. For example if we only wanted to actually run a function on
+ * BinaryOps, we could inherit OptOutDispatch and simply override: void
+ * handle(BinaryOp*) { doSomething; } Then we could run through all our
+ * Statement* and call OptOutDispatch::handle(statement). When a BinaryOp is
+ * encountered our override function will be called. For every other node,
+ * nothing will be done.
+ *
+ * OptInDispatch:
+ *
+ * This class is similar to OptOutDispatch, however if we encounter a node
+ * that we haven't specified an override for in the derived class, an error
+ * will be thrown. This is useful if we create a class that is expected to
+ * handle any type of node it encounters.
+ *
+ * OptOutMutator:
+ *
+ * This class is similar to OptOutDispatch except the functions provided are of
+ * type: virtual Statement* mutate(Statement*) this is useful for when we want
+ * to have an IR node result from our overloaded functions.
+ *
+ * OptInMutator:
+ *
+ * This class is similar to OptInDispatch except the functions provided are of
+ * type: virtual Statement* mutate(Statement*) this is useful for when we want
+ * to have an IR node result from our overloaded functions.
+ */
namespace torch {
namespace jit {
class TensorDomain;
class TensorView;
class Bool;
-class Double;
+class Float;
+class Half;
class Int;
class NamedScalar;
class BinaryOp;
class TernaryOp;
class ReductionOp;
-class WelfordOp;
class BroadcastOp;
-class TransposeOp;
-class ShiftOp;
-class GatherOp;
-// By default, all IR nodes are handled in this dispatch, and will call an empty
-// function on all nodes.
-class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
+// Kernel IR
+namespace kir {
+
+class Bool;
+class Float;
+class Half;
+class Int;
+class NamedScalar;
+
+class IterDomain;
+class TensorDomain;
+class TensorView;
+
+class UnaryOp;
+class BinaryOp;
+class TernaryOp;
+class ReductionOp;
+class BroadcastOp;
+
+class TensorIndex;
+class Allocate;
+class ForLoop;
+class IfThenElse;
+class GridReduction;
+class Sync;
+
+} // namespace kir
+
+/*
+ * By default, all IR nodes are handled in this dispatch, and will call an empty
+ * function on all nodes.
+ */
+class TORCH_CUDA_CU_API OptOutConstDispatch {
public:
+ virtual ~OptOutConstDispatch() = default;
+ OptOutConstDispatch() = default;
+
+ OptOutConstDispatch(const OptOutConstDispatch& other) = default;
+ OptOutConstDispatch& operator=(const OptOutConstDispatch& other) = default;
+
+ OptOutConstDispatch(OptOutConstDispatch&& other) = default;
+ OptOutConstDispatch& operator=(OptOutConstDispatch&& other) = default;
+
// Hierarchal dispatch functions for handle
virtual void handle(const Statement*);
virtual void handle(const Expr*);
virtual void handle(const TensorDomain*) {}
virtual void handle(const TensorView*) {}
virtual void handle(const Bool*) {}
- virtual void handle(const Double*) {}
+ virtual void handle(const Float*) {}
+ virtual void handle(const Half*) {}
virtual void handle(const Int*) {}
virtual void handle(const NamedScalar*) {}
virtual void handle(const BinaryOp*) {}
virtual void handle(const TernaryOp*) {}
virtual void handle(const ReductionOp*) {}
- virtual void handle(const WelfordOp*) {}
virtual void handle(const BroadcastOp*) {}
- virtual void handle(const TransposeOp*) {}
- virtual void handle(const ShiftOp*) {}
- virtual void handle(const GatherOp*) {}
+
+ // Kernel IR nodes
+ virtual void handle(const kir::Bool*) {}
+ virtual void handle(const kir::Float*) {}
+ virtual void handle(const kir::Half*) {}
+ virtual void handle(const kir::Int*) {}
+ virtual void handle(const kir::NamedScalar*) {}
+
+ virtual void handle(const kir::IterDomain*) {}
+ virtual void handle(const kir::TensorDomain*) {}
+ virtual void handle(const kir::TensorView*) {}
+
+ virtual void handle(const kir::UnaryOp*) {}
+ virtual void handle(const kir::BinaryOp*) {}
+ virtual void handle(const kir::TernaryOp*) {}
+ virtual void handle(const kir::ReductionOp*) {}
+ virtual void handle(const kir::BroadcastOp*) {}
+
+ virtual void handle(const kir::TensorIndex*) {}
+ virtual void handle(const kir::GridReduction*) {}
+ virtual void handle(const kir::ForLoop*) {}
+ virtual void handle(const kir::IfThenElse*) {}
+ virtual void handle(const kir::Allocate*) {}
+ virtual void handle(const kir::Sync*) {}
};
-class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
+class TORCH_CUDA_CU_API OptOutDispatch {
public:
+ virtual ~OptOutDispatch() = default;
+ OptOutDispatch() = default;
+
+ OptOutDispatch(const OptOutDispatch& other) = default;
+ OptOutDispatch& operator=(const OptOutDispatch& other) = default;
+
+ OptOutDispatch(OptOutDispatch&& other) = default;
+ OptOutDispatch& operator=(OptOutDispatch&& other) = default;
+
// Hierarchal dispatch functions for handle
virtual void handle(Statement*);
virtual void handle(Expr*);
virtual void handle(TensorDomain*) {}
virtual void handle(TensorView*) {}
virtual void handle(Bool*) {}
- virtual void handle(Double*) {}
+ virtual void handle(Float*) {}
+ virtual void handle(Half*) {}
virtual void handle(Int*) {}
virtual void handle(NamedScalar*) {}
virtual void handle(BinaryOp*) {}
virtual void handle(TernaryOp*) {}
virtual void handle(ReductionOp*) {}
- virtual void handle(WelfordOp*) {}
virtual void handle(BroadcastOp*) {}
- virtual void handle(TransposeOp*) {}
- virtual void handle(ShiftOp*) {}
- virtual void handle(GatherOp*) {}
+
+ // Kernel IR nodes
+ virtual void handle(kir::Bool*) {}
+ virtual void handle(kir::Float*) {}
+ virtual void handle(kir::Half*) {}
+ virtual void handle(kir::Int*) {}
+ virtual void handle(kir::NamedScalar*) {}
+
+ virtual void handle(kir::IterDomain*) {}
+ virtual void handle(kir::TensorDomain*) {}
+ virtual void handle(kir::TensorView*) {}
+
+ virtual void handle(kir::UnaryOp*) {}
+ virtual void handle(kir::BinaryOp*) {}
+ virtual void handle(kir::TernaryOp*) {}
+ virtual void handle(kir::ReductionOp*) {}
+ virtual void handle(kir::BroadcastOp*) {}
+
+ virtual void handle(kir::TensorIndex*) {}
+ virtual void handle(kir::GridReduction*) {}
+ virtual void handle(kir::ForLoop*) {}
+ virtual void handle(kir::IfThenElse*) {}
+ virtual void handle(kir::Allocate*) {}
+ virtual void handle(kir::Sync*) {}
};
-class TORCH_CUDA_CU_API OptInConstDispatch : public PolymorphicBase {
+class TORCH_CUDA_CU_API OptInConstDispatch {
public:
+ virtual ~OptInConstDispatch() = default;
+ OptInConstDispatch() = default;
+
+ OptInConstDispatch(const OptInConstDispatch& other) = default;
+ OptInConstDispatch& operator=(const OptInConstDispatch& other) = default;
+
+ OptInConstDispatch(OptInConstDispatch&& other) = default;
+ OptInConstDispatch& operator=(OptInConstDispatch&& other) = default;
+
// Hierarchal dispatch functions for handle
virtual void handle(const Statement*);
virtual void handle(const Expr*);
virtual void handle(const Bool*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool.");
}
- virtual void handle(const Double*) {
- TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Double.");
+ virtual void handle(const Float*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Float.");
+ }
+ virtual void handle(const Half*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Half.");
}
virtual void handle(const Int*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int.");
virtual void handle(const BinaryOp*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp.");
}
- virtual void handle(const WelfordOp*) {
- TORCH_INTERNAL_ASSERT(false, "Handle not overriden for WelfordOp.");
- }
virtual void handle(const TernaryOp*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TernaryOp.");
}
virtual void handle(const BroadcastOp*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp.");
}
- virtual void handle(const TransposeOp*) {
- TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TransposeOp.");
+
+ // Kernel IR
+ //
+ // TODO: move to a specialized visitor
+ //
+
+ virtual void handle(const kir::Bool*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Bool.");
+ }
+ virtual void handle(const kir::Float*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Float.");
+ }
+ virtual void handle(const kir::Half*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Half.");
+ }
+ virtual void handle(const kir::Int*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Int.");
+ }
+ virtual void handle(const kir::NamedScalar*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::NamedScalar.");
+ }
+
+ virtual void handle(const kir::IterDomain*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::IterDomain.");
+ }
+ virtual void handle(const kir::TensorDomain*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TensorDomain.");
+ }
+ virtual void handle(const kir::TensorView*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TensorView.");
+ }
+
+ virtual void handle(const kir::UnaryOp*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::UnaryOp.");
+ }
+ virtual void handle(const kir::BinaryOp*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::BinaryOp.");
+ }
+ virtual void handle(const kir::TernaryOp*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TernaryOp.");
+ }
+ virtual void handle(const kir::ReductionOp*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::ReductionOp.");
+ }
+ virtual void handle(const kir::BroadcastOp*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::BroadcastOp.");
+ }
+
+ virtual void handle(const kir::GridReduction*) {
+ TORCH_INTERNAL_ASSERT(
+ false, "Handle not overriden for kir::GridReduction.");
+ }
+ virtual void handle(const kir::ForLoop*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::ForLoop.");
}
- virtual void handle(const ShiftOp*) {
- TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ShiftOp.");
+ virtual void handle(const kir::Allocate*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Allocate.");
}
- virtual void handle(const GatherOp*) {
- TORCH_INTERNAL_ASSERT(false, "Handle not overriden for GatherOp.");
+ virtual void handle(const kir::Sync*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Sync.");
+ }
+ virtual void handle(const kir::IfThenElse*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::IfThenElse.");
+ }
+
+ virtual void handle(const kir::TensorIndex*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TensorIndex.");
}
};
-class TORCH_CUDA_CU_API OptInDispatch : public PolymorphicBase {
+class TORCH_CUDA_CU_API OptInDispatch {
public:
+ virtual ~OptInDispatch() = default;
+ OptInDispatch() = default;
+
+ OptInDispatch(const OptInDispatch& other) = default;
+ OptInDispatch& operator=(const OptInDispatch& other) = default;
+
+ OptInDispatch(OptInDispatch&& other) = default;
+ OptInDispatch& operator=(OptInDispatch&& other) = default;
+
// Hierarchal dispatch functions for handle
virtual void handle(Statement* s);
virtual void handle(Expr* e);
virtual void handle(Bool*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool.");
}
- virtual void handle(Double*) {
- TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Double.");
+ virtual void handle(Float*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Float.");
+ }
+ virtual void handle(Half*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Half.");
}
virtual void handle(Int*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int.");
virtual void handle(ReductionOp*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp.");
}
- virtual void handle(WelfordOp*) {
- TORCH_INTERNAL_ASSERT(false, "Handle not overriden for WelfordOp.");
- }
virtual void handle(BroadcastOp*) {
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp.");
}
- virtual void handle(TransposeOp*) {
- TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TransposeOp.");
+
+ // Kernel IR
+ //
+ // TODO: move to a specialized visitor
+ //
+
+ virtual void handle(kir::Bool*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool.");
+ }
+ virtual void handle(kir::Float*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Float.");
+ }
+ virtual void handle(kir::Half*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Half.");
+ }
+ virtual void handle(kir::Int*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int.");
+ }
+ virtual void handle(kir::NamedScalar*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::NamedScalar.");
+ }
+ virtual void handle(kir::TensorIndex*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TensorIndex.");
+ }
+
+ virtual void handle(kir::IterDomain*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::IterDomain.");
+ }
+ virtual void handle(kir::TensorDomain*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TensorDomain.");
+ }
+ virtual void handle(kir::TensorView*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TensorView.");
}
- virtual void handle(ShiftOp*) {
- TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ShiftOp.");
+
+ virtual void handle(kir::UnaryOp*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::UnaryOp.");
+ }
+ virtual void handle(kir::BinaryOp*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::BinaryOp.");
+ }
+ virtual void handle(kir::TernaryOp*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::TernaryOp.");
+ }
+ virtual void handle(kir::ReductionOp*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::ReductionOp.");
+ }
+ virtual void handle(kir::BroadcastOp*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::BroadcastOp.");
+ }
+
+ virtual void handle(kir::GridReduction*) {
+ TORCH_INTERNAL_ASSERT(
+ false, "Handle not overriden for kir::GridReduction.");
+ }
+ virtual void handle(kir::ForLoop*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::ForLoop.");
}
- virtual void handle(GatherOp*) {
- TORCH_INTERNAL_ASSERT(false, "Handle not overriden for GatherOp.");
+ virtual void handle(kir::Allocate*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Allocate.");
+ }
+ virtual void handle(kir::Sync*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::Sync.");
+ }
+ virtual void handle(kir::IfThenElse*) {
+ TORCH_INTERNAL_ASSERT(false, "Handle not overriden for kir::IfThenElse.");
}
};
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
-class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
+class TORCH_CUDA_CU_API OptOutMutator {
public:
+ virtual ~OptOutMutator() = default;
+ OptOutMutator() = default;
+
+ OptOutMutator(const OptOutMutator& other) = default;
+ OptOutMutator& operator=(const OptOutMutator& other) = default;
+
+ OptOutMutator(OptOutMutator&& other) = default;
+ OptOutMutator& operator=(OptOutMutator&& other) = default;
+
+ virtual void mutate(Fusion* fusion);
+
// Hierarchal dispatch functions for handle
virtual Statement* mutate(Statement* s);
virtual Statement* mutate(Expr* e);
virtual Statement* mutate(IterDomain*);
virtual Statement* mutate(TensorDomain*);
virtual Statement* mutate(TensorView*);
+ virtual Statement* mutate(kir::TensorIndex*);
virtual Statement* mutate(Bool*);
- virtual Statement* mutate(Double*);
+ virtual Statement* mutate(Float*);
+ virtual Statement* mutate(Half*);
virtual Statement* mutate(Int*);
virtual Statement* mutate(NamedScalar*);
virtual Statement* mutate(BinaryOp*);
virtual Statement* mutate(TernaryOp*);
virtual Statement* mutate(ReductionOp*);
- virtual Statement* mutate(WelfordOp*);
+ virtual Statement* mutate(kir::GridReduction*);
virtual Statement* mutate(BroadcastOp*);
- virtual Statement* mutate(TransposeOp*);
- virtual Statement* mutate(ShiftOp*);
- virtual Statement* mutate(GatherOp*);
+ virtual Statement* mutate(kir::ForLoop*);
+ virtual Statement* mutate(kir::IfThenElse*);
+ virtual Statement* mutate(kir::Allocate*);
+ virtual Statement* mutate(kir::Sync*);
};
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
-class TORCH_CUDA_CU_API OptInMutator : public PolymorphicBase {
+class TORCH_CUDA_CU_API OptInMutator {
public:
- std::unordered_map<Val*, Val*> mutations;
+ virtual ~OptInMutator() = default;
+ OptInMutator() = default;
+
+ OptInMutator(const OptInMutator& other) = default;
+ OptInMutator& operator=(const OptInMutator& other) = default;
+
+ OptInMutator(OptInMutator&& other) = default;
+ OptInMutator& operator=(OptInMutator&& other) = default;
- public:
void registerMutation(Val* val, Val* mutation) {
TORCH_INTERNAL_ASSERT(
mutations.find(val) == mutations.end(),
mutations[val] = mutation;
}
+ std::unordered_map<Val*, Val*> mutations;
+
// Hierarchal dispatch functions for mutate
virtual Statement* mutate(Statement*);
virtual Statement* mutate(Expr*);
virtual Statement* mutate(TensorView*) {
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TensorView.");
}
+ virtual Statement* mutate(kir::TensorIndex*) {
+ TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TensorIndex.");
+ }
virtual Statement* mutate(Bool*) {
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Bool.");
}
+ virtual Statement* mutate(Float*) {
+ TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Float.");
+ }
virtual Statement* mutate(Int*) {
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Int.");
}
virtual Statement* mutate(ReductionOp*) {
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ReductionOp.");
}
- virtual Statement* mutate(WelfordOp*) {
- TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for WelfordOp.");
+ virtual Statement* mutate(kir::GridReduction*) {
+ TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for GridReduction.");
}
virtual Statement* mutate(BroadcastOp*) {
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for BroadcastOp.");
}
- virtual Statement* mutate(TransposeOp*) {
- TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TransposeOp.");
+ virtual Statement* mutate(kir::ForLoop*) {
+ TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ForLoop.");
+ }
+ virtual Statement* mutate(kir::Allocate*) {
+ TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Allocate.");
}
- virtual Statement* mutate(ShiftOp*) {
- TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for ShiftOp.");
+ virtual Statement* mutate(kir::Sync*) {
+ TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Sync.");
}
- virtual Statement* mutate(GatherOp*) {
- TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for GatherOp.");
+ virtual Statement* mutate(kir::IfThenElse*) {
+ TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for IfThenElse.");
}
};
-
-#include <torch/csrc/jit/codegen/cuda/executor.h>
-
#include <torch/csrc/jit/codegen/cuda/codegen.h>
#include <torch/csrc/jit/codegen/cuda/executor_kernel_arg.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/utils.h>
+
+#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/irange.h>
-#include <fstream>
+#include <cstdlib>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
-int FusionExecutor::fusion_id_counter_ = 0; // NOLINT
-
-namespace {
-
-static const char* defineIndexMode(KernelIndexMode index_mode) {
- switch (index_mode) {
- case KernelIndexMode::INT32:
- return "typedef int nvfuser_index_t;\n";
- case KernelIndexMode::INT64:
- return "typedef int64_t nvfuser_index_t;\n";
- default:
- break;
- }
-
- TORCH_INTERNAL_ASSERT(false, "unknow indexing mode");
- return "";
-}
-
-static const char* defineIntegerTypes() {
- return R"(
-typedef unsigned char uint8_t;
-typedef signed char int8_t;
-typedef short int int16_t;
-typedef unsigned int uint32_t;
-typedef long long int int64_t;
-typedef unsigned long long int uint64_t;
-)";
-}
-
-} // namespace
+int FusionExecutor::fusion_id_counter_ = 0;
std::string FusionExecutor::getStructuredCode(const std::string& kernel) {
// generating cuda code;
#endif
#endif
code += std::string("namespace ") + FusionExecutor::kernelNamespace() +
- " {\n" + defineIntegerTypes() + defineIndexMode(options_.index_mode) +
- executor_utils::kernelPreamble() + kernel + "}\n";
-
- if (isDebugDumpEnabled(DebugDumpOption::CudaKernel)) {
- std::cout << "\n======= Codegen output for kernel: " << kernelName()
- << " =======\n\n"
- << kernel << "\n======================================\n\n";
- } else if (isDebugDumpEnabled(DebugDumpOption::CudaFull)) {
- std::cout << "\n======= Codegen output for kernel: " << kernelName()
- << " =======\n\n"
- << code << "\n======================================\n\n";
- } else if (isDebugDumpEnabled(DebugDumpOption::CudaToFile)) {
- std::stringstream file_name;
- file_name << "__tmp_kernel" << fusion_id_ << ".cu";
- std::cout << "PRINTING: " << file_name.str() << std::endl;
- std::ofstream out(file_name.str());
- out << code << std::endl;
- out.close();
+ " {\n" + executor_utils::kernelPreamble() + kernel + "}\n";
+
+ const char* debug_env = std::getenv("PYTORCH_CUDA_FUSER_DEBUG");
+ if (debug_env && atoi(debug_env)) {
+ std::cout << "\n==== codegen output for kernel: " << kernelName()
+ << " ====" << std::endl
+ << code << std::endl
+ << "======================================\n"
+ << std::endl;
}
return code;
}
-// TODO: come up with a more user friendly interface
void FusionExecutor::debugCompileFusionFromStr(
Fusion* fusion,
const std::string& code,
FusionGuard fg(&fusion_);
options_ = options;
- if (isDebugDumpEnabled(DebugDumpOption::FusionIr)) {
- fusion->print();
- } else if (isDebugDumpEnabled(DebugDumpOption::FusionIrMath)) {
- fusion->printMath();
- }
-
- if (isDebugDumpEnabled(DebugDumpOption::CudaFull)) {
+ const char* debug_env = std::getenv("PYTORCH_CUDA_FUSER_DEBUG");
+ if (debug_env && atoi(debug_env)) {
std::cout << "\n==== codegen output for kernel: " << kernelName()
<< " ====" << std::endl
<< code << std::endl
lowered_ = GpuLower(&fusion_);
const auto kernel = lowered_.kernel();
- if (isDebugDumpEnabled(DebugDumpOption::KernelIr)) {
+ const char* dump_kir_env = std::getenv("PYTORCH_CUDA_FUSER_DUMP_KIR");
+ if (dump_kir_env && atoi(dump_kir_env)) {
kernel->print();
}
const auto& kernel_summary = kernel->summary();
+ has_block_reductions = kernel_summary.has_block_reductions;
+ has_grid_reductions = kernel_summary.has_grid_reductions;
+ has_block_broadcasts = kernel_summary.has_block_broadcasts;
if (!kernel_summary.static_smem_allocations.empty()) {
- kir::ExpressionEvaluator static_evaluator;
+ StatefulExpressionEvaluator static_evaluator(&fusion_);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- const auto static_smem_size = computeSharedMemory(
+ unsigned static_smem_size = computeSharedMemory(
static_evaluator, kernel_summary.static_smem_allocations);
TORCH_INTERNAL_ASSERT(
static_smem_size < max_device_smem,
fusion_id_ > 0, "assign a fusion_id_ <= 0 is not accepted.");
}
-void FusionExecutor::compileFusion(
- Fusion* fusion,
- CompileOptions options,
- const at::ArrayRef<IValue>& inputs,
- const LaunchParams& launch_constraints) {
+void FusionExecutor::compileFusion(Fusion* fusion, CompileOptions options) {
FUSER_PERF_SCOPE("compileFusion");
TORCH_INTERNAL_ASSERT(
"Output types from fusions that are not tensors are not supported at this point.");
}
- if (isDebugDumpEnabled(DebugDumpOption::FusionIr)) {
- fusion->print();
- } else if (isDebugDumpEnabled(DebugDumpOption::FusionIrMath)) {
- fusion->printMath();
- }
-
// Clone the fusion so we can store it
fusion_ = *fusion;
FusionGuard fg(&fusion_);
options_ = options;
- c10::DeviceGuard dg(options_.device);
TORCH_INTERNAL_ASSERT(
options.device.is_cuda(), "Provided device to CUDA fuser is the CPU.");
lowered_ = GpuLower(&fusion_);
const auto kernel = lowered_.kernel();
- if (isDebugDumpEnabled(DebugDumpOption::KernelIr)) {
+ const char* dump_kir_env = std::getenv("PYTORCH_CUDA_FUSER_DUMP_KIR");
+ if (dump_kir_env && atoi(dump_kir_env)) {
kernel->print();
}
const auto structured_code = getStructuredCode(kernel_code);
const auto& kernel_summary = kernel->summary();
+ has_block_reductions = kernel_summary.has_block_reductions;
+ has_grid_reductions = kernel_summary.has_grid_reductions;
+ has_block_broadcasts = kernel_summary.has_block_broadcasts;
if (!kernel_summary.static_smem_allocations.empty()) {
- kir::ExpressionEvaluator static_evaluator;
+ StatefulExpressionEvaluator static_evaluator(&fusion_);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- const auto static_smem_size = computeSharedMemory(
+ unsigned static_smem_size = computeSharedMemory(
static_evaluator, kernel_summary.static_smem_allocations);
TORCH_INTERNAL_ASSERT(
static_smem_size < max_device_smem,
"The static shared memory allocation is larger than available memory.");
}
- if (kernel_summary.has_dynamic_local_memory_allocations) {
- std::stringstream ss;
- ss << "Allocations must be based on constant integers for local memory. However, found: ";
- for (auto alloc : kernel_summary.dynamic_lmem_allocations) {
- ss << toString(alloc->buffer(), false) << ", ";
- }
- ss << " have dynamic allocations but are placed in local memory.";
- TORCH_INTERNAL_ASSERT(false, ss.str());
- }
-
- TORCH_CHECK(
- !kernel_summary.has_grid_reduction_in_loop,
- "Grid reduction must not be placed inside a loop.");
-
- // TODO: pass block_size here;
- c10::optional<int> block_size = c10::nullopt;
- if (!inputs.empty()) {
- auto expr_eval = executor_utils::bindKernelInputs(inputs, kernel);
- auto launch_params = computeLaunchParams(launch_constraints, expr_eval);
- block_size = launch_params.nThreads();
- TORCH_INTERNAL_ASSERT(
- block_size > 0, "launch param inferred block size < 0");
- }
-
compiled_kernel_ = executor_utils::nvrtcCompile(
structured_code,
(kernelNamespace() + "::" + kernelName()).c_str(),
- fusion_id_,
- block_size);
+ fusion_id_);
TORCH_INTERNAL_ASSERT(
fusion_id_ > 0, "failed to assign a fusion_id_ after compilation.");
}
namespace {
at::Tensor inferAndAlloc(
- const kir::TensorView* tv,
- const std::vector<kir::Val*>& sizes,
- kir::ExpressionEvaluator& expr_eval,
+ const TensorView* tv,
+ StatefulExpressionEvaluator& see,
const CompileOptions& options,
bool zero_init = false) {
FUSER_PERF_SCOPE("inferAndAlloc");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<int64_t> inferred_sizes;
-
- for (const auto size : sizes) {
- const auto inferred_val = expr_eval.evaluate(size);
+ std::vector<int64_t> sizes;
+ for (auto id : TensorDomain::noReductions(tv->getMaybeRFactorDomain())) {
+ auto inferred_val = see.inferValue(id->rawExtent());
TORCH_INTERNAL_ASSERT(
inferred_val.has_value(),
"Could not launch kernel as program could not infer ",
- kir::toString(size),
+ id->rawExtent(),
" for the buffer ",
- kir::toString(tv));
- inferred_sizes.push_back(inferred_val.value());
+ tv);
+ sizes.push_back(inferred_val.value());
}
- const auto at_type = data_type_to_aten(tv->dtype());
+ auto at_type = data_type_to_aten(tv->getDataType().value());
if (zero_init) {
- const auto tensor_options =
+ auto tensor_options =
at::TensorOptions().dtype(at_type).device(options.device);
- c10::IntArrayRef isizes(inferred_sizes);
+ c10::IntArrayRef isizes(sizes);
return at::zeros(isizes, tensor_options);
} else {
- c10::IntArrayRef isizes(inferred_sizes);
+ c10::IntArrayRef isizes(sizes);
// Non Variable type guard for empty_cuda call
at::AutoDispatchBelowADInplaceOrView non_variable_type_mode;
return at::native::empty_cuda(
}
}
-at::Tensor inferAndAllocOutput(
- const kir::TensorView* tv,
- kir::ExpressionEvaluator& expr_eval,
- const CompileOptions& options,
- bool zero_init = false) {
- const auto domain = tv->domain();
- const auto maybe_rfactor_domain =
- domain->hasRFactor() ? domain->rfactorDomain() : domain->rootDomain();
-
- std::vector<kir::Val*> sizes;
-
- for (const auto id : maybe_rfactor_domain) {
- if (id->isReduction() ||
- id->iterType() == IterType::BroadcastWithoutStride) {
- continue;
- }
- sizes.push_back(id->extent());
- }
-
- return inferAndAlloc(tv, sizes, expr_eval, options, zero_init);
-}
-
} // namespace
uint64_t FusionExecutor::computeSharedMemory(
- kir::ExpressionEvaluator& expr_eval,
- const std::vector<const kir::Allocate*>& buffers,
+ StatefulExpressionEvaluator& see,
+ const std::vector<kir::Allocate*>& buffers,
bool align_padding,
uint64_t total) {
FUSER_PERF_SCOPE("computeSharedMemory");
// If this buffer aliases another buffer,
// then do not allocate memory for this buffer.
if (smem_alloc->alias() == nullptr) {
- const auto inferred_val = expr_eval.evaluate(smem_alloc->size());
+ auto inferred_val = see.inferValue(smem_alloc->size());
if (inferred_val.has_value()) {
- const uint64_t data_size = dataTypeSize(smem_alloc->buffer()->dtype());
+ const uint64_t data_size = dataTypeSize(smem_alloc->buffer_type());
// Add padding to align dynamic shared memory
if (align_padding) {
total = ceilDiv(total, data_size) * data_size;
LaunchParams FusionExecutor::computeLaunchParams(
const LaunchParams& launch_constraints,
- kir::ExpressionEvaluator& expr_eval) {
- FUSER_PERF_SCOPE("FusionExecutor::ComputeLaunchParams");
+ StatefulExpressionEvaluator& see) {
+ FUSER_PERF_SCOPE("computeLaunchParams");
LaunchParams launch_params;
// Lets collect all IterDomains that are bound to a thread binding
- std::unordered_map<ParallelType, std::vector<const kir::Val*>, TypeHash>
+ std::unordered_map<ParallelType, std::vector<IterDomain*>, TypeHash>
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- parallel_iter_extents;
+ parallel_iter_domains;
for (auto tv : getUsedTVs()) {
for (auto id : tv->domain()->domain()) {
if (id->isThread() && !id->isBroadcast()) {
- // TODO(kir): we should rewrite this logic based on the Kernel object
- auto kir_extent = lowered_.lowerValue(id->extent());
- const auto it = parallel_iter_extents.find(id->getParallelType());
- if (it != parallel_iter_extents.end()) {
- it->second.push_back(kir_extent);
+ if (parallel_iter_domains.find(id->getParallelType()) !=
+ parallel_iter_domains.end()) {
+ parallel_iter_domains.at(id->getParallelType()).push_back(id);
} else {
- parallel_iter_extents[id->getParallelType()] = {kir_extent};
+ parallel_iter_domains[id->getParallelType()] =
+ std::vector<IterDomain*>({id});
}
}
}
// If any dimension was set in launch constraints we need to run through
// IterDomains that have been parallelized, and bind those values. Or make
// sure if they could be inferred the inference matches what was set.
- for (auto& entry : parallel_iter_extents) {
- auto p_type = entry.first;
- if (launch_constraints.hasDim(p_type)) {
- auto parallel_extents = entry.second;
- for (auto extent : parallel_extents) {
- auto inferred_val = expr_eval.evaluate(extent);
- if (inferred_val.has_value()) {
- // This value could have been inferred, make sure it was set right.
- bool valid =
- inferred_val.value() == launch_constraints.getDim(p_type) ||
- launch_constraints.getRawVal(p_type) == -1;
- if (!useFallback() && !valid) {
- TORCH_WARN_ONCE(
- "Cannot validate parallelization scheme, "
- "this may be due to mixed broadcast axes that are parallelized.");
+ if (launch_constraints.nBlocks() * launch_constraints.nThreads() != -1) {
+ for (auto& entry : parallel_iter_domains) {
+ auto p_type = entry.first;
+ if (launch_constraints.hasDim(p_type)) {
+ auto parallel_ids = entry.second;
+ for (auto parallel_id : parallel_ids) {
+ auto inferred_val = see.inferValue(parallel_id->rawExtent());
+ if (inferred_val.has_value()) {
+ // This value could have been inferred, make sure it was set right.
+ TORCH_CHECK(
+ inferred_val.value() == launch_constraints.getDim(p_type) ||
+ launch_constraints.getRawVal(p_type) == -1,
+ "inferred that ",
+ p_type,
+ " should be set to ",
+ inferred_val.value(),
+ " but launch constraints specified ",
+ launch_constraints.getDim(p_type));
+ } else {
+ // Bind the launch constraint into our evaluation context
+ see.safeBind(
+ parallel_id->rawExtent(),
+ launch_constraints.getDim(entry.first),
+ &lowered_);
+ launch_params.bind(launch_constraints.getDim(p_type), p_type);
}
- } else {
- // Bind the launch constraint into our evaluation context
- expr_eval.bind(extent, launch_constraints.getDim(p_type));
- launch_params.bind(launch_constraints.getDim(p_type), p_type);
}
}
}
}
// Run through the rest of the parallel IterDomains and infer their size
- for (auto& entry : parallel_iter_extents) {
+ for (auto& entry : parallel_iter_domains) {
auto p_type = entry.first;
- auto parallel_extents = entry.second;
- // Select the maxmimum value out of all the parallel extents
- int64_t maximum_value = std::numeric_limits<int64_t>::min();
- for (auto extent : parallel_extents) {
- const auto val = expr_eval.evaluate(extent);
+ auto parallel_ids = entry.second;
+ for (auto parallel_id : parallel_ids) {
+ auto val = see.inferValue(parallel_id->rawExtent());
TORCH_INTERNAL_ASSERT(
- val.has_value(),
+ val,
"Tried to evaluate the extent of ",
- p_type,
+ parallel_id,
" to set launch bounds but could not.");
- maximum_value = std::max(maximum_value, *val);
+ launch_params.bind(val.value(), p_type);
}
- launch_params.bind(maximum_value, p_type);
}
const auto kernel = lowered_.kernel();
// Calculate Dynamic Shared Memory Size
// Add workspace for reduction and broadcast
uint64_t reduction_broadcast_workspace = 0;
- const bool has_workspace = kernel_summary.has_block_reductions ||
- kernel_summary.number_of_grid_reductions > 0 ||
- kernel_summary.has_block_broadcasts;
- if (has_workspace &&
- kernel_summary.largest_smem_data_type != DataType::Null) {
+ if (has_block_reductions || has_grid_reductions || has_block_broadcasts) {
// Not using nThreads here since it does not handle uninitialized value
-
- // TODO: here is an optimization opportunity since welford uses int64_t for
- // N while the data type is not neccessarily double. But it may need more
- // work on the alignment
- const int welford_factor =
- kernel_summary.has_block_welford || kernel_summary.has_grid_welford ? 3
- : 1;
reduction_broadcast_workspace =
- dataTypeSize(kernel_summary.largest_smem_data_type) * welford_factor *
+ dataTypeSize(kernel_summary.largest_smem_data_type) *
launch_params.bdimx() * launch_params.bdimy() * launch_params.bdimz();
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const uint64_t dynamic_smem_size = computeSharedMemory(
- expr_eval,
+ see,
kernel_summary.dynamic_smem_allocations,
true,
reduction_broadcast_workspace);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const uint64_t static_smem_size =
- computeSharedMemory(expr_eval, kernel_summary.static_smem_allocations);
+ computeSharedMemory(see, kernel_summary.static_smem_allocations);
TORCH_INTERNAL_ASSERT(
(dynamic_smem_size + static_smem_size) < max_device_smem,
}
FusionExecutor::GlobalBuffers FusionExecutor::allocGlobalVals(
- kir::ExpressionEvaluator& expr_eval) {
- FUSER_PERF_SCOPE("FusionExecutor::AllocGlobalVals");
+ StatefulExpressionEvaluator& see) {
+ FUSER_PERF_SCOPE("allocGlobalVals");
GlobalBuffers global_buffers;
- const auto kernel = lowered_.kernel();
const auto& kernel_summary = lowered_.kernel()->summary();
for (auto alloc : kernel_summary.global_allocations) {
TORCH_INTERNAL_ASSERT(
- alloc->buffer()->isA<kir::TensorView>(),
+ alloc->buffer()->getValType() == ValType::KirTensorView,
"Cannot allocate global buffers that are not tensors.");
- auto tv = alloc->buffer()->as<kir::TensorView>();
- if (kernel->isOutput(tv)) {
- continue;
- }
- if (alloc->zeroInit()) {
- global_buffers.buffers.push_back(
- inferAndAlloc(tv, alloc->shape(), expr_eval, options_, true));
- global_buffers.zero_init.push_back(true);
+ if (!alloc->zeroInit()) {
+ global_buffers.empty_buffers.push_back(inferAndAlloc(
+ alloc->buffer()->as<kir::TensorView>()->fuserTv(),
+ see,
+ options_,
+ false));
} else {
- global_buffers.buffers.push_back(
- inferAndAlloc(tv, alloc->shape(), expr_eval, options_, false));
- global_buffers.zero_init.push_back(false);
+ global_buffers.zero_buffers.push_back(inferAndAlloc(
+ alloc->buffer()->as<kir::TensorView>()->fuserTv(),
+ see,
+ options_,
+ true));
}
}
}
std::vector<at::Tensor> FusionExecutor::allocOutputs(
- kir::ExpressionEvaluator& expr_eval,
- const std::unordered_set<int>& alias_indices) {
- FUSER_PERF_SCOPE("FusionExecutor::AllocOutputs");
- const auto kernel = lowered_.kernel();
+ StatefulExpressionEvaluator& see) {
+ FUSER_PERF_SCOPE("allocOutputs");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<at::Tensor> outputs;
- for (size_t i = 0; i < kernel->outputs().size(); ++i) {
+ for (auto output : fusion_.outputs()) {
TORCH_INTERNAL_ASSERT(
- kernel->outputs()[i]->isA<kir::TensorView>(),
+ output->getValType() == ValType::TensorView,
"Cannot allocate outputs that are not tensors.");
- auto output = kernel->outputs()[i]->as<kir::TensorView>();
- if (alias_indices.count(i) == 0) {
- outputs.push_back(
- inferAndAllocOutput(output, expr_eval, options_, false));
- } else {
- // aliasing to inputs, no need to allocate real output
- outputs.push_back(inferAndAlloc(output, {}, expr_eval, options_, false));
- }
+ outputs.push_back(
+ inferAndAlloc(output->as<TensorView>(), see, options_, false));
}
return outputs;
}
void FusionExecutor::setUsedTVs() {
- auto used_vals = fusion_.usedMathVals();
- auto used_tvs = ir_utils::filterByType<TensorView>(used_vals);
used_tvs_.clear();
-
- for (auto tv : used_tvs)
- used_tvs_.push_back(tv);
+ auto used_vals = DependencyCheck::getAllValsBetween(
+ {fusion_.inputs().begin(), fusion_.inputs().end()}, fusion_.outputs());
+ for (auto val : used_vals) {
+ if (val->getValType().value() == ValType::TensorView) {
+ used_tvs_.push_back(val->as<TensorView>());
+ }
+ }
}
std::vector<at::Tensor> FusionExecutor::runFusion(
const std::vector<at::Tensor>& outputs,
const LaunchParams& launch_constraints,
const c10::optional<size_t>& opt_code) {
- FUSER_PERF_SCOPE("FusionExecutor::RunFusion");
+ FUSER_PERF_SCOPE("runFusion");
TORCH_INTERNAL_ASSERT(
fusion_id_ > 0, "Cannot run fusion, it was not compiled.");
LaunchParams launch_params;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<at::Tensor> allocated_outputs = outputs;
+ std::vector<at::Tensor> alloced_outputs = outputs;
GlobalBuffers global_buffers;
uint64_t rand_offset = 0;
if (executor_entry && executor_entry->init) {
{
- // context manager to disable auto grad for `empty_cuda` calls later
+ // context manager to disable auto grad for `empty_cuda` calls later;
at::AutoDispatchBelowADInplaceOrView non_variable_type_mode;
- // take the short-cut for launch if we see a recorded input set again
+ // take the short-cut for launch if we see a recorded input set again;
launch_params = executor_entry->launch_params;
- // only allocate outputs when not given
- if (outputs.empty()) {
- FUSER_PERF_SCOPE("ExecutorRunFusion::OutputAlloc");
- for (const auto i : c10::irange(executor_entry->output_sizes.size())) {
- allocated_outputs.push_back(at::native::empty_cuda(
- executor_entry->output_sizes[i],
- executor_entry->output_types[i],
- c10::nullopt,
- options_.device,
- c10::nullopt));
- }
- for (const auto& entry : executor_entry->io_alias_indices) {
- TORCH_INTERNAL_ASSERT(
- inputs[entry.second].isTensor(), "alias io only supports tensor");
- allocated_outputs[entry.first] = inputs[entry.second].toTensor();
- }
- } else {
- TORCH_INTERNAL_ASSERT(
- outputs.size() == fusion_.outputs().size(),
- __func__,
- " provided number of outputs does match fusion output");
+ for (const auto i : c10::irange(executor_entry->output_sizes.size())) {
+ alloced_outputs.push_back(at::native::empty_cuda(
+ executor_entry->output_sizes[i],
+ executor_entry->output_types[i],
+ c10::nullopt,
+ options_.device,
+ c10::nullopt));
}
- {
- FUSER_PERF_SCOPE("ExecutorRunFusion::IntermediateBufferAlloc");
- for (const auto i : c10::irange(executor_entry->buffer_sizes.size())) {
- if (executor_entry->buffer_zero_init[i]) {
- global_buffers.buffers.push_back(at::zeros(
- executor_entry->buffer_sizes[i],
- at::TensorOptions()
- .dtype(executor_entry->buffer_types[i])
- .device(options_.device)));
- } else {
- global_buffers.buffers.push_back(at::native::empty_cuda(
- executor_entry->buffer_sizes[i],
- executor_entry->buffer_types[i],
- c10::nullopt,
- options_.device,
- c10::nullopt));
- }
- }
+ for (const auto i :
+ c10::irange(executor_entry->empty_buffer_sizes.size())) {
+ global_buffers.empty_buffers.push_back(at::native::empty_cuda(
+ executor_entry->empty_buffer_sizes[i],
+ executor_entry->empty_buffer_types[i],
+ c10::nullopt,
+ options_.device,
+ c10::nullopt));
}
}
+ for (const auto i : c10::irange(executor_entry->zero_buffer_sizes.size())) {
+ auto tensor_options = at::TensorOptions()
+ .dtype(executor_entry->zero_buffer_types[i])
+ .device(options_.device);
+ global_buffers.zero_buffers.push_back(
+ at::zeros(executor_entry->zero_buffer_sizes[i], tensor_options));
+ }
rand_offset = executor_entry->rand_offset;
} else {
- FUSER_PERF_SCOPE("ExecutorRunFusion::ValidateAndInitialize");
// code path to take when either:
- // 1. no opt_code is provided or
+ // 1. no opt_code is provided or;
// 2. `executor_entry` is not initialized
executor_utils::validateKernelInputs(&fusion_, inputs, options_.device);
- const auto kernel = lowered_.kernel();
-
- auto expr_eval = executor_utils::bindKernelInputs(inputs, kernel);
-
- launch_params = computeLaunchParams(launch_constraints, expr_eval);
-
- executor_utils::validateVectorizedTensors(
- &fusion_, inputs, outputs, lowered_, expr_eval);
+ StatefulExpressionEvaluator evaluator =
+ executor_utils::statefulBindInputs(inputs, &fusion_, &lowered_);
- auto alias_indices = fusion_.getInputAliasIndices();
+ launch_params = computeLaunchParams(launch_constraints, evaluator);
- // ditch pre-allocated outputs if the number doesn't match.
// NOLINTNEXTLINE(bugprone-branch-clone)
- if (outputs.empty()) {
- allocated_outputs =
- allocOutputs(expr_eval, fusion_.getOutputAliasIndices());
-
- for (const auto& entry : alias_indices) {
- TORCH_INTERNAL_ASSERT(
- inputs[entry.second].isTensor(), "alias io only supports tensor");
- allocated_outputs[entry.first] = inputs[entry.second].toTensor();
- }
+ if (outputs.empty() || outputs.size() != fusion_.outputs().size()) {
+ alloced_outputs = allocOutputs(evaluator);
} else {
- // TODO: Update this as well;
executor_utils::validateKernelOutputs(
- &fusion_, allocated_outputs, options_.device);
+ &fusion_, alloced_outputs, options_.device);
}
- global_buffers = allocGlobalVals(expr_eval);
+ global_buffers = allocGlobalVals(evaluator);
- if (kernel->summary().is_stochastic) {
+ if (lowered_.kernel()->summary().is_stochastic) {
// NOTE: this is how we map offset to PW kernels in order to have
// identical random number generator to match native PyTorch results.
// But it doesn't really work as it takes assumption how threads are
// works.
rand_offset = 4 *
(std::ceil(
- allocated_outputs[0].numel() /
+ alloced_outputs[0].numel() /
(4.0 * 128 * launch_params.gdimx())) + // NOLINT
1);
}
// This is the entry when we have provided `opt_code` but the entry has not
// been initialized yet.
if (executor_entry) {
- FUSER_PERF_SCOPE("ExecutorRunFusion::FillCacheEntry");
// record the the short-cut executor entry for the given input set;
executor_entry->launch_params = launch_params;
- executor_entry->io_alias_indices = alias_indices;
- for (const auto& output : allocated_outputs) {
+ for (const auto& output : alloced_outputs) {
executor_entry->output_sizes.push_back(output.sizes().vec());
executor_entry->output_types.push_back(output.scalar_type());
}
-
- for (const auto& i : c10::irange(global_buffers.buffers.size())) {
- executor_entry->buffer_sizes.push_back(
- global_buffers.buffers[i].sizes().vec());
- executor_entry->buffer_types.push_back(
- global_buffers.buffers[i].scalar_type());
- executor_entry->buffer_zero_init.push_back(global_buffers.zero_init[i]);
+ for (const auto& buffer : global_buffers.empty_buffers) {
+ executor_entry->empty_buffer_sizes.push_back(buffer.sizes().vec());
+ executor_entry->empty_buffer_types.push_back(buffer.scalar_type());
+ }
+ for (const auto& buffer : global_buffers.zero_buffers) {
+ executor_entry->zero_buffer_sizes.push_back(buffer.sizes().vec());
+ executor_entry->zero_buffer_types.push_back(buffer.scalar_type());
}
executor_entry->rand_offset = rand_offset;
executor_entry->init = true;
}
}
- KernelArgumentHolder kernel_arguments(options_.index_mode);
- {
- FUSER_PERF_SCOPE("ExecutorRunFusion::FillKernelArgStructure");
- kernel_arguments.push(inputs);
- kernel_arguments.push(allocated_outputs);
- kernel_arguments.push(global_buffers.buffers);
- if (lowered_.kernel()->summary().is_stochastic) {
- kernel_arguments.appendPhiloxRNGSeed(rand_offset);
- }
- }
-
- if (isDebugDumpEnabled(DebugDumpOption::LaunchParam)) {
- launch_params.print();
- }
-
- if (isDebugDumpEnabled(DebugDumpOption::PrintRuntimeArgs)) {
- std::cout << "Arguments for kernel" << fusion_id_ << ":" << std::endl
- << "Inputs:" << std::endl;
- for (const auto& input : inputs) {
- if (input.isTensor()) {
- std::cout << input.toTensor().scalar_type() << " "
- << input.toTensor().sizes() << std::endl;
- }
- }
- std::cout << "Outputs:" << std::endl;
- for (const auto& output : allocated_outputs) {
- std::cout << " " << output.scalar_type() << " " << output.sizes()
- << std::endl;
- }
- std::cout << "Reduction and semaphore buffers:" << std::endl;
- for (const auto& buffer : global_buffers.buffers) {
- std::cout << " " << buffer.scalar_type() << " " << buffer.sizes()
- << std::endl;
- }
- }
-
- cudaEvent_t start_event = {};
- cudaEvent_t finish_event = {};
-
- if (measure_kernel_time_ ||
- isDebugDumpEnabled(DebugDumpOption::EffectiveBandwidth)) {
- cudaEventCreate(&start_event);
- cudaEventCreate(&finish_event);
- cudaEventRecord(start_event);
+ KernelArgumentHolder kernel_arguments;
+ kernel_arguments.push(inputs);
+ kernel_arguments.push(alloced_outputs);
+ kernel_arguments.push(global_buffers.empty_buffers);
+ kernel_arguments.push(global_buffers.zero_buffers);
+ if (lowered_.kernel()->summary().is_stochastic) {
+ kernel_arguments.appendPhiloxRNGSeed(rand_offset);
}
- if (execute_kernel_) {
- FUSER_PERF_SCOPE("ExecutorRunFusion::cuLaunchKernel");
+ {
+ FUSER_PERF_SCOPE("cuLaunchKernel");
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLaunchKernel(
compiled_kernel_.function,
launch_params.gdimx(),
stream,
kernel_arguments.getBuffer(),
nullptr));
+ at::cuda::stream_synchronize(stream);
}
- if (measure_kernel_time_ ||
- isDebugDumpEnabled(DebugDumpOption::EffectiveBandwidth)) {
- cudaEventRecord(finish_event);
- cudaEventSynchronize(start_event);
- cudaEventSynchronize(finish_event);
- cudaEventElapsedTime(&kernel_time_ms_, start_event, finish_event);
- cudaEventDestroy(start_event);
- cudaEventDestroy(finish_event);
-
- if (isDebugDumpEnabled(DebugDumpOption::EffectiveBandwidth)) {
- size_t bytes = 0;
- // Figure how many bytes are inputs, outputs, and temporary buffers
- for (auto input : inputs) {
- if (input.isTensor()) {
- bytes += input.toTensor().numel() *
- dataTypeSize(aten_to_data_type(input.toTensor().scalar_type()));
- }
- }
- for (auto output : allocated_outputs) {
- bytes += output.numel() *
- dataTypeSize(aten_to_data_type(output.scalar_type()));
- }
- double gb_per_s =
- ((double)bytes / ((double)kernel_time_ms_ / 1000)) / (double)1.0e9;
- std::cout << "kernel" << fusion_id_ << " run in " << kernel_time_ms_
- << " ms, achieved: " << gb_per_s << " GB/s" << std::endl;
- }
- }
-
- return allocated_outputs;
-}
-
-void FusionExecutor::compileRtc(
- const std::string& code,
- const std::string& name,
- bool structured) {
- FUSER_PERF_SCOPE("ExecutorRunFusion::compileRtc");
- std::string scode;
- if (!structured) {
- scode = getStructuredCode(code);
- } else {
- scode = code;
- }
- fusion_id_ = 1;
- options_ = CompileOptions();
- compiled_kernel_ = executor_utils::nvrtcCompile(scode, name, fusion_id_);
-}
-
-void FusionExecutor::runRtc(
- const LaunchParams& launch_params,
- const std::vector<at::Tensor>& args) {
- FUSER_PERF_SCOPE("runFusion");
-
- c10::DeviceGuard dg(options_.device);
- auto stream = at::cuda::getCurrentCUDAStream();
-
- KernelArgumentHolder kernel_arguments(options_.index_mode);
- kernel_arguments.push(args);
- AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLaunchKernel(
- compiled_kernel_.function,
- launch_params.gdimx(),
- launch_params.gdimy(),
- launch_params.gdimz(),
- launch_params.bdimx(),
- launch_params.bdimy(),
- launch_params.bdimz(),
- launch_params.smem(),
- stream,
- kernel_arguments.getBuffer(),
- nullptr));
+ return alloced_outputs;
}
} // namespace cuda
#pragma once
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
+#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>
// TODO: Should this actually be in launch params?
struct TORCH_CUDA_CU_API CompileOptions {
c10::Device device = c10::Device(c10::DeviceType::CUDA, 0);
- KernelIndexMode index_mode = KernelIndexMode::INT64;
};
class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
int id,
CompileOptions options = CompileOptions());
- void compileFusion(
- Fusion* fusion,
- CompileOptions options = CompileOptions(),
- const at::ArrayRef<IValue>& inputs = {},
- const LaunchParams& launch_constraints = LaunchParams());
+ void compileFusion(Fusion* fusion, CompileOptions options = CompileOptions());
std::vector<at::Tensor> runFusion(
const at::ArrayRef<IValue>& inputs,
executor_entry_lookup_.erase(cache_id);
}
- // struct used to hold necessary information to launch compiled kernel on a
- // given input set.
- //
// TODO: strides would also be important when we handle permutations in
// codegen.
- //
+ // struct used to hold necessary information to launch compiled kernel on a
+ // given input set.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct ExecutorEntry {
bool init = false;
LaunchParams launch_params;
- std::vector<std::pair<int, int>> io_alias_indices;
std::vector<std::vector<int64_t>> output_sizes;
std::vector<at::ScalarType> output_types;
- std::vector<std::vector<int64_t>> buffer_sizes;
- std::vector<at::ScalarType> buffer_types;
- std::vector<bool> buffer_zero_init;
+ std::vector<std::vector<int64_t>> empty_buffer_sizes;
+ std::vector<at::ScalarType> empty_buffer_types;
+ std::vector<std::vector<int64_t>> zero_buffer_sizes;
+ std::vector<at::ScalarType> zero_buffer_types;
uint64_t rand_offset;
};
- kir::Kernel* kernel() const {
+ Kernel* kernel() const {
return lowered_.kernel();
}
- //! Internal knob used for debugging/profiling only
- void setExecuteKernelFlag(bool execute_kernel) {
- execute_kernel_ = execute_kernel;
- }
-
- //! Internal knob used for debugging/profiling only
- void setMeasureKernelTimeFlag(bool measure_kernel_time) {
- measure_kernel_time_ = measure_kernel_time;
- }
-
- //! Returns the last kernel execution time, in milliseconds
- //!
- //! \note The kernel time is only tracked if enabled by calling
- //! setMeasureKernelTimeFlag(true)
- //!
- float kernelTimeMs() const {
- return measure_kernel_time_ ? kernel_time_ms_ : 0;
- }
-
- //! Internal tests only. Compiles CUDA code with NVRTC directly from
- //! string. This util provides a path to test runtime code, i.e. the resource
- //! strings.
- void compileRtc(
- const std::string& code,
- const std::string& name,
- bool structured = false);
-
- //! Internal tests only. Runs the compiled CUDA kernel from compileRtc.
- void runRtc(
- const LaunchParams& launch_params,
- const std::vector<at::Tensor>& args);
-
private:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct GlobalBuffers {
- std::vector<at::Tensor> buffers;
- std::vector<bool> zero_init;
+ std::vector<at::Tensor> empty_buffers;
+ std::vector<at::Tensor> zero_buffers;
};
std::string kernelName() const {
LaunchParams computeLaunchParams(
const LaunchParams& launch_constraints,
- kir::ExpressionEvaluator& expr_eval);
+ StatefulExpressionEvaluator& see);
uint64_t computeSharedMemory(
- kir::ExpressionEvaluator& expr_eval,
- const std::vector<const kir::Allocate*>& buffers,
+ StatefulExpressionEvaluator& see,
+ const std::vector<kir::Allocate*>& buffers,
bool align_padding = false,
uint64_t total = 0);
// return a pair of vector of tensors, where tensors in the first vector are
// not initialized, while the second vector contains zero-initiliazed tensors
- GlobalBuffers allocGlobalVals(kir::ExpressionEvaluator& expr_eval);
+ GlobalBuffers allocGlobalVals(StatefulExpressionEvaluator& see);
- // alias_index: index of outputs that are aliases to inputs, hence we should
- // skip allocating real storage for those, but still maintain its spot to
- // maintain the indexing from output aliases to inputs
- std::vector<at::Tensor> allocOutputs(
- kir::ExpressionEvaluator& expr_eval,
- const std::unordered_set<int>& alias_indices = {});
+ std::vector<at::Tensor> allocOutputs(StatefulExpressionEvaluator& see);
void setUsedTVs();
private:
Fusion fusion_;
+ // TODO(kir): caching the values here is no longer needed
+ bool has_block_reductions = false;
+ bool has_grid_reductions = false;
+ bool has_block_broadcasts = false;
+
CompileOptions options_;
size_t max_device_smem = std::numeric_limits<size_t>().max();
executor_utils::NvrtcFunction compiled_kernel_;
// lookup table to take short cut to retrieve recorded information in order to
// launch kernels without re-inference parameters.
std::unordered_map<size_t, ExecutorEntry> executor_entry_lookup_;
-
- // Profiling support: knob to control wheter we actually execute the
- // kernel on the GPU or not
- bool execute_kernel_ = true;
-
- // Profiling support: knob to enable measuring kernel execution time
- bool measure_kernel_time_ = false;
-
- // The last kernel execution time, if measure_kernel_time_ is true
- float kernel_time_ms_ = 0;
};
} // namespace cuda
namespace fuser {
namespace cuda {
-namespace {
-
-template <typename T, typename nvfuser_index_t>
-std::unique_ptr<TensorArgAbstract> getTensorArg(int nDims) {
- switch (nDims) {
- case (0):
- return std::make_unique<TensorArg<
- TensorArgCodegen<T, 0, nvfuser_index_t>,
- nvfuser_index_t>>();
- case (1):
- return std::make_unique<TensorArg<
- TensorArgCodegen<T, 1, nvfuser_index_t>,
- nvfuser_index_t>>();
- case (2):
- return std::make_unique<TensorArg<
- TensorArgCodegen<T, 2, nvfuser_index_t>,
- nvfuser_index_t>>();
- case (3):
- return std::make_unique<TensorArg<
- TensorArgCodegen<T, 3, nvfuser_index_t>,
- nvfuser_index_t>>();
- case (4):
- return std::make_unique<TensorArg<
- TensorArgCodegen<T, 4, nvfuser_index_t>,
- nvfuser_index_t>>();
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- case (5):
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- return std::make_unique<TensorArg<
- TensorArgCodegen<T, 5, nvfuser_index_t>,
- nvfuser_index_t>>();
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- case (6):
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- return std::make_unique<TensorArg<
- TensorArgCodegen<T, 6, nvfuser_index_t>,
- nvfuser_index_t>>();
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- case (7):
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- return std::make_unique<TensorArg<
- TensorArgCodegen<T, 7, nvfuser_index_t>,
- nvfuser_index_t>>();
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- case (8):
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- return std::make_unique<TensorArg<
- TensorArgCodegen<T, 8, nvfuser_index_t>,
- nvfuser_index_t>>();
- default:
- TORCH_INTERNAL_ASSERT(
- false,
- "Tried to gerneate a tensor to run a generated kernel with ",
- nDims,
- " dimensions, however it must be a 1-8 dimensional tensor.");
- }
- return nullptr;
-}
-
-template <typename INDEX_MODE>
std::unique_ptr<TensorArgAbstract> getTensorArg(
c10::ScalarType dtype,
int nDims) {
switch (dtype) {
- case c10::ScalarType::Double:
- return getTensorArg<double, INDEX_MODE>(nDims);
case c10::ScalarType::Float:
- return getTensorArg<float, INDEX_MODE>(nDims);
+ return getTensorArg<float>(nDims);
case c10::ScalarType::Half:
- return getTensorArg<at::Half, INDEX_MODE>(nDims);
+ return getTensorArg<at::Half>(nDims);
case c10::ScalarType::Bool:
- return getTensorArg<bool, INDEX_MODE>(nDims);
+ return getTensorArg<bool>(nDims);
case c10::ScalarType::Long:
- return getTensorArg<int64_t, INDEX_MODE>(nDims);
- case c10::ScalarType::Int:
- return getTensorArg<int32_t, INDEX_MODE>(nDims);
+ return getTensorArg<int64_t>(nDims);
default:
TORCH_CHECK(
false,
}
}
-} // namespace
-
-std::unique_ptr<TensorArgAbstract> getTensorArg(
- c10::ScalarType dtype,
- int nDims,
- KernelIndexMode index_mode) {
- switch (index_mode) {
- case KernelIndexMode::INT32:
- return getTensorArg<int>(dtype, nDims);
- case KernelIndexMode::INT64:
- return getTensorArg<int64_t>(dtype, nDims);
- default:
- break;
- }
-
- TORCH_INTERNAL_ASSERT(false, "unknown index mode");
- return nullptr;
-}
-
// Push a tensor to the arguments
void KernelArgumentHolder::push(const at::Tensor& tensor) {
changed_ = true;
int nDims = tensor.ndimension();
c10::ScalarType dtype = tensor.scalar_type();
- std::unique_ptr<TensorArgAbstract> tensor_arg =
- getTensorArg(dtype, nDims, index_mode_);
+ std::unique_ptr<TensorArgAbstract> tensor_arg = getTensorArg(dtype, nDims);
tensor_arg->setPointer(tensor.data_ptr());
for (const auto i : c10::irange(nDims)) {
tensor_arg->setSize(i, tensor.sizes()[i]);
val.isScalar(),
"Tried to push an arg to run in a fused kernel, expected a scalar but got, ",
val);
- auto scalar_val = val.toScalar();
- switch (scalar_val.type()) {
+ switch (val.toScalar().type()) {
// NOLINTNEXTLINE(bugprone-branch-clone)
case c10::ScalarType::Double:
- arguments_.push_back(std::make_unique<DoubleArg>(scalar_val.toDouble()));
+ arguments_.push_back(std::make_unique<FloatArg>((float)val.toDouble()));
return;
case c10::ScalarType::Long:
- arguments_.push_back(std::make_unique<LongArg>(scalar_val.toLong()));
- return;
- case c10::ScalarType::Bool:
- arguments_.push_back(std::make_unique<BoolArg>(scalar_val.toBool()));
+ arguments_.push_back(std::make_unique<LongArg>(val.toInt()));
return;
default:
TORCH_INTERNAL_ASSERT(
" Tried to create argument to send to a fused kernel, but got a non-scalar type.");
}
-void KernelArgumentHolder::push(const at::PhiloxCudaState& val) {
- arguments_.push_back(std::make_unique<PhiloxCudaStateArg>(val));
+void KernelArgumentHolder::push(const uint64_t& val) {
+ arguments_.push_back(std::make_unique<ULongArg>(val));
}
// Create buffer, flatten arguments into it, align by 8 Bytes, return pointers
}
void KernelArgumentHolder::appendPhiloxRNGSeed(uint64_t rand_offset) {
- at::PhiloxCudaState philox_engine_inputs;
+ std::pair<uint64_t, uint64_t> philox_engine_inputs;
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
philox_engine_inputs =
- at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_cuda_state(
+ at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
rand_offset);
}
- push(philox_engine_inputs);
+ push(philox_engine_inputs.first);
+ push(philox_engine_inputs.second);
}
} // namespace cuda
#pragma once
-#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/core/ivalue.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/ir/ir.h>
namespace cuda {
// This should match the tensor used in the code generation (almost exactly)
-template <typename T, int N, typename nvfuser_index_t>
+template <typename T, int N>
struct TensorArgCodegen {
- T& operator[](nvfuser_index_t ind) {
+ T& operator[](int64_t ind) {
return data[ind];
};
T* data;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
- nvfuser_index_t size[N];
+ int64_t size[N];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
- nvfuser_index_t stride[N];
+ int64_t stride[N];
constexpr int nDims() {
return N;
}
- void setSize(int i, nvfuser_index_t s) {
+ void setSize(int i, int64_t s) {
size[i] = s;
}
- void setStride(int i, nvfuser_index_t s) {
+ void setStride(int i, int64_t s) {
stride[i] = s;
}
};
-template <typename T, typename nvfuser_index_t>
-struct TensorArgCodegen<T, 0, nvfuser_index_t> {
- T& operator[](nvfuser_index_t ind) {
+template <typename T>
+struct TensorArgCodegen<T, 0> {
+ T& operator[](int64_t ind) {
return data[ind];
};
constexpr int nDims() {
return 0;
}
- void setSize(int, nvfuser_index_t) {
+ void setSize(int, int64_t) {
TORCH_INTERNAL_ASSERT(false, "Tried to set size of a 0-dim tensor");
}
- void setStride(int, nvfuser_index_t) {
+ void setStride(int, int64_t) {
TORCH_INTERNAL_ASSERT(false, "Tried to set stride of a 0-dim tensor");
}
};
virtual void* arg() = 0;
};
-struct PhiloxCudaStateArg : public ArgAbstract {
- at::PhiloxCudaState val_;
- PhiloxCudaStateArg(at::PhiloxCudaState _val) : val_(_val){};
- // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
- void* arg() {
+struct ULongArg : public ArgAbstract {
+ uint64_t val_;
+ ULongArg(uint64_t _val) : val_(_val) {}
+ void* arg() override {
return &val_;
}
};
struct LongArg : public ArgAbstract {
int64_t val_;
- explicit LongArg(int64_t _val) : val_(_val){};
- // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
- void* arg() {
+ LongArg(int64_t _val) : val_(_val) {}
+ void* arg() override {
return &val_;
}
};
-struct DoubleArg : public ArgAbstract {
- double val_;
- explicit DoubleArg(double _val) : val_(_val){};
- // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
- void* arg() {
+struct IntArg : public ArgAbstract {
+ int val_;
+ IntArg(int _val) : val_(_val) {}
+ void* arg() override {
return &val_;
}
};
-struct BoolArg : public ArgAbstract {
- bool val_;
- explicit BoolArg(bool _val) : val_(_val){};
- // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
- void* arg() {
+struct FloatArg : public ArgAbstract {
+ float val_;
+ FloatArg(float _val) : val_(_val) {}
+ void* arg() override {
return &val_;
}
};
};
// This should match the tensor used in the code generation (almost exactly)
-template <typename TENSOR_TYPE, typename nvfuser_index_t>
+template <typename TENSOR_TYPE>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct TensorArg : public TensorArgAbstract {
TENSOR_TYPE instance_;
void setSize(int i, int64_t size) override {
- instance_.setSize(i, (nvfuser_index_t)size);
+ instance_.setSize(i, size);
}
void setStride(int i, int64_t stride) override {
- instance_.setStride(i, (nvfuser_index_t)stride);
+ instance_.setStride(i, stride);
}
void setPointer(void* ptr) override {
instance_.data = static_cast<decltype(TENSOR_TYPE::data)>(ptr);
}
};
+template <typename T>
+std::unique_ptr<TensorArgAbstract> getTensorArg(int nDims) {
+ switch (nDims) {
+ case (0):
+ return std::make_unique<TensorArg<TensorArgCodegen<T, 0>>>();
+ case (1):
+ return std::make_unique<TensorArg<TensorArgCodegen<T, 1>>>();
+ case (2):
+ return std::make_unique<TensorArg<TensorArgCodegen<T, 2>>>();
+ case (3):
+ return std::make_unique<TensorArg<TensorArgCodegen<T, 3>>>();
+ case (4):
+ return std::make_unique<TensorArg<TensorArgCodegen<T, 4>>>();
+ case (5):
+ return std::make_unique<TensorArg<TensorArgCodegen<T, 5>>>();
+ case (6):
+ return std::make_unique<TensorArg<TensorArgCodegen<T, 6>>>();
+ case (7):
+ return std::make_unique<TensorArg<TensorArgCodegen<T, 7>>>();
+ case (8):
+ return std::make_unique<TensorArg<TensorArgCodegen<T, 8>>>();
+ default:
+ TORCH_INTERNAL_ASSERT(
+ false,
+ "Tried to gerneate a tensor to run a generated kernel with ",
+ nDims,
+ " dimensions, however it must be a 1-8 dimensional tensor.");
+ }
+}
+
std::unique_ptr<TensorArgAbstract> getTensorArg(
c10::ScalarType dtype,
int nDims);
class KernelArgumentHolder {
public:
- explicit KernelArgumentHolder(KernelIndexMode index_mode)
- : index_mode_(index_mode) {}
-
// Push a tensor to the arguments
void push(const at::Tensor& tensor);
// Push a scalar or integer to the arguments
void push(const IValue& val);
- void push(const at::PhiloxCudaState& val);
+ void push(const uint64_t& val);
// Create buffer, flatten arguments into it, align by 8 Bytes, return pointers
// in the buffer
std::vector<std::unique_ptr<ArgAbstract>> arguments_;
std::vector<void*> void_ptrs_;
bool changed_ = true;
- KernelIndexMode index_mode_ = KernelIndexMode::INT64;
};
} // namespace cuda
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
-#include <ATen/cuda/CUDAContext.h>
-
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
-void LaunchParams::assertValid() {
- TORCH_INTERNAL_ASSERT(
- bdimx() * bdimz() * bdimz() > 0 &&
- bdimx() * bdimz() * bdimz() <=
- (int64_t)at::cuda::getCurrentDeviceProperties()
- ->maxThreadsPerMultiProcessor,
- "Selected invalid number of threads for cuda: ",
- bdimx() * bdimz() * bdimz());
- TORCH_INTERNAL_ASSERT(
- gdimx() > 0 && gdimx() < (std::int64_t(1) << 32) - 1,
- "Invalid number of blocks in x direction: ",
- gdimx());
- TORCH_INTERNAL_ASSERT(
- gdimy() > 0 && gdimy() <= 65535,
- "Invalid number of blocks in y direction: ",
- gdimy());
- TORCH_INTERNAL_ASSERT(
- gdimz() > 0 && gdimz() <= 65535,
- "Invalid number of blocks in z direction: ",
- gdimz());
-}
-
void LaunchParams::bind(int64_t val, ParallelType p_type) {
switch (p_type) {
case ParallelType::TIDx:
"Tried to bind invalid parallel type in launch config: ",
p_type);
}
- assertValid();
}
int64_t LaunchParams::getDim(ParallelType p_type) const {
bdimx_ == other.bdimx_ && bdimy_ == other.bdimy_ && smem_ == other.smem_;
}
-void LaunchParams::print() const {
- std::cout << toString();
-}
-
-std::string LaunchParams::toString() const {
- std::stringstream ss;
- ss << "Launch Parameters \n"
- << "BlockDim.x = " << bdimx() << "\n"
- << "BlockDim.y = " << bdimy() << "\n"
- << "BlockDim.z = " << bdimz() << "\n"
- << "GridDim.x = " << gdimx() << "\n"
- << "GridDim.y = " << gdimy() << "\n"
- << "GridDim.z = " << gdimz() << "\n"
- << "Smem Size = " << smem() << "\n";
- return ss.str();
-}
-
} // namespace cuda
} // namespace fuser
} // namespace jit
gdimz_(gdimz),
bdimx_(bdimx),
bdimy_(bdimy),
- bdimz_(bdimz) {
- assertValid();
- }
-
- void assertValid();
+ bdimz_(bdimz) {}
void setSmem(int64_t smem) {
smem_ = smem;
}
int64_t nBlocks() const {
- return std::abs(gdimx_ * gdimy_ * gdimz_);
+ return gdimx_ * gdimy_ * gdimz_;
}
int64_t nThreads() const {
- return std::abs(bdimx_ * bdimy_ * bdimz_);
+ return bdimx_ * bdimy_ * bdimz_;
}
int64_t bdimx() const {
if (class_val == UNINITIALIZED_VAL) {
class_val = incoming_val;
}
- assertValid();
}
// Binds dim assocaited with p_type to val
bool operator==(const LaunchParams& other) const;
- void print() const;
-
- std::string toString() const;
-
private:
// Spell them out because I want signed ints to know if they were initialized
// or not.
// TODO: Fill in output sizes
std::vector<std::vector<int64_t>> output_sizes;
};
-
} // namespace cuda
} // namespace fuser
} // namespace jit
#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
#include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h>
#include <torch/csrc/jit/resource_guard.h>
-#include <nvfuser_resources/PhiloxCudaStateRaw.h>
#include <nvfuser_resources/block_reduction.h>
-#include <nvfuser_resources/block_sync_atomic.h>
-#include <nvfuser_resources/block_sync_default.h>
#include <nvfuser_resources/broadcast.h>
#include <nvfuser_resources/fp16_support.h>
#include <nvfuser_resources/grid_reduction.h>
#include <nvfuser_resources/helpers.h>
#include <nvfuser_resources/random_numbers.h>
#include <nvfuser_resources/tensor.h>
-#include <nvfuser_resources/welford.h>
-
-#ifndef USE_ROCM
-#include <cuda_occupancy.h>
-#endif
#include <fstream>
ss << nvfuser_resources::tensor_cu;
ss << nvfuser_resources::random_numbers_cu;
ss << nvfuser_resources::helpers_cu;
- if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
- ss << nvfuser_resources::block_sync_atomic_cu;
- } else {
- ss << nvfuser_resources::block_sync_default_cu;
- }
ss << nvfuser_resources::block_reduction_cu;
ss << nvfuser_resources::grid_reduction_cu;
ss << nvfuser_resources::broadcast_cu;
- ss << nvfuser_resources::welford_cu;
- ss << nvfuser_resources::PhiloxCudaStateRaw_cu;
return ss.str();
}
DataType param_data_type = *param->getDataType();
bool match = false;
switch (arg_data_type) {
- case at::ScalarType::Double:
- match = param_data_type == DataType::Double;
- break;
case at::ScalarType::Half:
match = param_data_type == DataType::Half;
break;
case at::ScalarType::Float:
match = param_data_type == DataType::Float;
break;
- case at::ScalarType::Long:
- match = param_data_type == DataType::Int;
- break;
- case at::ScalarType::Int:
- match = param_data_type == DataType::Int32;
- break;
case at::ScalarType::Bool:
match = param_data_type == DataType::Bool;
break;
// Return false if arg_type doesn't match the type in param
bool validateKernelArgScalar(
- const c10::IValue& arg,
+ const c10::TypePtr& arg_type,
const Val* param,
std::stringstream& msg) {
- if (!arg.isScalar()) {
+ if (!param->isScalar()) {
msg << "Argument is a scalar, but the parameter is not."
<< "\n";
return false;
}
DataType param_type = *param->getDataType();
bool match = false;
- switch (arg.toScalar().type()) {
- case c10::ScalarType::Long:
- match = param_type == DataType::Int || param_type == DataType::Int32;
+ switch (arg_type->kind()) {
+ case c10::TypeKind::IntType:
+ match = param_type == DataType::Int;
break;
- case c10::ScalarType::Double:
- match = param_type == DataType::Double || param_type == DataType::Float ||
- param_type == DataType::Half;
+ case c10::TypeKind::FloatType:
+ match = param_type == DataType::Float;
break;
- case c10::ScalarType::Bool:
+ case c10::TypeKind::BoolType:
match = param_type == DataType::Bool;
break;
default:
match = false;
}
if (!match) {
- msg << "Argument type is " << arg.toScalar().type()
- << ", but the parameter is " << param_type << "\n";
+ msg << "Argument type is " << *arg_type << ", but the parameter is "
+ << param_type << "\n";
}
return match;
}
if (arg.isTensor()) {
return validateKernelArgTensor(arg.toTensor(), param, device, msg);
} else {
- return validateKernelArgScalar(arg, param, msg);
- }
-}
-
-// Return true if all the tensors have the same stride, assumes all tensors are
-// contiguous
-bool checkSameStride(const std::vector<c10::IValue>& tensors) {
- if (tensors.size() < 2) {
- return true;
- }
- for (size_t idx = 0; idx < tensors.size() - 1; ++idx) {
- auto current = tensors[idx];
- auto next = tensors[idx + 1];
- if (!current.isTensor() || !next.isTensor()) {
- return false;
- }
-
- const auto& current_tensor = current.toTensor();
- const auto& next_tensor = next.toTensor();
- if (current_tensor.ndimension() != next_tensor.ndimension()) {
- return false;
- }
-
- for (int64_t i = 0; i < current_tensor.ndimension(); ++i) {
- if (current_tensor.stride(i) != next_tensor.stride(i)) {
- return false;
- }
- }
- }
- return true;
-}
-
-// Return true if all the tensors are contiguous and have the same striding
-bool checkSameContiguity(const std::vector<c10::IValue>& tensors) {
- auto reference = tensors.front();
- if (!reference.isTensor()) {
- return false;
- }
-
- // Determine if the reference tensor is contiguous
- const auto& reference_tensor = reference.toTensor();
- int64_t expected_stride = 1;
- for (int64_t i = 1; i <= reference_tensor.ndimension(); ++i) {
- int64_t ind = reference_tensor.ndimension() - i;
- if (reference_tensor.size(ind) == 1) {
- continue;
- }
- if (reference_tensor.stride(ind) != expected_stride) {
- return false;
- }
- expected_stride *= reference_tensor.size(ind);
- }
-
- // Check if all the tensors have the same contiguity
- return checkSameStride(tensors);
-}
-
-bool checkValidMisalignedTensors(
- const std::unordered_set<TensorView*>& inp_tv,
- const std::unordered_set<TensorView*>& out_tv,
- const std::vector<c10::IValue>& inp_tensors,
- const std::vector<c10::IValue>& out_tensors) {
- if (out_tv.empty()) {
- // Only check input tensors
- return checkSameStride(inp_tensors);
- } else if (!out_tv.empty() && out_tensors.empty()) {
- // Assume out tensors are contiguous
- return checkSameContiguity(inp_tensors);
- } else {
- // Only check input and output tensors
- std::vector<c10::IValue> tensors;
- tensors.insert(tensors.end(), inp_tensors.begin(), inp_tensors.end());
- tensors.insert(tensors.end(), out_tensors.begin(), out_tensors.end());
- return checkSameStride(tensors);
+ return validateKernelArgScalar(arg.type(), param, msg);
}
}
Fusion* fusion,
const at::ArrayRef<IValue>& inputs,
const c10::Device& device) {
- FUSER_PERF_SCOPE("executor_utils::ValidateKernelInputs");
+ FUSER_PERF_SCOPE("validateKernelInputs");
// This is necessary as we were traversing the fusion graph later in the check
FusionGuard fg(fusion);
Fusion* fusion,
const std::vector<at::Tensor>& outputs,
const c10::Device& device) {
- FUSER_PERF_SCOPE("executor_utils::ValidateKernelOutputs");
+ FUSER_PERF_SCOPE("validateKernelOutputs");
TORCH_INTERNAL_ASSERT(
fusion->outputs().size() != 0,
!mismatch, "Found one or more invalid arguments: ", msg.str());
}
-bool canVectorize(const IValue& aten_val, int word_size) {
- if (!aten_val.isTensor()) {
- return false;
- }
-
- const auto& aten_tensor = aten_val.toTensor();
-
- if (reinterpret_cast<size_t>(aten_tensor.data_ptr()) %
- (word_size * aten_tensor.dtype().itemsize()) !=
- 0) {
- return false;
- }
-
- for (size_t i = aten_tensor.ndimension(); i > 0; i--) {
- if (aten_tensor.size(i - 1) != 1) {
- if (aten_tensor.size(aten_tensor.ndimension() - 1) % word_size != 0 ||
- aten_tensor.stride(aten_tensor.ndimension() - 1) != 1) {
- return false;
- }
- break;
- }
- }
-
- for (auto stride : aten_tensor.strides()) {
- if (stride != 1 && stride % word_size != 0) {
- return false;
- }
- }
-
- return true;
-}
-
-bool canVectorize(
- TensorView* fusion_tv,
- int word_size,
- GpuLower& lower,
- kir::ExpressionEvaluator& expr_eval) {
- IterDomain* last_root_dim = nullptr;
- // TODO: Should this be rfactor instead of root??
- for (size_t i = fusion_tv->getRootDomain().size(); i > 0; i--) {
- auto r_id = fusion_tv->getRootDomain()[i - 1];
- if (r_id->isReduction() || r_id->isBroadcast()) {
- continue;
- }
- last_root_dim = r_id;
- break;
- }
-
- if (last_root_dim == nullptr) {
- return false;
- }
-
- auto last_dim_size =
- expr_eval.evaluate(lower.lowerValue(last_root_dim->extent()));
-
- if (!last_dim_size.has_value()) {
- return false;
- }
-
- if (last_dim_size.value() % word_size != 0) {
- return false;
- }
-
- return true;
-}
-
-// Misaligned vectorization check. Currently misaligned vectorization is limited
-// to global-register and register-global load/store patterns. However, this
-// could be improved to include shared memory.
-void validateVectorizedTensors(
- Fusion* fusion,
- const at::ArrayRef<IValue>& inputs,
- const std::vector<at::Tensor>& outputs,
- GpuLower& lower,
- kir::ExpressionEvaluator& expr_eval) {
- std::unordered_set<TensorView*> global_inp_misaligned_tv;
- std::unordered_set<TensorView*> global_out_misaligned_tv;
- std::unordered_map<TensorView*, int> tv_to_vector_word_size;
- // Find all vectorized tensors and their word size
- for (auto expr : fusion->exprs()) {
- if (!expr->isA<UnaryOp>() ||
- expr->as<UnaryOp>()->getUnaryOpType() != UnaryOpType::Set) {
- continue;
- }
- auto uop = expr->as<UnaryOp>();
- if (!uop->out()->isA<TensorView>() || !uop->in()->isA<TensorView>()) {
- continue;
- }
- auto out_tv = uop->out()->as<TensorView>();
- auto in_tv = uop->in()->as<TensorView>();
- IterDomain* vector_dim = nullptr;
- for (auto id : out_tv->domain()->domain()) {
- if (id->getParallelType() == ParallelType::Vectorize ||
- id->getParallelType() == ParallelType::MisalignedVectorize) {
- TORCH_INTERNAL_ASSERT(
- vector_dim == nullptr,
- "Found multiple vectorized dimensions on tensor ",
- out_tv);
- vector_dim = id;
- }
- }
- if (vector_dim == nullptr) {
- continue;
- }
- auto vector_word_size =
- expr_eval.evaluate(lower.lowerValue(vector_dim->extent()));
- TORCH_INTERNAL_ASSERT(
- vector_word_size.has_value(),
- "Non constant vector dimension found in ",
- out_tv);
- tv_to_vector_word_size[out_tv] = vector_word_size.value();
- tv_to_vector_word_size[in_tv] = vector_word_size.value();
-
- if (vector_dim->getParallelType() == ParallelType::MisalignedVectorize) {
- if (out_tv->getMemoryType() == MemoryType::Global &&
- in_tv->getMemoryType() == MemoryType::Local) {
- global_out_misaligned_tv.insert(out_tv);
- } else if (
- in_tv->getMemoryType() == MemoryType::Global &&
- out_tv->getMemoryType() == MemoryType::Local) {
- global_inp_misaligned_tv.insert(in_tv);
- } else {
- TORCH_INTERNAL_ASSERT(
- false,
- "Unsupported memory configuration for misaligned vectorization.");
- }
- }
- }
-
- // Check striding information on input and outputs as well as size information
- // of all
- std::vector<c10::IValue> inp_misaligned_tensors;
- std::vector<c10::IValue> out_misaligned_tensors;
- for (auto entry : tv_to_vector_word_size) {
- auto tv = entry.first;
- auto word_size = entry.second;
- if (tv->isFusionInput()) {
- auto inp_it =
- std::find(fusion->inputs().begin(), fusion->inputs().end(), tv);
- TORCH_INTERNAL_ASSERT(
- inp_it != fusion->inputs().end(),
- "Could not find ",
- tv,
- " in fusion inputs.");
- auto inp_pos = std::distance(fusion->inputs().begin(), inp_it);
- auto aten_inp = inputs[inp_pos];
-
- if (global_inp_misaligned_tv.find(tv) != global_inp_misaligned_tv.end()) {
- inp_misaligned_tensors.emplace_back(aten_inp);
- } else {
- TORCH_INTERNAL_ASSERT(
- canVectorize(aten_inp, word_size),
- "Error vectorizing, ",
- tv,
- " as input provided does not allowed vectorization by word size, ",
- word_size);
- }
- } else if (tv->isFusionOutput() && outputs.size() > 0) {
- auto out_it =
- std::find(fusion->outputs().begin(), fusion->outputs().end(), tv);
- TORCH_INTERNAL_ASSERT(
- out_it != fusion->outputs().end(),
- "Could not find ",
- tv,
- " in provided fusion outputs.");
- auto out_pos = std::distance(fusion->outputs().begin(), out_it);
- auto aten_out = outputs[out_pos];
-
- if (global_out_misaligned_tv.find(tv) != global_out_misaligned_tv.end()) {
- out_misaligned_tensors.emplace_back(aten_out);
- } else {
- TORCH_INTERNAL_ASSERT(
- canVectorize(aten_out, word_size),
- "Error vectorizing, ",
- tv,
- " as output provided does not allowed vectorization by word size, ",
- word_size);
- }
- } else {
- if (!tv_to_vector_word_size.count(tv)) {
- TORCH_INTERNAL_ASSERT(
- canVectorize(tv, word_size, lower, expr_eval),
- "Could not vectorize ",
- tv,
- " it's inner most dim is not a multiple of ",
- word_size);
- }
- }
- }
-
- // If input stride is non-contiguous + no outputs, return false
- TORCH_INTERNAL_ASSERT(
- checkValidMisalignedTensors(
- global_inp_misaligned_tv,
- global_out_misaligned_tv,
- inp_misaligned_tensors,
- out_misaligned_tensors),
- "All global tensors must have the same stride for misaligned vectorization.");
-}
-
-kir::ExpressionEvaluator bindKernelInputs(
- const at::ArrayRef<IValue>& aten_inputs,
- kir::Kernel* kernel) {
- FUSER_PERF_SCOPE("executor_utils::BindKernelInputs");
-
- TORCH_INTERNAL_ASSERT(
- kernel->inputs().size() == aten_inputs.size(),
- "Something went wrong configuring launch. Inputs no longer match.");
-
- kir::ExpressionEvaluator expr_eval;
- const auto& inputs = kernel->inputs();
-
- for (size_t i = 0; i < inputs.size(); i++) {
- const auto input = inputs[i];
-
- if (auto tensor_input = dynamic_cast<kir::TensorView*>(input)) {
- TORCH_INTERNAL_ASSERT(
- aten_inputs[i].isTensor(),
- "Something went wrong configuring launch. Inputs no longer match.");
-
- const auto aten_tensor = aten_inputs[i].toTensor();
- const auto root_domain =
- kir::TensorDomain::noReductions(tensor_input->domain()->rootDomain());
- TORCH_INTERNAL_ASSERT(
- aten_tensor.ndimension() == static_cast<int>(root_domain.size()),
- "Something went wrong configuring launch. Inputs no longer match.");
-
- for (size_t dim = 0; dim < root_domain.size(); dim++) {
- const auto extent = root_domain[dim]->extent();
- const auto value = aten_tensor.sizes()[dim];
- const auto prev_value = expr_eval.evaluate(extent);
- if (prev_value.has_value()) {
- TORCH_CHECK(
- *prev_value == value,
- "Attempting to bind ",
- kir::toString(extent),
- " to ",
- value,
- "but it's already set to ",
- *prev_value);
- } else {
- expr_eval.bind(extent, value);
- }
- }
- // NOLINTNEXTLINE: https://bugs.llvm.org/show_bug.cgi?id=48525
- } else if (input->isScalar() && input->dtype() == DataType::Int) {
- TORCH_INTERNAL_ASSERT(
- aten_inputs[i].type()->kind() == c10::TypeKind::IntType);
- expr_eval.bind(input, aten_inputs[i].toInt());
- }
- }
-
- return expr_eval;
-}
-
-ExpressionEvaluator bindFusionInputs(
+StatefulExpressionEvaluator statefulBindInputs(
const at::ArrayRef<IValue>& aten_inputs,
- Fusion* fusion) {
- FUSER_PERF_SCOPE("executor_utils::BindFusionInputs");
+ Fusion* fusion,
+ GpuLower* lower) {
+ FUSER_PERF_SCOPE("statefulBindInputs");
TORCH_INTERNAL_ASSERT(
fusion->inputs().size() == aten_inputs.size(),
"Something went wrong configuring launch. Inputs no longer match.");
- ExpressionEvaluator evaluator(fusion);
- auto inputs = fusion->inputs();
+ auto fusion_inputs = fusion->inputs();
+ StatefulExpressionEvaluator evaluator(fusion);
// This should probably move to EvaluationContext as we may want to bind
// input values frequently. Bind fusion input values to runtime values.
for (const auto i : c10::irange(fusion->inputs().size())) {
- if (inputs[i]->getValType() == ValType::TensorView) {
- TensorView* cg_tensor = inputs[i]->as<TensorView>();
+ if (fusion->inputs()[i]->getValType() == ValType::TensorView) {
+ TensorView* cg_tensor = fusion->inputs()[i]->as<TensorView>();
TORCH_INTERNAL_ASSERT(
aten_inputs[i].isTensor(),
"Something went wrong configuring launch. Inputs no longer match.");
for (const auto dim : c10::irange(root_dom.size())) {
- const auto extent = root_dom[dim]->extent();
- const auto value = aten_tensor.sizes()[dim];
- const auto prev_value = evaluator.evaluate(extent);
- if (prev_value.has_value()) {
- TORCH_CHECK(
- *prev_value == value,
- "Attempting to bind ",
- extent,
- " to ",
- value,
- "but it's already set to ",
- *prev_value);
- } else {
- evaluator.bind(extent, value);
- }
+ evaluator.safeBind(
+ root_dom[dim]->extent(), aten_tensor.sizes()[dim], lower);
}
} else if (
- inputs[i]->getValType().value() == ValType::Scalar &&
- inputs[i]->getDataType().value() == DataType::Int) {
+ fusion->inputs()[i]->getValType().value() == ValType::Scalar &&
+ fusion->inputs()[i]->getDataType().value() == DataType::Int) {
TORCH_INTERNAL_ASSERT(
aten_inputs[i].type()->kind() == c10::TypeKind::IntType);
- evaluator.bind(inputs[i], aten_inputs[i].toInt());
+ evaluator.safeBind(fusion->inputs()[i], aten_inputs[i].toInt(), lower);
}
}
return evaluator;
NvrtcFunction nvrtcCompile(
const std::string& code,
const std::string& func_name,
- int id,
- c10::optional<int> opt_block_size) {
- FUSER_PERF_SCOPE("executor_utils::NVRTC");
+ int id) {
+ FUSER_PERF_SCOPE("NVRTC");
// lazily construct context if non-existing yet;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
nvrtcProgram program; // NOLINT(cppcoreguidelines-init-variables)
{
- FUSER_PERF_SCOPE("executor_utils::NvrtcCreateProgram");
+ FUSER_PERF_SCOPE("nvrtcCreateProgram");
AT_CUDA_NVRTC_CHECK(at::globalContext().getNVRTC().nvrtcCreateProgram(
&program, code.c_str(), nullptr, 0, nullptr, nullptr));
}
ResourceGuard holdProgram([&] {
- FUSER_PERF_SCOPE("executor_utils::NvrtcDestroyProgram");
+ FUSER_PERF_SCOPE("nvrtcDestroyProgram");
AT_CUDA_NVRTC_CHECK(
at::globalContext().getNVRTC().nvrtcDestroyProgram(&program));
});
args.push_back("-hip-pch");
#endif
#else
-#if CUDA_VERSION < 11010
- // compile to sass is not allowed prior to CUDA 11.1
- compile_to_sass = false;
-#endif
- // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
- // which gives better backwards compatibility to work on older driver,
- // (since older driver doesn't necessrily recognize PTX emitted by new
- // toolkit);
- // Meanwhile, for forward compatibility (future device with
- // `unsupported_arch==True`), since SASS are not necessarily compatible,
- // we fallback to PTX instead.
const std::string compute = std::string("--gpu-architecture=") +
- (compile_to_sass ? "sm_" : "compute_") + std::to_string(major) +
- std::to_string(minor);
+#if CUDA_VERSION >= 11010
+ // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
+ // which gives better backwards compatibility to work on older driver,
+ // (since older driver doesn't necessrily recognize PTX emitted by new
+ // toolkit);
+ // Meanwhile, for forward compatibility (future device with
+ // `unsupported_arch==True`), since SASS are not necessarily compatible,
+ // we fallback to PTX instead.
+ (compile_to_sass ? "sm_" : "compute_") +
+#else
+ "compute_" +
+#endif
+ std::to_string(major) + std::to_string(minor);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<const char*> args = {
"--std=c++14", compute.c_str(), "-default-device"};
#endif
- const char* disable_fastmath = getenv("PYTORCH_NVFUSER_DISABLE_FASTMATH");
- if (!disable_fastmath || (atoi(disable_fastmath) == 0)) {
- args.push_back("--use_fast_math");
- } else {
- TORCH_WARN_ONCE(
- "fast math disabled in nvfuser, try set `PYTORCH_NVFUSER_DISABLE_FASTMATH=0`");
- }
-
- const char* disable_fma = getenv("PYTORCH_NVFUSER_DISABLE_FMA");
+ const char* disable_fma = getenv("PYTORCH_CUDA_FUSER_DISABLE_FMA");
// int disable_fma_flag = disable_fma ? atoi(disable_fma) : 0;
if (disable_fma && atoi(disable_fma)) {
#ifdef __HIP_PLATFORM_HCC__
#endif
}
-#ifndef NDEBUG
- // Add line info to generated kernels
- args.push_back("-lineinfo");
-#else
- // Avoid excessive register usage from assertion
- args.push_back("-DNDEBUG");
-#endif
-
- const char* ptxas_opt_level = getenv("PYTORCH_NVFUSER_JIT_OPT_LEVEL");
- std::string jit_opt_level = "-O";
+ const char* ptxas_opt_level = getenv("PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL");
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ uint32_t jit_opt_level;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<CUjit_option> options;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<void*> option_vals;
- std::vector<char> info_log;
- unsigned int log_size = 8196;
-
- if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) {
- // show register usage in compilation log
- if (compile_to_sass) {
- args.push_back("--ptxas-options");
- args.push_back("--verbose");
- } else {
- options.push_back(CU_JIT_LOG_VERBOSE);
- option_vals.push_back((void*)1);
- info_log.reserve(log_size);
-
- options.push_back(CU_JIT_INFO_LOG_BUFFER);
- option_vals.push_back((void*)info_log.data());
-
- options.push_back(CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES);
- option_vals.push_back((void*)(long)log_size);
- }
- }
if (ptxas_opt_level) {
int val = atoi(ptxas_opt_level);
if (val <= 4 && val >= 0) {
- if (compile_to_sass) {
- jit_opt_level += std::to_string(val);
- args.push_back("--ptxas-options");
- args.push_back(jit_opt_level.c_str());
- } else {
- options.push_back(CU_JIT_OPTIMIZATION_LEVEL);
- option_vals.push_back((void*)(intptr_t)val);
- }
+ jit_opt_level = static_cast<uint32_t>(val);
+ options.push_back(CU_JIT_OPTIMIZATION_LEVEL);
+ option_vals.emplace_back(&jit_opt_level);
} else {
TORCH_WARN_ONCE(
- "acceptable range for PYTORCH_NVFUSER_JIT_OPT_LEVEL is between 0 and 4, but received ",
- val,
+ "acceptable range for PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL is between 0 and 4, but received ",
+ jit_opt_level,
", ignoring the option");
}
}
-#ifndef USE_ROCM
- // keeping the string outside the loop for lifetime
- std::string max_register_usage = "--maxrregcount=";
- uint32_t max_register = 0;
- if (opt_block_size.has_value() && opt_block_size.value() > 0) {
- int num_partition = 0;
- int reg_allocation_granularity = 0;
- int max_regs_per_thread = 0;
- cudaOccDeviceProp occ_prop(*prop);
- cudaOccSubPartitionsPerMultiprocessor(&num_partition, &occ_prop);
- cudaOccRegAllocationGranularity(®_allocation_granularity, &occ_prop);
- cudaOccRegAllocationMaxPerThread(&max_regs_per_thread, &occ_prop);
- int warp_size = prop->warpSize;
- int num_warps = ceilDiv(opt_block_size.value(), warp_size);
-
- // warps could be distributed unevenly across partition
- int max_warps_per_sm_partition = ceilDiv(num_warps, num_partition);
- // registers are evenly distributed across partitions, partition with most
- // wraps determins the maximum register available per warp
- int max_reg_per_warp =
- prop->regsPerBlock / num_partition / max_warps_per_sm_partition;
- // clamp down to register allocation granularity at warp level
- int effective_max_reg_per_warp = max_reg_per_warp /
- reg_allocation_granularity * reg_allocation_granularity;
- max_register = static_cast<uint32_t>(
- std::min(effective_max_reg_per_warp / warp_size, max_regs_per_thread));
-
- if (compile_to_sass) {
- max_register_usage += std::to_string(max_register);
- args.push_back(max_register_usage.c_str());
- } else {
- options.push_back(CU_JIT_MAX_REGISTERS);
- option_vals.push_back((void*)(intptr_t)max_register);
- }
- }
-#endif
-
at::globalContext().getNVRTC().nvrtcAddNameExpression(
program, func_name.c_str());
{
- FUSER_PERF_SCOPE("executor_utils::Nvrtc::CompileProgram");
+ FUSER_PERF_SCOPE("nvrtcCompileProgram");
const auto result = at::globalContext().getNVRTC().nvrtcCompileProgram(
program, args.size(), args.data());
TORCH_INTERNAL_ASSERT(
false, code.c_str(), "\nCUDA NVRTC compile error: ", log.data());
- } else if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) {
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- size_t logsize;
- at::globalContext().getNVRTC().nvrtcGetProgramLogSize(program, &logsize);
- std::vector<char> log(logsize);
- at::globalContext().getNVRTC().nvrtcGetProgramLog(program, log.data());
-
- std::cout << log.data() << std::endl;
}
AT_CUDA_NVRTC_CHECK(result);
std::vector<char> ptx;
{
- FUSER_PERF_SCOPE("executor_utils::Nvrtc::GetPTX");
+ FUSER_PERF_SCOPE("get PTX");
#if CUDA_VERSION >= 11010
// compile_to_sass determines whether we are generating SASS or PTX, hence
// the different API.
// TODO: We do go through different code path, should investigate whether this
// has an impact on generated binary.
+ const char* prefix_env = getenv("PYTORCH_CUDA_FUSER_CUBIN");
#ifndef __HIP_PLATFORM_HCC__
- const char* prefix_env = getenv("PYTORCH_NVFUSER_CUBIN");
if (prefix_env) {
- FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadCUBIN");
+#if CUDA_VERSION >= 11010
+ TORCH_CHECK(
+ !compile_to_sass,
+ "PYTORCH_NVFUSER_CUBIN cannot be used when compile direct to SASS. Please set PYTORCH_NVFUSER_CUBIN to empty");
+#endif
+ FUSER_PERF_SCOPE("load CUBIN");
// Output ptx file
- std::stringstream output_file_name;
- output_file_name << prefix_env << "_" << id
- << (compile_to_sass ? ".cubin" : ".ptx");
- std::ofstream outputFile(output_file_name.str().c_str(), std::ios::out);
- if (outputFile.is_open()) {
- outputFile.write(ptx.data(), ptx.size());
- outputFile.close();
+ std::stringstream ptx_file_name;
+ ptx_file_name << prefix_env << "_" << id << ".ptx";
+ std::ofstream myPtxFile(ptx_file_name.str().c_str(), std::ios::out);
+ if (myPtxFile.is_open()) {
+ myPtxFile.write(ptx.data(), ptx.size());
+ myPtxFile.close();
}
- if (compile_to_sass) {
- FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadPTX");
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ CUlinkState linkState;
- // load sass directly
- AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx(
- &(compiled_kernel_.module),
- ptx.data(),
- options.size(),
- options.data(),
- option_vals.data()));
- } else {
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- CUlinkState linkState;
-
- AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkCreate(
- // 0, nullptr, nullptr, &linkState));
- options.size(),
- options.data(),
- option_vals.data(),
- &linkState));
-
- AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkAddData(
- linkState,
- CU_JIT_INPUT_PTX,
- ptx.data(),
- ptx_size,
- "compiling PTX",
- 0,
- nullptr,
- nullptr));
-
- if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) {
- std::cout << info_log.data() << std::endl;
- }
+ AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkCreate(
+ 0, nullptr, nullptr, &linkState));
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- size_t cubinSize;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- void* cubin;
- AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkComplete(
- linkState, &cubin, &cubinSize));
+ AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkAddData(
+ linkState,
+ CU_JIT_INPUT_PTX,
+ ptx.data(),
+ ptx_size,
+ "compiling PTX",
+ options.size(),
+ options.data(),
+ option_vals.data()));
- // Output binary file
- std::stringstream cubin_file_name;
- cubin_file_name << prefix_env << "_" << id << ".cubin";
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ size_t cubinSize;
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ void* cubin;
+ AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLinkComplete(
+ linkState, &cubin, &cubinSize));
- std::ofstream myCubinFile(
- cubin_file_name.str().c_str(), std::ios::out | std::ios::binary);
+ // Output binary file
+ std::stringstream cubin_file_name;
+ cubin_file_name << prefix_env << "_" << id << ".cubin";
- if (myCubinFile.is_open()) {
- myCubinFile.write(static_cast<const char*>(cubin), cubinSize);
- myCubinFile.close();
- }
- // load compiled cubin
- // AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData(
- // &(compiled_kernel_.module), cubin));
- AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx(
- &(compiled_kernel_.module),
- cubin,
- options.size(),
- options.data(),
- option_vals.data()));
+ std::ofstream myCubinFile(
+ cubin_file_name.str().c_str(), std::ios::out | std::ios::binary);
+
+ if (myCubinFile.is_open()) {
+ myCubinFile.write(static_cast<const char*>(cubin), cubinSize);
+ myCubinFile.close();
}
+ // load compiled cubin
+ AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData(
+ &(compiled_kernel_.module), cubin));
} else {
- FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadPTX");
+ FUSER_PERF_SCOPE("load PTX");
// load ptx directly
AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx(
options.size(),
options.data(),
option_vals.data()));
-
- if (!compile_to_sass &&
- isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) {
- std::cout << info_log.data() << std::endl;
- }
}
#else
// load ptx directly
#include <c10/core/DeviceType.h>
#include <c10/util/Exception.h>
-#include <ATen/cuda/CUDAContext.h>
+#include <cuda.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <string>
// Include all the functions we might need in generated code
std::string kernelPreamble();
-// TODO(kir): rewrite in terms of Kernel inputs
void validateKernelInputs(
Fusion* fusion,
const at::ArrayRef<IValue>& inputs,
const c10::Device& device);
-// TODO(kir): rewrite in terms of Kernel outputs
void validateKernelOutputs(
Fusion* fusion,
const std::vector<at::Tensor>& outputs,
const c10::Device& device);
-// Returns if vectorizing the aten value by word size is possible
-bool canVectorize(const IValue& aten_val, int word_size);
-
-// Returns if vectorizing the aten value by word size is possible
-bool canVectorize(
- TensorView* fusion_tv,
- int word_size,
- GpuLower& lower,
- kir::ExpressionEvaluator& expr_eval);
-
-// TODO(kir): rewrite in terms of Kernel tensors
-void validateVectorizedTensors(
- Fusion* fusion,
- const at::ArrayRef<IValue>& inputs,
- const std::vector<at::Tensor>& outputs,
- GpuLower& lower,
- kir::ExpressionEvaluator& expr_eval);
-
-//! Bind kernel input values to runtime values
-kir::ExpressionEvaluator bindKernelInputs(
+StatefulExpressionEvaluator statefulBindInputs(
const at::ArrayRef<IValue>& aten_inputs,
- kir::Kernel* kernel);
-
-//! Bind fusion input values to runtime values
-TORCH_CUDA_CU_API ExpressionEvaluator
-bindFusionInputs(const at::ArrayRef<IValue>& aten_inputs, Fusion* fusion);
+ Fusion* fusion,
+ GpuLower* lower = nullptr);
struct NvrtcFunction {
CUmodule module = CUmodule();
NvrtcFunction nvrtcCompile(
const std::string& code,
const std::string& func_name,
- int id,
- c10::optional<int> opt_block_size = c10::nullopt);
+ int id);
} // namespace executor_utils
} // namespace cuda
-
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
namespace fuser {
namespace cuda {
-void ExpressionEvaluator::bind(Val* value, Int::ScalarType concrete_value) {
- TORCH_CHECK(value->isAnInt());
- auto val = value->getInt();
- if (val.has_value() && val.value() == concrete_value) {
- return;
+void StatefulExpressionEvaluator::safeBind(
+ Val* value,
+ Int::ScalarType concrete_value,
+ GpuLower* lower) {
+ auto already_concrete_val = getValue(value);
+
+ if (already_concrete_val.has_value()) {
+ TORCH_INTERNAL_ASSERT(
+ concrete_value == already_concrete_val.value(),
+ "Tried to bind ",
+ value,
+ " to ",
+ " concrete value, but it's already set to ",
+ already_concrete_val.value());
+ } else {
+ TORCH_INTERNAL_ASSERT(
+ value->getOrigin() == nullptr,
+ "Tried to bind to a value that is computed in the fusion IR. ",
+ "Can only bind to symbolic values to the fusion that do not have an origin expr.");
+
+ bindings_[value] = concrete_value;
}
- TORCH_CHECK(!value->isConstScalar(), "Tried to bind to a constant value");
- TORCH_CHECK(
- value->definition() == nullptr,
- "Tried to bind to a value that is computed in the fusion IR");
- known_values_[value] = concrete_value;
-}
-c10::optional<Int::ScalarType> ExpressionEvaluator::evaluate(Val* value) {
- FUSER_PERF_SCOPE("ExpressionEvaluator::evaluate");
- auto maybe_concrete_value = getValue(value);
- if (!maybe_concrete_value.has_value()) {
- if (value->definition() != nullptr) {
- OptOutDispatch::handle(value->definition());
- maybe_concrete_value = getValue(value);
+ if (lower != nullptr) {
+ // TODO(kir): we should not need to lower (or mutate the IR in any way)
+ // during expression evaluation
+ auto lowered_val = lower->getLowerValue(value);
+ already_concrete_val = getValue(lowered_val);
+
+ if (already_concrete_val.has_value()) {
+ TORCH_INTERNAL_ASSERT(
+ concrete_value == already_concrete_val.value(),
+ "Tried to bind ",
+ lowered_val,
+ " to ",
+ " concrete value, but it's already set to ",
+ already_concrete_val.value());
+ } else {
+ TORCH_INTERNAL_ASSERT(
+ lowered_val->getOrigin() == nullptr,
+ "Tried to bind to a value that is computed in the fusion IR. ",
+ "Can only bind to symbolic values to the fusion that do not have an origin expr.");
+
+ bindings_[lowered_val] = concrete_value;
}
}
- return maybe_concrete_value;
}
-void ExpressionEvaluator::print() const {
+c10::optional<Int::ScalarType> StatefulExpressionEvaluator::inferValue(
+ Val* value) {
+ FUSER_PERF_SCOPE("inferValue");
+ return maybeHandle(value);
+}
+
+void StatefulExpressionEvaluator::print() const {
std::cout << "\nEvaluation context\n";
std::cout << "--------------------\n";
- for (const auto& kv : known_values_) {
- TORCH_INTERNAL_ASSERT(!kv.first->isConstScalar());
- std::cout << kv.first << " = " << kv.second << " ; "
- << *kv.first->getValType() << "\n";
+ for (const auto& kv : bindings_) {
+ std::cout << kv.first << " = " << kv.second;
+ if (kv.first->isConstScalar()) {
+ std::cout << " ; original value = "
+ << kv.first->as<Int>()->value().value();
+ }
+ std::cout << " ; " << *kv.first->getValType() << "\n";
}
std::cout << "--------------------\n\n";
}
-c10::optional<Int::ScalarType> ExpressionEvaluator::getValue(Val* value) {
+c10::optional<Int::ScalarType> StatefulExpressionEvaluator::getValue(
+ Val* value) {
TORCH_INTERNAL_ASSERT(
value->isAnInt(),
"Expression Evaluation does not support values other than integers at this time.");
- if (value->getValType().value() == ValType::Scalar) {
- if (value->as<Int>()->value().has_value()) {
- return value->as<Int>()->value();
+ switch (value->getValType().value()) {
+ case ValType::Scalar:
+ if (value->as<Int>()->value().has_value()) {
+ return value->as<Int>()->value();
+ }
+ break;
+ case ValType::KirScalar:
+ if (value->as<kir::Int>()->value().has_value()) {
+ return value->as<kir::Int>()->value();
+ }
+ break;
+ default:
+ break;
+ }
+
+ const auto it = bindings_.find(value);
+ return it != bindings_.end() ? c10::optional<Int::ScalarType>(it->second)
+ : c10::nullopt;
+}
+
+c10::optional<Int::ScalarType> StatefulExpressionEvaluator::maybeHandle(
+ Val* val) {
+ auto maybe_concrete_value = getValue(val);
+ if (!maybe_concrete_value.has_value()) {
+ auto origin = val->getOrigin();
+ if (origin != nullptr) {
+ handle(origin);
+ maybe_concrete_value = getValue(val);
}
}
+ return maybe_concrete_value;
+}
- const auto it = known_values_.find(value);
- return it != known_values_.end() ? c10::optional<Int::ScalarType>(it->second)
- : c10::nullopt;
+void StatefulExpressionEvaluator::handle(UnaryOp* uop) {
+ const auto in = maybeHandle(uop->in());
+ if (in.has_value()) {
+ switch (uop->getUnaryOpType()) {
+ case UnaryOpType::Neg:
+ bindings_[uop->out()] = -*in;
+ break;
+ case UnaryOpType::Cast:
+ bindings_[uop->out()] = *in;
+ break;
+ default:
+ TORCH_CHECK(!"Unexpected operator type");
+ }
+ }
+}
+
+void StatefulExpressionEvaluator::handle(BinaryOp* bop) {
+ const auto lhs = maybeHandle(bop->lhs());
+ const auto rhs = maybeHandle(bop->rhs());
+ if (lhs.has_value() && rhs.has_value()) {
+ switch (bop->getBinaryOpType()) {
+ case BinaryOpType::Add:
+ bindings_[bop->out()] = *lhs + *rhs;
+ break;
+ case BinaryOpType::Sub:
+ bindings_[bop->out()] = *lhs - *rhs;
+ break;
+ case BinaryOpType::Mul:
+ bindings_[bop->out()] = *lhs * *rhs;
+ break;
+ case BinaryOpType::Div:
+ TORCH_CHECK(*rhs != 0);
+ bindings_[bop->out()] = *lhs / *rhs;
+ break;
+ case BinaryOpType::Mod:
+ TORCH_CHECK(*rhs != 0);
+ bindings_[bop->out()] = *lhs % *rhs;
+ break;
+ case BinaryOpType::CeilDiv:
+ TORCH_CHECK(*rhs != 0);
+ bindings_[bop->out()] = (*lhs + *rhs - 1) / *rhs;
+ break;
+ case BinaryOpType::And:
+ bindings_[bop->out()] = Int::ScalarType(*lhs && *rhs);
+ break;
+ default:
+ TORCH_CHECK(!"Unexpected operator type");
+ }
+ }
}
-void ExpressionEvaluator::handle(UnaryOp* uop) {
- const auto in = evaluate(uop->in());
+void StatefulExpressionEvaluator::handle(kir::UnaryOp* uop) {
+ const auto in = maybeHandle(uop->in());
if (in.has_value()) {
switch (uop->getUnaryOpType()) {
case UnaryOpType::Neg:
- known_values_[uop->out()] = -*in;
+ bindings_[uop->out()] = -*in;
break;
case UnaryOpType::Cast:
- known_values_[uop->out()] = *in;
+ bindings_[uop->out()] = *in;
break;
default:
TORCH_CHECK(!"Unexpected operator type");
}
}
-void ExpressionEvaluator::handle(BinaryOp* bop) {
- const auto lhs = evaluate(bop->lhs());
- const auto rhs = evaluate(bop->rhs());
+void StatefulExpressionEvaluator::handle(kir::BinaryOp* bop) {
+ const auto lhs = maybeHandle(bop->lhs());
+ const auto rhs = maybeHandle(bop->rhs());
if (lhs.has_value() && rhs.has_value()) {
switch (bop->getBinaryOpType()) {
case BinaryOpType::Add:
- known_values_[bop->out()] = *lhs + *rhs;
+ bindings_[bop->out()] = *lhs + *rhs;
break;
case BinaryOpType::Sub:
- known_values_[bop->out()] = *lhs - *rhs;
+ bindings_[bop->out()] = *lhs - *rhs;
break;
case BinaryOpType::Mul:
- known_values_[bop->out()] = *lhs * *rhs;
+ bindings_[bop->out()] = *lhs * *rhs;
break;
case BinaryOpType::Div:
TORCH_CHECK(*rhs != 0);
- known_values_[bop->out()] = *lhs / *rhs;
+ bindings_[bop->out()] = *lhs / *rhs;
break;
case BinaryOpType::Mod:
TORCH_CHECK(*rhs != 0);
- known_values_[bop->out()] = *lhs % *rhs;
+ bindings_[bop->out()] = *lhs % *rhs;
break;
case BinaryOpType::CeilDiv:
TORCH_CHECK(*rhs != 0);
- known_values_[bop->out()] = (*lhs + *rhs - 1) / *rhs;
+ bindings_[bop->out()] = (*lhs + *rhs - 1) / *rhs;
break;
case BinaryOpType::And:
- known_values_[bop->out()] = Int::ScalarType(*lhs && *rhs);
+ bindings_[bop->out()] = Int::ScalarType(*lhs && *rhs);
break;
default:
TORCH_CHECK(!"Unexpected operator type");
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
+#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <c10/util/Optional.h>
namespace fuser {
namespace cuda {
-//! Calculate Fusion IR expressions
-class TORCH_CUDA_CU_API ExpressionEvaluator : private OptOutDispatch {
+class TORCH_CUDA_CU_API StatefulExpressionEvaluator : private OptOutDispatch {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- explicit ExpressionEvaluator(Fusion* fusion) : fusion_(fusion) {}
+ explicit StatefulExpressionEvaluator(Fusion* fusion) : fusion_(fusion) {}
- //! Returns the associated fusion object
Fusion* fusion() const {
return fusion_;
}
- //! Bind a concrete value to an IR variable
- void bind(Val* value, Int::ScalarType concrete_value);
+ void safeBind(
+ Val* value,
+ Int::ScalarType concrete_value,
+ GpuLower* lower = nullptr);
- //! Try to evaluate a Fusion IR value
- c10::optional<Int::ScalarType> evaluate(Val* value);
+ // Returns value if found in mapping, otherwise returns c10::nullopt
+ c10::optional<Int::ScalarType> getValue(Val* value);
+
+ // Checks if value is already infered, returns infered value if so, otherwise
+ // runs traversal on value. Warning: should not be called in traversal.
+ c10::optional<Int::ScalarType> inferValue(Val* value);
- //! Debugging helper, prints all the currently known values
+ // Debugging helper, prints all the currently set values
void print() const;
private:
- c10::optional<Int::ScalarType> getValue(Val* value);
+ using OptOutDispatch::handle;
+
+ void handle(Expr* expr) override {
+ switch (expr->getExprType().value()) {
+ case ExprType::UnaryOp:
+ handle(expr->as<UnaryOp>());
+ break;
+ case ExprType::BinaryOp:
+ handle(expr->as<BinaryOp>());
+ break;
+ case ExprType::KirUnaryOp:
+ handle(expr->as<kir::UnaryOp>());
+ break;
+ case ExprType::KirBinaryOp:
+ handle(expr->as<kir::BinaryOp>());
+ break;
+ default:
+ TORCH_INTERNAL_ASSERT(
+ false,
+ "Cannot handle Expr type: ",
+ expr->getExprType().value(),
+ " in stateful expression evaluator.");
+ }
+ }
+
+ void handle(UnaryOp*) override;
+ void handle(BinaryOp*) override;
+
+ // TODO(kir): remove this
+ void handle(kir::UnaryOp*) override;
+ void handle(kir::BinaryOp*) override;
- void handle(UnaryOp*) override final;
- void handle(BinaryOp*) override final;
+ c10::optional<Int::ScalarType> maybeHandle(Val*);
private:
- std::unordered_map<const Val*, Int::ScalarType> known_values_;
+ std::unordered_map<const Val*, Int::ScalarType> bindings_;
Fusion* fusion_ = nullptr;
};
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/codegen.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
+
+#include <torch/csrc/jit/codegen/cuda/codegen.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
+#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
+// TODO(kir): only needed until we can fix Fusion::origin()
#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
namespace torch {
namespace fuser {
namespace cuda {
-static thread_local Fusion* ACTIVE_FUSION = nullptr; // NOLINT
+static thread_local Fusion* ACTIVE_FUSION = nullptr;
FusionGuard::FusionGuard(Fusion* fusion) {
prev_fusion = ACTIVE_FUSION;
return ACTIVE_FUSION;
}
-TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept {
+void swap(Fusion& a, Fusion& b) noexcept {
FUSER_PERF_SCOPE("Fusion swap");
using std::swap;
swap(a.val_type_name_map_, b.val_type_name_map_);
swap(a.expr_name_counter_, b.expr_name_counter_);
+ swap(a.origin_, b.origin_);
+ swap(a.uses_, b.uses_);
+
swap(a.inputs_, b.inputs_);
swap(a.outputs_, b.outputs_);
- swap(a.io_alias_, b.io_alias_);
-
// Fixup the Statement::fusion_ links for a
for (auto val : a.val_set_) {
val->fusion_ = &a;
for (auto expr : b.expr_set_) {
expr->fusion_ = &b;
}
-}
-Fusion::Fusion(const Fusion& other) {
- FUSER_PERF_SCOPE("Fusion copy");
- Fusion::copy(&other, this);
-}
+ // Lowered IR nodes
+ swap(a.lowered_val_set_, b.lowered_val_set_);
+ swap(a.lowered_expr_set_, b.lowered_expr_set_);
+ swap(a.lowered_origin_, b.lowered_origin_);
-std::unique_ptr<SegmentedFusion> Fusion::segment(
- const at::ArrayRef<IValue>& inputs) {
- FUSER_PERF_SCOPE("Segment Fusion");
- return SegmentCandidateFinder::segment(this, inputs);
+ for (auto val : a.lowered_val_set_) {
+ val->fusion_ = &a;
+ }
+ for (auto expr : a.lowered_expr_set_) {
+ expr->fusion_ = &a;
+ }
+ for (auto val : b.lowered_val_set_) {
+ val->fusion_ = &b;
+ }
+ for (auto expr : b.lowered_expr_set_) {
+ expr->fusion_ = &b;
+ }
}
-IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
- to->clear();
- IrCloner ir_cloner(to);
+Fusion::Fusion(const Fusion& other) {
+ FUSER_PERF_SCOPE("Fusion copy");
- for (auto val : from->val_set_) {
- to->val_set_.insert(ir_cloner.clone(val));
- }
+ IrCloner ir_cloner(this);
- for (auto expr : from->expr_set_) {
- to->expr_set_.insert(ir_cloner.clone(expr));
+ for (auto val : other.val_set_) {
+ val_set_.insert(ir_cloner.clone(val));
}
- for (auto val : from->val_deque_) {
- to->val_deque_.push_back(ir_cloner.clone(val));
+ for (auto expr : other.expr_set_) {
+ expr_set_.insert(ir_cloner.clone(expr));
}
- for (auto val : from->val_set_) {
- ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
- ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
+ for (auto val : other.val_deque_) {
+ val_deque_.push_back(ir_cloner.clone(val));
}
- to->val_type_name_map_ = from->val_type_name_map_;
- to->expr_name_counter_ = from->expr_name_counter_;
+ val_type_name_map_ = other.val_type_name_map_;
+ expr_name_counter_ = other.expr_name_counter_;
- to->inputs_ = ir_cloner.clone(from->inputs_);
- to->outputs_ = ir_cloner.clone(from->outputs_);
+ for (const auto& kv : other.origin_) {
+ auto val = ir_cloner.clone(kv.first);
+ auto expr = ir_cloner.clone(kv.second);
+ origin_.insert({val, expr});
+ }
- // TODO: put this into ir_cloner instead
- for (const auto& entry : from->io_alias_) {
- Val* copied_output = ir_cloner.clone(entry.first);
- Val* copied_input = ir_cloner.clone(entry.second);
- to->io_alias_[copied_output] = copied_input;
+ for (const auto& kv : other.uses_) {
+ auto val = ir_cloner.clone(kv.first);
+ std::unordered_set<Expr*> val_uses;
+ for (auto expr : kv.second) {
+ val_uses.insert(ir_cloner.clone(expr));
+ }
+ uses_.insert({val, std::move(val_uses)});
}
- return ir_cloner;
+ inputs_ = ir_cloner.clone(other.inputs_);
+ outputs_ = ir_cloner.clone(other.outputs_);
}
Fusion::Fusion(Fusion&& other) noexcept {
expr_name_counter_ = 0;
+ origin_.clear();
+ uses_.clear();
+
inputs_.clear();
outputs_.clear();
- io_alias_.clear();
+ // Lowered IR nodes
+ for (auto ptr : lowered_val_set_) {
+ delete ptr;
+ }
+ for (auto ptr : lowered_expr_set_) {
+ delete ptr;
+ }
+ lowered_val_set_.clear();
+ lowered_expr_set_.clear();
+ lowered_origin_.clear();
}
void Fusion::removeExpr(Expr* expr) {
// that removing something that doesn't exist simply does nothing. For now,
// we're going with the strictest model which errors.
- for (auto out : expr->outputs()) {
- out->setDefinition(nullptr);
- }
+ for (auto out : expr->outputs())
+ if (origin_.find(out) != origin_.end())
+ if (origin_.find(out)->second == expr)
+ origin_.erase(out);
for (auto inp : expr->inputs()) {
- auto uses_copy = inp->uses();
- auto it = std::find(uses_copy.begin(), uses_copy.end(), expr);
- if (it != uses_copy.end()) {
- uses_copy.erase(it);
- inp->setUses(uses_copy);
+ if (uses_.find(inp) != uses_.end()) {
+ if (uses_.find(inp)->second.find(expr) != uses_.find(inp)->second.end()) {
+ uses_.find(inp)->second.erase(expr);
+ }
}
}
void Fusion::removeVal(Val* val) {
assertInFusion(val, "Cannot remove val ");
- TORCH_CHECK(
- !val->isFusionInput(),
- "Cannot remove val as it is an input of the fusion.");
- TORCH_CHECK(
- !val->isFusionOutput(),
- "Cannot remove val as it is an output of the fusion.");
+ for (Val* inp : inputs())
+ if (val->sameAs(inp))
+ TORCH_CHECK(false, "Cannot remove val as it is an input of the fusion.");
+
+ for (Val* out : outputs())
+ if (val->sameAs(out))
+ TORCH_CHECK(false, "Cannot remove val as it is an output of the fusion.");
- Expr* orig = val->definition();
+ Expr* orig = origin(val);
if (orig != nullptr)
- removeExpr(val->definition());
+ removeExpr(origin(val));
for (Expr* use : unordered_uses(val))
removeExpr(use);
if (input->getValType().value() == ValType::TensorView) {
auto tv = input->as<TensorView>();
+ if (tv->hasReduction()) {
+ TORCH_WARN_ONCE(
+ "Registered input ",
+ input,
+ " has a reduction axis, but this does nothing in the fusion.");
+ }
tv->setMemoryType(MemoryType::Global);
}
+ TORCH_INTERNAL_ASSERT(
+ input->getOrigin() == nullptr,
+ input,
+ " cannot be registered as an input as it is used as an output of an expression (",
+ input->getOrigin(),
+ ").");
inputs_.push_back(input);
- input->setIsFusionInput(true);
-
- all_tv_uses_valid_ = false;
}
void Fusion::addOutput(Val* output) {
tv->setMemoryType(MemoryType::Global);
}
outputs_.push_back(output);
- output->setIsFusionOutput(true);
-
- all_tv_uses_valid_ = false;
}
-void Fusion::addOutput(WelfordResult& wr) {
- // Want to always make sure the avg gets added last
- // since avg will be the out() value of welfordOp,
- // and want to make it the top of the computeAt chain
- addOutput(wr.var_sum);
- addOutput(wr.n);
- addOutput(wr.avg);
-}
+bool Fusion::inFusion(const Statement* stmt) const {
+ bool in_fusion = stmt->fusion() == this;
+ Statement* nonconst_stmt = const_cast<Statement*>(stmt); // NOLINT
-void Fusion::removeInput(Val* input) {
- auto find_input = std::find(inputs_.begin(), inputs_.end(), input);
- if (find_input != inputs_.end()) {
- inputs_.erase(find_input);
+ if (stmt->isExpr()) {
+ in_fusion &= expr_set_.find(nonconst_stmt->as<Expr>()) != expr_set_.end();
}
- input->setIsFusionInput(false);
- all_tv_uses_valid_ = false;
-}
-
-void Fusion::removeOutput(Val* output) {
- auto find_output = std::find(outputs_.begin(), outputs_.end(), output);
- if (find_output != outputs_.end()) {
- outputs_.erase(find_output);
+ if (stmt->isVal()) {
+ in_fusion &= val_set_.find(nonconst_stmt->as<Val>()) != val_set_.end();
}
- output->setIsFusionOutput(false);
- all_tv_uses_valid_ = false;
-}
-
-void Fusion::replaceOutput(Val* output, Val* replacement) {
- auto find_output = std::find(outputs_.begin(), outputs_.end(), output);
- TORCH_CHECK(find_output != outputs_.end(), "Unable to find output in Fusion");
- if (find_output != outputs_.end()) {
- *find_output = replacement;
-
- if (replacement->getValType().value() == ValType::TensorView) {
- replacement->setIsFusionOutput(true);
- replacement->as<TensorView>()->setMemoryType(MemoryType::Global);
- }
- if (output->getValType().value() == ValType::TensorView) {
- output->setIsFusionOutput(false);
- output->as<TensorView>()->setMemoryType(MemoryType::Local);
- }
- resetTvUses();
- }
+ return in_fusion;
}
-bool Fusion::inFusion(const Statement* stmt) const {
+bool Fusion::inKernelIr(const Statement* stmt) const {
bool in_fusion = stmt->fusion() == this;
Statement* nonconst_stmt = const_cast<Statement*>(stmt); // NOLINT
if (stmt->isExpr()) {
- in_fusion &= expr_set_.find(nonconst_stmt->as<Expr>()) != expr_set_.end();
+ in_fusion &= lowered_expr_set_.find(nonconst_stmt->as<Expr>()) !=
+ lowered_expr_set_.end();
}
if (stmt->isVal()) {
- in_fusion &= val_set_.find(nonconst_stmt->as<Val>()) != val_set_.end();
+ in_fusion &= lowered_val_set_.find(nonconst_stmt->as<Val>()) !=
+ lowered_val_set_.end();
}
return in_fusion;
void Fusion::assertInFusion(const Statement* stmt, const std::string& msg)
const {
- TORCH_CHECK(inFusion(stmt), msg, " it was not found in the active fusion.");
+ if (inFusion(stmt)) {
+ return;
+ }
+ if (inKernelIr(stmt)) {
+ return;
+ }
+ TORCH_CHECK(false, msg, " it was not found in the active fusion.");
}
-std::vector<Expr*> Fusion::exprs() {
- return ExprSort::getExprs(this);
+std::vector<Expr*> Fusion::exprs(bool from_outputs_only) {
+ return ExprSort::getExprs(this, from_outputs_only);
}
-std::vector<Val*> Fusion::inputsOf(Val* val) {
+std::unordered_set<Val*> Fusion::inputsOf(Val* val) {
return InputsOf::output(this, val);
}
}
}
for (Val* input : all_inputs) {
- if (!input->isConstScalar()) {
+ if (!input->isConstScalar())
TORCH_CHECK(
- hasInput(input) || inFusion(input),
+ hasInput(input),
"Could not figure out how ",
input,
" is generated, however it was not specified as an input.");
- }
}
}
FUSER_PERF_SCOPE("Fusion::print");
FusionGuard fg(this);
- std::cout << "\n%kernel {\n";
+ std::cout << "%kernel {\n";
IrMathPrinter op_exprs(std::cout);
op_exprs.handle(this);
- std::cout << "\nTransformPrinter : \n";
IrTransformPrinter t_exprs(std::cout);
t_exprs.handle(this);
- std::cout << "}\n\n";
+ std::cout << "}\n";
}
void Fusion::printKernel() {
std::cout << codegen::generateCudaKernel(GpuLower(this).kernel());
}
-void Fusion::printMath(bool from_outputs_only) {
+void Fusion::printMath() {
FUSER_PERF_SCOPE("Fusion::printMath");
FusionGuard fg(this);
- auto exprs_for_print = exprs();
- std::cout << "Inputs:" << std::endl;
- for (auto inp : inputs()) {
- std::cout << " " << inp << ", " << inp->getDataType().value() << std::endl;
- }
-
- std::cout << "Outputs:" << std::endl;
- for (auto out : outputs()) {
- std::cout << " " << out << ", " << out->getDataType().value() << std::endl;
- }
-
- // If we want everything in the fusion, grab all values without uses to
- // traverse from.
- if (!from_outputs_only) {
- std::vector<Val*> leaf_vals;
- for (auto val : deterministic_vals()) {
- if (val->uses().empty()) {
- leaf_vals.push_back(val);
- }
- }
- exprs_for_print = ExprSort::getExprs(this, leaf_vals);
- }
-
- std::cout << "\n%kernel_math {\n";
- for (auto expr : exprs_for_print) {
+ for (auto expr : exprs(true))
std::cout << expr;
- }
- std::cout << "}\n\n";
}
void Fusion::printTransforms() {
}
StmtNameType Fusion::registerVal(Val* val) {
+ TORCH_CHECK(!inKernelIr(val));
+
if (val->fusion()) {
if (val->fusion() != this) {
TORCH_CHECK(false, val, " was not found in the active fusion.");
}
StmtNameType Fusion::registerExpr(Expr* expr) {
+ TORCH_CHECK(!inKernelIr(expr));
+
if (expr->fusion()) {
if (expr->fusion() != this) {
TORCH_CHECK(false, expr, " was not found in the active fusion.");
for (Val* input : expr->inputs()) {
assertInFusion(input, "Input to expr is invalid, ");
- auto uses_copy = input->uses();
- if (std::find(uses_copy.begin(), uses_copy.end(), expr) ==
- uses_copy.end()) {
- uses_copy.push_back(expr);
- input->setUses(uses_copy);
+ TORCH_CHECK(!inKernelIr(input));
+ if (uses_.find(input) == uses_.end()) {
+ uses_[input] = {expr};
+ } else {
+ uses_.find(input)->second.emplace(expr);
}
}
for (Val* output : expr->outputs()) {
assertInFusion(output, "Output to expr is invalid, ");
- if (output->definition() != nullptr) {
- removeExpr(output->definition());
+ TORCH_CHECK(!inKernelIr(output));
+ auto it = origin_.find(output);
+ if (it != origin_.end()) {
+ removeExpr(it->second); // will also remove origin entry
}
- output->setDefinition(expr);
+
+ origin_[output] = expr;
}
expr_set_.emplace(expr);
-
- resetTvUses();
return getExprName();
}
TORCH_INTERNAL_ASSERT(
false,
"Could not register statement as Fusion could not recognize its type.");
- return kInvalidStmName;
+ return UNINITIALIZED_STMTNAMETYPE;
}
-void Fusion::resetTvUses() {
- FUSER_PERF_SCOPE("Fusion::resetTvUses");
- is_during_update_uses_ = true;
+StmtNameType Fusion::registerLoweredVal(Val* val) {
+ TORCH_INTERNAL_ASSERT(val->fusion() == this);
+ TORCH_INTERNAL_ASSERT(!inFusion(val));
+ TORCH_INTERNAL_ASSERT(!inKernelIr(val));
+ lowered_val_set_.insert(val);
+ return getValName(*val->getValType());
+}
- // getExprs only uses definition, so even if we've modified uses already to
- // remove dead exprs, this could reinsert them. getExprs is also boundeds by
- // inputs as registered inputs will return nullptr as their definition.
- const auto all_tvs = ir_utils::filterByType<TensorView>(val_set_);
- const auto used_exprs = ExprSort::getExprs(this);
+StmtNameType Fusion::registerLoweredExpr(Expr* expr) {
+ TORCH_INTERNAL_ASSERT(expr->fusion() == this);
+ TORCH_INTERNAL_ASSERT(!inFusion(expr));
+ TORCH_INTERNAL_ASSERT(!inKernelIr(expr));
- for (auto tv : all_tvs) {
- tv->setUses({});
+ for (Val* input : expr->inputs()) {
+ TORCH_CHECK(inKernelIr(input));
}
- // Same as in register expr
- for (auto expr : used_exprs) {
- for (Val* input : expr->inputs()) {
- auto uses_copy = input->uses();
- if (std::find(uses_copy.begin(), uses_copy.end(), expr) ==
- uses_copy.end()) {
- uses_copy.push_back(expr);
- input->setUses(uses_copy);
- }
- }
+ for (Val* output : expr->outputs()) {
+ TORCH_CHECK(inKernelIr(output));
+ TORCH_CHECK(lowered_origin_.insert({output, expr}).second);
}
- all_tv_uses_valid_ = true;
- is_during_update_uses_ = false;
+ lowered_expr_set_.insert(expr);
+ return getExprName();
+}
+
+bool Fusion::used(Val* val) const {
+ assertInFusion(val, "Cannot detect if val was used, ");
+ return (uses_.find(val) != uses_.end()) &&
+ (uses_.find(val)->second.size() > 0);
}
const std::unordered_set<Val*>& Fusion::vals() const noexcept {
return val_deque_;
}
-std::vector<Val*> Fusion::usedMathVals() {
- // Note that using fusion->inputs() as the argument for the first
- // parameter of getAllValsBetween does not grab all used vals as
- // there can be vals that are created inside a fusion without using
- // anything from inputs. See, for example, tv0 in the
- // FusionOuterSplit test.
- const auto inputs = InputsOf::outputs(this, outputs());
- auto used_math_vals = DependencyCheck::getAllValsBetween(
- {inputs.begin(), inputs.end()}, outputs());
- // When an expre has multiple outputs and only some of them are
- // used, the rest aren't included in used_math_vals as they are not
- // used. However, we want them to be included as they must show up
- // in the fusion.
- std::vector<Val*> vals_to_add;
- std::unordered_set<Val*> added_vals;
-
- for (auto val : used_math_vals) {
- auto def = val->definition();
- if (def == nullptr || def->outputs().size() < 2) {
- continue;
- }
- for (auto out : def->outputs()) {
- if (std::find(used_math_vals.begin(), used_math_vals.end(), out) ==
- used_math_vals.end()) {
- if (!added_vals.count(out)) {
- vals_to_add.push_back(out);
- added_vals.insert(out);
- }
- }
- }
- }
-
- used_math_vals.insert(
- used_math_vals.end(), vals_to_add.begin(), vals_to_add.end());
-
- return used_math_vals;
-}
-
const std::unordered_set<Expr*>& Fusion::unordered_exprs() const noexcept {
return expr_set_;
}
std::unordered_set<Expr*> Fusion::unordered_uses(Val* val) const {
- return std::unordered_set<Expr*>(val->uses().begin(), val->uses().end());
+ assertInFusion(val, "Cannot detect where val was used, ");
+ if (uses_.find(val) != uses_.end()) {
+ auto ret = uses_.find(val)->second;
+ return ret;
+ }
+ return std::unordered_set<Expr*>();
}
-Expr* Fusion::definition(const Val* val) const {
- assertInFusion(val, "Cannot detect the definition of val, ");
- return val->definition();
+Expr* Fusion::origin(const Val* val) const {
+ // TODO(kir): remove the lowered branch
+ if (kir::isLoweredVal(val)) {
+ TORCH_INTERNAL_ASSERT(inKernelIr(val));
+ auto it = lowered_origin_.find(val);
+ return it != lowered_origin_.end() ? it->second : nullptr;
+ } else {
+ assertInFusion(val, "Cannot detect the origin of val, ");
+ auto it = origin_.find(val);
+ return it != origin_.end() ? it->second : nullptr;
+ }
}
bool Fusion::hasInput(const Val* val) const {
- assertInFusion(val, "Cannot check if val is an input, ");
- return val->isFusionInput();
+ return std::find(inputs_.begin(), inputs_.end(), val) != inputs_.end();
}
bool Fusion::hasOutput(const Val* val) const {
- assertInFusion(val, "Cannot check if val is an output, ");
- return val->isFusionOutput();
+ return std::find(outputs_.begin(), outputs_.end(), val) != outputs_.end();
+}
+
+void Fusion::replaceInput(Val* replace, Val* with) {
+ std::replace(inputs_.begin(), inputs_.end(), replace, with);
+}
+
+void Fusion::replaceOutput(Val* replace, Val* with) {
+ std::replace(outputs_.begin(), outputs_.end(), replace, with);
}
StmtNameType Fusion::getValName(ValType vtype) {
// Indicate to kernel to set itself up to generate random numbers
bool Fusion::isStochastic() {
- for (auto expr : exprs())
+ for (auto expr : exprs(true))
if (expr->getExprType() == ExprType::UnaryOp)
if (expr->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::RandLike)
return true;
bool Fusion::hasReduction() {
FUSER_PERF_SCOPE("Fusion::hasReduction");
- for (auto expr : exprs())
+ for (auto expr : exprs(true))
for (auto out : expr->outputs())
if (out->getValType() == ValType::TensorView)
if (out->as<TensorView>()->hasReduction())
return false;
}
-bool Fusion::hasWelford() {
- FUSER_PERF_SCOPE("Fusion::hasWelford");
- for (auto expr : exprs()) {
- if (expr->isA<WelfordOp>()) {
- return true;
- }
- }
+bool Fusion::hasBlockReduction() {
+ FUSER_PERF_SCOPE("Fusion::hasBlockReduction");
+
+ for (auto expr : exprs(true))
+ for (auto out : expr->outputs())
+ if (out->getValType() == ValType::TensorView)
+ if (out->as<TensorView>()->hasBlockReduction())
+ return true;
+
return false;
}
-std::vector<Val*> Fusion::getTerminatingOutputs() {
- FUSER_PERF_SCOPE("getTerminatingOutputs");
+bool Fusion::hasGridReduction() {
+ FUSER_PERF_SCOPE("Fusion::hasGridReduction");
- auto is_reachable_to_output = [](Val* val) {
- // traverse to consumers of val and see if there is an output
- std::deque<Val*> consumers;
- for (auto use : val->uses()) {
- for (auto consumer : use->outputs()) {
- consumers.push_back(consumer);
- }
- }
- while (!consumers.empty()) {
- auto consumer = consumers.back();
- consumers.pop_back();
- if (consumer->isFusionOutput()) {
- return true;
- }
- // consumer is not an output; proceed to its consumers
- for (auto use : consumer->uses()) {
- for (auto consumer_of_consumer : use->outputs()) {
- consumers.push_back(consumer_of_consumer);
- }
- }
- }
- return false;
- };
+ for (auto expr : exprs(true))
+ for (auto out : expr->outputs())
+ if (out->getValType() == ValType::TensorView)
+ if (out->as<TensorView>()->hasGridReduction())
+ return true;
- std::vector<Val*> terminating_outputs;
+ return false;
+}
- for (auto out : outputs()) {
- // If there is another output reachable from this output, it's not
- // terminating.
- if (is_reachable_to_output(out)) {
- continue;
+bool Fusion::hasBlockBroadcast() {
+ for (auto expr : exprs(true)) {
+ for (auto out : expr->outputs()) {
+ if (out->getValType() == ValType::TensorView) {
+ if (out->as<TensorView>()->hasBlockBroadcast()) {
+ return true;
+ }
+ }
}
- terminating_outputs.push_back(out);
}
-
- return terminating_outputs;
+ return false;
}
-bool Fusion::isAliasCompatible(Val* left, Val* right) {
- // Nullptr check
- if (left == nullptr || right == nullptr) {
- return false;
- }
-
- // DataType check
- if (!left->getDataType().has_value() || !right->getDataType().has_value() ||
- left->getDataType().value() != right->getDataType().value()) {
- return false;
- }
-
- // ValType check
- if (!left->getValType().has_value() || !right->getValType().has_value() ||
- left->getValType().value() != right->getValType().value()) {
- return false;
- }
+bool Fusion::hasBroadcast() {
+ for (auto expr : exprs(true))
+ for (auto out : expr->outputs())
+ if (out->getValType() == ValType::TensorView)
+ if (out->as<TensorView>()->hasBroadcast())
+ return true;
- // Check same number of dimensions if both values are TensorViews
- if (ir_utils::isTV(left) && ir_utils::isTV(right)) {
- return left->as<TensorView>()->nDims() == right->as<TensorView>()->nDims();
- }
return false;
}
-void Fusion::aliasOutputToInput(Val* output, Val* input) {
- TORCH_INTERNAL_ASSERT(
- isAliasCompatible(input, output),
- "The input and output values are not alias-compatible.");
- io_alias_[output] = input;
-}
+std::vector<Val*> Fusion::getTerminatingOutputs() {
+ FUSER_PERF_SCOPE("getTerminatingOutputs");
-std::unordered_set<int> Fusion::getOutputAliasIndices() const {
- if (io_alias_.empty()) {
- return {};
- }
+ FusionGuard fg(this);
- std::unordered_set<int> alias_indices;
+ std::unordered_set<Val*> used_vals;
- for (size_t i = 0; i < outputs_.size(); i++) {
- if (io_alias_.count(outputs_[i]) != 0) {
- alias_indices.insert(i);
- }
- }
- return alias_indices;
-}
+ const auto exprs = ExprSort::getExprs(
+ this, std::vector<Val*>(outputs().begin(), outputs().end()));
-std::vector<std::pair<int, int>> Fusion::getInputAliasIndices() const {
- if (io_alias_.empty()) {
- return {};
+ for (auto expr : exprs) {
+ for (auto inp : expr->inputs())
+ used_vals.emplace(inp);
}
- std::vector<std::pair<int, int>> alias_indices;
- for (size_t i = 0; i < outputs_.size(); i++) {
- if (io_alias_.count(outputs_[i]) != 0) {
- bool found = false;
- for (size_t j = 0; j < inputs_.size(); j++) {
- if (io_alias_.at(outputs_[i]) == inputs_[j]) {
- alias_indices.emplace_back(i, j);
- found = true;
- break;
- }
- }
- TORCH_INTERNAL_ASSERT(
- found,
- "io_alias_ mapping failure, alias output is not present in inputs");
- }
+ std::vector<Val*> terminating_outputs;
+ for (auto out : outputs()) {
+ if (used_vals.find(out) != used_vals.end())
+ continue;
+ terminating_outputs.push_back(out);
}
- // can't assert here, we could have segmented fusion where not all alias
- // outputs are present
-
- return alias_indices;
+ return terminating_outputs;
}
} // namespace cuda
#pragma once
-#include <ATen/core/ivalue.h>
#include <c10/util/Exception.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <unordered_map>
#include <unordered_set>
namespace fuser {
namespace cuda {
-//! Usage: FusionGuard and Fusion are required user interfaces for any operation
-//! underlying the code generator. In order to create values, expressions, and
-//! generate code a Fusion instance must be active. It is the responsibility of
-//! the user to create a Fusion instance and register it with the fusion guard.
-//! The simplest example of this is:
-//!
-//! Fusion fusion;
-//! FusionGuard fg(&fusion);
-//!
-//! Once a fusion is active all values and operations will be registered with
-//! it.
-//!
-//! FusionGuard and Fusion are critical to the lifetime model of the IR system.
-//! FusionGuard is a convenient way to set what base container instance holds
-//! the defined IR. Statements that are defined are registered through the
-//! FusionGuard with a particular Fusion. FusionGuard provides convenient
-//! methods to access the active fusion so it doesn't need to be passed around
-//! constantly. Any IR node derived classes from Statement must register with
-//! Fusion to avoid memory leaks.
-//!
-//! Fusion is generally thought of as a translated fusion group from the JIT. It
-//! is likely a single kernel, although, we don't have to stick to this in the
-//! future and could in theory generate multiple kernels with an executor to run
-//! them.
-//!
-//! Fusion also allows users to set input/output values that will allow us to
-//! figure out how to hook up runtime data to and from the JIT as well as
-//! provide us mechanisms for dependency analysis and DCE including safety
-//! checks.
+/*
+ * Usage: FusionGuard and Fusion are required user interfaces for any operation
+ * underlying the code generator. In order to create values, expressions, and
+ * generate code a Fusion instance must be active. It is the responsibility of
+ * the user to create a Fusion instance and register it with the fusion guard.
+ * The simplest example of this is: Fusion fusion; FusionGuard fg(&fusion); Once
+ * a fusion is active all values and operations will be registered with it.
+ *
+ * FusionGuard and Fusion are critical to the lifetime model of the IR system.
+ * FusionGuard is a convenient way to set what base container instance holds the
+ * defined IR. Statements that are defined are registered through the
+ * FusionGuard with a particular Fusion. FusionGuard provides convenient methods
+ * to access the active fusion so it doesn't need to be passed around
+ * constantly. Any IR node derived classes from Statement must register with
+ * Fusion to avoid memory leaks.
+ *
+ * Fusion is generally thought of as a translated fusion group from the JIT. It
+ * is likely a single kernel, although, we don't have to stick to this in the
+ * future and could in theory generate multiple kernels with an executor to run
+ * them.
+ *
+ * Fusion also allows users to set input/output values that will allow us to
+ * figure out how to hook up runtime data to and from the JIT as well as provide
+ * us mechanisms for dependency analysis and DCE including safety checks.
+ */
class Fusion;
class TensorView;
-class WelfordResult;
-class SegmentCandidateFinder;
-class SegmentedFusion;
-
-//! Fusion Guard is our "context manager". It holds the actrive fusion and
-//! allows it to be accessed anywhere through FusionGuard::getCurFusion()
+// Fusion Guard is our "context manager". It holds the actrive fusion and allows
+// it to be accessed anywhere through FusionGuard::getCurFusion().
class TORCH_CUDA_CU_API FusionGuard {
public:
Fusion* prev_fusion;
- //! Set the active fusion so it can be manipulated.
+ // Set the active fusion so it can be manipulated.
explicit FusionGuard(Fusion* fusion);
~FusionGuard();
static Fusion* getCurFusion();
};
-//! Fusion is mutable but unique. Nodes cannot be copied in any way from one
-//! Fusion to another. If anything like that is desired, it would require
-//! duplicating all associated values and exprs. Fusion is considered to SSA,
-//! though this could also change in the future if there is a good reason to do
-//! so.
-//!
-//! The Fusion owns the whole IR graph (Vals and Exprs)
-//!
+/*
+ * Fusion is mutable but unique. Nodes cannot be copied in any way from one
+ * Fusion to another. If anything like that is desired, it would require
+ * duplicating all associated values and exprs. Fusion is considered to SSA,
+ * though this could also change in the future if there is a good reason to do
+ * so.
+ *
+ * The Fusion owns the whole IR graph (Vals and Exprs)
+ */
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class TORCH_CUDA_CU_API Fusion final {
public:
void clear() noexcept;
- //! Break dependency chains associated with Expr, remove references to expr
- //! delete expr
+ // Break dependency chains associated with Expr, remove references to expr
+ // delete expr.
void removeExpr(Expr* expr);
- //! Completely remove val from the fusion, break all dependencies associated
- //! with it
+ // Completely remove val from the fusion, break all dependencies associated
+ // with it.
void removeVal(Val* val);
- //! Register input as an input of the fusion
- // TODO: Rename to register
+ // Register input as an input of the fusion
void addInput(Val* input);
- //! Register output as an output of the fusion
- // TODO: Rename to register
+ // Register output as an output of the fusion
void addOutput(Val* output);
- //! Register output as an output of the fusion
- // TODO: Rename to register
- void addOutput(WelfordResult& output);
-
- //! Deregister input as an input of the fusion
- // TODO: Rename to register
- void removeInput(Val* input);
-
- //! Deregister output as an output of the fusion
- // TODO: Rename to register
- void removeOutput(Val* output);
-
- //! Replace output with another value
- void replaceOutput(Val* output, Val* replacement);
-
- //! Clear Expr's from TV uses that are not required to produce outputs from
- //! inputs
- void resetTvUses();
-
- //! Check if stmt is properly registered with this fusion
+ // Check if stmt is properly registered with this fusion
bool inFusion(const Statement* stmt) const;
- //! Throw an error if stmt is not in this fusion
+ // Throw an error if stmt is not in this fusion. Message will be:
+ // msg + " it was not found in the active fusion."
void assertInFusion(const Statement* stmt, const std::string& msg = "") const;
- //! Assert that all leaves found from outputs are registered as an input
+ /*
+ * Return a list of topologically sorted expressions. We can start
+ * by only traversing back from registered outputs, or from all terminating
+ * Vals.
+ *
+ * from_outputs_only:
+ * True - Sort from DAG associated with registered outputs
+ * False - Sort from all terminating Vals.
+ */
+ std::vector<Expr*> exprs(bool from_outputs_only = false);
+
+ // Return a vector of fusion inputs that feed this Val
+ std::unordered_set<Val*> inputsOf(Val* val);
+
+ // Assert that all leaves found from outputs are registered as an input.
void validateInputs();
- //! Print this fusion to the console
+ // Print this fusion to cout.
void print();
- //! Print Arith exprs
- //! \param from_outputs_only Only print exprs reachable from outputs
- void printMath(bool from_outputs_only = true);
+ // Print Arith exprs used in outputs
+ void printMath();
- //! Print transformations used in fusion (can be very verbose)
+ // Print transformations used in fusion (can be very verbose)
void printTransforms();
- //! Lower the fusion and print a kernel
+ // Lower the fusion and print a kernel
void printKernel();
- //! Register the Val with this fusion
+ // Register the Val with this fusion
StmtNameType registerVal(Val* val);
- //! Register expr with this fusion.
- //! When we register an expression, we want to update the dependency tracking
- //! of Vals. We add expr to our general expr_set_,
+ // Register expr with this fusion.
+ // When we register an expression, we want to update the dependency tracking
+ // of Vals. We add expr to our general expr_set_, we add use tracking for
+ // inputs and origin tracking for outputs.
StmtNameType registerExpr(Expr* expr);
- //! Register stmt with this fusion
+ // Register stmt with this fusion.
StmtNameType registerStatement(Statement* stmt);
- //! Return a list of topologically sorted expressions. This only includes
- //! exprs required to genereate registered outputs.
- std::vector<Expr*> exprs();
+ // Lowered nodes
+ // TODO(kir): to be removed
+ StmtNameType registerLoweredVal(Val* val);
+ StmtNameType registerLoweredExpr(Expr* expr);
- //! Return a vector of fusion inputs that feed this Val
- std::vector<Val*> inputsOf(Val* val);
+ // Lowered counterpart to inFusion()
+ // TODO(kir): to be removed
+ bool inKernelIr(const Statement* stmt) const;
- //! Return the set of Vals registered with this fusion
- const std::unordered_set<Val*>& vals() const noexcept;
+ // Check if val is used in this fusion. Not equivelent to DCE
+ bool used(Val* val) const;
- //! Return in insertion order
+ // Return the set of Vals registered with this fusion
+ const std::unordered_set<Val*>& vals() const noexcept;
+ // Return in insertion order
const std::deque<Val*>& deterministic_vals() const noexcept;
- //! Return all Vals in math expressions that cannot be eliminated.
- //!
- //! It is generally equivalent to vals that are used to generate
- //! outputs, however, when a multi-output expression exists, and only
- //! some of the outputs are used, the remaining unused outputs are
- //! also included as they must show up in the final code.
- std::vector<Val*> usedMathVals();
-
- //! Return the set of Exprs registered with this fusion. Warning: This will
- //! return exprs outside inputs/outputs, so can be unsafe for use with
- //! segmented fusions.
+ // Return the set of Exprs registered with this fusion
const std::unordered_set<Expr*>& unordered_exprs() const noexcept;
- //! Return all Exprs that use val
+ // Return all Exprs that use val
std::unordered_set<Expr*> unordered_uses(Val* val) const;
- //! Return the Expr that produces val
- Expr* definition(const Val* val) const;
+ // Return the Expr that produces val
+ Expr* origin(const Val* val) const;
- //! Indicate to kernel to set itself up to generate random numbers
+ // Indicate to kernel to set itself up to generate random numbers
bool isStochastic();
- //! Indicate that the fusion contains reduction operations
+ // TODO(kir): revisit to see how many of these are still needed
bool hasReduction();
-
- //! Indicate that the fusion contains welford operations
- bool hasWelford();
-
- //! Run fusion segmentation algorithm to create a segmented fusion
- std::unique_ptr<SegmentedFusion> segment(
- const at::ArrayRef<at::IValue>& inputs);
+ bool hasBlockReduction();
+ bool hasGridReduction();
+ bool hasBlockBroadcast();
+ bool hasBroadcast();
+ size_t gridReductionTempBufferSize();
const auto& inputs() const {
return inputs_;
bool hasInput(const Val* val) const;
bool hasOutput(const Val* val) const;
- // Aliasing output to input value, this is a WAR to allow inplace update on
- // input tensor.
- // Note: this is not always safe and should be used with extra caution.
- // Currently the only place it's used is in the running stats update for batch
- // normalization.
- // TODO: alias should be made aware to segmentation, so we'll always include
- // the input tensor to the section where output is produced.
- void aliasOutputToInput(Val* output, Val* input);
- std::unordered_set<int> getOutputAliasIndices() const;
- std::vector<std::pair<int, int>> getInputAliasIndices() const;
-
- bool isTVUseInfoValid() {
- return all_tv_uses_valid_;
- }
-
- bool isUpdatingTVUseInfo() {
- return is_during_update_uses_;
- }
-
- protected:
- friend SegmentCandidateFinder;
- friend SegmentedFusion;
- friend class TranslateApplicableWelford;
-
- static IrCloner copy(const Fusion* from, Fusion* to);
+ void replaceInput(Val* replace, Val* with);
+ void replaceOutput(Val* replace, Val* with);
private:
// Return an int that monotonically increases for each val/expr, some are
StmtNameType getValName(ValType vtype);
StmtNameType getExprName();
- // Determine if the two values are compatible for aliasing
- // Same DataType, ValType, and number of dimensions
- bool isAliasCompatible(Val* left, Val* right);
-
private:
// Sets of all Vals/Exprs registered with this fusion
// (val_deque_ is not owning the objects)
// Expression names counter
StmtNameType expr_name_counter_ = 0;
+ // Dependency tracking for Vals. Where did it come from? Where is it used?
+ std::unordered_map<const Val*, Expr*> origin_;
+ std::unordered_map<Val*, std::unordered_set<Expr*>> uses_;
+
// Fusion inputs and outputs
std::vector<Val*> inputs_;
std::vector<Val*> outputs_;
- // io alias pointing from output to input
- std::unordered_map<Val*, Val*> io_alias_;
-
- // Records if the current use data in the IR nodes are valid
- // the states are either all valid or all invalid
- bool all_tv_uses_valid_ = false;
- bool is_during_update_uses_ = false;
+ // Lowered IR
+ std::unordered_set<Val*> lowered_val_set_;
+ std::unordered_set<Expr*> lowered_expr_set_;
+ std::unordered_map<const Val*, Expr*> lowered_origin_;
};
} // namespace cuda
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
-#include <torch/csrc/jit/codegen/cuda/ir_graphviz.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-
-#include <sstream>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup::getNeighborGroups() {
- std::vector<NeighborGroup> neighbors;
- for (auto inp : producer_edges) {
- if (inp->val->isFusionOutput()) {
- // Don't fuse across output nodes, would need to find another path.
- continue;
- }
- neighbors.emplace_back(inp->from, inp);
- }
- for (auto out : consumer_edges) {
- if (out->val->isFusionOutput()) {
- // Don't fuse across output nodes, would need to find another path.
- continue;
- }
- neighbors.emplace_back(out->to, out);
- }
- return neighbors;
-}
-
-std::vector<SegmentedGroup*> SegmentedGroup::getNeighbors() {
- std::vector<SegmentedGroup*> neighbors;
- auto neighbors_pair = getNeighborGroups();
-
- std::transform(
- neighbors_pair.begin(),
- neighbors_pair.end(),
- std::back_inserter(neighbors),
- [](auto& neighbor_group) { return neighbor_group.group; });
- return neighbors;
-}
-
-std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup::
- getMergeCandidates() {
- // Don't look for candidates if already merged
- if (merged_) {
- return {};
- }
-
- std::vector<NeighborGroup> neighbors = getNeighborGroups();
-
- // Can this node be merged with another? Check if neighbors are merged, if
- // so and merged neighbor is within 1 level or node merged with neighbor is
- // within 1 level, can't merge this node with anything else.
- bool can_merge_this = true;
- for (auto& neighbor : neighbors) {
- if (!neighbor.group->merged_) {
- continue;
- }
- if (std::abs(neighbor.group->level_ - level_) <= 1) {
- can_merge_this = false;
- }
- if (std::abs(neighbor.group->merge_with_->level_ - level_) <= 1) {
- can_merge_this = false;
- }
- }
- if (!can_merge_this) {
- return {};
- }
-
- std::vector<bool> can_merge(true, neighbors.size());
-
- // Find neighbors with a level that is only 1 differant than this groups level
- for (size_t i = 0; i < neighbors.size(); i++) {
- if (std::abs(neighbors[i].group->level_ - level_) > 1) {
- can_merge[i] = false;
- }
- }
-
- // Check neighbor of neighbors we're considering, if any of them are merged
- // with another node, make sure the resulting edge wouldn't have a level
- // difference of 1
- for (size_t i = 0; i < neighbors.size(); i++) {
- if (!can_merge[i]) {
- continue;
- }
-
- for (auto neighbor_neighbor : neighbors[i].group->getNeighbors()) {
- // Don't check self
- if (neighbor_neighbor == neighbors[i].group) {
- continue;
- }
- if (neighbor_neighbor->merged_) {
- // check neighbor_neighbor level
- if (std::abs(neighbor_neighbor->level_ - level_) <= 1) {
- can_merge[i] = false;
- }
- if (std::abs(neighbor_neighbor->level_ - neighbors[i].group->level_) <=
- 1) {
- can_merge[i] = false;
- }
-
- // check neighbor_neighber->merged_->level_
- if (std::abs(neighbor_neighbor->merge_with_->level_ - level_) <= 1) {
- can_merge[i] = false;
- }
- if (std::abs(
- neighbor_neighbor->merge_with_->level_ -
- neighbors[i].group->level_) <= 1) {
- can_merge[i] = false;
- }
- }
- }
- }
-
- std::vector<NeighborGroup> merge_candidates;
- for (size_t i = 0; i < neighbors.size(); i++) {
- if (can_merge[i]) {
- merge_candidates.push_back(neighbors[i]);
- }
- }
- return merge_candidates;
-}
-
-void SegmentedGroup::clearTraversalInfo() {
- level_ = -1;
- visited_ = false;
- merge_with_ = nullptr;
- merge_through_ = nullptr;
- merged_ = false;
-}
-
-std::vector<Val*> SegmentedGroup::edgesToVals(
- const std::vector<SegmentedEdge*>& se_v) {
- std::vector<Val*> ret_v;
- ret_v.reserve(se_v.size());
-
- std::transform(
- se_v.cbegin(),
- se_v.cend(),
- std::back_inserter(ret_v),
- [](SegmentedEdge* se) { return se->val; });
- return ret_v;
-}
-
-template <typename PREDICATE>
-void insertUniquePredicated(
- std::vector<Val*>& v,
- const std::vector<SegmentedEdge*>& e,
- PREDICATE pred) {
- std::unordered_set<Val*> to_add;
- std::transform(
- e.cbegin(),
- e.cend(),
- std::inserter(to_add, to_add.end()),
- [](SegmentedEdge* se) { return se->val; });
- std::copy_if(
- to_add.begin(), to_add.end(), std::back_inserter(v), [pred](Val* val) {
- return pred(val);
- });
-}
-
-void SegmentedGroup::finalize() {
- // Move all the edges to group input/output
- // Inputs
- insertUniquePredicated(
- input_vals, producer_edges, [](Val* v) { return !v->isFusionInput(); });
-
- std::unordered_set<Val*> input_set(input_vals.begin(), input_vals.end());
-
- for (auto expr : exprs_) {
- for (auto i : expr->inputs()) {
- if (i->isAnInt() && i->definition() == nullptr && !i->isConstScalar() &&
- !i->isFusionInput() && !input_set.count(i)) {
- input_set.insert(i);
- input_vals.push_back(i);
- }
- }
- }
-
- // Outputs
- insertUniquePredicated(
- output_vals, consumer_edges, [](Val* v) { return !v->isFusionOutput(); });
-
- // alias aware segmentation. we add inputs that are aliased by output
- // generated in this SegmentedGroup
- for (auto output : output_vals) {
- if (auto aliased_input = segmented_fusion_->findAlias(output)) {
- // aliasing currently only supported as output to input
- TORCH_INTERNAL_ASSERT(
- aliased_input->isFusionInput(),
- "aliased input is not found in the complete fusion");
- if (!input_set.count(aliased_input)) {
- input_set.insert(aliased_input);
- input_vals.push_back(aliased_input);
- }
- }
- }
-}
-
-std::ostream& operator<<(std::ostream& os, const SegmentedGroup* group) {
- os << "g{";
- auto expr_to_print = group->exprs();
- std::sort(
- expr_to_print.begin(),
- expr_to_print.end(),
- [](auto expr_a, auto expr_b) -> bool {
- return expr_a->name() < expr_b->name();
- });
- for (size_t i = 0; i < expr_to_print.size(); i++) {
- os << expr_to_print[i]->name();
- if (i + 1 != expr_to_print.size())
- os << ", ";
- }
- os << "}\n";
- return os;
-}
-
-void SegmentedGroup::print() const {
- std::cout << this << "\n";
-}
-
-std::string toString(const SegmentedGroup* group) {
- std::stringstream ss;
- ss << group;
- return ss.str();
-}
-
-std::ostream& operator<<(std::ostream& os, const SegmentedEdge* edge) {
- os << "e{ " << edge->from << " -> " << edge->to << "(";
- IrPrinter irp(os);
- irp.handle(edge->val);
- os << ") }\n";
- return os;
-}
-
-void SegmentedEdge::print() const {
- std::cout << this << "\n";
-}
-
-std::string toString(const SegmentedEdge* edge) {
- std::stringstream ss;
- ss << edge;
- return ss.str();
-}
-
-SegmentedFusion::SegmentedFusion(std::unique_ptr<Fusion> fusion)
- : impl_(this), complete_fusion_(std::move(fusion)) {
- segmented_fusion_name_ = segmentedFusionName();
- annotateFP16IntermediateTensors();
-}
-
-SegmentedGroup* SegmentedFusion::Impl::makeGroup() {
- groups_.emplace_back(std::make_unique<SegmentedGroup>(owning_fusion_));
- return groups_.back().get();
-}
-
-SegmentedGroup* SegmentedFusion::Impl::makeGroup(Expr* expr) {
- groups_.emplace_back(std::make_unique<SegmentedGroup>(expr, owning_fusion_));
- return groups_.back().get();
-}
-
-SegmentedEdge* SegmentedFusion::Impl::makeEdge(
- SegmentedGroup* from,
- SegmentedGroup* to,
- Val* val) {
- edges_.emplace_back(std::make_unique<SegmentedEdge>(from, to, val));
- return edges_.back().get();
-}
-
-void SegmentedFusion::Impl::cleanUnused() {
- std::unordered_set<SegmentedGroup*> g_used(
- owning_fusion_->groups().begin(), owning_fusion_->groups().end());
- std::unordered_set<SegmentedEdge*> e_used(
- owning_fusion_->edges().begin(), owning_fusion_->edges().end());
-
- groups_.erase(
- std::remove_if(
- groups_.begin(),
- groups_.end(),
- [&g_used](auto& g) { return g_used.count(g.get()) == 0; }),
- groups_.end());
-
- edges_.erase(
- std::remove_if(
- edges_.begin(),
- edges_.end(),
- [&e_used](auto& e) { return e_used.count(e.get()) == 0; }),
- edges_.end());
-}
-
-SegmentedGroup* SegmentedFusion::newGroup() {
- SegmentedGroup* g = impl_.makeGroup();
- groups_.push_back(g);
- return g;
-}
-
-SegmentedGroup* SegmentedFusion::newGroup(Expr* expr) {
- SegmentedGroup* g = impl_.makeGroup(expr);
- groups_.push_back(g);
- return g;
-}
-
-SegmentedEdge* SegmentedFusion::newEdge(
- SegmentedGroup* from,
- SegmentedGroup* to,
- Val* val) {
- SegmentedEdge* e = impl_.makeEdge(from, to, val);
- edges_.push_back(e);
- return e;
-}
-
-void SegmentedFusion::draw() {
- size_t group_index = 0;
- std::unordered_map<const Expr*, size_t> expr_color_map;
-
- for (auto group : groups()) {
- for (auto expr : group->exprs()) {
- if (ir_utils::isTVOp(expr)) {
- expr_color_map[expr] = group_index;
- }
- }
- group_index++;
- }
-
- std::stringstream sstream;
- sstream << "segmented_fusion" << segmented_fusion_name_ << ".dot";
- auto filename = sstream.str();
-
- IrGraphGenerator::print(
- completeFusion(),
- filename.c_str(),
- IrGraphGenerator::DetailLevel::ComputeOnly,
- &expr_color_map);
-}
-
-namespace {
-
-std::vector<Val*> uniqueValConcat(
- const std::vector<std::vector<Val*>>& val_vecs) {
- std::vector<Val*> unique_vals;
- std::unordered_set<Val*> added;
- for (const auto& vec : val_vecs) {
- for (auto val : vec) {
- if (added.find(val) == added.end()) {
- unique_vals.push_back(val);
- added.emplace(val);
- }
- }
- }
- return unique_vals;
-}
-
-// Concat's producer edges of sg1 and sg2, but removes any edges from/to sg1/sg2
-std::vector<SegmentedEdge*> getMergedProducerEdges(
- const SegmentedGroup* sg1,
- const SegmentedGroup* sg2) {
- TORCH_INTERNAL_ASSERT(
- sg1 != nullptr && sg2 != nullptr,
- "This function doesn't handle trivial.");
-
- auto producer_edges = sg1->producer_edges;
-
- producer_edges.insert(
- producer_edges.end(),
- sg2->producer_edges.begin(),
- sg2->producer_edges.end());
-
- // Register producers into sg2
- std::unordered_set<Val*> sg2_vals;
- for (auto se : sg2->producer_edges) {
- sg2_vals.emplace(se->val);
- }
-
- producer_edges.erase(
- std::remove_if(
- producer_edges.begin(),
- producer_edges.end(),
- [&sg1, &sg2, &sg2_vals](SegmentedEdge* se) {
- // remove edges in between the groups and common uses
- return (se->to == sg1 && se->from == sg2) ||
- (se->to == sg2 && se->from == sg1) ||
- (se->to == sg1 && sg2_vals.count(se->val));
- }),
- producer_edges.end());
-
- // Remove Duplicate Edges
-
- return producer_edges;
-}
-
-// Concat's consumer edges of sg1 and sg2, but removes any edges from/to sg1/sg2
-std::vector<SegmentedEdge*> getMergedConsumerEdges(
- const SegmentedGroup* sg1,
- const SegmentedGroup* sg2) {
- TORCH_INTERNAL_ASSERT(
- sg1 != nullptr && sg2 != nullptr,
- "This function doesn't handle trivial.");
-
- auto consumer_edges = sg1->consumer_edges;
- consumer_edges.insert(
- consumer_edges.end(),
- sg2->consumer_edges.begin(),
- sg2->consumer_edges.end());
-
- consumer_edges.erase(
- std::remove_if(
- consumer_edges.begin(),
- consumer_edges.end(),
- [&sg1, &sg2](SegmentedEdge* se) {
- return (se->to == sg1 && se->from == sg2) ||
- (se->to == sg2 && se->from == sg1);
- }),
- consumer_edges.end());
-
- return consumer_edges;
-}
-
-// Returns a determinstic, unique set of inputs of the segment group, sg1, or
-// the combined group sg1 + sg2
-std::vector<Val*> getAllInputs(
- const SegmentedGroup* sg1,
- const SegmentedGroup* sg2 = nullptr) {
- std::vector<SegmentedEdge*> merged_producer_edges;
-
- if (sg1 != nullptr && sg2 != nullptr) {
- merged_producer_edges = getMergedProducerEdges(sg1, sg2);
- } else if (sg1 != nullptr) {
- merged_producer_edges = sg1->producer_edges;
- } else if (sg2 != nullptr) {
- merged_producer_edges = sg2->producer_edges;
- }
-
- std::vector<Val*> producer_edge_vals;
-
- std::transform(
- merged_producer_edges.begin(),
- merged_producer_edges.end(),
- std::back_inserter(producer_edge_vals),
- [](SegmentedEdge* se) { return se->val; });
-
- return uniqueValConcat(
- {sg1 == nullptr ? std::vector<Val*>() : sg1->input_vals,
- sg2 == nullptr ? std::vector<Val*>() : sg2->input_vals,
- producer_edge_vals});
-}
-
-// Returns a determinstic, unique set of outputs of the segment group, sg1, or
-// the combined group sg1 + sg2
-std::vector<Val*> getAllOutputs(
- const SegmentedGroup* sg1,
- const SegmentedGroup* sg2 = nullptr) {
- std::vector<SegmentedEdge*> merged_consumer_edges;
-
- if (sg1 != nullptr && sg2 != nullptr) {
- merged_consumer_edges = getMergedConsumerEdges(sg1, sg2);
- } else if (sg1 != nullptr) {
- merged_consumer_edges = sg1->consumer_edges;
- } else if (sg2 != nullptr) {
- merged_consumer_edges = sg2->consumer_edges;
- }
-
- std::vector<Val*> consumer_edge_vals;
-
- std::transform(
- merged_consumer_edges.begin(),
- merged_consumer_edges.end(),
- std::back_inserter(consumer_edge_vals),
- [](SegmentedEdge* se) { return se->val; });
-
- auto output_vals = uniqueValConcat(
- {sg1 == nullptr ? std::vector<Val*>() : sg1->output_vals,
- sg2 == nullptr ? std::vector<Val*>() : sg2->output_vals,
- consumer_edge_vals});
-
- return output_vals;
-}
-
-// Set version of getting merged input or output if segmented_groups were
-// merged
-// outputs respects order in segmented_groups for deterministic
-// merge trace
-// will get input if get_inputs otherwise will get ouputs
-// TODO: merge with the binary counter parts
-std::vector<Val*> allInputsIfTrueElseOutputs(
- const std::vector<SegmentedGroup*>& segmented_groups,
- bool get_inputs = true) {
- // Helper to distinguish if we are getting inputs or outputs
- using EdgeVec = std::vector<SegmentedEdge*>;
- using ValVec = std::vector<Val*>;
-
- // Get producer edges to get inputs, consumer edges to get outputs
- auto edges_to_process_from_or_to_group =
- [get_inputs](SegmentedGroup* group) -> EdgeVec& {
- return get_inputs ? group->producer_edges : group->consumer_edges;
- };
-
- // Get the group that is connected to current group
- auto global_vals_from_or_to_group =
- [get_inputs](SegmentedGroup* group) -> ValVec& {
- return get_inputs ? group->input_vals : group->output_vals;
- };
-
- // Get the group that is connected to current group by given edge
- auto opposite_end_of_edge = [get_inputs](SegmentedEdge* edge) {
- return get_inputs ? edge->from : edge->to;
- };
-
- // Keep track of value and order to ensure deterministic result
- std::vector<Val*> merged_vals;
- std::unordered_set<Val*> merged_vals_set;
-
- // Put groups in a set for quick look up
- std::unordered_set<SegmentedGroup*> segmented_groups_set(
- segmented_groups.begin(), segmented_groups.end());
-
- // Collect vals associated with edges
- for (auto group : segmented_groups) {
- for (auto edge : edges_to_process_from_or_to_group(group)) {
- if (
- // Need to de-duplicate values so we don't get multiple of any input
- !merged_vals_set.count(edge->val) &&
- // One side of this edge will be `group`, if the other end is
- // also in segmented_groups, then this is an internal edge
- // that we don't want.
- !segmented_groups_set.count(opposite_end_of_edge(edge))) {
- merged_vals.push_back(edge->val);
- merged_vals_set.insert(edge->val);
- }
- }
- }
-
- // Collect original fusion's inputs/outputs and append at the end
- for (auto group : segmented_groups) {
- for (auto global_val : global_vals_from_or_to_group(group)) {
- // de-duplicate
- if (!merged_vals_set.count(global_val)) {
- merged_vals.push_back(global_val);
- merged_vals_set.insert(global_val);
- }
- }
- }
-
- return merged_vals;
-}
-
-// A sorting utility used for debug printing only
-// sorts the given vector of expressions in topological
-// order, with equal cases respecting the original order
-// in the vector.
-std::vector<Expr*> groupExprPrintSorting(const std::vector<Expr*>& exprs) {
- std::vector<Expr*> exprs_to_print(exprs.begin(), exprs.end());
- std::unordered_set<Expr*> exprs_to_print_set(exprs.begin(), exprs.end());
- std::unordered_set<Expr*> exprs_visited;
- std::vector<Expr*> sorted_list;
- while (sorted_list.size() != exprs_to_print.size()) {
- bool expr_added_to_sorted_list = false;
- for (auto expr : exprs_to_print) {
- if (!exprs_visited.count(expr)) {
- bool add_this_expr = true;
- // Check if any of the inputs of current
- // expression within the group
- // hasn't been visited
- for (auto input : expr->inputs()) {
- if (input->definition() &&
- exprs_to_print_set.count(input->definition()) &&
- !exprs_visited.count(input->definition())) {
- add_this_expr = false;
- break;
- }
- }
-
- // Append the current group to sorted list
- // and mark visited
- if (add_this_expr) {
- expr_added_to_sorted_list = true;
- exprs_visited.insert(expr);
- sorted_list.push_back(expr);
- break;
- }
- }
- }
- TORCH_INTERNAL_ASSERT(
- expr_added_to_sorted_list,
- "group debug print failed, exprs within given vector not a DAG");
- }
- return sorted_list;
-}
-
-// Utility function to list all expressions in a group
-void detailGroupPrint(std::ostream& os, const SegmentedGroup* group) {
- IrPrinter irp(os);
-
- auto sort_val_by_name = [](std::vector<Val*> vals_to_sort) {
- std::sort(vals_to_sort.begin(), vals_to_sort.end(), [](Val* a, Val* b) {
- return a->name() < b->name();
- });
- return vals_to_sort;
- };
-
- os << "g{"
- << "(" << toString(group->heuristic()) << ")\n";
- os << "inputs: \n";
- for (auto input : sort_val_by_name(getAllInputs(group))) {
- os << input << " " << input->getDataType().value() << "\n";
- }
- os << "outputs: \n";
- for (auto output : sort_val_by_name(getAllOutputs(group))) {
- os << output << " " << output->getDataType().value() << "\n";
- }
-
- os << "\n\n";
-
- auto expr_to_print = groupExprPrintSorting(group->exprs());
-
- for (size_t i = 0; i < expr_to_print.size(); i++) {
- irp.handle(expr_to_print[i]);
- }
- os << "}\n\n";
-}
-
-//! Insert casts for an intermediate tensorview, i.e. ones
-//! that are in segmentedEdges. The insertion is done on
-//! the complete fusion, which should be owned by a segmented
-//! fusion so that only one segmented fusion will be affected.
-//! The replacement pattern is:
-//! TV0
-//! replaced as:
-//! fp16_tv = cast(TV0)
-//! fp32_tv = cast(fp16_tv)
-//!
-//! All segmented groups that take TV0 as input will then
-//! take fp16_tv instead and the cast to fp32 will be
-//! automatically included in each of the groups.
-TensorView* castIntermediateValueInCompleteFusion(
- Fusion* fusion,
- TensorView* original_tv) {
- FusionGuard fg(fusion);
-
- // A utility lambda that creates consumer tensordomain of
- // the given tv and create a new tensorview around the
- // new tensordomain with the given data type.
- auto make_consumer_tv = [&](TensorView* from, DataType data_type) {
- // Keep broadcast axes and remove reduction axes
- size_t i = 0;
- auto no_reduction_root_domain =
- TensorDomain::noReductions(original_tv->getRootDomain());
- std::vector<IterDomain*> new_root_domain(no_reduction_root_domain.size());
- for (const auto& dom : no_reduction_root_domain) {
- new_root_domain[i++] = dom->clone();
- }
-
- // Create the actual domain and tv.
- return new TensorView(
- new TensorDomain(
- new_root_domain, std::vector<bool>(new_root_domain.size(), true)),
- data_type);
- };
-
- // create the tv's to cast
- auto fp16_tv = make_consumer_tv(original_tv, DataType::Half);
- auto fp32_tv = make_consumer_tv(original_tv, DataType::Float);
-
- // replace uses of original tv with fp32_tv in the complete
- // fusion
- for (auto expr : fusion->unordered_uses(original_tv)) {
- ir_utils::replaceValInExpr(expr, original_tv, fp32_tv);
- }
-
- // Insert the cast ops.
- new UnaryOp(UnaryOpType::Cast, fp16_tv, original_tv);
- new UnaryOp(UnaryOpType::Cast, fp32_tv, fp16_tv);
-
- // Return the new tv to replace original tv with
- // on the segmented edges.
- return fp16_tv;
-}
-
-} // namespace
-
-void SegmentedFusion::finalize() {
- impl_.cleanUnused();
-
- // Insert casts for the tensorviews that are on
- // segmented edges and also on the force_to_fp16 list
- //
- // Note:
- // The cast is inserted after the segmenter canSchedule check, which
- // shouldn't cause problem short-term. The reason we put the cast here
- // is we don't want to keep making copies of the original fusion
- // during segmentation. Could consider making the cast insertion
- // reversible if we do have to test canSchedule with the casts inserted
- // during segmentation process in the future.
-
- // Keep track of groups that need to update expr list,
- // including both the producer and consumer of the selected tv's that
- // we cast to fp16.
- std::unordered_set<SegmentedGroup*> affected_group_set;
-
- // A map to keep track of the tv's that have been inserted cast
- // and its fp16 version.
- std::unordered_map<TensorView*, TensorView*> fp32_to_fp16_cast_map;
-
- // Go through all edges of the segmented fusion.
- for (auto edge : edges()) {
- auto edge_tv = edge->val->as<TensorView>();
- // Only look at ones that need to cast to fp16
- if (force_fp16_tv_set_.count(edge_tv)) {
- auto cast_tv_it = fp32_to_fp16_cast_map.find(edge->val->as<TensorView>());
- TensorView* cast_tv = nullptr;
- // Insert cast ops for this tv if we haven't done so.
- if (cast_tv_it == fp32_to_fp16_cast_map.end()) {
- cast_tv = castIntermediateValueInCompleteFusion(
- complete_fusion_.get(), edge_tv);
- fp32_to_fp16_cast_map[edge->val->as<TensorView>()] = cast_tv;
- } else {
- cast_tv = cast_tv_it->second;
- }
-
- // Update the edge to use the fp16 version
- edge->val = cast_tv;
-
- // Mark the groups for update later
- affected_group_set.insert(edge->from);
- affected_group_set.insert(edge->to);
- }
- }
-
- // Reset expression lists of all affected groups
- // TODO : this could have been a general operation that
- // the group supports. Could consider moving this into
- // segmentedGroup in a follow up.
- for (auto group : affected_group_set) {
- auto input_group_vec = getAllInputs(group);
- std::unordered_set<Val*> input_group_set(
- input_group_vec.begin(), input_group_vec.end());
-
- auto expr_set = DependencyCheck::getAllExprsBetween(
- input_group_set, getAllOutputs(group));
- group->exprs_ = std::vector<Expr*>(expr_set.begin(), expr_set.end());
- }
-}
-
-//! An utility class to compute and maintain the "producers of"
-//! relationship in a segmented graph. Space heavy and should
-//! avoid use on very large graphs.
-//!
-//! Currently trying to move as far as possible with only a
-//! producer map, without transposing it to make a consumer map.
-//! Making it NonCopyable because we should never need to
-//! copy an instance of this class.
-//! TODO: Space efficiency of this class will be important,
-//! because we need it in the pre-merging of segmentedGroups,
-//! currently O(n^2). O(nlogn) would be a reasonable
-//! goal to achieve.
-class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
- using GroupSet = std::unordered_set<SegmentedGroup*>;
- using GroupSetOwningPtr = std::unique_ptr<GroupSet>;
- using DependencyMap = std::unordered_map<SegmentedGroup*, GroupSetOwningPtr>;
-
- public:
- //! Populate producers of all groups in segmented fusion
- explicit GroupDependencyAnalysis(const SegmentedFusion* segmented_fusion)
- : segmented_fusion_(segmented_fusion) {
- computeAllProducers();
- }
-
- //! Checks if group is consumer of any group in groups_to_check
- //! TODO: refactor this similar to isConsumerOf
- bool isConsumerOfAny(
- SegmentedGroup* group,
- const std::vector<SegmentedGroup*>& groups_to_check) {
- auto& producers_of_group = getAllKnownProducersSet(group);
- for (const auto& potential_producer : groups_to_check) {
- if (producers_of_group->count(potential_producer)) {
- return true;
- }
- }
- return false;
- }
-
- bool isConsumerOf(SegmentedGroup* a, SegmentedGroup* b) {
- auto it = known_producers_of_.find(a);
- if (it == known_producers_of_.end()) {
- return false;
- }
- return it->second->count(b);
- }
-
- bool isProducerOf(SegmentedGroup* a, SegmentedGroup* b) {
- return isConsumerOf(b, a);
- }
-
- //! Finds the common producers of given set of groups
- GroupSet getCommonProducersOf(std::vector<SegmentedGroup*> groups);
-
- //! Update the map when the given two groups have been merged to create `ab`
- //! this method is for book keeping and query only, doesn't implicitly check
- //! for DAG
- void mergeGroups(SegmentedGroup* a, SegmentedGroup* b, SegmentedGroup* ab);
-
- //! Update the map when the given two groups have been merged to create
- //! `merged` this method is for book keeping and query only, doesn't
- //! implicitly check
- //! for DAG
- void mergeGroups(const GroupSet& groups, SegmentedGroup* merged);
-
- //! Populate all values that is on a path from producer to consumer
- //! efficiency can be important here. (TODO)
- GroupSet valuesBetween(SegmentedGroup* producer, SegmentedGroup* consumer) {
- if (producer == consumer) {
- return {};
- }
-
- GroupSet values_between;
- auto& all_producers_of_consumer = known_producers_of_.at(consumer);
- TORCH_INTERNAL_ASSERT(
- all_producers_of_consumer->count(producer),
- "Fusion segment: Trying to compute path between two nodes that are not producer-consumer pairs");
-
- std::copy_if(
- all_producers_of_consumer->begin(),
- all_producers_of_consumer->end(),
- std::inserter(values_between, values_between.end()),
- [this, producer](SegmentedGroup* producer_of_consumer) {
- // Checks if producer is on the producer path of this intermediate
- // node
- return known_producers_of_.at(producer_of_consumer)->count(producer);
- });
-
- return values_between;
- }
-
- //! Checks if the segmented fusion this class tracks is still a DAG
- //! used for generating assertions after transforms
- bool isproducerMapDAG() const {
- for (auto& it : known_producers_of_) {
- if (it.second->count(it.first)) {
- return false;
- }
- }
- return true;
- }
-
- private:
- //! Collect initial producer info using
- //! a work list algorithm through forward traversal
- //! a backward DFS would do the same
- void computeAllProducers();
-
- //! Add all consumers of `producer` to `to_visit`
- void addConsumersToWorkList(SegmentedGroup* producer, GroupSet& to_visit) {
- for (auto e : producer->consumer_edges) {
- // A consumer wouldn't have been worked before any of its producer
- to_visit.insert(e->to);
- }
- }
-
- //! Propagate all known producers of `from` into `into`, used to keep track
- //! of:
- //! 1. `from` is a producer of `into`
- //! 2. `from` has been merged with other group to create `into`
- void mergeAllKnownProducersIntoFrom(
- SegmentedGroup* into,
- SegmentedGroup* from) {
- auto& producer_set_to_merge = *getAllKnownProducersSet(from);
- for (auto group : producer_set_to_merge) {
- getAllKnownProducersSet(into)->insert(group);
- }
- }
-
- //! Utility to access known producers of a group so far
- GroupSetOwningPtr& getAllKnownProducersSet(SegmentedGroup* group) {
- auto& producer_set_ptr = known_producers_of_[group];
- if (!producer_set_ptr) {
- producer_set_ptr = std::make_unique<GroupSet>();
- }
- return producer_set_ptr;
- }
-
- // utility to compute the set intersection of group sets a,b
- GroupSet groupSetIntersection(const GroupSet& a, const GroupSet& b) {
- bool a_is_smaller = a.size() < b.size();
- const auto& smaller_group_set = a_is_smaller ? a : b;
- const auto& bigger_group_set = a_is_smaller ? b : a;
-
- GroupSet intersection;
- for (auto group : smaller_group_set) {
- if (bigger_group_set.count(group)) {
- intersection.insert(group);
- }
- }
- return intersection;
- }
-
- private:
- const SegmentedFusion* segmented_fusion_;
- DependencyMap known_producers_of_;
-};
-
-//! Finds the common producers of given set of groups
-GroupDependencyAnalysis::GroupSet GroupDependencyAnalysis::getCommonProducersOf(
- std::vector<SegmentedGroup*> groups) {
- if (groups.empty()) {
- return {};
- }
-
- // Optimization: start with the smallest producer set
- std::sort(
- groups.begin(),
- groups.end(),
- [this](SegmentedGroup* a, SegmentedGroup* b) {
- return known_producers_of_.at(a)->size() <
- known_producers_of_.at(b)->size();
- });
-
- // Get intersection of producers
- GroupSet common_producers = *(known_producers_of_.at(groups[0]));
- for (size_t i = 1; i < groups.size(); i++) {
- common_producers = groupSetIntersection(
- common_producers, *(known_producers_of_.at(groups[i])));
- }
-
- return common_producers;
-}
-
-//! Update the map when the given two groups have been merged to create `ab`
-//! this method is for book keeping and query only, doesn't implicitly check
-//! for DAG
-void GroupDependencyAnalysis::mergeGroups(
- SegmentedGroup* a,
- SegmentedGroup* b,
- SegmentedGroup* ab) {
- // Access/Create the producer set of ab
- auto& ab_set = getAllKnownProducersSet(ab);
-
- // propagate a's and b's known producers into ab
- mergeAllKnownProducersIntoFrom(ab, a);
- mergeAllKnownProducersIntoFrom(ab, b);
-
- // a, b are now merged, so no longer exist
- ab_set->erase(a);
- ab_set->erase(b);
-
- // a, b no longer exist, remove their producer sets
- known_producers_of_.erase(a);
- known_producers_of_.erase(b);
-
- // update producer maps of other groups
- for (auto& it : known_producers_of_) {
- // for all groups that are produced by either a or b
- if (it.second->count(a) || it.second->count(b)) {
- // insert ab as the new producer
- it.second->insert(ab);
- // all producers of both a and b are now producers of `it`
- mergeAllKnownProducersIntoFrom(it.first, ab);
- }
- // a, b no longer exist, remove them from `it`
- it.second->erase(a);
- it.second->erase(b);
- }
-}
-
-//! Update the map when the given two groups have been merged to create
-//! `merged` this method is for book keeping and query only, doesn't
-//! implicitly check
-//! for DAG
-void GroupDependencyAnalysis::mergeGroups(
- const GroupSet& groups,
- SegmentedGroup* merged) {
- // Access/Create the producer set of merged
- auto& merged_set = getAllKnownProducersSet(merged);
-
- // Populate all producers of groups and
- // write into producer map of merged
- std::for_each(
- groups.begin(), groups.end(), [this, merged](SegmentedGroup* group) {
- mergeAllKnownProducersIntoFrom(merged, group);
- });
-
- // Erase all groups that was merged from producer map
- std::for_each(
- groups.begin(), groups.end(), [this, &merged_set](SegmentedGroup* group) {
- // erase inter dependencies
- merged_set->erase(group);
- // erase producer map tracking merged entires
- known_producers_of_.erase(group);
- });
-
- // Update producer relationships with other groups in producer map
- for (auto& it : known_producers_of_) {
- auto producer_intersection = groupSetIntersection(*(it.second), groups);
- // if current node has any producer that was merged
- if (producer_intersection.size() > 0) {
- for (auto merged_producer : producer_intersection) {
- // delete all disappearing producers
- it.second->erase(merged_producer);
- }
- // insert the new group as producer
- it.second->insert(merged);
- }
- }
-}
-
-//! Collect initial producer info using
-//! a work list algorithm through forward traversal
-//! a backward DFS would do the same
-void GroupDependencyAnalysis::computeAllProducers() {
- GroupSet visited;
- GroupSet to_visit;
-
- // Collect source nodes, with no producers we are guaranteed
- // a source node on a DAG
- std::copy_if(
- segmented_fusion_->cgroups().begin(),
- segmented_fusion_->cgroups().end(),
- std::inserter(visited, visited.end()),
- [](SegmentedGroup* group) { return group->producer_edges.empty(); });
-
- // visited now only contain source nodes
- // they can go backward to nowhere
- for (auto group : visited) {
- addConsumersToWorkList(group, to_visit);
- }
-
- while (!to_visit.empty()) {
- SegmentedGroup* to_update = nullptr;
- for (auto visiting_group : to_visit) {
- if (std::all_of(
- visiting_group->producer_edges.begin(),
- visiting_group->producer_edges.end(),
- [&visited](SegmentedEdge* e) {
- return visited.count(e->from);
- })) {
- // filter multi-edges
- GroupSet producers_of_visiting_group;
- for (auto edge : visiting_group->producer_edges) {
- producers_of_visiting_group.insert(edge->from);
- }
-
- // populate all possible paths
- // from producer backward, including
- // the producer
- for (auto producer : producers_of_visiting_group) {
- getAllKnownProducersSet(visiting_group)->insert(producer);
- mergeAllKnownProducersIntoFrom(visiting_group, producer);
- }
- to_update = visiting_group;
- break;
- }
- }
- if (to_update) {
- addConsumersToWorkList(to_update, to_visit);
- to_visit.erase(to_update);
- visited.insert(to_update);
- } else {
- TORCH_INTERNAL_ASSERT(false, "unreachable, original graph not a DAG");
- }
- }
-}
-
-std::ostream& operator<<(
- std::ostream& os,
- const SegmentedFusion* segmented_fusion) {
- // Topologically sort groups
- GroupDependencyAnalysis dependency(segmented_fusion);
- std::vector<SegmentedGroup*> groups_to_print(
- segmented_fusion->cgroups().begin(), segmented_fusion->cgroups().end());
- std::vector<SegmentedGroup*> sorted_groups_to_print;
-
- // Sort groups topologically from producer to consumer before printing
- while (!groups_to_print.empty()) {
- auto group_it_to_append = groups_to_print.begin();
- for (auto group_it_to_compare = groups_to_print.begin();
- group_it_to_compare != groups_to_print.end();
- group_it_to_compare++) {
- if (dependency.isProducerOf(*group_it_to_compare, *group_it_to_append)) {
- group_it_to_append = group_it_to_compare;
- }
- }
- sorted_groups_to_print.push_back(*group_it_to_append);
- groups_to_print.erase(group_it_to_append);
- }
-
- // Do a reverse look up to check the order of sorted groups
- std::unordered_map<SegmentedGroup*, size_t> group_order;
- for (size_t i = 0; i < sorted_groups_to_print.size(); i++) {
- group_order[sorted_groups_to_print[i]] = i;
- }
-
- // Sort edges to print
- std::vector<SegmentedEdge*> sorted_edges_to_print(
- segmented_fusion->cedges().begin(), segmented_fusion->cedges().end());
- std::sort(
- sorted_edges_to_print.begin(),
- sorted_edges_to_print.end(),
- [&group_order](SegmentedEdge* edge_a, SegmentedEdge* edge_b) {
- return group_order.at(edge_a->from) < group_order.at(edge_b->from);
- });
-
- os << "Segmented_Fusion{ \n";
- os << "groups: \n";
- for (const auto g : sorted_groups_to_print) {
- os << g << "\n";
- }
- os << "edges: \n";
- for (const auto e : sorted_edges_to_print) {
- os << e << "\n";
- }
- os << "\ngroup details:\n";
- for (const auto g : sorted_groups_to_print) {
- detailGroupPrint(os, g);
- }
- os << "} //Segmented_Fusion\n";
- return os;
-}
-
-void SegmentedFusion::print() const {
- std::cout << this << "\n";
-}
-
-std::string toString(SegmentedFusion* segmented_fusion) {
- std::stringstream ss;
- ss << segmented_fusion;
- return ss.str();
-}
-
-std::unique_ptr<Fusion> SegmentedFusion::makeFusion(SegmentedGroup* sg) {
- std::unique_ptr<Fusion> fusion_segment = std::make_unique<Fusion>();
-
- auto complete_to_segment_map =
- Fusion::copy(completeFusion(), fusion_segment.get());
-
- std::vector<Val*> input_list(
- fusion_segment->inputs().begin(), fusion_segment->inputs().end());
- for (auto inp : input_list) {
- fusion_segment->removeInput(inp);
- }
-
- std::vector<Val*> output_list(
- fusion_segment->outputs().begin(), fusion_segment->outputs().end());
- for (auto out : output_list) {
- fusion_segment->removeOutput(out);
- }
-
- for (auto inp : getAllInputs(sg)) {
- fusion_segment->addInput(complete_to_segment_map.clone(inp));
- }
-
- for (auto out : getAllOutputs(sg)) {
- fusion_segment->addOutput(complete_to_segment_map.clone(out));
- }
-
- return fusion_segment;
-}
-
-void SegmentCandidateFinder::resetTraversal() {
- for (auto group : groups()) {
- // Start traversal at input groups
- if (group->producer_edges.empty()) {
- to_visit_.push_back(group);
- }
- group->visited_ = false;
- group->level_ = 0;
- }
-}
-
-void SegmentCandidateFinder::resetLevels() {
- while (!to_visit_.empty()) {
- auto visit = to_visit_.front();
- to_visit_.pop_front();
-
- // All inputs processed?
- bool ready = true;
- if (!visit->producer_edges.empty()) {
- ready = std::all_of(
- visit->producer_edges.begin(),
- visit->producer_edges.end(),
- [&](SegmentedEdge* dep) { return dep->from->visited_; });
- }
-
- if (!ready) {
- // In case traversal doesn't complete because there's an error in the
- // DAG topology.
- next_to_visit_.push_back(visit);
- continue;
- }
-
- visit->visited_ = true;
-
- to_visit_.insert(
- to_visit_.end(), next_to_visit_.begin(), next_to_visit_.end());
- next_to_visit_.clear();
-
- for (auto out : visit->consumer_edges) {
- to_visit_.push_back(out->to);
- }
-
- visit->level_ = 0;
- for (auto inp : visit->producer_edges) {
- visit->level_ = std::max(visit->level_, inp->from->level_ + 1);
- }
- }
- TORCH_INTERNAL_ASSERT(
- next_to_visit_.empty(), "Error in graph, is not a DAG.");
-}
-
-// Disconect group from neighbors, and return edges that were disconnected
-std::unordered_set<SegmentedEdge*> SegmentCandidateFinder::disconnectGroup(
- SegmentedGroup* group) {
- std::unordered_set<SegmentedEdge*> removed_edges(
- group->producer_edges.begin(), group->producer_edges.end());
-
- for (auto edge : group->producer_edges) {
- auto from = edge->from;
- auto& from_edges = from->consumer_edges;
- auto from_edge_it = std::find(from_edges.begin(), from_edges.end(), edge);
- TORCH_INTERNAL_ASSERT(
- from_edge_it != from_edges.end(), "Could not find edge to remove.");
- from_edges.erase(from_edge_it);
- }
-
- for (auto edge : group->consumer_edges) {
- removed_edges.insert(edge);
- auto to = edge->to;
- auto& to_edges = to->producer_edges;
- auto to_edge_it = std::find(to_edges.begin(), to_edges.end(), edge);
- TORCH_INTERNAL_ASSERT(
- to_edge_it != to_edges.end(), "Could not find edge to remove.");
- to_edges.erase(to_edge_it);
- }
-
- group->producer_edges.clear();
- group->consumer_edges.clear();
-
- return removed_edges;
-}
-
-void SegmentCandidateFinder::eraseGroups(
- std::unordered_set<SegmentedGroup*>& groups_to_erase) {
- std::unordered_set<SegmentedEdge*> edges_to_erase;
- for (auto group : groups_to_erase) {
- auto disconnected_edges = disconnectGroup(group);
- edges_to_erase.insert(disconnected_edges.begin(), disconnected_edges.end());
- }
-
- edges().erase(
- std::remove_if(
- edges().begin(),
- edges().end(),
- [&edges_to_erase](SegmentedEdge* edge) {
- if (edges_to_erase.find(edge) != edges_to_erase.end()) {
- return true;
- };
- return false;
- }),
- edges().end());
-
- groups().erase(
- std::remove_if(
- groups().begin(),
- groups().end(),
- [&groups_to_erase](SegmentedGroup* group) {
- if (groups_to_erase.find(group) != groups_to_erase.end()) {
- return true;
- };
- return false;
- }),
- groups().end());
-}
-
-SegmentedGroup* SegmentCandidateFinder::mergeNodes() {
- SegmentedGroup* last_merged = nullptr;
- auto it = to_merge_.begin();
- TORCH_INTERNAL_ASSERT(to_merge_.size() % 2 == 0);
- while (it != to_merge_.end()) {
- auto group1 = *it++;
- auto group2 = *it++;
-
- clean_up_groups_.emplace(group1);
- clean_up_groups_.emplace(group2);
-
- // Make the new joined node
- auto joined_group = segmented_fusion_->newGroup();
-
- joined_group->input_vals =
- uniqueValConcat({group1->input_vals, group2->input_vals});
-
- joined_group->output_vals =
- uniqueValConcat({group1->output_vals, group2->output_vals});
-
- joined_group->exprs_ = group1->exprs_;
- joined_group->exprs_.insert(
- joined_group->exprs_.end(),
- group2->exprs_.begin(),
- group2->exprs_.end());
-
- auto producer_edges = getMergedProducerEdges(group1, group2);
- // Connect joined group to resulting neighbors
- for (auto edge : producer_edges) {
- auto from = edge->from;
- auto val = edge->val;
-
- auto new_edge = segmented_fusion_->newEdge(from, joined_group, val);
- joined_group->producer_edges.push_back(new_edge);
- from->consumer_edges.push_back(new_edge);
- }
-
- auto consumer_edges = getMergedConsumerEdges(group1, group2);
-
- for (auto edge : consumer_edges) {
- auto to = edge->to;
- auto val = edge->val;
-
- auto new_edge = segmented_fusion_->newEdge(joined_group, to, val);
- joined_group->consumer_edges.push_back(new_edge);
- edge->to->producer_edges.push_back(new_edge);
- }
-
- joined_group->setHeuristic(deriveHeuristic(joined_group));
- // Need to maintain the group dependency data if it has been intialized
- // by previous merging
- if (group_dependency_) {
- group_dependency_->as<GroupDependencyAnalysis>()->mergeGroups(
- group1, group2, joined_group);
- }
- last_merged = joined_group;
- }
-
- to_merge_.clear();
- for (auto group : clean_up_groups_) {
- auto disconnected_edges = disconnectGroup(group);
- clean_up_edges_.insert(
- disconnected_edges.begin(), disconnected_edges.end());
- }
-
- edges().erase(
- std::remove_if(
- edges().begin(),
- edges().end(),
- [this](SegmentedEdge* edge) {
- if (this->clean_up_edges_.find(edge) !=
- this->clean_up_edges_.end()) {
- return true;
- };
- return false;
- }),
- edges().end());
-
- groups().erase(
- std::remove_if(
- groups().begin(),
- groups().end(),
- [this](SegmentedGroup* group) {
- if (this->clean_up_groups_.find(group) !=
- this->clean_up_groups_.end()) {
- return true;
- };
- return false;
- }),
- groups().end());
-
- clean_up_edges_.clear();
- clean_up_groups_.clear();
-
- return last_merged;
-}
-
-// Logic largely parallels mergeNodes, but they are used
-// in different phases of segmentation. Should consider
-// a clean up and share the implementations.
-SegmentedGroup* SegmentCandidateFinder::mergeAllGivenGroups(
- const std::vector<SegmentedGroup*>& groups_to_merge) {
- TORCH_INTERNAL_ASSERT(
- !groups_to_merge.empty(),
- "fusion segment :(mergeAllGivenGroups) tried to merge no groups")
-
- // Make a set to detect internal edges
- std::unordered_set<SegmentedGroup*> group_set(
- groups_to_merge.begin(), groups_to_merge.end());
-
- // Sets to de-duplicate multiple uses of
- // input/edge values and re-computations of exprs
- std::unordered_set<Val*> used_edge_vals_set;
- std::unordered_set<Val*> used_input_vals_set;
- std::unordered_set<Expr*> exprs_set;
-
- // Create new group
- auto joined_group = segmented_fusion_->newGroup();
-
- // Populate edges, exprs, global vals
- // from each of the groups
- for (auto group : groups_to_merge) {
- // Populate complete fusion inputs to the group
- for (auto input_val : group->input_vals) {
- if (!used_input_vals_set.count(input_val)) {
- used_input_vals_set.insert(input_val);
- joined_group->input_vals.push_back(input_val);
- }
- }
-
- // Populate complete fusion outputs from the group
- for (auto output_val : group->output_vals) {
- joined_group->output_vals.push_back(output_val);
- }
-
- // Populate producer edges to the group
- for (auto edge : group->producer_edges) {
- if (
- // Check this is not internal edge
- !group_set.count(edge->from) &&
- // Check this val has been added or not
- !used_edge_vals_set.count(edge->val)) {
- used_edge_vals_set.insert(edge->val);
- auto new_producer_edge =
- segmented_fusion_->newEdge(edge->from, joined_group, edge->val);
- joined_group->producer_edges.push_back(new_producer_edge);
- edge->from->consumer_edges.push_back(new_producer_edge);
- }
- }
-
- // Populate consumer edges from the group
- for (auto edge : group->consumer_edges) {
- if (
- // Check this is not internal edge
- !group_set.count(edge->to)) {
- auto new_consumer_edge =
- segmented_fusion_->newEdge(joined_group, edge->to, edge->val);
- joined_group->consumer_edges.push_back(new_consumer_edge);
- edge->to->producer_edges.push_back(new_consumer_edge);
- }
- }
-
- // Populate exprs
- for (auto expr : group->exprs_) {
- if (!exprs_set.count(expr)) {
- joined_group->exprs_.push_back(expr);
- exprs_set.insert(expr);
- }
- }
- }
-
- // Clean up original groups from segmented fusion
- for (auto group : groups_to_merge) {
- auto disconnected_edges = disconnectGroup(group);
- clean_up_edges_.insert(
- disconnected_edges.begin(), disconnected_edges.end());
- }
-
- edges().erase(
- std::remove_if(
- edges().begin(),
- edges().end(),
- [this](SegmentedEdge* edge) { return clean_up_edges_.count(edge); }),
- edges().end());
-
- groups().erase(
- std::remove_if(
- groups().begin(),
- groups().end(),
- [&group_set](SegmentedGroup* group) -> bool {
- return group_set.count(group);
- }),
- groups().end());
-
- clean_up_edges_.clear();
-
- joined_group->setHeuristic(deriveHeuristic(joined_group));
- return joined_group;
-}
-namespace {
-
-// Guard to temporarily change the inputs and outputs of a fusion. On
-// destruction will return fusion to original state.
-// Not used temporarily but will be useful when adding more mergin heuristics
-class FusionSegmentGuard : public NonCopyable {
- public:
- FusionSegmentGuard() = delete;
-
- FusionSegmentGuard(
- Fusion* fusion,
- std::vector<Val*> inputs,
- std::vector<Val*> outputs)
- : fusion_(fusion),
- old_inputs_(fusion->inputs()),
- old_outputs_(fusion->outputs()),
- new_inputs_(std::move(inputs)),
- new_outputs_(std::move(outputs)) {
- TORCH_INTERNAL_ASSERT(fusion_ != nullptr);
- for (auto old_inp : old_inputs_) {
- fusion_->removeInput(old_inp);
- }
-
- for (auto old_out : old_outputs_) {
- fusion_->removeOutput(old_out);
- }
-
- for (auto new_inp : new_inputs_) {
- fusion_->addInput(new_inp);
- }
-
- for (auto new_out : new_outputs_) {
- fusion_->addOutput(new_out);
- }
- }
-
- ~FusionSegmentGuard() {
- FUSER_PERF_SCOPE("~Segmenter::FusionSegmentGuard");
-
- if (fusion_ == nullptr) {
- return;
- }
- for (auto new_inp : new_inputs_) {
- fusion_->removeInput(new_inp);
- }
-
- for (auto new_out : new_outputs_) {
- fusion_->removeOutput(new_out);
- }
-
- for (auto old_inp : old_inputs_) {
- fusion_->addInput(old_inp);
- }
-
- for (auto old_out : old_outputs_) {
- fusion_->addOutput(old_out);
- }
- }
-
- private:
- Fusion* const fusion_ = nullptr;
- const std::vector<Val*> old_inputs_;
- const std::vector<Val*> old_outputs_;
- const std::vector<Val*> new_inputs_;
- const std::vector<Val*> new_outputs_;
-};
-
-c10::optional<ScheduleHeuristic> tryMerge(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- SegmentedGroup* a,
- SegmentedGroup* b = nullptr) {
- FusionSegmentGuard fsg(fusion, getAllInputs(a, b), getAllOutputs(a, b));
-
- return SchedulerEntry::proposeHeuristics(fusion, runtime_info);
-}
-
-c10::optional<ScheduleHeuristic> tryMerge(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- const std::vector<SegmentedGroup*>& segmented_groups) {
- FusionSegmentGuard fsg(
- fusion,
- allInputsIfTrueElseOutputs(segmented_groups, true),
- allInputsIfTrueElseOutputs(segmented_groups, false));
- return SchedulerEntry::proposeHeuristics(fusion, runtime_info);
-}
-
-// This function is for cleanup and
-// easier debugging. It shouldn't affect functionality
-// since segmented fusions are compiled with fusion
-// guard on the edges instead of actually looking
-// at the exprs.
-void deDuplicateScalarExprs(std::vector<Expr*>& exprs) {
- // Exprs in SegmentedGroup are not ordered
- // so it is ok to insert them from unordered
- // set
- std::unordered_set<Expr*> scalar_expr_set;
-
- std::copy_if(
- exprs.begin(),
- exprs.end(),
- std::inserter(scalar_expr_set, scalar_expr_set.end()),
- [](Expr* expr) { return ir_utils::isScalarOp(expr); });
-
- if (!scalar_expr_set.empty()) {
- exprs.erase(
- std::remove_if(
- exprs.begin(),
- exprs.end(),
- [&scalar_expr_set](Expr* expr) {
- return scalar_expr_set.count(expr);
- }),
- exprs.end());
- exprs.insert(exprs.end(), scalar_expr_set.begin(), scalar_expr_set.end());
- }
-}
-
-} // namespace
-
-c10::optional<std::unique_ptr<SchedulerEntry>> SegmentedGroup::
- getMaybeSchedulerEntry(SchedulerRuntimeInfo& runtime_info) {
- FUSER_PERF_SCOPE("SegmentedGroup::getMaybeSchedulerEntry");
- auto fusion = segmented_fusion_->completeFusion();
- auto data_cache = segmented_fusion_->getCachedHeuristicDataFor(this);
- FusionSegmentGuard fsg(fusion, getAllInputs(this), getAllOutputs(this));
- if (!SchedulerEntry::canSchedule(
- heuristic(), fusion, runtime_info, data_cache)) {
- return c10::nullopt;
- }
- return SchedulerEntry::makeEntry(
- heuristic(), fusion, runtime_info, data_cache);
-}
-
-// Custom merge node passes:
-// These passes are added at the beginning or the end of
-// the node merging process to direct the heuristics of
-// node merging process
-//
-// Should consider generalization and make a proper interface
-// if we have more merge node heuristics like this
-
-//! Translate Welford
-//!
-//! This pass can be inserted at any stages of segmentation,
-//! and it tries to replace welford ops with persistent
-//! mean and var ops.
-//!
-//! The checking of feasibility of persistent kernels
-//! is through normalization schedulers. The general idea
-//! is to first try to translate on a copy, and see if
-//! normalization scheduler is willing to produce a
-//! persistent kernel.
-//!
-//! For complete fusion this pass checks if all the
-//! welford ops can be translated simultaneously to
-//! produce a persistent normalization kernel and
-//! will perform translation if checks pass.
-//!
-//! For segmented fusion, same check is performed within
-//! each segmented group to collect applicable welford ops,
-//! and actual translations are performed on the complete
-//! fusion after all the checks are done.
-class TranslateApplicableWelford {
- public:
- //! Try translation on each segmented group of
- //! given segmented fusion
- //! returns true if any welford has been translated
- static bool run(
- SegmentedFusion* segmented_fusion,
- const at::ArrayRef<IValue>& runtime_inputs) {
- TranslateApplicableWelford translate_welford(
- segmented_fusion, runtime_inputs);
- return translate_welford.translated_any_welford_;
- }
-
- //! Try translation on complete fusion,
- //! returns true if any welford has been translated
- static bool run(Fusion* fusion, const at::ArrayRef<IValue>& runtime_inputs) {
- TranslateApplicableWelford translate_welford(fusion, runtime_inputs);
- return translate_welford.translated_any_welford_;
- }
-
- private:
- explicit TranslateApplicableWelford(
- SegmentedFusion* segmented_fusion,
- const at::ArrayRef<IValue>& runtime_inputs);
-
- explicit TranslateApplicableWelford(
- Fusion* fusion,
- const at::ArrayRef<IValue>& runtime_inputs);
-
- //! Given vector of welford ops from the same fusion,
- //! checks if translating all of them result in a
- //! persistent normalization kernel by try-runs on
- //! a test copy of the original fusion.
- //!
- //! Supported use cases are either un-segmented fusion,
- //! or all the given welfords are within the same
- //! segmented group. In the latter case, the segmented
- //! group containing all the welford ops needs to be
- //! provided.
- bool wouldTranslateToPersistent(
- const std::vector<WelfordOp*>& orignal_welfords,
- SegmentedGroup* group = nullptr);
-
- //! Translate the given welford op into separate
- //! average and standard deviation calculation.
- void translateSingleWelford(WelfordOp* welford);
-
- //! Utility to test if a translated fusion
- //! gives a persistent kernel. Uses normalization
- //! scheduler to do the test.
- bool isValidPersistentFusion(
- Fusion* translated_fusion,
- SchedulerRuntimeInfo& runtime_info);
-
- //! Update expression list of groups containing
- //! welford ops that have been translated.
- void updateGroupExprs(SegmentedGroup* group);
-
- private:
- //! Indicates any translation happened.
- bool translated_any_welford_ = false;
-
- //! a reference to global fusion runtime inputs
- const at::ArrayRef<IValue>& runtime_inputs_;
-
- //! For translation within group only,
- //! group boundary at test copy
- //! (see wouldTranslateToPersistent implementation )
- std::vector<Val*> test_group_inputs_;
- std::vector<Val*> test_group_outputs_;
-};
-
-TranslateApplicableWelford::TranslateApplicableWelford(
- Fusion* fusion,
- const at::ArrayRef<IValue>& runtime_inputs)
- : runtime_inputs_(runtime_inputs) {
- std::vector<WelfordOp*> orignal_welfords(
- ir_utils::filterByType<WelfordOp>(fusion->unordered_exprs()).begin(),
- ir_utils::filterByType<WelfordOp>(fusion->unordered_exprs()).end());
-
- if (wouldTranslateToPersistent(orignal_welfords)) {
- for (auto welford : orignal_welfords) {
- translateSingleWelford(welford);
- }
- translated_any_welford_ = true;
- }
-}
-
-TranslateApplicableWelford::TranslateApplicableWelford(
- SegmentedFusion* segmented_fusion,
- const at::ArrayRef<IValue>& runtime_inputs)
- : runtime_inputs_(runtime_inputs) {
- std::vector<SegmentedGroup*> translated_groups;
- std::vector<WelfordOp*> welford_to_translate;
- // Find welfords that can be translated in each group
- for (auto group : segmented_fusion->groups()) {
- std::vector<WelfordOp*> welford_in_group(
- ir_utils::filterByType<WelfordOp>(group->exprs()).begin(),
- ir_utils::filterByType<WelfordOp>(group->exprs()).end());
-
- if (wouldTranslateToPersistent(welford_in_group, group)) {
- translated_groups.push_back(group);
- welford_to_translate.insert(
- welford_to_translate.end(),
- welford_in_group.begin(),
- welford_in_group.end());
- }
- }
-
- // Actually translate the welford ops
- // and record all the vals that have been
- // replaced by the translation.
- for (auto welford : welford_to_translate) {
- translateSingleWelford(welford);
- }
-
- for (auto translated_group : translated_groups) {
- // Update heuristics and expr list of translated groups
- translated_group->heuristic_ = ScheduleHeuristic::Normalization;
- updateGroupExprs(translated_group);
- }
-}
-
-bool TranslateApplicableWelford::isValidPersistentFusion(
- Fusion* translated_fusion,
- SchedulerRuntimeInfo& runtime_info) {
- if (!SchedulerEntry::canSchedule(
- ScheduleHeuristic::Normalization, translated_fusion, runtime_info)) {
- return false;
- }
-
- auto scheduler = SchedulerEntry::makeEntry(
- ScheduleHeuristic::Normalization, translated_fusion, runtime_info);
-
- return scheduler->reductionParams().persistent_kernel;
-}
-
-bool TranslateApplicableWelford::wouldTranslateToPersistent(
- const std::vector<WelfordOp*>& orignal_welfords,
- SegmentedGroup* group) {
- if (orignal_welfords.empty()) {
- return false;
- }
-
- // Make sure all welford ops come from the same complete fusion
- auto fusion = orignal_welfords[0]->fusion();
- TORCH_INTERNAL_ASSERT(
- std::all_of(
- orignal_welfords.begin(),
- orignal_welfords.end(),
- [fusion](WelfordOp* welford) { return welford->fusion() == fusion; }),
- "Welfords in given vector not in the same fusion");
-
- // Make initial `in-progress copy`
- auto test_copy = std::make_unique<Fusion>();
- auto original_to_test_map = Fusion::copy(fusion, test_copy.get());
-
- std::vector<WelfordOp*> copied_welfords;
- std::transform(
- orignal_welfords.begin(),
- orignal_welfords.end(),
- std::back_inserter(copied_welfords),
- [&original_to_test_map](auto welford) {
- return original_to_test_map.clone(welford);
- });
-
- // Translate the welford ops
- for (auto welford_to_translate : copied_welfords) {
- translateSingleWelford(welford_to_translate);
- }
-
- SchedulerRuntimeInfo runtime_info(test_copy.get(), runtime_inputs_, true);
- // If we are looking at a segment of fusion,
- // we maintain the segmented group boundary,
- // one set for in_progress copy and one set
- // for `test copy`
- if (group != nullptr) {
- auto original_inputs = getAllInputs(group);
- auto original_outputs = getAllOutputs(group);
- test_group_inputs_.clear();
- test_group_outputs_.clear();
- std::transform(
- original_inputs.begin(),
- original_inputs.end(),
- std::back_inserter(test_group_inputs_),
- [&original_to_test_map](Val* in) {
- return original_to_test_map.clone(in);
- });
- std::transform(
- original_outputs.begin(),
- original_outputs.end(),
- std::back_inserter(test_group_outputs_),
- [&original_to_test_map](Val* out) {
- return original_to_test_map.clone(out);
- });
-
- // Temporarily localize test copy around
- // the group boundary
- FusionSegmentGuard fsg(
- test_copy.get(), test_group_inputs_, test_group_outputs_);
-
- // Test if the translated copy is persistent
- return isValidPersistentFusion(test_copy.get(), runtime_info);
- }
- // In the case where we work on un-segmented
- // fusion, no group boundary logic, just
- // translate and test.
- return isValidPersistentFusion(test_copy.get(), runtime_info);
-}
-
-void TranslateApplicableWelford::translateSingleWelford(WelfordOp* welford) {
- auto fusion = welford->fusion();
- FusionGuard fg(fusion);
- // Only support translation of welford ops that
- // doesn't take inputs that are already statistics,
- // i.e. an r-factor product.
- // This translation works on un-scheduled fusions so
- // shouldn't expect to see this.
- TORCH_INTERNAL_ASSERT(welford->inN()->isOneInt());
-
- // Grab the inputs and outputs of the welford
- auto in_val = welford->in()->as<TensorView>();
- auto out_avg = welford->outAvg()->as<TensorView>();
- auto out_var = welford->outVar()->as<TensorView>();
- auto out_N = welford->outN()->as<TensorView>();
-
- fusion->removeExpr(welford);
-
- // Create normalization based welford graph
- // largely taken from batchnorm cpp benchmark
- auto& in_root = in_val->getRootDomain();
- auto& out_root = out_avg->getRootDomain();
- std::vector<int> red_axes;
-
- // Create scalar version of the feature element
- // counting.
- Val* num_features = new Double(1);
- std::vector<bool> broadcast_mask(in_root.size(), false);
- for (size_t i = 0; i < in_root.size(); i++) {
- if (out_root[i]->isReduction()) {
- red_axes.push_back(i);
- broadcast_mask[i] = true;
- num_features = mul(num_features, out_root[i]->extent());
- }
- }
-
- // Build a normalization expression group that is
- // equivalent to a welford operation.
- auto x_sum = sum(in_val, red_axes);
- new BinaryOp(BinaryOpType::Div, out_avg, x_sum, num_features);
- // welford.avg may be broadcast. Reuse it if found.
- TensorView* x_avg_bcast = nullptr;
- for (auto& use_expr : out_avg->uses()) {
- if (auto bcast = dynamic_cast<BroadcastOp*>(use_expr)) {
- if (bcast->getBroadcastDimFlags() == broadcast_mask) {
- // Same broadcast found.
- x_avg_bcast = bcast->out()->as<TensorView>();
- break;
- }
- }
- }
-
- // x_mean_sub may already exist. Reuse it if found.
- TensorView* x_mean_sub = nullptr;
- if (x_avg_bcast != nullptr) {
- for (auto& use_expr : x_avg_bcast->uses()) {
- if (auto bop = dynamic_cast<BinaryOp*>(use_expr)) {
- if (bop->getBinaryOpType() == BinaryOpType::Sub) {
- if (bop->lhs() == in_val && bop->rhs() == x_avg_bcast) {
- x_mean_sub = bop->out()->as<TensorView>();
- }
- }
- }
- }
- }
-
- if (x_avg_bcast == nullptr) {
- x_avg_bcast = broadcast(out_avg, broadcast_mask);
- }
-
- if (x_mean_sub == nullptr) {
- x_mean_sub = sub(in_val, x_avg_bcast);
- }
-
- auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub);
- new ReductionOp(BinaryOpType::Add, new Double(0.0), out_var, x_mean_sub_pow);
- new UnaryOp(UnaryOpType::Set, out_N, num_features);
-
- // out_avg, out_N are now outputs of a pointwise ops and we
- // need to clear out its reduction domains.
- out_avg->clearReductionIterDomains();
- out_N->clearReductionIterDomains();
-}
-
-void TranslateApplicableWelford::updateGroupExprs(SegmentedGroup* group) {
- // Re-evaluate expression list of the translated group
- auto input_vec = getAllInputs(group);
- auto output_vec = getAllOutputs(group);
-
- if (input_vec.empty() || output_vec.empty()) {
- return;
- }
-
- std::unordered_set<Val*> input_set(input_vec.begin(), input_vec.end());
- auto expr_set = DependencyCheck::getAllExprsBetween(input_set, output_vec);
- group->exprs_ = std::vector<Expr*>(expr_set.begin(), expr_set.end());
-}
-
-bool SegmentCandidateFinder::TranslateWelfordInFusion(
- Fusion* fusion,
- const at::ArrayRef<IValue>& runtime_inputs) {
- return TranslateApplicableWelford::run(fusion, runtime_inputs);
-}
-
-//! CombineReductions:
-//! This pass works before the main merge node process
-//! It identifies reduction operations that can be combined
-//! together to form a normalization kernel.
-//! Two reductions are considered the same type if they have
-//! the same root domain length, and the reduction axis are the same.
-//! This pass tries to merge nodes with the same reduction type based
-//! on the graph structure.
-class CombineReductions {
- using GroupSet = std::unordered_set<SegmentedGroup*>;
- using GroupVec = std::vector<SegmentedGroup*>;
- class ReductionSignature;
-
- public:
- static void run(SegmentCandidateFinder* segment_candidate_finder) {
- CombineReductions combine_reductions(segment_candidate_finder);
- }
- static bool shouldRun(SegmentCandidateFinder* segment_candidate_finder);
-
- private:
- CombineReductions(SegmentCandidateFinder* segment_candidate_finder)
- : segment_candidate_finder_(segment_candidate_finder) {
- // Run pass over the segments
-
- // Collect segmented groups with reductions in them,
- // Assuming running before any merge happened, so
- // should see exactly one non-trivial reduction in each group
- for (auto group : segment_candidate_finder_->groups()) {
- if (auto rop_signature =
- ReductionSignature::makeReductionSignature(group)) {
- // Ignore pure squeeze operations in this analysis
- if (!rop_signature->hasNonTrivialReduction()) {
- continue;
- }
-
- groups_with_reductions_.push_back(group);
- // Check if this reduction signature is one that we have seen before
- auto signature_match_it = std::find_if(
- known_reduction_signatures_.begin(),
- known_reduction_signatures_.end(),
- [&rop_signature](auto& know_signature) {
- return know_signature->sameAs(rop_signature.get());
- });
- // Unmatched: Create a new signature entry if not known
- if (signature_match_it == known_reduction_signatures_.end()) {
- group_reduction_signature_map_[group] = rop_signature.get();
- known_reduction_signatures_.emplace_back(std::move(rop_signature));
- } else {
- // Matched known signature: Mark that this groups belongs to know
- // signature
- group_reduction_signature_map_[group] = signature_match_it->get();
- }
- }
- }
-
- // Keep trying to merge groups with compatible reductions and compatible
- // paths
- // until no more merge opportunity can be identified
- bool merged_groups = true;
- while (merged_groups) {
- merged_groups = false;
-
- // Merge one pair of reduction groups at a time, and need
- // the pass to update dependency info along the way to avoid cycles
- for (size_t first_group_index = 0;
- first_group_index < groups_with_reductions_.size();
- first_group_index++) {
- if (merged_groups) {
- // Need to break and re-enter this loop because
- // groups_with_reductions_ will be updated
- break;
- }
-
- // Select one of the group to merge and get its reduction signature
- auto first_group = groups_with_reductions_[first_group_index];
- auto first_group_signature =
- group_reduction_signature_map_.at(first_group);
-
- for (size_t second_group_index = first_group_index + 1;
- second_group_index < groups_with_reductions_.size();
- second_group_index++) {
- if (merged_groups) {
- // Need to break and re-enter this loop because
- // groups_with_reductions_ will be updated
- break;
- }
- auto second_group = groups_with_reductions_[second_group_index];
- auto second_group_signature =
- group_reduction_signature_map_.at(second_group);
-
- // Cannot merge if their signatures are not the same
- if (!first_group_signature->sameAs(second_group_signature)) {
- continue;
- }
-
- // first try a vertical merge
- merged_groups =
- verticalReductionMerge(first_group, second_group) != nullptr;
- if (!merged_groups) {
- // vertical merge didn't happen, try a horizontal merge
- merged_groups =
- horizontalReductionMerge(first_group, second_group) != nullptr;
- }
- }
- }
- }
- }
-
- //! Merge a vertical pair of producers and consumers,
- //! the resulting group will include all nodes that are
- //! also consumers of producer and producers of consumer,
- //! i.e. values between the given producer-consumer pair.
- //! Can be proven that:
- //! 1. Including all of these nodes will be cycle-free
- //! 2. These nodes are the minimal set of nodes to include if
- //! for producer-consumer pair to be in the same group cycle-free
- //!
- //! Returns nullptr if such merge cannot be achieved.
- //! Reasons for not merging will include:
- //! 1. Given groups do not form producer-consumer pair
- //! 2. Merge will create cycle on the graph
- //! 3. The merged joined group cannot be scheduled
- SegmentedGroup* verticalReductionMerge(
- SegmentedGroup* first_group,
- SegmentedGroup* second_group) {
- // This is part of ReductionCombine pass, and we should only call this
- // function on a pair of
- // reduction/normalization groups
- TORCH_INTERNAL_ASSERT(
- group_reduction_signature_map_.at(first_group)
- ->sameAs(group_reduction_signature_map_.at(second_group)));
- TORCH_INTERNAL_ASSERT(first_group != second_group);
- // Get the group dependency data from segment finder
- auto dependency_analysis = segment_candidate_finder_->getGroupDependency();
-
- // Check producer-consumer relationship
- SegmentedGroup* producer = nullptr;
- SegmentedGroup* consumer = nullptr;
- if (dependency_analysis->isConsumerOf(first_group, second_group)) {
- producer = second_group;
- consumer = first_group;
- } else if (dependency_analysis->isProducerOf(first_group, second_group)) {
- producer = first_group;
- consumer = second_group;
- } else {
- // Given groups aren't producer-consumer pair, won't merge
- return nullptr;
- }
-
- // Collect all groups that we need to merge along with the producer and
- // consumer
- auto all_groups_to_merge =
- getValidMinVerticalMergedGroupSet(producer, consumer);
-
- if (all_groups_to_merge.empty()) {
- // The vertical paths from producer to consumer have in-compatible
- // reductions
- // so this vertical merge cannot be done.
- return nullptr;
- }
-
- // TODO: this step would not be deterministic, because valuesBetween isn't
- // could fix this by a topological order
- std::vector<SegmentedGroup*> all_groups_to_merge_vec(
- all_groups_to_merge.begin(), all_groups_to_merge.end());
-
- // Final sanity check: the merged group can actually be scheduled
- Fusion* fusion =
- segment_candidate_finder_->segmented_fusion_->completeFusion();
- if (!tryMerge(
- fusion,
- segment_candidate_finder_->runtimeInfo(),
- all_groups_to_merge_vec)) {
- return nullptr;
- }
-
- // Merge this group
- auto joined_group =
- segment_candidate_finder_->mergeAllGivenGroups(all_groups_to_merge_vec);
-
- // Update dependency analysis
- dependency_analysis->mergeGroups(all_groups_to_merge, joined_group);
-
- // Update the reduction groups that are merged
- groups_with_reductions_.push_back(joined_group);
- group_reduction_signature_map_[joined_group] =
- group_reduction_signature_map_.at(first_group);
- groups_with_reductions_.erase(
- std::remove_if(
- groups_with_reductions_.begin(),
- groups_with_reductions_.end(),
- [&all_groups_to_merge](SegmentedGroup* group) {
- return all_groups_to_merge.count(group);
- }),
- groups_with_reductions_.end());
-
- return joined_group;
- }
-
- //! Horizontal reduction merging:
- //! merge two horizontal groups with reduction expressions to make a joined
- //! normalization group. A pair of horizontal groups are ones that are not
- //! a producer-consumer pair, and share either a common producer or a common
- //! consumer.
- //!
- //! TODO: This implementation looks at common producers only, since common
- //! consumers
- //! are not computed easily with current dependency analysis.
- SegmentedGroup* horizontalReductionMerge(
- SegmentedGroup* first_group,
- SegmentedGroup* second_group) {
- // This is part of ReductionCombine pass, and we should only call this
- // function on a pair of
- // reduction/normalization groups
- TORCH_INTERNAL_ASSERT(
- group_reduction_signature_map_.at(first_group)
- ->sameAs(group_reduction_signature_map_.at(second_group)));
- TORCH_INTERNAL_ASSERT(first_group != second_group);
-
- auto dependency_analysis = segment_candidate_finder_->getGroupDependency();
-
- // Check that the two groups are not producer-consumer's
- if (dependency_analysis->isConsumerOf(first_group, second_group) ||
- dependency_analysis->isProducerOf(first_group, second_group)) {
- // This merge pass will not handle producer-consumer pairs
- return nullptr;
- }
-
- // Get common producers of the two group
- auto common_producers_set =
- dependency_analysis->getCommonProducersOf({first_group, second_group});
- if (common_producers_set.empty()) {
- // The given pair doesn't have a common producer.
- // Either they have a common consumer, which we don't handle for now,
- // or maybe the two given groups are not connected.
- return nullptr;
- }
-
- // We are looking for a very specific patterns here. The cases that this
- // pattern will not capture are ones that reductions of different
- // signatures are so interleaved that we cannot find a clear cut as
- // explained below, without graph rewriting. Some graph re-writing on the
- // segmented groups level could provide extra merging opportunities for
- // free, which could be part of next step.
- //
- // The specific pattern we look for contains a common producer P with
- // immediate consumers C1, C2 such that all paths from C1 to first_group and
- // all paths from C2
- // to second_group won't hit a reduction with a different signature.
-
- // Topologically sort the common producers and start with the topologically
- // minimal,
- // i.e. one that are closest to the two groups. This will cut the search
- // space.
- std::vector<SegmentedGroup*> common_producers(
- common_producers_set.begin(), common_producers_set.end());
- std::sort(
- common_producers.begin(),
- common_producers.end(),
- [&dependency_analysis](SegmentedGroup* a, SegmentedGroup* b) {
- return dependency_analysis->isConsumerOf(a, b);
- });
-
- // Use a visited filter to prune search space.
- GroupSet visited_common_producers;
-
- // Visit the common producers found, starting from topologically minimum,
- // i.e. the ones closer to the groups
- for (auto common_producer : common_producers) {
- // Visit this common producer
- // Use a double loop in case the schedulers like some patterns
- // better than the other
- for (auto first_consumer_edge : common_producer->consumer_edges) {
- auto producer_of_first_group = first_consumer_edge->to;
- if (visited_common_producers.count(producer_of_first_group)) {
- // We have visited this node as common producer before and it
- // had conflicts. It'd hit the same conflict again if we continued
- // to pursue this edge.
- continue;
- }
- auto to_merge_with_first_group = getValidMinVerticalMergedGroupSet(
- producer_of_first_group, first_group);
- if (to_merge_with_first_group.empty()) {
- // There's no valid merge path from this consumer of common producer,
- // either due to a conflicting reduction signature, or simply there's
- // no path to first group
- continue;
- }
- for (auto second_consumer_edge : common_producer->consumer_edges) {
- auto producer_of_second_group = second_consumer_edge->to;
- if (visited_common_producers.count(producer_of_second_group)) {
- // We have visited this node as common producer before and it
- // had conflicts. It'd hit the same conflict again if we continued
- // to pursue this edge.
- continue;
- }
- auto to_merge_with_second_group = getValidMinVerticalMergedGroupSet(
- producer_of_second_group, second_group);
- if (to_merge_with_second_group.empty()) {
- // There's no valid merge path from this consumer of common
- // producer,
- // either due to a conflicting reduction signature, or simply
- // there's no path to second group
- continue;
- }
-
- // At this point we should have a pair of valid candidates,final check
- // is to see if the combined group
- // can be scheduled by schedulers
- // merge the two paths and de-duplicate,
- // re-using container here with to_merge_with_second_group
- auto& groups_to_merge_set = to_merge_with_second_group;
- groups_to_merge_set.insert(
- to_merge_with_first_group.begin(),
- to_merge_with_first_group.end());
- std::vector<SegmentedGroup*> groups_to_merge_vec(
- groups_to_merge_set.begin(), groups_to_merge_set.end());
- Fusion* fusion =
- segment_candidate_finder_->segmented_fusion_->completeFusion();
- if (tryMerge(
- fusion,
- segment_candidate_finder_->runtimeInfo(),
- groups_to_merge_vec)) {
- // Found a valid horizontal merge, want to proceed with merging here
- auto joined_group = segment_candidate_finder_->mergeAllGivenGroups(
- groups_to_merge_vec);
- dependency_analysis->mergeGroups(groups_to_merge_set, joined_group);
-
- groups_with_reductions_.push_back(joined_group);
- group_reduction_signature_map_[joined_group] =
- group_reduction_signature_map_.at(first_group);
- groups_with_reductions_.erase(
- std::remove_if(
- groups_with_reductions_.begin(),
- groups_with_reductions_.end(),
- [&groups_to_merge_set](SegmentedGroup* group) {
- return groups_to_merge_set.count(group);
- }),
- groups_with_reductions_.end());
-
- return joined_group;
- }
- }
- }
- // Here we should have searched all consumer edges of this common producer
- // and
- // found no valid pattern. Should just add it to the visted list.
- visited_common_producers.insert(common_producer);
- }
-
- // Searched all possibilities and there is no valid horizontal merge pattern
- // found.
- return nullptr;
- }
-
- //! This is a utility method that is used in both vertical merging and
- //! horizontal merging.
- //! It is used to identify the smallest set of groups to merge vertically
- //! involving the
- //! two given nodes.
- //! Given a pair of nodes this utility distinguishes 3 cases:
- //! 1. if maybe_producer is the same as maybe_consumer, then returns
- //! {maybe_producer}
- //! 2. if maybe_producer is actually a producer of consumer, returns a set
- //! containing
- //! the smallest merged group that would contain producer and consumer and
- //! would not introduce a cycle. Returns empty set if such group has
- //! a conflicting reduction signature.
- //! 3. returns empty set if neither conditions above apply.
- GroupSet getValidMinVerticalMergedGroupSet(
- SegmentedGroup* maybe_producer,
- SegmentedGroup* maybe_consumer) {
- auto dependency_analysis = segment_candidate_finder_->getGroupDependency();
- if (maybe_consumer == maybe_producer) {
- // maybe producer is the same as maybe_consumer
- return {maybe_consumer};
- } else if (dependency_analysis->isConsumerOf(
- maybe_consumer, maybe_producer)) {
- auto groups_to_check =
- dependency_analysis->valuesBetween(maybe_producer, maybe_consumer);
- groups_to_check.insert(maybe_producer);
- groups_to_check.insert(maybe_consumer);
-
- // Check that either no group has a reduction or all groups have the same
- // reduction signature
- ReductionSignature* reduction_signature = nullptr;
-
- // Iterate through the minimal group set to see if any conflicts
- for (auto group : groups_to_check) {
- // Check that this group does not involve a output edge contraction
- // This pass is intended to be a pre-merging pass. Since contracting an
- // output edge does not generate much saving of global memory access
- // we want to postpone merging these edges till the very final pass
- for (auto producer_edge_of_group : group->producer_edges) {
- if (groups_to_check.count(producer_edge_of_group->from) &&
- producer_edge_of_group->val->isFusionOutput()) {
- return {};
- }
- }
- for (auto consumer_edge_of_group : group->consumer_edges) {
- if (groups_to_check.count(consumer_edge_of_group->to) &&
- consumer_edge_of_group->val->isFusionOutput()) {
- return {};
- }
- }
-
- // Check that this group does not have a conflicting reduction signature
- if (group_reduction_signature_map_.count(group)) {
- if (reduction_signature != nullptr) {
- if (!group_reduction_signature_map_.at(group)->sameAs(
- reduction_signature)) {
- // Found a conflict in reduction signature, cannot do a vertical
- // merge
- return {};
- }
- } else {
- reduction_signature = group_reduction_signature_map_.at(group);
- }
- }
- }
- return groups_to_check;
- }
- // maybe producer is not a producer of maybe consumer
- return {};
- }
-
- private:
- SegmentCandidateFinder* segment_candidate_finder_;
-
- // Wrapper class for reduction type
- // Assuming there wouldn't be too many of them
- // so won't need to create a hash
- // TODO:
- // Want to reconsider this for transpose operations,
- // need refactoring to handle reduction fusions across a transpose operation
- class ReductionSignature {
- public:
- bool sameAs(const ReductionSignature* reduction_signature) {
- if (reduction_signature == this) {
- return true;
- }
-
- if (root_domain_size_ != reduction_signature->root_domain_size_ ||
- has_nontrivial_reduction_ !=
- reduction_signature->has_nontrivial_reduction_ ||
- reduction_axes_.size() !=
- reduction_signature->reduction_axes_.size()) {
- return false;
- }
-
- for (size_t i = 0; i < reduction_axes_.size(); i++) {
- if (reduction_axes_[i] != reduction_signature->reduction_axes_[i]) {
- return false;
- }
- }
-
- return true;
- }
-
- bool sameAs(const ReductionSignature& reduction_signature) {
- return sameAs(&reduction_signature);
- }
-
- bool hasNonTrivialReduction() const {
- return has_nontrivial_reduction_;
- }
-
- static std::unique_ptr<ReductionSignature> makeReductionSignature(
- SegmentedGroup* group) {
- std::unique_ptr<ReductionSignature> signature = nullptr;
-
- for (auto expr : group->exprs()) {
- std::unique_ptr<ReductionSignature> new_signature = nullptr;
-
- if (auto rop = dynamic_cast<ReductionOp*>(expr)) {
- new_signature = std::make_unique<ReductionSignature>(rop);
- }
- if (auto wop = dynamic_cast<WelfordOp*>(expr)) {
- new_signature = std::make_unique<ReductionSignature>(wop);
- }
-
- if (new_signature != nullptr) {
- TORCH_INTERNAL_ASSERT(
- signature == nullptr || !signature->has_nontrivial_reduction_ ||
- !new_signature->has_nontrivial_reduction_ ||
- signature->sameAs(new_signature.get()),
- "Conflicting signature found in this group");
- signature = std::move(new_signature);
- }
- }
- return signature;
- }
-
- template <typename REDUCTION = ReductionOp>
- ReductionSignature(REDUCTION* rop) {
- auto out_tv = rop->out()->template as<TensorView>();
- has_nontrivial_reduction_ = out_tv->hasReduction();
- TORCH_INTERNAL_ASSERT(out_tv != nullptr);
- auto& root_domain = out_tv->getRootDomain();
- root_domain_size_ = root_domain.size();
-
- // Trivial reduction i.e. squeeze is tricky here:
- // this pass doesn't want to touch any pure squeeze, i.e.:
- // T0 [R(1), I(i0), I(i1)]
- // meanwhile, for two reductions having
- // squeezes, we do require they have squeeze at the
- // same position so that they can be easily root domain mapped
- // So T0 and T1 are the same signature,
- // T0 [R(1), R(i0), I(i1)]
- // T1 [R(1), R(i0), I(i1)]
- // but T2 and T3 below are not
- // T0 [R(1), R(1), R(i0), I(i1)]
- // T1 [R(1), R(i0), I(i1)]
- for (size_t i = 0; i < root_domain_size_; i++) {
- if (root_domain[i]->isReduction()) {
- reduction_axes_.push_back(i);
- }
- if (!root_domain[i]->isTrivialReduction()) {
- has_nontrivial_reduction_ = true;
- }
- }
- }
-
- private:
- size_t root_domain_size_ = 0;
- std::vector<int> reduction_axes_;
- bool has_nontrivial_reduction_ = false;
- };
-
- //! Keeps track of groups with reduction expressions,
- //! using a vector here to maintain a deterministic ordering
- GroupVec groups_with_reductions_;
-
- //! Maps groups to their corresponding signature type
- std::unordered_map<SegmentedGroup*, ReductionSignature*>
- group_reduction_signature_map_;
-
- //! Maintains all reduction signatures seen in the segmented fusion
- std::vector<std::unique_ptr<ReductionSignature>> known_reduction_signatures_;
-};
-
-//! This is to be checked
-bool CombineReductions::shouldRun(
- SegmentCandidateFinder* segment_candidate_finder) {
- std::vector<std::unique_ptr<ReductionSignature>> known_reductions;
- // Iterate over group segments we have before segment candidate finder
- // tries to merge any groups
- for (auto group : segment_candidate_finder->groups()) {
- if (auto reduction_signature =
- ReductionSignature::makeReductionSignature(group)) {
- if (reduction_signature->hasNonTrivialReduction() &&
- std::any_of(
- known_reductions.begin(),
- known_reductions.end(),
- [&reduction_signature](auto& know_signature) {
- return know_signature->sameAs(reduction_signature.get());
- })) {
- // Found two reductions with the same signature, run pass
- return true;
- }
- known_reductions.emplace_back(std::move(reduction_signature));
- }
- }
- return false;
-}
-
-bool SegmentCandidateFinder::codeGenSupportedMerge(SegmentedEdge* edge) {
- Fusion* fusion = segmented_fusion_->completeFusion();
- auto h = tryMerge(fusion, runtime_info_, edge->from, edge->to);
- return h.has_value();
-}
-
-// TODO: consider caching the heuristics value so tryMerge doesn't have to be
-// called twice
-ScheduleHeuristic SegmentCandidateFinder::deriveHeuristic(
- SegmentedGroup* group) {
- Fusion* fusion = segmented_fusion_->completeFusion();
- auto h = tryMerge(fusion, runtime_info_, group);
- TORCH_INTERNAL_ASSERT(h.has_value());
- return h.value();
-}
-
-SegmentCandidateFinder::SegmentCandidateFinder(
- std::unique_ptr<Fusion> fusion,
- const at::ArrayRef<IValue>& inputs,
- SegmentCandidateFinderOptions options)
- : options_(options),
- runtime_info_(fusion.get(), inputs, true),
- runtime_inputs_(inputs) {
- segmented_fusion_ = std::make_unique<SegmentedFusion>(std::move(fusion));
- findSegments();
-}
-
-void SegmentCandidateFinder::findSegments() {
- FUSER_PERF_SCOPE("Finding valid fusion segment solutions");
- // TODO: Make traversal items local to this function.
-
- // Need this for initialization of the DAG that is process
- std::unordered_map<Expr*, SegmentedGroup*> expr2group;
-
- // Keep track of complete fusion input use
- std::unordered_map<Val*, SegmentedGroup*> input2group;
-
- // Initialize DAG, convert each expr to a segment group
- auto exprs = completeFusion()->exprs();
- for (auto expr : exprs) {
- if (!ir_utils::isScalarOp(expr)) {
- auto new_group = segmented_fusion_->newGroup(expr);
- expr2group.insert(std::make_pair(expr, new_group));
- }
- }
-
- // Insert auxiliary groups to use group dependency on inputs as well
- // TODO: these groups should never merged into any other groups, but are
- // just there to support the dependency analysis. Later re-factor should
- // avoid introducing them explicitly on the segmented fusion.
- for (auto input : completeFusion()->inputs()) {
- // These groups are used to represent input as a common
- // producer in horizontal merges, and should never be
- // seen as a candidate for vertical merge
- auto new_group = segmented_fusion_->newGroup();
- input2group.insert({input, new_group});
- }
-
- // Create edges between the Exprs. Mark inputs and outputs of the fusion.
- for (auto expr : exprs) {
- // No group created for scalar ops
- if (ir_utils::isScalarOp(expr)) {
- continue;
- }
-
- auto expr_group = expr2group.at(expr);
- for (auto inp : expr->inputs()) {
- if (inp->isFusionInput()) {
- expr_group->input_vals.push_back(inp);
- auto aux_group = input2group.at(inp);
- auto new_edge = segmented_fusion_->newEdge(aux_group, expr_group, inp);
- expr_group->producer_edges.push_back(new_edge);
- aux_group->consumer_edges.push_back(new_edge);
- continue;
- }
-
- // Could be something like a constant scalar, definition is nullptr, but
- // isn't an "input" to the fusion. At least not one provided by an
- // external source.
- if (inp->definition() == nullptr) {
- continue;
- }
-
- // No group created for scalar ops since they may need to be duplicated
- // to avoid scalar edges. They are handled in resolveScalarsInGroup
- if (inp->isScalar()) {
- continue;
- }
-
- auto def_group = expr2group.at(inp->definition());
- auto new_edge = segmented_fusion_->newEdge(def_group, expr_group, inp);
- expr_group->producer_edges.push_back(new_edge);
- def_group->consumer_edges.push_back(new_edge);
- }
- for (auto out : expr->outputs()) {
- if (out->isFusionOutput()) {
- expr_group->output_vals.push_back(out);
- }
- }
- }
-
- if (options_.run_translate_welford &&
- segmented_fusion_->completeFusion()->hasWelford()) {
- TranslateApplicableWelford::run(segmented_fusion_.get(), runtime_inputs_);
- }
-
- for (auto group : groups()) {
- // Set heuristics in case single reduction kernels were left out
- group->setHeuristic(deriveHeuristic(group));
- }
-
- // Remove all scalar edges since they do not represent actual
- // dependency among segmented groups.
- removeScalarEdges();
-
- // Run pre-merge heuristics
- if (options_.run_combine_reductions && CombineReductions::shouldRun(this)) {
- CombineReductions::run(this);
- }
-
- // All merges will be vertical beyond this point for now, so
- // we can remove the input auxiliary groups. Should make the vertical
- // merges avoid auxiliary group once we start general horizontal merges
- std::unordered_set<SegmentedGroup*> input_groups;
- for (auto input : completeFusion()->inputs()) {
- input_groups.insert(input2group.at(input));
- }
- eraseGroups(input_groups);
-
- if (options_.run_herrmann_merge) {
- bool merged_nodes = true;
- // Initial merge iteration
- while (merged_nodes) {
- // Reset stateful traversal details in SegmentedGroups
- resetTraversal();
-
- resetLevels();
-
- for (auto& group : groups()) {
- if (group->merged_) {
- continue;
- }
- auto candidates = group->getMergeCandidates();
- if (candidates.empty()) {
- continue;
- }
-
- auto candidate_it = candidates.begin();
- while (candidate_it != candidates.end() &&
- !codeGenSupportedMerge(candidate_it->edge)) {
- candidate_it++;
- }
- if (candidate_it == candidates.end()) {
- continue;
- }
-
- to_merge_.emplace_back(group);
- to_merge_.emplace_back(candidate_it->group);
-
- group->merged_ = true;
- group->merge_with_ = candidate_it->group;
- group->merge_through_ = candidate_it->edge;
-
- candidate_it->group->merged_ = true;
- candidate_it->group->merge_with_ = group;
- candidate_it->group->merge_through_ = candidate_it->edge;
- }
-
- if (to_merge_.empty()) {
- merged_nodes = false;
- }
-
- mergeNodes();
- }
- }
-
- if (options_.run_final_merge) {
- // TODO: consider interleaving herrmman merge and bruteforce merge, as
- // bruteforce merge can introduce
- // opportunities for more herrmann merge
- finalMerge();
- }
-
- finalize();
- if (isDebugDumpEnabled(DebugDumpOption::FusionSegmentsDrawing)) {
- segmented_fusion_->draw();
- }
-}
-
-void SegmentCandidateFinder::finalMerge() {
- auto producer_check = getGroupDependency();
-
- bool merged_nodes = true;
- while (merged_nodes) {
- // Iterate all groups and check if a group
- // can merge with one of its consumers
- for (auto producer_group : groups()) {
- // Populate consumers and their corresponding consumer edges
- std::unordered_map<SegmentedGroup*, SegmentedEdge*> consumer_edge_map;
- std::vector<SegmentedGroup*> all_consumers_of_producer_group;
- for (auto consumer : producer_group->consumer_edges) {
- // Since this is the last fusion pass, we can enable fusion through
- // outputs. Priority of this was decreased because if the only
- // connection between groups is an output node, best case scenario we
- // can save a single pass in memory. Where if it wasn't an output it
- // would be two passes.
- consumer_edge_map.insert({consumer->to, consumer});
- }
- // Populate all consumers from the map to avoid duplicate
- std::transform(
- consumer_edge_map.begin(),
- consumer_edge_map.end(),
- std::back_inserter(all_consumers_of_producer_group),
- [](auto& it) { return it.first; });
-
- for (auto consumer : all_consumers_of_producer_group) {
- if (!producer_check->isConsumerOfAny(
- consumer, all_consumers_of_producer_group) &&
- codeGenSupportedMerge(consumer_edge_map.at(consumer))) {
- to_merge_.emplace_back(producer_group);
- to_merge_.emplace_back(consumer);
- producer_group->merged_ = true;
- producer_group->merge_with_ = consumer;
- producer_group->merge_through_ = consumer_edge_map.at(consumer);
- consumer->merged_ = true;
- consumer->merge_with_ = producer_group;
- consumer->merge_through_ = producer_group->merge_through_;
- break;
- }
- }
-
- // Only want to merge one pair at a time so break if found any
- if (!to_merge_.empty()) {
- break;
- }
- }
-
- if (to_merge_.empty()) {
- merged_nodes = false;
- } else {
- TORCH_INTERNAL_ASSERT(
- to_merge_.size() == 2, "merging more than 2 nodes in final iter");
- mergeNodes();
- }
- }
-}
-
-void SegmentCandidateFinder::resolveScalarsInGroup(SegmentedGroup* group) {
- std::vector<Val*> to_visit;
- std::unordered_set<Val*> visited;
-
- // Collect all scalar uses in the group
- for (auto expr : group->exprs()) {
- for (auto input : expr->inputs()) {
- if (input->isScalar()) {
- to_visit.push_back(input);
- }
- }
- }
-
- // Keep track of composite fusion inputs used in this group
- std::unordered_set<Val*> input_set(
- group->input_vals.begin(), group->input_vals.end());
-
- // Record and append all missing scalar exprs at the end.
- std::vector<Expr*> exprs_to_add;
-
- // Do a stack based traversal of the scalar ops to avoid
- // combinatorial duplication of exprs.
- while (!to_visit.empty()) {
- auto stack_top_val = to_visit.back();
- if (visited.count(stack_top_val)) {
- to_visit.pop_back();
- } else if (stack_top_val->definition() == nullptr) {
- // A scalar without def can be a scalar, a tensor dim,
- // or a composite fusion input
- // The first two cases are handled in finalize(),
- // the last case needs to add new input_val to this group.
- visited.insert(stack_top_val);
- // If this is a composite fusion scalar input, make sure this group has it
- if (stack_top_val->isFusionInput() && !input_set.count(stack_top_val)) {
- group->input_vals.push_back(stack_top_val);
- input_set.insert(stack_top_val);
- }
- to_visit.pop_back();
- } else {
- // A scalar with an actual definition
- auto definition_expr = stack_top_val->definition();
- bool all_inputs_visited = true;
- // If any of the inputs are not visited, visit them first
- for (auto input : definition_expr->inputs()) {
- if (!visited.count(input)) {
- all_inputs_visited = false;
- to_visit.push_back(input);
- }
- }
- // This node is ready to be visited
- if (all_inputs_visited) {
- // Collect the defining expr to insert into group
- exprs_to_add.push_back(definition_expr);
- visited.insert(stack_top_val);
- to_visit.pop_back();
- }
- }
- }
-
- // Add all the defining expr to the group
- for (auto expr : exprs_to_add) {
- group->exprs_.push_back(expr);
- }
-}
-
-void SegmentCandidateFinder::removeScalarEdges() {
- // Remove all scalar edges between groups
- // They may have been created by welford
- // translation.
- // we will not need them after scalar
- // resolution
- auto remove_scalar_edges_from_vec = [](std::vector<SegmentedEdge*>& edges) {
- edges.erase(
- std::remove_if(
- edges.begin(),
- edges.end(),
- [](SegmentedEdge* segmented_edge) {
- return segmented_edge->val->isScalar();
- }),
- edges.end());
- };
-
- remove_scalar_edges_from_vec(edges());
- for (auto group : groups()) {
- remove_scalar_edges_from_vec(group->producer_edges);
- remove_scalar_edges_from_vec(group->consumer_edges);
- }
-}
-
-void SegmentCandidateFinder::finalize() {
- // Remove unconnected groups
- groups().erase(
- std::remove_if(
- groups().begin(),
- groups().end(),
- [](SegmentedGroup* sg) { return !sg->isConnected(); }),
- groups().end());
-
- // Add group labeling
- int i = 0;
- for (auto it = groups().begin(); it != groups().end(); it++, i++) {
- deDuplicateScalarExprs((*it)->exprs_);
- (*it)->setID(i);
- }
-
- // TODO: too many things are currently abstracted under the term
- // finalize. Need to re-structure in a follow up.
-
- // Finalize connections between segmented groups
- segmented_fusion_->finalize();
-
- // Resolve all the scalar expressions needed in each group
- for (auto group : segmented_fusion_->groups()) {
- resolveScalarsInGroup(group);
- }
-
- // Finalize each group, fill in the missing inputs, i.e. tensor dims.
- for (auto g : groups()) {
- g->finalize();
- }
-}
-
-GroupDependencyAnalysis* SegmentCandidateFinder::getGroupDependency() {
- if (!group_dependency_) {
- group_dependency_ =
- std::make_unique<GroupDependencyAnalysis>(segmented_fusion_.get());
- }
- return group_dependency_->as<GroupDependencyAnalysis>();
-}
-
-FusionKernelRuntime::SchedulerEntryPtr SegmentedFusion::
- makeInitialSchedulerEntry(
- SegmentedGroup* sg,
- SchedulerRuntimeInfo& runtime_info) {
- auto local_fusion = completeFusion();
- FusionSegmentGuard fsg(local_fusion, getAllInputs(sg), getAllOutputs(sg));
- // This will be the first time each group is scheduled. So we'd want to
- // construct the cache data here.
- auto data_cache_ptr = std::make_unique<HeuristicSummary>(
- local_fusion, sg->heuristic(), runtime_info);
- auto data_cache = data_cache_ptr.get();
- setCachedHeuristicDataFor(sg, std::move(data_cache_ptr));
- return SchedulerEntry::makeEntry(
- sg->heuristic(), local_fusion, runtime_info, data_cache);
-}
-
-std::unique_ptr<FusionHeuristics> SegmentedFusion::makeInitialHeuristics(
- const at::ArrayRef<IValue>& inputs) {
- auto ret = std::make_unique<FusionHeuristics>();
- SchedulerRuntimeInfo runtime_info(completeFusion(), inputs, true);
- for (auto g : groups()) {
- ret->emplaceBack(makeInitialSchedulerEntry(g, runtime_info));
- }
- return ret;
-}
-
-HeuristicSummary* SegmentedFusion::getCachedHeuristicDataFor(
- SegmentedGroup* group) {
- auto data_it = heuristic_summary_cache_.find(group);
- if (data_it == heuristic_summary_cache_.end()) {
- return nullptr;
- }
- return data_it->second.get();
-}
-
-void SegmentedFusion::setCachedHeuristicDataFor(
- SegmentedGroup* group,
- std::unique_ptr<HeuristicSummary> data) {
- TORCH_INTERNAL_ASSERT(!heuristic_summary_cache_.count(group));
- heuristic_summary_cache_[group] = std::move(data);
-}
-
-namespace {
-
-//! A thin traversal class that collects all the tensorviews
-//! that could cast to fp16 if they were segmented edges.
-//! The selected values are currently defined as all the
-//! tensorviews that
-//! 1. are not complete fusion input/output,
-//! 2. have a use chain that ends with a fp16
-//! complete fusion output
-//! 3. are fp32 datatype
-class ForceFP16Annotation : public IterVisitor {
- public:
- static std::unordered_set<TensorView*> getAnnotatedSet(Fusion* fusion) {
- ForceFP16Annotation annotation;
- std::vector<Val*> fp16_outputs;
-
- std::copy_if(
- fusion->outputs().begin(),
- fusion->outputs().end(),
- std::back_inserter(fp16_outputs),
- [](auto* val) {
- return val->template isA<TensorView>() &&
- val->getDataType().has_value() &&
- val->getDataType().value() == DataType::Half;
- });
-
- annotation.traverseFrom(fusion, fp16_outputs);
- return annotation.force_fp16_tv_set_;
- }
-
- private:
- using IterVisitor::handle;
-
- void handle(TensorView* tv) override {
- auto dtype = tv->getDataType();
- if (dtype.has_value() && dtype.value() == DataType::Float &&
- !tv->isFusionOutput() && !tv->isFusionInput()) {
- force_fp16_tv_set_.insert(tv);
- }
- }
-
- std::unordered_set<TensorView*> force_fp16_tv_set_;
-};
-
-} // namespace
-
-void SegmentedFusion::annotateFP16IntermediateTensors() {
- force_fp16_tv_set_ =
- ForceFP16Annotation::getAnnotatedSet(complete_fusion_.get());
-}
-
-TORCH_CUDA_CU_API std::string toString(
- const SegmentCandidateFinderOptions& segment_options) {
- std::stringstream ss;
- ss << "segmentation phases {\n";
- if (segment_options.run_combine_reductions) {
- ss << "combine reductions\n";
- }
- if (segment_options.run_herrmann_merge) {
- ss << "herrmann merging\n";
- }
- if (segment_options.run_final_merge) {
- ss << "final merging\n";
- }
- ss << "\n}\n";
- return ss.str();
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
-
-#include <deque>
-#include <list>
-#include <unordered_set>
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class SegmentedGroup;
-class SegmentCandidateFinder;
-
-// A directed edge on DAG,
-// Wrapper for values, edges between segmented groups which are made up
-// of Exprs. Multiple edges can exist between segmented groups.
-struct SegmentedEdge {
- SegmentedEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val)
- : from(from), to(to), val(val) {}
-
- SegmentedGroup* from;
- SegmentedGroup* to;
- Val* val;
-
- void print() const;
-};
-
-std::ostream& operator<<(std::ostream& os, const SegmentedEdge* edge);
-
-//! Groups together expressions which create a segmented group
-//! Can be used to produce fusions
-class TORCH_CUDA_CU_API SegmentedGroup {
- public:
- SegmentedGroup(SegmentedFusion* segmented_fusion)
- : segmented_fusion_(segmented_fusion) {}
-
- SegmentedGroup(Expr* expr, SegmentedFusion* segmented_fusion)
- : segmented_fusion_(segmented_fusion) {
- exprs_.push_back(expr);
- }
-
- //! Checks if this group takes original fusion's input
- bool isInputGroup() {
- return !input_vals.empty();
- };
-
- //! Checks if this group is used any where in the segmented fusion
- bool isConnected() const {
- return !producer_edges.empty() || !consumer_edges.empty() ||
- !output_vals.empty();
- }
-
- //! returns the id assigned by segment pass
- int groupId() const {
- return group_id_;
- }
-
- //! Returns inputs that this group shares with the original fusion
- const auto& inputs() const {
- return input_vals;
- }
-
- //! Returns outputs that this group shares with the original fusion
- const auto& outputs() const {
- return output_vals;
- }
-
- //! Returns the schedule heuristic associated with this group
- ScheduleHeuristic heuristic() const {
- return heuristic_;
- }
-
- //! Returns the exprs that make up this group
- const auto& exprs() const {
- return exprs_;
- }
-
- //! Debug print function
- void print() const;
-
- //! Returns the segmented fusion that this group is in
- SegmentedFusion* segmentedFusion() const {
- return segmented_fusion_;
- }
-
- //! Try to get a scheduler entry for this group with
- //! the given runtime info.
- //! Returns a new scheduler with the same heuristics
- //! for this group if possible.
- //! Note that the schedule params can be different.
- //! Returns a nullopt if this group cannot be scheduled
- //! with the same heuristics.
- c10::optional<std::unique_ptr<SchedulerEntry>> getMaybeSchedulerEntry(
- SchedulerRuntimeInfo& runtime_info);
-
- public:
- //! "Ancestor nodes", towards inputs of segmentedDAG
- std::vector<SegmentedEdge*> producer_edges;
-
- //! "Descendent nodes", towards outputs of segmentedDAG
- std::vector<SegmentedEdge*> consumer_edges;
-
- //! Composite Fusion inputs in this group
- std::vector<Val*> input_vals;
-
- //! Composite Fusion outputs in this group
- std::vector<Val*> output_vals;
-
- private:
- friend class SegmentCandidateFinder;
- friend class SegmentedFusion;
- friend class FusionKernelRuntime;
- friend class TranslateApplicableWelford;
-
- //! unique identifier of group in the segmented fusion
- int group_id_ = -1;
-
- //! The scheduler to use for compiling this group
- ScheduleHeuristic heuristic_ = ScheduleHeuristic::PointWise;
-
- //! Exprs that make up the group
- std::vector<Expr*> exprs_;
-
- //! Maximum path distance from an input segmented group required for
- //! Theorem 4.2
- int level_ = -1;
-
- //! traversal marker, has this node already been processed
- bool visited_ = false;
-
- //! Did we select another group to merge with
- SegmentedGroup* merge_with_ = nullptr;
-
- //! if we selected another group to merge, which edge is to be contracted
- SegmentedEdge* merge_through_ = nullptr;
-
- //! Has this node been merged?
- bool merged_ = false;
-
- private:
- //! Utility to convert edge vector to value vector
- std::vector<Val*> edgesToVals(const std::vector<SegmentedEdge*>& se_v);
-
- //! Reset method to call at begining of each
- //! merge node iteration
- void clearTraversalInfo();
-
- //! To be called at the very end of segment fusion
- //! no more segment merging should be done beyond
- void finalize();
-
- //! Return all segmented groups connected with *this
- std::vector<SegmentedGroup*> getNeighbors();
-
- //! Utility struct to represent a group connection
- //! both the group to connect with and the edge
- //! to connect through
- struct NeighborGroup {
- NeighborGroup(SegmentedGroup* g, SegmentedEdge* e) : group(g), edge(e) {}
- SegmentedGroup* group;
- SegmentedEdge* edge;
- };
-
- //! TODO: May want to sort this based on size of connections between this and
- //! neighbors as well as if the connection is an output of the fusion (has to
- //! be saved to gmem anyways)
- std::vector<NeighborGroup> getNeighborGroups();
-
- //! Look at all neighbors of this and return who this could merge with based
- //! on level values of this, neighbors, and merged neighbors of neighbors
- std::vector<NeighborGroup> getMergeCandidates();
-
- //! Assign schedule heuristic to this group
- void setHeuristic(ScheduleHeuristic sh) {
- heuristic_ = sh;
- }
-
- //! Assign Id for this group
- void setID(int id) {
- TORCH_INTERNAL_ASSERT(group_id_ == -1);
- group_id_ = id;
- }
-
- //! SegmentedFusion this group belongs to
- SegmentedFusion* segmented_fusion_;
-};
-
-std::ostream& operator<<(std::ostream& os, const SegmentedGroup* group);
-
-//! Auxiliary class for storing heuristics. The managed data is either
-//! a single scheduler entry for complete fusion,
-//! or a vector of schedulers, one for each segment, for segmented fusion.
-class TORCH_CUDA_CU_API FusionHeuristics {
- using SchedulerEntryOwningPtr = std::unique_ptr<SchedulerEntry>;
-
- public:
- //! Constructor for segmented fusion case. Created with empty list and
- //! uses emplaceBack for inserting heuristics in order
- explicit FusionHeuristics() = default;
-
- //! Constructor for complete fusion case, generates the scheduler entry
- //! for the fusion owning the given expression
- explicit FusionHeuristics(
- ScheduleHeuristic schedule_heuristic,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr) {
- heuristics_.emplace_back(SchedulerEntry::makeEntry(
- schedule_heuristic, runtime_info.fusion(), runtime_info, data_cache));
- is_segmented_ = false;
- }
-
- //! Place a scheduler entry on the list. Applies to segmented fusion only.
- void emplaceBack(SchedulerEntryOwningPtr&& pt) {
- TORCH_INTERNAL_ASSERT(is_segmented_);
- heuristics_.emplace_back(std::move(pt));
- }
-
- //! Returns list of schedulers for a segmneted fusion.
- const std::vector<SchedulerEntryOwningPtr>& heuristicsList() const {
- return heuristics_;
- }
-
- //! Returns the single scheduler for a complete fusion.
- SchedulerEntry* singleKernelHeuristics() {
- TORCH_INTERNAL_ASSERT(!is_segmented_);
- return heuristics_.begin()->get();
- }
-
- private:
- std::vector<SchedulerEntryOwningPtr> heuristics_;
- bool is_segmented_ = true;
-};
-
-//! Exported Interface for representing segmented fusion graph
-//! this class owns the segmented groups
-class TORCH_CUDA_CU_API SegmentedFusion {
- public:
- explicit SegmentedFusion(std::unique_ptr<Fusion> fusion);
-
- //! Is the fusion segmented?
- bool isSegmented() const {
- return !groups_.empty();
- }
-
- std::vector<SegmentedGroup*>& groups() {
- return groups_;
- }
-
- std::vector<SegmentedEdge*>& edges() {
- return edges_;
- }
-
- const std::vector<SegmentedGroup*>& cgroups() const {
- return groups_;
- }
-
- const std::vector<SegmentedEdge*>& cedges() const {
- return edges_;
- }
-
- //! Returns the original un-segmented fusion
- Fusion* completeFusion() {
- return complete_fusion_.get();
- }
-
- const auto& inputs() const {
- return complete_fusion_->inputs();
- }
-
- const auto& outputs() const {
- return complete_fusion_->outputs();
- }
-
- Val* findAlias(Val* val) const {
- Val* alias_val = nullptr;
- if (complete_fusion_->io_alias_.count(val) != 0) {
- alias_val = complete_fusion_->io_alias_[val];
- }
- return alias_val;
- }
-
- //! Make a clone of the group and convert to fusion
- std::unique_ptr<Fusion> makeFusion(SegmentedGroup* sg);
-
- //! Make heuristics for all groups in this segmented fusion
- std::unique_ptr<FusionHeuristics> makeInitialHeuristics(
- const at::ArrayRef<IValue>& inputs);
-
- //! Inline Debug print for segmented fusion
- std::string toString(int verbosity) const;
-
- //! Debug drawing for graphviz
- void draw();
-
- //! Debug print for segmented fusions
- void print() const;
-
- //! API for adding groups
- SegmentedGroup* newGroup();
-
- //! API shortcut for adding a singleton group
- SegmentedGroup* newGroup(Expr* expr);
-
- //! API for adding edges
- SegmentedEdge* newEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val);
-
- //! Returns the set of potential intermediate tensors that
- //! will be cast to fp16 when written to global mem.
- //! These are not actual intermediate tensors,
- //! just the ones that will need to cast to fp16 if
- //! they end up being an intermediate tensor between
- //! segmented groups.
- const auto& getForceToFP16Set() {
- return force_fp16_tv_set_;
- }
-
- HeuristicSummary* getCachedHeuristicDataFor(SegmentedGroup* group);
-
- private:
- //! Unique name for segmented fusion
- int segmented_fusion_name_;
-
- //! States representing segmentation
- std::vector<SegmentedEdge*> edges_;
- std::vector<SegmentedGroup*> groups_;
-
- //! Owning object to explicitly manage groups and edges
- class Impl {
- public:
- explicit Impl(SegmentedFusion* sf) : owning_fusion_(sf) {}
-
- SegmentedGroup* makeGroup();
- SegmentedGroup* makeGroup(Expr*);
- SegmentedEdge* makeEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val);
- void cleanUnused();
-
- private:
- using GroupPtr = std::unique_ptr<SegmentedGroup>;
- using EdgePtr = std::unique_ptr<SegmentedEdge>;
- std::vector<GroupPtr> groups_;
- std::vector<EdgePtr> edges_;
- SegmentedFusion* owning_fusion_;
- };
- Impl impl_;
-
- //! A Copy of original full fusion
- std::unique_ptr<Fusion> complete_fusion_;
-
- //! A set of intermediate tensors that need to be cast to fp16
- std::unordered_set<TensorView*> force_fp16_tv_set_;
-
- //! Static traversal information to be used for fast heuristics lookup
- std::unordered_map<SegmentedGroup*, std::unique_ptr<HeuristicSummary>>
- heuristic_summary_cache_;
-
- // TODO: this class needs cleanup
- protected:
- friend class SegmentCandidateFinder;
- //! Make a heuristics entry for a group and parameters
- std::unique_ptr<SchedulerEntry> makeInitialSchedulerEntry(
- SegmentedGroup* sg,
- SchedulerRuntimeInfo& runtime_info);
-
- //! Cleanup function to be call at the end of fusion
- //! segment pass
- void finalize();
-
- //! Collect all the intermediate tensors between segmented
- //! groups that will cast to fp16
- void annotateFP16IntermediateTensors();
-
- //! Keep heuristic checking intermediate data
- void setCachedHeuristicDataFor(
- SegmentedGroup* group,
- std::unique_ptr<HeuristicSummary> data);
-
- //! Utility to give unique name for each segmented fusion
- static size_t segmentedFusionName() {
- static size_t counter = 0;
- return counter++;
- }
-};
-
-//! This is a base class for segmenter analysis
-//! provides the minimal implementation on header so that
-//! a unique_ptr can use this base class
-//! actual implementations of analyses are in the .cpp files
-//! TODO: In the next refactor PR, should put segment candidate
-//! finder in .cpp file completely since API doesn't require these
-//! details
-class SegmenterAnalysis : public PolymorphicBase {};
-class GroupDependencyAnalysis;
-
-// Manual node merging passes
-class CombineReductions;
-
-//! Options to configure/debug candidate finder
-struct TORCH_CUDA_CU_API SegmentCandidateFinderOptions {
- bool run_translate_welford = true;
- bool run_combine_reductions = true;
- bool run_herrmann_merge = true;
- bool run_final_merge = true;
-};
-
-//! SegmentCandidateFinder
-//! Responsible for going through DAG and proposing things we could try to
-//! fuse together, calls "canGenerateCode" on these proposed segments to see
-//! if they are valid and we can generate code for them.
-//! FusionSegment
-//! A group of exprs that are segmented together
-//! FusionSegmentConnections
-//! Holds vals and what they connect. In other words it's a val that is an
-//! output of a FusionSegment "from" and an input of FusionSegment "to".
-//! There's nothing preventing from a val being between segments twice.
-//! TODO: make sure there's nothing wrong with segmentation on nodes that
-//! have the same value input twice. i.e. (B = A*A)
-//! Selecting segments to propose is based on the theorem 4.2 in the paper which
-//! makes sure when segment the segmented graph will be a DAG (assumes Fusion is
-//! already a DAG). The segmentation code relies on assumptions of DAG-ness
-//! during segmentation, meaning proposed merging of groups must maintain the
-//! DAG property of the graph.
-//!
-//! Julien Herrmann, Yusuf Özkaya, Bora Uçar, Kamer Kaya, Umit Catalyurek.
-//! Multilevel Algorithms for Acyclic Partitioning of Directed Acyclic Graphs.
-//! SIAM Journal on Scientific Computing, Society for Industrial and Applied
-//! Mathematics, 2019, 41 (4), pp.A2117-A2145. ff10.1137/18M1176865ff.
-//! ffhal02306566f
-class TORCH_CUDA_CU_API SegmentCandidateFinder {
- public:
- // Perform segmentation on a copy of the given fusion
- static std::unique_ptr<SegmentedFusion> segment(
- const Fusion* fusion,
- const at::ArrayRef<IValue>& inputs,
- SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) {
- auto fusion_copy = std::make_unique<Fusion>(*fusion);
- SegmentCandidateFinder scf(std::move(fusion_copy), inputs, options);
- return std::move(scf.segmented_fusion_);
- }
-
- // Perform segmentation on and take ownership of the given fusion
- static std::unique_ptr<SegmentedFusion> segment(
- std::unique_ptr<Fusion> fusion,
- const at::ArrayRef<IValue>& inputs,
- SegmentCandidateFinderOptions options = SegmentCandidateFinderOptions()) {
- SegmentCandidateFinder scf(std::move(fusion), inputs, options);
- return std::move(scf.segmented_fusion_);
- }
-
- static bool TranslateWelfordInFusion(
- Fusion* fusion,
- const at::ArrayRef<IValue>& runtime_inputs);
-
- private:
- // Perform segmentation on and take ownership of the given fusion
- SegmentCandidateFinder(
- std::unique_ptr<Fusion> fusion,
- const at::ArrayRef<IValue>& inputs,
- SegmentCandidateFinderOptions options);
-
- void resetTraversal();
-
- void resetLevels();
-
- SegmentedGroup* mergeNodes();
-
- bool codeGenSupportedMerge(SegmentedEdge* edge);
-
- void findSegments();
-
- std::unordered_set<SegmentedEdge*> disconnectGroup(SegmentedGroup* group);
-
- std::vector<SegmentedGroup*>& groups() {
- TORCH_INTERNAL_ASSERT(
- segmented_fusion_ != nullptr, "Segment finder not owinging any fusion");
- return segmented_fusion_->groups();
- }
-
- std::vector<SegmentedEdge*>& edges() {
- TORCH_INTERNAL_ASSERT(
- segmented_fusion_ != nullptr, "Segment finder not owinging any fusion");
- return segmented_fusion_->edges();
- }
-
- Fusion* completeFusion() {
- TORCH_INTERNAL_ASSERT(
- segmented_fusion_ != nullptr, "Segment finder not owinging any fusion");
- return segmented_fusion_->completeFusion();
- }
-
- SchedulerRuntimeInfo& runtimeInfo() {
- return runtime_info_;
- }
-
- ExpressionEvaluator& expressionEvaluator() {
- return runtime_info_.expressionEvaluator();
- }
-
- //! Additional merging iteration, clean up the rest of
- //! the merging opportunities
- //! Herrmann et al. is a fast and safe algorithm for finding merge candidates
- //! but can become too conservative in our use cases because we place
- //! additional qualifiers on valid merges other than having to generate DAGs,
- //! i.e. canSchedule. So we need a bruteforce final merging iteration as a
- //! clean up pass. Cost isn't expected to be high since the graph at this
- //! stage is already quite merged. Example cf. test_gpu.cpp:
- //! FusionDAGMerging_CUDA
- //!
- //! This merging algorithm is based on Theorem 4.1 of Herrmann et al.,
- //! to check if a producer-consumer pair can be merged into one group,
- //! it's enough to check if any other consumer of the producer also
- //! produces the consumer.
- void finalMerge();
-
- //! Duplicate and add all exprs producing the used
- //! scalar values in group
- void resolveScalarsInGroup(SegmentedGroup* group);
-
- //! Remove all scalar edges in group
- //! (TODO: need structure better so we don't have to do this)
- void removeScalarEdges();
-
- //! Utility function to merge a vector of groups in one step,
- //! need to check for DAG condition before using this method
- SegmentedGroup* mergeAllGivenGroups(
- const std::vector<SegmentedGroup*>& groups);
-
- //! Utility to remove a group and corresponding edges
- //! TODO: remove inline versions of this as much as possible
- void eraseGroups(std::unordered_set<SegmentedGroup*>& groups_to_erase);
-
- void finalize();
-
- //! Return the resulting heuristic corresponding to the merged
- //! group built by merging the two groups connected by edge
- ScheduleHeuristic deriveHeuristic(SegmentedGroup* edge);
-
- GroupDependencyAnalysis* getGroupDependency();
-
- protected:
- //! These are the merge node heuristic passes, should
- //! eventually should have a dedicated interface
- //! instead of keeping adding friends
- friend class CombineReductions;
-
- //! options to configure and debug the segment process
- SegmentCandidateFinderOptions options_;
-
- std::deque<SegmentedGroup*> to_visit_;
- std::vector<SegmentedGroup*> next_to_visit_;
-
- std::unordered_set<SegmentedGroup*> clean_up_groups_;
- std::unordered_set<SegmentedEdge*> clean_up_edges_;
-
- std::vector<SegmentedGroup*> to_merge_;
-
- std::unique_ptr<SegmentedFusion> segmented_fusion_;
-
- std::unique_ptr<SegmenterAnalysis> group_dependency_;
-
- SchedulerRuntimeInfo runtime_info_;
-
- //! Note:
- //! Segmenter should eventually rely only on runtime_info_ for
- //! safe caching. runtime_inputs_ is only used in translateWelford
- //! to initialize expression evaluators on copies of the original
- //! fusion, which doesn't use any un-cached info and is safe.
- //!
- //! Directly using runtime_inputs_ in other cases is in general
- //! risky.
- //!
- //! To get rid of runtime_inputs_ we need mechanisms
- //! to copy expression evaluator values from fusion
- //! to a copy, or even better to a copy of a
- //! sub-graph of original fusion.
- //! TODO:
- //! implement the expression evaluator transfer and
- //! remove runtime_inputs_ in a follow up.
- const at::ArrayRef<IValue>& runtime_inputs_;
-};
-
-TORCH_CUDA_CU_API std::string toString(const SegmentedGroup* group);
-TORCH_CUDA_CU_API std::string toString(const SegmentedEdge* edge);
-TORCH_CUDA_CU_API std::string toString(const SegmentedFusion* segmented_fusion);
-TORCH_CUDA_CU_API std::string toString(
- const SegmentCandidateFinderOptions& segment_options);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
#include <torch/csrc/jit/codegen/cuda/partition.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
-#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/pass_manager.h>
-#include <torch/csrc/jit/passes/remove_inplace_ops.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
namespace {
-bool usedOnlyInDtype(Value* v) {
- const auto& uses = v->uses();
- if (uses.empty()) {
- return false;
- }
- return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
- return u.user->matches("prim::dtype(Tensor a) -> int");
- });
-}
-
Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
AT_ASSERT(!sizes.empty());
Graph* graph = sizes[0]->owningGraph();
return broadcast_n->output();
}
-Value* createConditionalConstant(Node* profile_ivalue) {
- TORCH_INTERNAL_ASSERT(profile_ivalue->kind() == prim::profile_ivalue);
-
- auto graph = profile_ivalue->owningGraph();
-
- IValue val; // default to None
- if (profile_ivalue->hasAttribute(Symbol::attr("profiled_int_list"))) {
- // int[]
- val = IValue(profile_ivalue->is(Symbol::attr("profiled_int_list")));
- } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_bool_list"))) {
- // bool[]
- auto int_list = profile_ivalue->is(Symbol::attr("profiled_bool_list"));
- std::vector<bool> bool_list(int_list.begin(), int_list.end());
- val = IValue(bool_list);
- } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_size"))) {
- // int[]
- val = IValue(profile_ivalue->is(Symbol::attr("profiled_size")));
- } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_bool"))) {
- // bool
- val = IValue(
- static_cast<bool>(profile_ivalue->i(Symbol::attr("profiled_bool"))));
- } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_int"))) {
- // int
- val = IValue(
- static_cast<int>(profile_ivalue->i(Symbol::attr("profiled_int"))));
- } else {
- GRAPH_DEBUG("profile_ivalue: ", *profile_ivalue);
- TORCH_WARN(
- __func__,
- " profile_node ",
- *profile_ivalue,
- " does not have profile information");
- return nullptr;
- }
-
- return graph->insertConstant(val);
-}
-
struct CudaGraphFuser {
using FusionCallback = std::function<bool(Node*)>;
std::unordered_map<Value*, Value*> inputs_map;
size_t i = 0;
size_t tensor_insert_idx = 0;
+ AT_ASSERT(group->inputs().size() == subgraph.inputs().size());
for (auto input : group->inputs()) {
inputs_map[input] = subgraph.inputs()[i++];
if (input->type()->isSubtypeOf(TensorType::get()))
} else if (
// TODO: extend the supporting inputs here.
(input->type()->isSubtypeOf(FloatType::get()) &&
- input->node()->kind() != prim::Constant)) {
+ input->node()->kind() != prim::Constant) ||
+ (n->kind() == aten::_grad_sum_to_size &&
+ input->type()->isSubtypeOf(ListType::ofInts()))) {
auto in_group = subgraph.addInput();
in_group->setType(input->type());
inputs_map[input] = in_group;
} else if (input->node()->kind() == prim::Constant) {
// inline the constants directly in the body of the fused group.
Node* in_const =
- subgraph.createClone(input->node(), [&](Value* v) -> Value* {
- if (v->node()->kind() != prim::profile_ivalue) {
- throw std::runtime_error(
- std::string(
- "merging constant with unexpected input from node") +
- v->node()->kind().toDisplayString());
- }
- group->addInput(v->node()->output());
-
- // we are doing this just to keep alias_analysis silent with
- // their checks
- auto in_group = subgraph.addInput();
- in_group->setType(v->type());
- return in_group;
+ subgraph.createClone(input->node(), [](Value*) -> Value* {
+ throw std::runtime_error("unexpected input");
});
subgraph.insertNode(in_const);
inputs_map[input] = in_const->output();
// have a valid mapping
group->insertBefore(n);
Node* mergedNode = mergeNodeIntoGroup(group, n);
- for (size_t i = 0; i < n->outputs().size(); i++) {
- getSubgraph(group).registerOutput(mergedNode->output(i));
- auto sel = group->addOutput();
- sel->copyMetadata(n->output(i));
- }
+ getSubgraph(group).registerOutput(mergedNode->output());
+ auto sel = group->addOutput();
+ sel->copyMetadata(n->output());
n->replaceAllUsesWith(group);
n->destroy();
return group;
// but this requires better handling of merging fusion groups so it is not
// done now
bool shouldFuse =
- fuser::cuda::isFusibleCudaFusionGroup(consumer, producer->node()) &&
+ fuser::cuda::isFusableCudaFusionGroup(consumer, producer->node()) &&
// Rearrange nodes such that all uses of producer are after the
// consumer. Fusion will rewrite those later uses to use the version of
// producer generated by the fused blob. In this case, producer becomes
mergeFusionGroups(group, producer->node());
return group;
}
+ AT_ASSERT(producer->node()->outputs().size() == 1);
Node* merged = mergeNodeIntoGroup(group, producer->node());
// remaining uses of this producer can occur because we allow
// fusion in cases where uses remain after the consumer
// if these exist, re-route them to the version of producer
// created in FusionGroup
-
- // We need to apply this to all outputs from producer->node();
- auto producer_outputs = producer->node()->outputs();
- for (size_t i = 0; i < producer_outputs.size(); i++) {
- if (producer_outputs[i]->uses().size() != 0) {
- getSubgraph(group).registerOutput(merged->outputs()[i]);
- Value* new_producer = group->addOutput();
- new_producer->copyMetadata(producer_outputs[i]);
- producer_outputs[i]->replaceAllUsesWith(new_producer);
- }
+ if (producer->uses().size() != 0) {
+ getSubgraph(group).registerOutput(merged->output());
+ Value* new_producer = group->addOutput();
+ new_producer->copyMetadata(producer);
+ producer->replaceAllUsesWith(new_producer);
}
producer->node()->destroy();
return group;
chunk->inputs().begin(),
chunk->inputs().end(),
[&](Value* producer_for_chunk) {
- return fuser::cuda::isFusibleCudaFusionGroup(
+ return fuser::cuda::isFusableCudaFusionGroup(
consumer, producer_for_chunk->node()) &&
allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
});
bchunk = promoteChunkToBroadcastingChunk(chunk);
}
size_t nchunks = bchunk->i(attr::chunks);
- TORCH_INTERNAL_ASSERT(nchunks > 0, "number of chunks cannot be zero");
WithInsertPoint guard(bchunk->next());
std::vector<Value*> producer_chunk_outputs;
// = Node* for chunk_output_idx'th output of the chunk(inputs[input_nr])
std::vector<std::vector<Value*>> chunked_inputs;
- // We have asserted single output earlier
- auto producer_output_sizes =
- producer_for_chunk_node->output()->type()->cast<TensorType>()->sizes();
-
for (auto input : producer_for_chunk_node->inputs()) {
// XXX: we only work with pointwise ops in here, so we know it is valid to
// push the concat only through tensor arguments (and all other args can
// distinct from Node.
bchunk->addInput(input);
chunked_inputs.emplace_back(); // alas, to not be C++17
-
- // properly compute strides for BroadcastingChunk
- //
- // We copy stride of each dimension from input to output for
- // BroadcastingChunk. A note is that Chunk should not alter strides,
- // However, broadcasted dimension should have a stride 0. We could have
- // broadcasting happening on existing dimensions in input (case1), as well
- // as extended dimension that does not exist in input (case2).
- // e.g.
- // If we look at an input tensor t0 with shape [3, 1] broadcasted to
- // output tensor t1 with shape [4, 1, 3, 3],
- // We set stride to zero in case of broadcast, which could happen in:
- // case1: t1.dim[3] (broadcasted as in the description above)
- // case2: t1.dim[0] (broadcasted implicitly)
- std::vector<int64_t> strides;
- auto input_type = input->type()->cast<TensorType>();
- auto input_sizes = input_type->sizes();
- auto input_strides = input_type->strides();
- if (producer_output_sizes.isComplete() && input_sizes.isComplete() &&
- input_strides.isComplete()) {
- auto input_c_sizes = input_sizes.concrete_sizes().value();
- auto input_c_strides = input_strides.concrete_sizes().value();
- auto output_c_sizes = producer_output_sizes.concrete_sizes().value();
- int output_index = int(output_c_sizes.size()) - 1;
- strides.resize(output_index);
- AT_ASSERT(output_index >= int(input_c_sizes.size()) - 1);
- for (int input_index = int(input_c_sizes.size()) - 1; input_index >= 0;
- input_index--, output_index--) {
- // in braodcast case 1, we set stride to 0;
- // otherwise, stride remain the same.
- if (input_c_sizes[input_index] == 1 &&
- output_c_sizes[output_index] != 1) {
- strides[output_index] = 0;
- } else {
- strides[output_index] = input_c_strides[input_index];
- }
- }
-
- // continue on expanding dimensions to set stride to 0 for case2
- while (output_index >= 0) {
- strides[output_index] =
- output_c_sizes[output_index] == 1 ? strides[output_index + 1] : 0;
- output_index--;
- }
- }
-
for (auto chunk_sel : producer_chunk_outputs) {
Value* input_chunk_sel = bchunk->addOutput();
- auto chunk_sel_type = chunk_sel->type()->cast<TensorType>();
- if (strides.empty() || !chunk_sel_type->sizes().isComplete()) {
- input_chunk_sel->setType(chunk_sel_type);
- } else {
- input_chunk_sel->setType(chunk_sel_type->withSizesStrides(
- chunk_sel_type->sizes().concrete_sizes().value(), strides));
- }
+ input_chunk_sel->setType(chunk_sel->type());
chunked_inputs.back().push_back(input_chunk_sel);
}
}
bchunk->removeInput(producer_index);
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
for (const auto i : c10::irange(nchunks)) {
- (void)i; // Suppress unused variable warning
bchunk->eraseOutput(nchunks * producer_index);
}
// returns where to continue scanning, and whether any fusion was made
std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
- if (fuser::cuda::isFusibleCudaFusionGroup(consumer)) {
+ if (fuser::cuda::isFusableCudaFusionGroup(consumer)) {
// handle inputs in reverse topological order as well...
// otherwise in f(a,a+b) it will appear a is used twice if we consider
// the f-a fusion before the f-(a+b) fusion first.
}
}
- bool usedInDtype(Value* v) {
- const auto& uses = v->uses();
- return std::any_of(uses.begin(), uses.end(), [](const Use& u) {
- return u.user->matches("prim::dtype(Tensor a) -> int");
- });
- }
-
- bool usedOnlyInDtypeAndSize(Value* v) {
+ bool usedOnlyInSize(Value* v) {
const auto& uses = v->uses();
return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
- return u.user->matches("prim::dtype(Tensor a) -> int") ||
- u.user->matches("aten::size(Tensor self) -> int[]");
+ return u.user->matches("aten::size(Tensor self) -> int[]");
});
}
auto outputs = fusion_group->outputs();
auto soutputs = subgraph->outputs();
AT_ASSERT(outputs.size() == soutputs.size());
- for (size_t i = 0; i < outputs.size(); ++i) {
- if (usedOnlyInDtypeAndSize(outputs[i]))
+ for (const auto i : c10::irange(outputs.size())) {
+ if (usedOnlyInSize(outputs[i]))
continue;
- if (soutputs[i]->type()->isSubtypeOf(TensorType::get())) {
- shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]});
- }
+ shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]});
}
for (Node* n : subgraph->nodes()) {
continue;
}
if (n->kind() == prim::ConstantChunk) {
- TORCH_INTERNAL_ASSERT(
- shape_of.count(n->input()) > 0,
- "buildShapeExpressions failed at accessing input shapes");
Node* sizes_node = graph->insertNode(
graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
sizes_node->i_(attr::dim, n->i(attr::dim));
"only supports reduction axes and keepdim being constant");
// hmmm, do I need to setInsertPoint...
- const auto map_inputs = [&](Value* v) -> Value* {
- // if constant ever has an input, it has to come from
- // profile_ivalue dependency
- if (v->node()->kind() == prim::Param &&
- fusion_group->input(v->offset())->node()->kind() ==
- prim::profile_ivalue) {
- // we need to map it along profile_ivalue dependency
- return fusion_group->input(v->offset());
- } else {
- throw std::runtime_error(
- std::string("unexpected input from node") +
- v->node()->kind().toDisplayString());
- }
- };
- Node* in1_const = graph->createClone(n->input(1)->node(), map_inputs);
+ Node* in1_const =
+ graph->createClone(n->input(1)->node(), [](Value*) -> Value* {
+ throw std::runtime_error("unexpected input");
+ });
graph->insertNode(in1_const);
- Node* in2_const = graph->createClone(n->input(2)->node(), map_inputs);
+ Node* in2_const =
+ graph->createClone(n->input(2)->node(), [](Value*) -> Value* {
+ throw std::runtime_error("unexpected input");
+ });
graph->insertNode(in2_const);
- TORCH_INTERNAL_ASSERT(
- shape_of.count(n->input(0)) > 0,
- "buildShapeExpressions failed at accessing input shapes");
std::vector<Value*> inputs = {
shape_of.at(n->input(0)), in1_const->output(), in2_const->output()};
Node* size_node =
shape_of.emplace(n->output(), size);
continue;
}
- // TODO: output(1) & output(2) should also be marked
- if (n->kind() == aten::native_layer_norm) {
- TORCH_INTERNAL_ASSERT(
- shape_of.count(n->input(0)) > 0,
- "buildShapeExpressions failed at accessing input shapes");
- shape_of.emplace(n->output(0), shape_of.at(n->input(0)));
- continue;
- }
- // TODO: output(1) & output(2) should also be marked
- if (n->kind() == aten::native_layer_norm_backward) {
- TORCH_INTERNAL_ASSERT(
- shape_of.count(n->input(0)) > 0,
- "buildShapeExpressions failed at accessing input shapes");
- shape_of.emplace(n->output(0), shape_of.at(n->input(0)));
- if (shape_of.count(n->input(5)) > 0) {
- shape_of.emplace(n->output(1), shape_of.at(n->input(5)));
- }
- if (shape_of.count(n->input(6)) > 0) {
- shape_of.emplace(n->output(2), shape_of.at(n->input(6)));
- }
- continue;
- }
- // TODO: output(1) & output(2) should also be marked
- if (n->kind() == aten::native_batch_norm) {
- TORCH_INTERNAL_ASSERT(
- shape_of.count(n->input(0)) > 0,
- "buildShapeExpressions failed at accessing input shapes");
- shape_of.emplace(n->output(0), shape_of.at(n->input(0)));
- continue;
- }
- // TODO: output(1) & output(2) should also be marked
- if (n->kind() == aten::native_batch_norm_backward) {
- TORCH_INTERNAL_ASSERT(
- shape_of.count(n->input(0)) > 0,
- "buildShapeExpressions failed at accessing input shapes");
- shape_of.emplace(n->output(0), shape_of.at(n->input(0)));
- if (shape_of.count(n->input(2)) > 0) {
- shape_of.emplace(n->output(1), shape_of.at(n->input(2)));
- // use shape of weight here for grad_bias
- shape_of.emplace(n->output(2), shape_of.at(n->input(2)));
- }
- continue;
- }
auto tensor_inputs = filter(n->inputs(), [](Value* v) {
return v->type()->isSubtypeOf(TensorType::get());
});
- auto shapes = fmap(tensor_inputs, [&](Value* v) {
- TORCH_INTERNAL_ASSERT(
- shape_of.count(v) > 0,
- "buildShapeExpressions failed at accessing input shapes");
- return shape_of.at(v);
- });
+ auto shapes =
+ fmap(tensor_inputs, [&](Value* v) { return shape_of.at(v); });
AT_ASSERT(!shapes.empty());
shape_of.emplace(
- n->output(0),
- shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
+ n->output(), shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
}
return shape_of;
}
// TODO: failure in buildShapeExpressions should not break fusion execution,
// we can add a try/catch here to bailout from removeOutputsUsedOnlyInSize.
- GRAPH_DEBUG("before build shape expression: ", *graph_);
auto shape_of = buildShapeExpressions(fusion_group);
- GRAPH_DEBUG("after build shape expression: ", *graph_);
auto outputs = fusion_group->outputs().vec();
auto soutputs = subgraph->outputs().vec();
// XXX: Iterating in this order is not only good for performance reasons!
for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
auto output = outputs[i];
auto soutput = soutputs[i];
- if (usedOnlyInDtypeAndSize(output) && shape_of.count(soutput) > 0) {
- bool has_dtype = usedInDtype(output);
+ if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
auto uses = output->uses();
for (Use u : uses) {
- if (u.user->matches("aten::size(Tensor self) -> int[]")) {
- u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
- u.user->destroy();
- } else if (u.user->matches("prim::dtype(Tensor a) -> int")) {
- continue;
- } else {
- AT_ASSERT(
- false,
- "unrecognized consumer should not trigger removeOutputsUsedOnlyInSize");
- }
- }
- // We only wipe the output when there's no more dtype consumer.
- // This is to be removed by `removeOutputUsedOnlyInDtype`
- if (!has_dtype) {
- fusion_group->eraseOutput(i);
- subgraph->eraseOutput(i);
+ AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]"));
+ u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
+ u.user->destroy();
}
+ fusion_group->eraseOutput(i);
+ subgraph->eraseOutput(i);
}
}
- GRAPH_DEBUG("after build shape expression and re-wiring: ", *graph_);
}
void refreshAliasDb() {
any_changed = false;
refreshAliasDb();
for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
- bool changed = false;
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ bool changed;
std::tie(it, changed) = scanNode(*it);
any_changed |= changed;
}
}
-
- GRAPH_DEBUG("after scan and merge", *graph_);
refreshAliasDb();
// fuseConcats();
// it = scanNodeForChunks(*it);
//}
- GRAPH_DEBUG("before removeOutputsUsedOnlyInSize", *graph_);
// Remove outputs that have been added only because we need their size
for (Node* n : block_->nodes()) {
removeOutputsUsedOnlyInSize(n);
}
- GRAPH_DEBUG("after removeOutputsUsedOnlyInSize", *graph_);
for (Node* node : block_->nodes()) {
for (Block* sub_block : node->blocks()) {
}
};
-void removeCudaFusionPathForGuardNode(Node* n) {
- auto uses = n->output()->uses();
- TORCH_INTERNAL_ASSERT(
- uses.size() == 1,
- "CudaFusionGuard should only be used by a single prim::If");
- Node* if_node = uses[0].user;
- TORCH_INTERNAL_ASSERT(
- if_node->kind() == prim::If,
- "CudaFusionGuard should only be used by prim::If");
- auto fall_back_graph = if_node->blocks()[1];
- Node* fallback_node = nullptr;
- for (auto fb_n : fall_back_graph->nodes()) {
- TORCH_INTERNAL_ASSERT(
- fb_n->kind() == prim::FallbackGraph,
- "CudaFusionGuard fallback path should only have single fallback node");
- TORCH_INTERNAL_ASSERT(
- fallback_node == nullptr,
- "CudaFusionGuard fallback path should only have single fallback node");
- fallback_node = fb_n;
- }
-
- TORCH_INTERNAL_ASSERT(
- fallback_node != nullptr,
- "CudaFusionGuard fallback path found no fallback node");
- fallback_node->moveBefore(n);
-
- TORCH_INTERNAL_ASSERT(
- fallback_node->outputs().size() == if_node->outputs().size(),
- "CudaFusionGuard fallback should have same number of outputs as with nesting if block");
-
- if_node->replaceAllUsesWith(fallback_node);
- if_node->destroy();
- n->destroy();
-}
-
-bool missingCompleteTypes(const std::vector<TypePtr>& types) {
- for (const auto& type : types) {
- if (auto tensor_type = type->cast<TensorType>()) {
- // if we found one missing value, we know that we are not going to able to
- // generate a kernel, so we bail out;
- if (!tensor_type->device().has_value() ||
- !tensor_type->dim().has_value() ||
- !tensor_type->scalarType().has_value()) {
- return true;
- }
- }
- }
- return false;
-}
-
-void removeFusionWithMissingProfilingInformation(Block* block) {
- FUSER_PERF_SCOPE("compileFusionRecursive");
- std::vector<Node*> removeCudaFusionNodes;
-
- for (auto node : block->nodes()) {
- if (node->kind() == prim::CudaFusionGuard &&
- missingCompleteTypes(node->tys(attr::types))) {
- removeCudaFusionNodes.push_back(node);
- }
- for (auto sub_block : node->blocks()) {
- removeFusionWithMissingProfilingInformation(sub_block);
- }
- }
-
- for (auto node : removeCudaFusionNodes) {
- removeCudaFusionPathForGuardNode(node);
- }
-}
-
void compileFusionRecursive(Block* block) {
FUSER_PERF_SCOPE("compileFusionRecursive");
void guardFusionGroup(Node* fusion) {
// Fixup types of the subgraph inputs
std::vector<TypePtr> guard_types;
- std::vector<Value*> tensor_inputs_to_check;
- std::set<size_t> profiled_ivalue_indices;
-
- for (size_t index = 0; index < fusion->inputs().size(); index++) {
- Value* input = fusion->inputs()[index];
- if (input->type()->cast<TensorType>()) {
- // We only check inputs of the fusion group and expect NNC to infer
- // intermediates and outputs shapes
-
- // note: modified from original implementation, we are guarding fusion
- // outputs
- if (input->node()->kind() == prim::Constant) {
- continue;
- }
- tensor_inputs_to_check.push_back(input);
- guard_types.push_back(input->type());
- } else if (input->node()->kind() == prim::profile_ivalue) {
- // Conditional constant from profiled_ivalue, should be guarded
- profiled_ivalue_indices.insert(index);
+ std::vector<Value*> inputs_to_check;
+ for (Value* input : fusion->inputs()) {
+ // We only check inputs of the fusion group and expect NNC to infer
+ // intermediates and outputs shapes
+ if (!input->type()->cast<TensorType>()) {
+ continue;
}
+
+ // note: modified from original implementation, we are guarding fusion
+ // outputs
+ if (input->node()->kind() == prim::Constant) {
+ continue;
+ }
+ inputs_to_check.push_back(input);
+ guard_types.push_back(input->type());
+ }
+ if (!inputs_to_check.size()) {
+ return;
}
- // we should assert on non-tensor inputs
- TORCH_INTERNAL_ASSERT(
- tensor_inputs_to_check.size(),
- "CudaFusionGuard expects at least one tensor input");
- // insert the if block first;
+ Node* typecheck_node = fusion->owningGraph()
+ ->create(prim::CudaFusionGuard, inputs_to_check, 1)
+ ->insertBefore(fusion);
+ // fix output to BoolType
+ typecheck_node->output()->setType(BoolType::get());
+ Value* typecheck_result = typecheck_node->output();
+ typecheck_node->tys_(attr::types, guard_types);
+
+ std::unordered_map<Value*, Value*> typechecked_inputs;
+
+ // Insert if block
auto versioning_if =
- fusion->owningGraph()->create(prim::If, fusion->outputs().size());
+ fusion->owningGraph()
+ ->create(prim::If, {typecheck_result}, fusion->outputs().size())
+ ->insertAfter(typecheck_node);
for (size_t idx = 0; idx < fusion->outputs().size(); ++idx) {
versioning_if->output(idx)->setType(fusion->output(idx)->type());
fusion->output(idx)->replaceAllUsesWith(versioning_if->output(idx));
auto true_block = versioning_if->addBlock();
auto false_block = versioning_if->addBlock();
- // insert typecheck_node;
- Node* typecheck_node =
- fusion->owningGraph()
- ->create(prim::CudaFusionGuard, tensor_inputs_to_check, 1)
- ->insertBefore(fusion);
- // fix output to BoolType
- typecheck_node->output()->setType(BoolType::get());
- Value* typecheck_result = typecheck_node->output();
- typecheck_node->tys_(attr::types, guard_types);
-
- versioning_if->insertAfter(typecheck_node);
-
// Fill in the false block. It should contain the unoptimized
- // copy of the fused subgraph, unless we have conditional constants from
- // profiled_ivalue;
- auto fusion_graph = fusion->g(attr::Subgraph);
- std::shared_ptr<Graph> fb_graph; // resource holder;
- // Restore the dependency for constant introduced by profiled_ivalue within
- // the graph.
- if (!profiled_ivalue_indices.empty()) {
- // This is necessary as it cleans up the fallback graph, which was copied
- // from subgraph, since the two graph would differ as we cannot use
- // conditional constant in fallback
-
- // 1. RESTORE conditional constant dependency in fallback group;
- fb_graph = fusion_graph->copy();
- GRAPH_DEBUG("re-wiring fallback graph", *fb_graph);
-
- for (const auto& offset : profiled_ivalue_indices) {
- auto val = fb_graph->inputs()[offset];
- auto uses = val->uses();
- // since we are updating use of val in the loop, we have to copy
- // val->uses() before hand.
- for (const auto& use : uses) {
- // re-wire inputs and remove conditional constant nodes;
- TORCH_INTERNAL_ASSERT(
- use.user->kind() == prim::Constant,
- "profile_ivalue at index: ",
- offset,
- " can only be used by conditional constant, instead got: ",
- use.user->kind().toDisplayString());
- use.user->output()->replaceAllUsesWith(val);
- use.user->destroy();
- }
- }
-
- WithInsertPoint guard(false_block->return_node());
- const auto subgraph_outputs =
- insertGraph(*fusion->owningGraph(), *fb_graph, fusion->inputs());
- for (Value* output : subgraph_outputs) {
- false_block->registerOutput(output);
- }
- // types get copied to the fallback graph, so remove specializations before
- // replacing
- // TODO: this is not exposed here, I need to remove that before inserting
- // the graph
- // removeTensorTypeSpecializations(false_block);
- replaceBlockWithFallbackGraph(false_block, fusion->inputs());
-
- // 2. REMOVE conditional constant dependency in fusion group
- size_t compensation = 0;
-
- // get a constant false, which is used by `and` pattern later
- auto const_true = fusion->owningGraph()->insertConstant(IValue(true));
- const_true->node()->moveBefore(versioning_if);
-
- for (const auto& original_offset : profiled_ivalue_indices) {
- size_t offset = original_offset - compensation;
-
- // step a. handle fusion
- // remove inputs to fusion, and update check logic for fallback
- auto profiled_ival = fusion->input(offset)->node()->input();
- auto const_o = createConditionalConstant(fusion->input(offset)->node());
- TORCH_INTERNAL_ASSERT(
- const_o,
- "profile_ivalue node are expected to have profile information, at node: ",
- *fusion->input(offset)->node());
- const_o->node()->moveBefore(versioning_if);
- Value* ivalue_check = nullptr;
-
- if (fusion->input(offset)->node()->hasAttribute(
- Symbol::attr("profiled_bool"))) {
- // aten::eq doesn't support comparison between two boolean
- auto xor_n = fusion->owningGraph()
- ->create(aten::__xor__, {profiled_ival, const_o}, 1)
- ->insertBefore(versioning_if);
- xor_n->output()->setType(BoolType::get());
- ivalue_check =
- fusion->owningGraph()
- ->create(aten::__xor__, {xor_n->output(), const_true}, 1)
- ->insertBefore(versioning_if)
- ->output();
- } else if (fusion->input(offset)->node()->hasAttribute(
- Symbol::attr("profiled_size"))) {
- // TODO(profile_size): check sizes here with special size comparison op
- // TORCH_INTERNAL_ASSERT(false, "not implemented yet");
- ivalue_check =
- fusion->owningGraph()
- ->create(
- c10::Symbol::fromQualString("prim::CudaFusionSizeEq"),
- {profiled_ival, const_o},
- 1)
- ->insertBefore(versioning_if)
- ->output();
- } else {
- ivalue_check = fusion->owningGraph()
- ->create(aten::eq, {profiled_ival, const_o}, 1)
- ->insertBefore(versioning_if)
- ->output();
- }
- ivalue_check->setType(BoolType::get());
-
- typecheck_result =
- fusion->owningGraph()
- ->create(aten::__and__, {ivalue_check, typecheck_result}, 1)
- ->insertBefore(versioning_if)
- ->output();
- typecheck_result->setType(BoolType::get());
-
- // remove inputs to fusion;
- fusion->removeInput(offset);
-
- // step b. remove the extra dependency inside fusion;
- for (const auto& use : fusion_graph->inputs()[offset]->uses()) {
- TORCH_INTERNAL_ASSERT(
- use.user->kind() == prim::Constant,
- "profile_ivalue at index: ",
- offset,
- " can only be used by conditional constant, instead got: ",
- use.user->kind().toDisplayString());
- use.user->removeAllInputs();
- }
- fusion_graph->eraseInput(offset);
- compensation++;
- }
- // update graph in fusion node
- fusion->g_(attr::Subgraph, fusion_graph);
- } else {
- WithInsertPoint guard(false_block->return_node());
- const auto subgraph_outputs =
- insertGraph(*fusion->owningGraph(), *fusion_graph, fusion->inputs());
- for (Value* output : subgraph_outputs) {
- false_block->registerOutput(output);
- }
- // types get copied to the fallback graph, so remove specializations before
- // replacing
- // TODO: this is not exposed here, I need to remove that before inserting
- // the graph
- // removeTensorTypeSpecializations(false_block);
- replaceBlockWithFallbackGraph(false_block, fusion->inputs());
+ // copy of the fused subgraph.
+ auto& subgraph = *fusion->g(attr::Subgraph);
+ WithInsertPoint guard(false_block->return_node());
+ const auto subgraph_outputs =
+ insertGraph(*fusion->owningGraph(), subgraph, fusion->inputs());
+ for (Value* output : subgraph_outputs) {
+ false_block->registerOutput(output);
}
- // wiring up if block
- versioning_if->addInput(typecheck_result);
+ // types get copied to the fallback graph, so remove specializations before
+ // replacing
+ // TODO: this is not exposed here, I need to remove that before inserting the
+ // graph
+ // removeTensorTypeSpecializations(false_block);
+ replaceBlockWithFallbackGraph(false_block, fusion->inputs());
// Fill in the true block. It has all inputs type-checked and its
// body should be the fusion group node.
}
}
for (Node* fusion : fusions) {
- // step 1: a. add prim::CudaFusionGuard and fallback logic
- // b. insert guard logic of profile_ivalue with if block
- // c. restore conditional constant to non-constant for fallback
guardFusionGroup(fusion);
}
}
-// rewire const integer index & empty byte-typed reserve space tensor outputs,
-// so `CudaFusionGroup` doesn't have to handle those
-void alterBatchNormImplIndex(Node* node) {
- std::set<size_t> bn_index_out_indices;
- std::set<size_t> bn_buffer_out_indices;
-
- auto subgraph = node->g(attr::Subgraph);
- for (size_t i = 0; i < subgraph->outputs().size(); i++) {
- auto val = subgraph->outputs()[i];
- if (val->node()->kind() == aten::_batch_norm_impl_index &&
- val->offset() == 4) {
- bn_index_out_indices.emplace(i);
- } else if (
- val->node()->kind() == aten::_batch_norm_impl_index &&
- val->offset() == 3) {
- bn_buffer_out_indices.emplace(i);
- }
- }
-
- if (!bn_index_out_indices.empty()) {
- // we output index to 0 so backwards go through native_batch_norm, which is
- // what we support;
- auto const_1 = node->owningGraph()->insertConstant(IValue(0));
- const_1->node()->moveBefore(node);
- for (auto i : bn_index_out_indices) {
- node->outputs()[i]->replaceAllUsesWith(const_1);
- }
- }
-
- if (!bn_buffer_out_indices.empty()) {
- auto graph = node->owningGraph();
- std::vector<int64_t> sizes{0}; // empty tensor with no size;
- // std::vector<int64_t> sizes; // empty tensor with no size;
- auto const_size_0 = node->owningGraph()->insertConstant(IValue(sizes));
- const_size_0->node()->moveBefore(node);
- auto const_0 = node->owningGraph()->insertConstant(IValue(0));
- const_0->node()->moveBefore(node);
- auto none_val = node->owningGraph()->insertConstant(IValue());
- none_val->node()->moveBefore(node);
- auto device =
- graph->insertNode(graph->create(prim::device, {node->inputs()[0]}, 1));
- device->moveBefore(node);
- device->output()->setType(DeviceObjType::get());
- auto empty_tensor = graph->insertNode(graph->create(
- aten::empty,
- {const_size_0, const_0, none_val, device->output(), none_val, none_val},
- 1));
- empty_tensor->moveBefore(node);
- for (auto i : bn_buffer_out_indices) {
- node->outputs()[i]->replaceAllUsesWith(empty_tensor->output());
- }
- }
-
- bn_index_out_indices.insert(
- bn_buffer_out_indices.begin(), bn_buffer_out_indices.end());
- for (auto iter = bn_index_out_indices.crbegin();
- iter != bn_index_out_indices.crend();
- ++iter) {
- subgraph->eraseOutput(*iter);
- node->eraseOutput(*iter);
- }
-}
-
-// rewire empty byte-typed reserve space tensor input to an empty float-typed
-// tensor, because `CudaFusionGroup` doesn't support byte-typed tensor, nor does
-// it use reserve space.
-void alterBatchNormImplIndexBackward(Node* node) {
- std::set<size_t> bn_buffer_in_indices;
-
- auto subgraph = node->g(attr::Subgraph);
- for (auto n : subgraph->nodes()) {
- if (n->kind() == aten::_batch_norm_impl_index_backward) {
- // 11th inputs are `reserve`, which is not used by codegen kernel and its
- // type is not supported `Byte`. So we disconnect it here to avoid codegen
- // error
- auto byte_input = n->inputs()[11];
- // TODO: let's check the data type for buffer and skip if it's good
- // TODO: we can actually support it by adding an extra inputs to the
- // subgraph
- // TODO: assert on empty buffer
- TORCH_INTERNAL_ASSERT(
- byte_input->node() == subgraph->param_node(),
- "Assumption that reserve input to aten::_batch_norm_impl_index_backward comes from forward graph is broken");
- bn_buffer_in_indices.emplace(byte_input->offset());
- }
- }
-
- if (!bn_buffer_in_indices.empty()) {
- auto graph = node->owningGraph();
- std::vector<int64_t> sizes{0}; // empty tensor with no size;
- // std::vector<int64_t> sizes{}; // empty tensor with no size;
- auto const_size_0 = node->owningGraph()->insertConstant(IValue(sizes));
- const_size_0->node()->moveBefore(node);
- auto const_0 = node->owningGraph()->insertConstant(IValue(6));
- const_0->node()->moveBefore(node);
- auto none_val = node->owningGraph()->insertConstant(IValue());
- none_val->node()->moveBefore(node);
- auto device =
- graph->insertNode(graph->create(prim::device, {node->inputs()[1]}, 1));
- device->moveBefore(node);
- device->output()->setType(DeviceObjType::get());
- auto empty_tensor = graph->insertNode(graph->create(
- aten::empty,
- {const_size_0, const_0, none_val, device->output(), none_val, none_val},
- 1));
- empty_tensor->moveBefore(node);
-
- for (auto iter = bn_buffer_in_indices.begin();
- iter != bn_buffer_in_indices.end();
- ++iter) {
- subgraph->inputs()[*iter]->setType(
- node->inputs()[*iter]->type()->cast<TensorType>()->withScalarType(
- at::ScalarType::Float));
- node->replaceInput(*iter, empty_tensor->output());
- }
- }
-}
-
-void alterBatchNormImpls(Block* block) {
- std::vector<Node*> fusions;
- for (Node* n : block->nodes()) {
- for (Block* b : n->blocks()) {
- alterBatchNormImpls(b);
- }
- if (n->kind() == prim::CudaFusionGroup) {
- fusions.push_back(n);
- }
- }
- for (Node* fusion : fusions) {
- // remove index & reserve from outputs;
- alterBatchNormImplIndex(fusion);
- // remove reserve from inputs;
- alterBatchNormImplIndexBackward(fusion);
- }
-}
-
-// We absorb `prim::dtype` node into CudaFusion structure. The structure below
-//
-// %1 = prim::CudaFusionGuard(...)
-// %2, %3 = prim::If(...)
-// block0():
-// %4, %5 = prim::CudaFusionGroup(...)
-// -> (%4, %5)
-// block1():
-// %6, %7 = prim::FallbackGraph(...)
-// -> (%6, %7)
-// %4 = prim::dtype(%3)
-// ... (uses %2, %4, but never reference to %3 any more)
-//
-// is updated to:
-//
-// %1 = prim::CudaFusionGuard(...)
-// %2, %3 = prim::If(...)
-// block0():
-// %4 = prim::CudaFusionGroup(...) # %5 is also removed from subgraph
-// %8 = prim::Constant[value=...]()
-// -> (%4, %8)
-// block1():
-// %6, %7 = prim::FallbackGraph(...)
-// %9 = prim::dtype(%7)
-// -> (%6, %9)
-// # %4 = prim::dtype(%3) is removed. All reference to %4 is replaced with %3
-// ... (uses %2, %4, but never reference to %3 any more)
-void removeOutputUsedOnlyInDtype(Node* fusion_node) {
- auto fusion_block = fusion_node->owningBlock();
- TORCH_INTERNAL_ASSERT(
- fusion_block->owningNode() &&
- fusion_block->owningNode()->kind() == prim::If,
- "CudaFusionGroup should be inside `prim::CudaFusionGuard` / `prim::If`");
-
- auto if_node = fusion_block->owningNode();
- auto fusion_node_graph = fusion_node->g(attr::Subgraph);
- auto fallback_block = if_node->blocks()[1];
-
- bool updated = false;
- // Iterating in this order is crucial for correctness (i has to reflect the
- // current true index of outputs[i])!
- for (int64_t i = static_cast<int64_t>(if_node->outputs().size()) - 1; i >= 0;
- --i) {
- auto output = if_node->outputs()[i];
- // output only used in dtype, we eliminate the output and rely on
- // profiled/static scalar type inference to save on memory IO.
- if (usedOnlyInDtype(output)) {
- updated = true;
- {
- // update fusion_block to output profiled scalar type
- auto fusion_output = fusion_block->outputs()[i];
- auto tensor_type = fusion_output->type()->cast<TensorType>();
- TORCH_INTERNAL_ASSERT(
- tensor_type, "non tensor fed to dtype is not supported");
- auto scalar_type = tensor_type->scalarType();
- TORCH_INTERNAL_ASSERT(
- scalar_type.has_value(),
- "ScalarType should be static for Tensors in fusion for amp optimization");
- auto type_const =
- fusion_block->owningGraph()->insertConstant(IValue(scalar_type));
- type_const->setType(IntType::get());
- type_const->node()->moveBefore(fusion_block->return_node());
- fusion_block->replaceOutput(i, type_const);
-
- // remove the dangling output tensor in CudaFusionGroup
- fusion_node->eraseOutput(i);
- fusion_node_graph->eraseOutput(i);
- }
-
- {
- // update fallback_block to output dtype instead of tensor
- auto tensor_output = fallback_block->outputs()[i];
- auto dtype_node = fallback_block->owningGraph()->create(
- prim::dtype, tensor_output, 1);
- dtype_node->output()->setType(IntType::get());
- fallback_block->appendNode(dtype_node);
- fallback_block->replaceOutput(i, dtype_node->output());
- }
-
- // we just shot-cut the `dtype` node since we are already outputing dtype
- auto uses = output->uses();
- for (Use u : uses) {
- AT_ASSERT(u.user->matches("prim::dtype(Tensor a) -> int"));
- u.user->output()->replaceAllUsesWith(output);
- u.user->destroy();
- }
- output->setType(IntType::get());
- }
- }
-
- if (updated) {
- fusion_node->g_(attr::Subgraph, fusion_node_graph);
- }
-}
-
-// For output tensors in fusion group that is only used by dtype node, with
-// CudaFusionGuard, we can short-cut it with constant dtype directly instead to
-// save IO memory bandwidth.
-// The reason that we do it after we insert the guard, instead of doing it along
-// during graph fusion/partitioning, is that we needed to handle the fallback
-// differently, since fallback is not inside CudaFusionGuard, and hence doesn't
-// have the dtype as a constant.
-void removeOutputUsedOnlyInDtype(Block* block) {
- std::vector<Node*> fusions;
- for (Node* n : block->nodes()) {
- for (Block* b : n->blocks()) {
- removeOutputUsedOnlyInDtype(b);
- }
- if (n->kind() == prim::CudaFusionGroup) {
- fusions.push_back(n);
- }
- }
- for (Node* fusion : fusions) {
- // remove index & reserve from outputs;
- removeOutputUsedOnlyInDtype(fusion);
- }
-}
-
-void RemoveProfileIValue(Node* profile_ivalue) {
- for (const auto& use : profile_ivalue->output()->uses()) {
- if (use.user->kind() == prim::Constant) {
- use.user->output()->replaceAllUsesWith(profile_ivalue->input());
- use.user->destroy();
- }
- }
- profile_ivalue->output()->replaceAllUsesWith(profile_ivalue->input());
- profile_ivalue->destroy();
-}
-
-void ExtractProfileIValue(Node* profile_ivalue) {
- auto const_o = createConditionalConstant(profile_ivalue);
- if (const_o) {
- auto const_n = const_o->node();
- const_n->moveAfter(profile_ivalue);
- profile_ivalue->output()->replaceAllUsesAfterNodeWith(const_n, const_o);
- // special wiring, we add this input to constant simply in order to create
- // dependency, which we can trace and remove later;
- const_n->addInput(profile_ivalue->output());
- } else {
- // no profile value available, remove profile_ivalue node;
- RemoveProfileIValue(profile_ivalue);
- }
-}
-
-void traverseProfileIValues(
- Block* block,
- const std::function<void(Node*)>& func) {
- std::vector<Node*> profile_ivalues;
- for (Node* n : block->nodes()) {
- for (Block* b : n->blocks()) {
- traverseProfileIValues(b, func);
- }
- if (n->kind() == prim::profile_ivalue) {
- profile_ivalues.push_back(n);
- }
- }
- for (Node* profile_ivalue : profile_ivalues) {
- func(profile_ivalue);
- }
-}
-
-// break `linear` layer into `matmul` and `add_optional`. This allows us to fuse
-// the binary operation without supporting gemm.
-// Note that we are not breaking `linear` layer without bias.
-void decomposeLinearOps(Block* block) {
- std::vector<Node*> linear_nodes;
- for (Node* n : block->nodes()) {
- for (Block* b : n->blocks()) {
- decomposeLinearOps(b);
- }
- // only decompose `linear` layer with bias.
- if (n->kind() == aten::linear &&
- !n->input(2)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- linear_nodes.push_back(n);
- }
- }
-
- auto graph = block->owningGraph();
- for (Node* n : linear_nodes) {
- WithInsertPoint guard(n);
- auto weight_t = graph->insertNode(graph->create(aten::t, {n->input(1)}, 1));
- auto matmul = graph->insertNode(
- graph->create(aten::matmul, {n->input(0), weight_t->output()}, 1));
- auto input_tensor_type = n->input(0)->type()->cast<c10::TensorType>();
- auto mat0_size = input_tensor_type->sizes().concrete_sizes();
- auto mat1_size =
- n->input(1)->type()->cast<c10::TensorType>()->sizes().concrete_sizes();
-
- // TODO: The assert is not necessary when we can handle matmul, right now we
- // are splitting the linear between matmul & bias_add. Our fuser can only
- // take the second half and we would need the size information.
- TORCH_INTERNAL_ASSERT(
- mat0_size.has_value() && mat1_size.has_value(),
- "concrete shape for linear input & weight are required");
- auto out_size = mat0_size.value();
- out_size[out_size.size() - 1] = mat1_size.value()[0];
- matmul->output()->setType(input_tensor_type->withSizes(out_size));
-
- // TODO: memory stride should be considered here, our inference above is not
- // safe.
- auto bias = graph->insertNode(
- graph->create(prim::add_optional, {matmul->output(0), n->input(2)}, 1));
- bias->output()->setType(matmul->output(0)->type());
-
- n->output()->replaceAllUsesWith(bias->output());
- n->destroy();
- }
-}
-
} // anonymous namespace
void CudaFuseGraph(std::shared_ptr<Graph>& graph) {
- FUSER_PERF_SCOPE("nvFuser::Manager::CudaFuseGraph");
- GRAPH_DUMP("Before Fusion: ", graph);
-
- // TODO: extract & guard profile_ivalue; but how do we restore it???
- // I don't know how to store edge/node in attribute. so let's abuse data flow
- // dependency and add inputs to conditional constant generated by
- // aten::profile_ivalue
- traverseProfileIValues(graph->block(), ExtractProfileIValue);
- GRAPH_DEBUG("insert conditional constant from profile_ivalue: ", *graph);
-
+ FUSER_PERF_SCOPE("CudaFuseGraph");
// TODO: we need to properly restore shape information after fusion.
// shamelessly use tool from NNC.
RemoveProfileNodesAndSpecializeTypes(graph);
- GRAPH_DEBUG("After Profiling Nodes Removed: ", *graph);
-
- // TODO: separate passes into different file;
- // TODO: restore decomposition after fusion, in case we are decomposing
- // operation that can't be fused;
- decomposeLinearOps(graph->block());
- GRAPH_DEBUG("decompose operations by nvfuser: ", *graph);
CudaGraphFuser(graph->block(), graph).run();
- GRAPH_DEBUG("After Fusion: ", *graph);
-
- // guard input types as well as conditional constants from
- // aten::profile_ivalue
guardFusionGroups(graph->block());
- GRAPH_DEBUG("After Guard Fusion: ", *graph);
-
- // mutate `aten::_batch_norm_impl_index` and
- // `aten::_batch_norm_impl_index_backward` node in the fusion group to WAR
- // the lack of fusion support on integer output as well as byte-typed tensor.
- alterBatchNormImpls(graph->block());
- GRAPH_DEBUG("After _batch_norm_impl_index: ", *graph);
-
- traverseProfileIValues(graph->block(), RemoveProfileIValue);
-
- GRAPH_DEBUG("Before remove missing profiling: ", *graph);
- removeFusionWithMissingProfilingInformation(graph->block());
- GRAPH_DEBUG("After remove missing profiling: ", *graph);
-
- // optimization targeting AMP
- removeOutputUsedOnlyInDtype(graph->block());
- GRAPH_DEBUG("After removeOutputUsedOnlyInDtype: ", *graph);
// After FuseGraph some common subexpressions may come back
EliminateCommonSubexpression(graph);
// We might have emitted a fair amount of useless shape propagating code, so
// shamelessly use tool from NNC.
RemoveTensorTypeSpecializations(graph);
- GRAPH_DUMP("Before Compilation: ", graph);
// Compile CudaFusionGroup
compileFusionRecursive(graph->block());
}
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/index_reference_replay.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
void handle(Split*) override {}
void handle(Merge* merge) override {
- const auto gpu_lower = GpuLower::current();
-
// If either input is non-contiguous so is output.
- const auto inner = merge->inner();
- const auto outer = merge->outer();
-
- if ((!isContig(gpu_lower->lowerValue(inner)->as<kir::IterDomain>()) ||
- !isContig(gpu_lower->lowerValue(outer)->as<kir::IterDomain>()))) {
+ auto inner = merge->inner();
+ auto outer = merge->outer();
+ if (!isContig(GpuLower::lowerValue(inner)->as<kir::IterDomain>()) ||
+ !isContig(GpuLower::lowerValue(outer)->as<kir::IterDomain>())) {
return;
}
if (root_copy.front() == ordered_inputs.front()) {
root_copy.pop_front();
ordered_inputs.pop_front();
- // This is no longer causing an error in:
- // ReductionSchedulerMultiDimNonFastest TODO: test reenablement to make
- // sure it does what's expected
+ // We probably should be able to make access contiguous through
+ // reduction domains, however, for now it's causing issues in predicate
+ // generation. See test: ReductionSchedulerMultiDimNonFastest
// } else if (
// root_copy.front()->isReduction() ||
// root_copy.front()->isBroadcast()) {
// top contig ID, lower ids should be placed in the "within_contig_ids" map
// of top id.
auto kir_inner =
- gpu_lower->lowerValue(merge->inner())->as<kir::IterDomain>();
+ GpuLower::lowerValue(merge->inner())->as<kir::IterDomain>();
auto kir_outer =
- gpu_lower->lowerValue(merge->outer())->as<kir::IterDomain>();
- auto kir_out = gpu_lower->lowerValue(merge->out())->as<kir::IterDomain>();
+ GpuLower::lowerValue(merge->outer())->as<kir::IterDomain>();
+ auto kir_out = GpuLower::lowerValue(merge->out())->as<kir::IterDomain>();
if (ordered_inputs.empty()) {
if (contig_ids.find(kir_inner) != contig_ids.end()) {
contig_ids.erase(kir_inner);
// Check through thie history of ids whose inputs map to root_domain with
// contiguity root_contiguity. Return unordered_set of all merges that are
- // contiguous. Ignore root order is primarily used for predicate generation.
- // In this case we can linearize indexing of any ID that only consists of
- // merge operations.
+ // contiguous.
ContigIDs(
const std::vector<IterDomain*>& ids,
- const std::vector<IterDomain*>& root_domain,
- const std::vector<bool>& root_contiguity)
- : root_domain_(root_domain), root_contiguity_(root_contiguity) {
+ const std::vector<IterDomain*>& _root_domain,
+ const std::vector<bool>& _root_contiguity)
+ : root_domain_(_root_domain), root_contiguity_(_root_contiguity) {
if (ids.empty()) {
return;
}
" != ",
root_contiguity_.size());
- const auto gpu_lower = GpuLower::current();
-
for (const auto i : c10::irange(root_domain_.size())) {
if (root_contiguity_[i]) {
auto kir_root_domain_i =
- gpu_lower->lowerValue(root_domain_[i])->as<kir::IterDomain>();
+ GpuLower::lowerValue(root_domain_[i])->as<kir::IterDomain>();
contig_ids.emplace(kir_root_domain_i);
within_contig_ids[kir_root_domain_i] =
std::unordered_set<kir::IterDomain*>();
- is_contig_root[root_domain_[i]] = true;
- } else {
- is_contig_root[root_domain_[i]] = false;
}
+ is_contig_root[root_domain_[i]] = root_contiguity_[i];
}
auto exprs = ExprSort::getExprs(ids[0]->fusion(), {ids.begin(), ids.end()});
}
};
-// Update the HaloInfo mappings for a reference tensor by propagating
-// the halo information from the consumer tensor.
-void updateHaloInfoForReference(
- const ReferenceTensor& reference,
- const TensorView* consumer_tv) {
- const auto gpu_lower = GpuLower::current();
-
- auto& halo_info = gpu_lower->haloInfo();
-
- auto* reference_domain = reference.domain;
- const auto& reference_concrete_map = reference.concrete_to_id;
-
- for (auto reference_root_axis : reference_domain->getRootDomain()) {
- // Set default
- halo_info.setRootAxisInfo(reference_root_axis, AxisHaloInfo());
- auto consumer_it = std::find_if(
- consumer_tv->getRootDomain().begin(),
- consumer_tv->getRootDomain().end(),
- [&](IterDomain* consumer_root) {
- auto concrete_id =
- gpu_lower->caIndexMap().getConcreteMappedID(consumer_root);
- auto it = reference_concrete_map.find(concrete_id);
- return it != reference_concrete_map.end() &&
- it->second == reference_root_axis;
- });
- // When no corresponding ID of the consumer exists, the reference
- // axis can be ignored
- if (consumer_it == consumer_tv->getRootDomain().end()) {
- continue;
- }
- auto consumer_root_axis = *consumer_it;
- auto root_axis_info =
- gpu_lower->haloInfo().getRootAxisInfo(consumer_root_axis);
- if (root_axis_info.width() == 0) {
- continue;
- }
- halo_info.setRootAxisInfo(reference_root_axis, root_axis_info);
- }
-
- halo_info.build(reference_domain);
-
- return;
-}
-
-// Get a map of IterDomains to halo-extended extents of corresponding
-// reference IterDomains.
-//
-// ref_map: ref-to-consumer in consumer indexing; ref-to-producer in
-// producer indexing
-std::unordered_map<kir::IterDomain*, kir::Val*> getReferenceHaloExtentMap(
- const ReferenceTensor& reference,
- const TensorView* consumer_tv,
- const std::unordered_map<IterDomain*, IterDomain*>& ref_map,
- const std::unordered_map<kir::IterDomain*, kir::Val*>& extent_map) {
- const auto gpu_lower = GpuLower::current();
-
- // First, update HaloInfo with the reference tensor, which reflects
- // the halo extents of the consumer tensor.
- updateHaloInfoForReference(reference, consumer_tv);
-
- const auto& halo_info = gpu_lower->haloInfo();
-
- std::unordered_map<kir::IterDomain*, kir::Val*> reference_halo_extent_map;
-
- // Propagate halo extents of the reference to the consumer or
- // producer tensor
- for (auto kv : ref_map) {
- auto ref_id = gpu_lower->lowerValue(kv.first)->as<kir::IterDomain>();
- auto producer_or_consumer_id =
- gpu_lower->lowerValue(kv.second)->as<kir::IterDomain>();
- auto extent = halo_info.getExtent(ref_id);
- if (extent == nullptr) {
- auto extent_it = extent_map.find(ref_id);
- if (extent_it != extent_map.end()) {
- extent = extent_it->second;
- } else {
- extent = ref_id->extent();
- }
- }
- reference_halo_extent_map[producer_or_consumer_id] = extent;
- }
-
- return reference_halo_extent_map;
-}
-
-//! Offset of an index of a producer axis with respect to its
-//! corresponding consumer index
-kir::Val* getProducerHaloOffset(
- const TensorView* producer_tv,
- size_t producer_axis,
- const TensorView* consumer_tv) {
- auto p2c =
- PairwiseRootDomainMap(producer_tv, consumer_tv)
- .mapProducerToConsumer(producer_tv->domain(), consumer_tv->domain());
-
- auto producer_id = producer_tv->getMaybeRFactorDomain()[producer_axis];
-
- auto it = p2c.find(producer_id);
- // p2c should always have a mapping for producer_id. The only case
- // where no mapping exists for a producer axis is when it is a
- // reduction axis. Since this function is only used for indexing
- // producer tensors, where reduction axes are skipped, producer_id
- // should never be a reduction axis.
- TORCH_INTERNAL_ASSERT(it != p2c.end());
- IterDomain* consumer_id = it->second;
-
- const auto& halo_map = GpuLower::current()->haloInfo();
- const auto p_pad = halo_map.getRootAxisInfo(producer_id).width(0);
- const auto c_pad = halo_map.getRootAxisInfo(consumer_id).width(0);
-
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- kir::Val* offset = (p_pad->isConst() && c_pad->isConst())
- ? ir_builder.create<kir::Int>(
- p_pad->value().value() - c_pad->value().value())
- : ir_builder.subExpr(p_pad, c_pad);
-
- // If the consumer is a result of shifting the producer, adjust the
- // producer index per the offsets argument of the shift op.
- if (auto shift_op = dynamic_cast<const ShiftOp*>(consumer_tv->definition())) {
- offset = ir_builder.subExpr(
- offset, ir_builder.create<kir::Int>(shift_op->offset(producer_axis)));
- }
-
- return offset;
-}
-
-//! Offset producer index when necessary
-kir::Val* getProducerIndexWithHalo(
- const TensorView* producer_tv,
- size_t producer_axis,
- kir::Val* producer_index,
- const TensorView* consumer_tv) {
- const auto offset =
- getProducerHaloOffset(producer_tv, producer_axis, consumer_tv);
-
- if (offset->isZeroInt()) {
- return producer_index;
- }
-
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- producer_index = ir_builder.addExpr(producer_index, offset);
-
- return producer_index;
-}
-
-//! Offset a producer index of a gather expression
-//!
-//! Given an index of a producer root axis, build a new index
-//! expression that accesses a window position that the current loop
-//! structure refers to.
-kir::Val* getProducerIndexWithGather(
- size_t producer_root_axis,
- kir::Val* producer_index,
- const TensorView* producer_tv,
- const TensorView* consumer_tv,
- const std::unordered_map<kir::IterDomain*, kir::Val*>& ref_index_map,
- const std::unordered_map<IterDomain*, IterDomain*>& ref_concrete_map) {
- auto gather_op = dynamic_cast<const GatherOp*>(consumer_tv->definition());
-
- // Just return the producer index as is if this is not a gather
- if (gather_op == nullptr) {
- return producer_index;
- }
-
- // Consumer axis that corresponds to the producer axis
- int consumer_axis = -1;
- for (size_t i = 0; i <= producer_root_axis; ++i) {
- if (producer_tv->getRootDomain()[i]->isReduction()) {
- continue;
- }
- ++consumer_axis;
- }
-
- TORCH_INTERNAL_ASSERT(
- consumer_axis >= 0 &&
- consumer_axis < (int)gather_op->windowShape().size(),
- "Invalid consumer axis",
- consumer_axis,
- ", producer_axis: ",
- producer_root_axis);
-
- // If the window extent is one, no specific offsetting
- // is necessary
- if (gather_op->windowShape()[consumer_axis]->isOneInt()) {
- return producer_index;
- }
-
- // Basically, the goal is to build an expression of producer_index +
- // window_index, so we first need to locate the index expression
- // that corresponds to the window axis of this producer axis.
-
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- // Locate the root IterDomain of the reference that corresponds to the gather
- // axis
- const auto window_root_axis = gather_op->gatherAxis(consumer_axis);
- auto concrete_window_id = gpu_lower->caIndexMap().getConcreteMappedID(
- consumer_tv->getRootDomain().at(window_root_axis));
- auto ref_concrete_map_it = ref_concrete_map.find(concrete_window_id);
- TORCH_INTERNAL_ASSERT(ref_concrete_map_it != ref_concrete_map.end());
- IterDomain* reference_root_of_gather_axis = ref_concrete_map_it->second;
-
- // Now that reference_root_of_gather_axis is the IterDomain for the
- // window axis, take its corresponding index from the index map
- auto window_idx =
- ref_index_map.at(gpu_lower->lowerValue(reference_root_of_gather_axis)
- ->as<kir::IterDomain>());
-
- // Positive (or negative) padding at offset zero means the indexing
- // shifted to the negative (or positive) direction.
- auto pad_width = gather_op->padWidth()[consumer_axis][0];
-
- // producer_index - padding + window_index
- auto offset_producer_index = ir_builder.addExpr(
- ir_builder.subExpr(
- producer_index, ir_builder.create<kir::Int>(pad_width)),
- window_idx);
-
- return offset_producer_index;
-}
-
} // namespace
void IndexCompute::handle(Split* split) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- auto in_id = gpu_lower->lowerValue(split->in())->as<kir::IterDomain>();
- auto outer_id = gpu_lower->lowerValue(split->outer())->as<kir::IterDomain>();
- auto inner_id = gpu_lower->lowerValue(split->inner())->as<kir::IterDomain>();
+ auto in_id = GpuLower::lowerValue(split->in())->as<kir::IterDomain>();
+ auto outer_id = GpuLower::lowerValue(split->outer())->as<kir::IterDomain>();
+ auto inner_id = GpuLower::lowerValue(split->inner())->as<kir::IterDomain>();
auto outer_it = index_map_.find(outer_id);
auto inner_it = index_map_.find(inner_id);
if (outer_it == index_map_.end() || inner_it == index_map_.end())
return;
- const auto outer_ind = outer_it->second;
- const auto inner_ind = inner_it->second;
-
- const bool outer_zero = isZero(outer_id);
- const bool inner_zero = isZero(inner_id);
+ auto outer_ind = outer_it->second;
+ auto inner_ind = inner_it->second;
- // We want to mark as zero merged in if we're working with shared or local
- // memory, and the dimension we're working with is not part of the allocation,
- // as we have special propagation rules for that scenario.
+ bool outer_zero = outer_ind->isZeroInt();
+ bool inner_zero = inner_ind->isZeroInt();
- // Maybe clear in_id as it could have been mapped over from another
- // IndexCompute. Uncertain if this is needed but seems to be safe.
- bool zero_merged_in = hasZeroMerged(in_id) || hasZeroMerged(inner_id) ||
- hasZeroMerged(outer_id);
+ bool outer_bcast = outer_id->isBroadcast();
+ bool inner_bcast = inner_id->isBroadcast();
- // If both are zero, the split input is also zero
- if (inner_zero && outer_zero) {
- zero_.emplace(in_id);
- }
-
- if (zero_merged_in) {
+ // Zero inds because a dim is bcast is part of normal traversal, if it's not
+ // bcast but is zero ind then it's from local or smem. In the latter case we
+ // want to propagate this property.
+ if ((outer_zero && !outer_bcast) || (inner_zero && !inner_bcast) ||
+ hasZeroMerged(inner_id) || hasZeroMerged(outer_id)) {
zero_merged_in_.emplace(in_id);
+ } else {
+ // Maybe clear in_id as it could have been mapped over from another
+ // IndexCompute. Uncertain if this is needed but seems to be safe.
+ if (hasZeroMerged(in_id)) {
+ zero_merged_in_.erase(in_id);
+ }
}
- if (isZero(in_id)) {
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+
+ if (outer_zero && inner_zero) {
index_map_[in_id] = ir_builder.create<kir::Int>(0);
extent_map_[in_id] = ir_builder.create<kir::Int>(0);
- } else if (zero_merged_in && outer_zero) {
+ } else if (outer_zero) {
index_map_[in_id] = inner_ind;
+ zero_merged_in_.emplace(in_id);
extent_map_[in_id] = getExtent(inner_id);
- } else if (zero_merged_in && inner_zero) {
+ } else if (inner_zero) {
index_map_[in_id] = outer_ind;
+ zero_merged_in_.emplace(in_id);
extent_map_[in_id] = getExtent(outer_id);
} else {
index_map_[in_id] = ir_builder.addExpr(
ir_builder.mulExpr(outer_ind, getExtent(inner_id)), inner_ind);
- // The extent of a root axis should be only updated when its
- // allocation is partial, i.e., zero_merged_in is true. See issue
- // #1016 and the FusionIssue1016 test.
- if (split->in()->definition() != nullptr || zero_merged_in) {
+ if (extent_map_.find(outer_id) != extent_map_.end() ||
+ extent_map_.find(inner_id) != extent_map_.end()) {
extent_map_[in_id] =
ir_builder.mulExpr(getExtent(outer_id), getExtent(inner_id));
}
}
void IndexCompute::handle(Merge* merge) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- auto out_id = gpu_lower->lowerValue(merge->out())->as<kir::IterDomain>();
- auto outer_id = gpu_lower->lowerValue(merge->outer())->as<kir::IterDomain>();
- auto inner_id = gpu_lower->lowerValue(merge->inner())->as<kir::IterDomain>();
+ auto out_id = GpuLower::lowerValue(merge->out())->as<kir::IterDomain>();
+ auto outer_id = GpuLower::lowerValue(merge->outer())->as<kir::IterDomain>();
+ auto inner_id = GpuLower::lowerValue(merge->inner())->as<kir::IterDomain>();
auto out_it = index_map_.find(out_id);
- if (out_it == index_map_.end()) {
+ if (out_it == index_map_.end())
return;
- }
+
auto out_ind = out_it->second;
- auto zero = ir_builder.zeroVal();
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+ auto zero = ir_builder.create<kir::Int>(0);
- if (isZero(out_id)) {
+ if (out_ind->isZeroInt()) {
index_map_[outer_id] = zero;
index_map_[inner_id] = zero;
extent_map_[outer_id] = zero;
extent_map_[inner_id] = zero;
- zero_.emplace(outer_id);
- zero_.emplace(inner_id);
return;
}
if (!hasZeroMerged(out_id) && contig_ids.find(out_id) != contig_ids.end()) {
- // Contiguous indexing path
auto input_ids = ir_utils::iterDomainInputsOfOrderedAs(
{merge->out()}, td_->getRootDomain());
TORCH_INTERNAL_ASSERT(!input_ids.empty());
for (auto root_id : input_ids) {
- index_map_[gpu_lower->lowerValue(root_id)->as<kir::IterDomain>()] = zero;
+ index_map_[GpuLower::lowerValue(root_id)->as<kir::IterDomain>()] = zero;
}
- index_map_[gpu_lower
- ->lowerValue(*(input_ids.end() - 1))
+ index_map_[GpuLower::lowerValue(*(input_ids.end() - 1))
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
->as<kir::IterDomain>()] = out_ind;
return;
}
- kir::Val* inner_extent = getExtent(inner_id);
-
- // When the reference has halo extent for inner_id, that extent needs to
- // be used to un-merge
- if (reference_halo_extent_map_.find(inner_id) !=
- reference_halo_extent_map_.end()) {
- inner_extent = reference_halo_extent_map_[inner_id];
- }
-
- const auto outer_extent = getExtent(outer_id);
+ Val* inner_extent = getExtent(inner_id);
+ Val* outer_extent = getExtent(outer_id);
if (inner_id->isBroadcast() && inner_extent->isOneInt()) {
- // Propagate away from broadcast dims
index_map_[outer_id] = out_ind;
index_map_[inner_id] = zero;
extent_map_[outer_id] = getExtent(out_id);
} else if (outer_id->isBroadcast() && outer_extent->isOneInt()) {
- // Propagate away from broadcast dims
index_map_[outer_id] = zero;
index_map_[inner_id] = out_ind;
extent_map_[inner_id] = getExtent(out_id);
} else if (hasZeroMerged(out_id)) {
- // Don't propagate to inner id if it's comprised of only broadcast root
- // domains, unless outer is also all broadcast domains. Index shouldn't be
- // anything but zero if both inner and outer are all broadcast domains, but
- // didn't add a hard check for this. See FusionAdvancedIndexing5_CUDA
- if (!inner_id->isBroadcast() && !outer_id->isBroadcast()) {
- // If neither dimension is a broadcast (should be true for reference
- // indexing) pick the preferred path or the inner path.
- if (preferred_paths_.find(outer_id) != preferred_paths_.end() &&
- preferred_paths_.find(inner_id) == preferred_paths_.end()) {
- // Marked that we should prop through outer, not inner.
- index_map_[outer_id] = out_ind;
- extent_map_[outer_id] = getExtent(out_id);
- index_map_[inner_id] = zero;
- extent_map_[inner_id] = zero;
- } else {
- // Prop through inner
- index_map_[inner_id] = out_ind;
- extent_map_[inner_id] = getExtent(out_id);
- index_map_[outer_id] = zero;
- extent_map_[outer_id] = zero;
- }
- } else if (inner_id->isBroadcast() && !outer_id->isBroadcast()) {
- // Inner is broadcast and outer isn't, prop through outer
- index_map_[outer_id] = out_ind;
- extent_map_[outer_id] = getExtent(out_id);
- index_map_[inner_id] = zero;
- extent_map_[inner_id] = zero;
- } else {
- // Default to propagating through inner
- index_map_[inner_id] = out_ind;
- extent_map_[inner_id] = getExtent(out_id);
- index_map_[outer_id] = zero;
- extent_map_[outer_id] = zero;
- }
+ index_map_[inner_id] = out_ind;
+ extent_map_[inner_id] = getExtent(out_id);
+
+ index_map_[outer_id] = zero;
+ extent_map_[outer_id] = zero;
+
zero_merged_in_.emplace(inner_id);
zero_merged_in_.emplace(outer_id);
} else {
- index_map_[outer_id] = ir_builder.divExpr(out_ind, inner_extent);
- index_map_[inner_id] = ir_builder.modExpr(out_ind, inner_extent);
+ Val* I = inner_extent;
+
+ Val* outer_ind = ir_builder.divExpr(out_ind, I);
+ Val* inner_ind = ir_builder.modExpr(out_ind, I);
+
+ index_map_[outer_id] = outer_ind;
+ index_map_[inner_id] = inner_ind;
}
}
// using TransformIter::runBackward;
IndexCompute::IndexCompute(
const TensorDomain* _td,
- std::unordered_map<kir::IterDomain*, kir::Val*> initial_index_map,
- std::unordered_map<kir::IterDomain*, kir::Val*> extent_map,
- std::unordered_set<kir::IterDomain*> zero_merged_in,
- const std::vector<bool>& root_contiguity,
- std::unordered_set<kir::IterDomain*> preferred_paths,
- std::unordered_map<kir::IterDomain*, kir::Val*> reference_halo_extent_map)
+ std::unordered_map<kir::IterDomain*, Val*> initial_index_map,
+ std::unordered_map<kir::IterDomain*, Val*> _extent_map,
+ std::unordered_set<kir::IterDomain*> _zero_merged_in,
+ const std::vector<bool>& root_contiguity)
: td_(_td),
index_map_(std::move(initial_index_map)),
- extent_map_(std::move(extent_map)),
- zero_merged_in_(std::move(zero_merged_in)),
- preferred_paths_(std::move(preferred_paths)),
- reference_halo_extent_map_(std::move(reference_halo_extent_map)) {
- FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute");
+ extent_map_(std::move(_extent_map)),
+ zero_merged_in_(std::move(_zero_merged_in)) {
+ FUSER_PERF_SCOPE("IndexCompute::IndexCompute");
// Make sure we recompute any indices we can that map to a contiguous access
// in physical memory.
}
}
- // Initialize the zero_ set with domains that do not contibute to
- // the resulting index. Any domain that is mapped to Int(0), except
- // for vectorized ones, is included in this set.
- const auto gpu_lower = GpuLower::current();
- for (auto dom : td_->domain()) {
- auto kir_dom = gpu_lower->lowerValue(dom)->as<kir::IterDomain>();
- auto it = index_map_.find(kir_dom);
- if (it == index_map_.end()) {
- continue;
- }
- auto idx = it->second;
- if (idx->isZeroInt() && !isParallelTypeVectorize(dom->getParallelType())) {
- zero_.emplace(kir_dom);
- }
- }
-}
-
-void IndexCompute::run() {
const std::vector<Val*> domain_vals(
td_->domain().begin(), td_->domain().end());
traverseFrom(td_->fusion(), domain_vals, false);
}
-kir::Val* IndexCompute::getExtent(kir::IterDomain* id) {
- if (isParallelTypeThread(id->parallelType())) {
- auto parallel_dim =
- GpuLower::current()->parallelDimensionMap().get(id->parallelType());
- TORCH_INTERNAL_ASSERT(parallel_dim != nullptr);
- return parallel_dim;
- } else if (extent_map_.find(id) != extent_map_.end()) {
+Val* IndexCompute::getExtent(kir::IterDomain* id) {
+ if (extent_map_.find(id) != extent_map_.end()) {
return extent_map_.at(id);
} else {
return id->extent();
}
}
-bool IndexCompute::hasZeroMerged(kir::IterDomain* id) const {
- return zero_merged_in_.find(id) != zero_merged_in_.end() || isZero(id);
-}
-
-bool IndexCompute::isZero(kir::IterDomain* id) const {
- return zero_.find(id) != zero_.end();
+bool IndexCompute::hasZeroMerged(kir::IterDomain* id) {
+ return zero_merged_in_.find(id) != zero_merged_in_.end();
}
IndexCompute IndexCompute::updateIndexCompute(
const TensorDomain* new_td,
const std::unordered_map<IterDomain*, IterDomain*>& id_map,
- const std::vector<bool>& root_contiguity,
- const std::unordered_map<kir::IterDomain*, kir::Val*>&
- reference_halo_extent_map) {
- FUSER_PERF_SCOPE("GpuLower::Lower::updateIndexCompute");
-
- const auto gpu_lower = GpuLower::current();
+ std::unordered_map<kir::IterDomain*, Val*> new_index_entries,
+ const std::vector<bool>& root_contiguity) {
+ FUSER_PERF_SCOPE("updateIndexCompute");
- std::unordered_map<kir::IterDomain*, kir::Val*> updated_index_map;
- std::unordered_map<kir::IterDomain*, kir::Val*> updated_extent_map;
+ std::unordered_map<kir::IterDomain*, Val*> updated_index_map =
+ std::move(new_index_entries);
+ std::unordered_map<kir::IterDomain*, Val*> updated_extent_map;
std::unordered_set<kir::IterDomain*> updated_zero_merged_in;
for (auto id_entry : id_map) {
kir::IterDomain* prev_id =
- gpu_lower->lowerValue(id_entry.first)->as<kir::IterDomain>();
+ GpuLower::lowerValue(id_entry.first)->as<kir::IterDomain>();
kir::IterDomain* new_id =
- gpu_lower->lowerValue(id_entry.second)->as<kir::IterDomain>();
+ GpuLower::lowerValue(id_entry.second)->as<kir::IterDomain>();
if (index_map_.find(prev_id) != index_map_.end()) {
updated_index_map[new_id] = index_map_.at(prev_id);
}
- updated_extent_map[new_id] = getExtent(prev_id);
+ if (extent_map_.find(prev_id) != extent_map_.end()) {
+ updated_extent_map[new_id] = extent_map_.at(prev_id);
+ }
if (zero_merged_in_.find(prev_id) != zero_merged_in_.end()) {
updated_zero_merged_in.emplace(new_id);
}
}
- IndexCompute updated_index_compute(
+ return IndexCompute(
new_td,
updated_index_map,
updated_extent_map,
updated_zero_merged_in,
- root_contiguity,
- {},
- reference_halo_extent_map);
- updated_index_compute.run();
+ root_contiguity);
+}
- return updated_index_compute;
+std::vector<bool> IndexCompute::contiguityAnd(
+ const std::vector<bool>& contig1,
+ const std::vector<bool>& contig2) {
+ TORCH_INTERNAL_ASSERT(
+ contig1.size() == contig2.size(),
+ "Called contiguityAnd with mismatched vectors.");
+
+ std::vector<bool> contig_result;
+ std::transform(
+ contig1.begin(),
+ contig1.end(),
+ contig2.begin(),
+ std::back_inserter(contig_result),
+ std::logical_and<>());
+ return contig_result;
}
-namespace {
-// Map indices down to the leaf domains for applying swizzle
-class UpdateLeafIndices : public IterVisitor {
- public:
- UpdateLeafIndices(
- const TensorDomain* td,
- std::unordered_map<kir::IterDomain*, kir::Val*> initial_index_map,
- std::unordered_map<kir::IterDomain*, kir::Val*> extent_map)
- : td_(td),
- index_map_(std::move(initial_index_map)),
- extent_map_(std::move(extent_map)) {
- const std::vector<Val*> domain_vals(
- td_->domain().begin(), td_->domain().end());
-
- traverseFrom(td_->fusion(), domain_vals, false);
+// TODO: use new mapping functions
+// This mapping might need to go through rfactor, unclear
+std::vector<bool> IndexCompute::contiguityPasC(
+ TensorDomain* producer,
+ TensorDomain* consumer) {
+ FUSER_PERF_SCOPE("contiguityPasC");
+
+ const std::vector<bool>& producer_contiguity = producer->contiguity();
+ std::vector<bool> as_consumer_contiguity;
+
+ auto c_root = consumer->getRootDomain();
+ auto p_root = producer->getRootDomain();
+
+ size_t p_ind = 0;
+ size_t c_ind = 0;
+ while (p_ind < p_root.size()) {
+ if (p_root[p_ind]->isReduction()) {
+ p_ind++;
+ } else if (
+ c_root[c_ind]->isBroadcast() &&
+ p_root[p_ind]->getIterType() != c_root[c_ind]->getIterType()) {
+ c_ind++;
+ as_consumer_contiguity.push_back(false);
+ } else {
+ as_consumer_contiguity.push_back(producer_contiguity[p_ind]);
+ c_ind++;
+ p_ind++;
+ }
}
- const std::unordered_map<kir::IterDomain*, kir::Val*>& indexMap() const {
- return index_map_;
+ while (c_ind < c_root.size()) {
+ as_consumer_contiguity.push_back(false);
+ c_ind++;
}
- const std::unordered_map<kir::IterDomain*, kir::Val*>& extentMap() const {
- return extent_map_;
+ return as_consumer_contiguity;
+}
+
+namespace {
+
+std::deque<TensorView*> getComputeAtTVStackFrom(TensorView* from_tv) {
+ // What's the computeAt root tensor view in this operation
+ // This tensor is the terminating tensor in the computeAT dag from consumer
+ auto end_tv = from_tv->getComputeAtAxis(0).second;
+
+ // grab all tensor views from producer_tv -> computeAtRoot
+ std::deque<TensorView*> tv_stack;
+
+ // Then immediate consumer
+ auto running_tv = from_tv;
+
+ // Follow computeAt path until we hit end_tv
+ while (running_tv != end_tv) {
+ TORCH_INTERNAL_ASSERT(running_tv->hasComputeAt());
+ tv_stack.push_front(running_tv);
+ running_tv = running_tv->getComputeAtView();
}
- private:
- using IterVisitor::handle;
+ tv_stack.push_front(end_tv);
+
+ return tv_stack;
+}
+
+std::pair<
+ std::unordered_map<kir::IterDomain*, Val*>,
+ std::unordered_map<kir::IterDomain*, Val*>>
+generateIndexAndExtentMap(
+ std::deque<TensorView*> c2p_tv_stack,
+ std::deque<kir::ForLoop*> loops,
+ const std::unordered_map<kir::ForLoop*, Val*>& loop_to_ind_map,
+ const std::vector<bool>& last_tv_root_contiguity) {
+ if (c2p_tv_stack.empty())
+ return std::make_pair(
+ std::unordered_map<kir::IterDomain*, Val*>(),
+ std::unordered_map<kir::IterDomain*, Val*>());
+
+ // Go through our stack, and map the intermediate IterDomains from common
+ // transformations from consumer to producer
+ std::deque<std::unordered_map<IterDomain*, IterDomain*>> c2p_ID_maps;
+ std::deque<std::unordered_map<IterDomain*, IterDomain*>> p2c_ID_maps;
+
+ // c2p_tv_stack comes in as consumer -> producer
+ // Realized we may want to actually do a pass from producer->consumer first to
+ // propagate iterators outside the compute at position back into consumers, so
+ // we can repropagate back to producer. The need for this was exposed in
+ // https://github.com/csarofeen/pytorch/issues/286
+
+ for (size_t i = 0; i + 1 < c2p_tv_stack.size(); i++) {
+ auto c_tv = c2p_tv_stack[i];
+ auto p_tv = c2p_tv_stack[i + 1];
+
+ // Map root ID's from consumer to producer
+ auto c2p_root_map =
+ TensorDomain::mapRootCtoP(c_tv->domain(), p_tv->domain());
+
+ // Look for matching ID transformations in producer and consumer...
+ BestEffortReplay replay(
+ p_tv->domain()->domain(), c_tv->domain()->domain(), c2p_root_map);
+
+ // and grab the intermediate IterDomain map.
+ c2p_ID_maps.push_back(replay.getReplay());
+
+ // Something wasn't symmetric when using:
+ //
+ // auto p2c_root_map = TensorDomain::mapRootPtoC(p_tv->domain(),
+ // c_tv->domain());
+ //
+ // replay = BestEffortReplay(
+ // c_tv->domain()->domain(), p_tv->domain()->domain(), p2c_root_map,
+ // true);
- void handle(Split* split) override {
- const auto gpu_lower = GpuLower::current();
+ BestEffortReplay replay_p2c(
+ p_tv->domain()->domain(), c_tv->domain()->domain(), c2p_root_map, true);
- auto in_id = gpu_lower->lowerValue(split->in())->as<kir::IterDomain>();
- auto outer_id =
- gpu_lower->lowerValue(split->outer())->as<kir::IterDomain>();
- auto inner_id =
- gpu_lower->lowerValue(split->inner())->as<kir::IterDomain>();
+ std::unordered_map<IterDomain*, IterDomain*> p2c_id_map;
- // Nothing need to be done when mappings for the output axes
- // already exist.
- if (index_map_.find(outer_id) != index_map_.end()) {
- TORCH_INTERNAL_ASSERT(
- index_map_.find(inner_id) != index_map_.end(),
- "Outer exists but inner not found");
- return;
+ for (auto ent : replay_p2c.getReplay()) {
+ p2c_id_map[ent.second] = ent.first;
}
- kir::IrBuilder ir_builder(gpu_lower->kernel());
- auto factor = gpu_lower->lowerValue(split->factor());
- index_map_[inner_id] = ir_builder.modExpr(index_map_[in_id], factor);
- extent_map_[inner_id] = factor;
- index_map_[outer_id] = ir_builder.divExpr(index_map_[in_id], factor);
- extent_map_[outer_id] = ir_builder.ceilDivExpr(getExtent(in_id), factor);
+ // and grab the intermediate IterDomain map.
+ p2c_ID_maps.push_front(p2c_id_map);
}
- void handle(Merge* merge) override {
- const auto gpu_lower = GpuLower::current();
+ // Maps to be used in the c2p propagation
+ std::unordered_map<TensorView*, std::unordered_map<kir::IterDomain*, Val*>>
+ p2c_index_maps;
- auto out_id = gpu_lower->lowerValue(merge->out())->as<kir::IterDomain>();
- auto outer_id =
- gpu_lower->lowerValue(merge->outer())->as<kir::IterDomain>();
- auto inner_id =
- gpu_lower->lowerValue(merge->inner())->as<kir::IterDomain>();
+ // PROPAGATE PRODUCER -> CONSUMER START
- // Nothing need to be done when mappings for the output axes
- // already exist.
- if (index_map_.find(out_id) != index_map_.end()) {
- return;
- }
+ std::deque<TensorView*> p2c_tv_stack(
+ c2p_tv_stack.rbegin(), c2p_tv_stack.rend());
+ // Setup initial IndexCompute:
+ auto tv = p2c_tv_stack.front();
+ p2c_tv_stack.pop_front();
+ auto td = tv->domain()->domain();
+
+ std::vector<kir::IterDomain*> kir_td;
+
+ std::transform(
+ td.begin(), td.end(), std::back_inserter(kir_td), [](IterDomain* id) {
+ return GpuLower::lowerValue(id)->as<kir::IterDomain>();
+ });
+
+ // Map from all IterDomain's to corresponding index as we process each tv in
+ // the stack
+ std::unordered_map<kir::IterDomain*, Val*> initial_index_map;
+
+ // Match loops to this TV if the loop matchis this TV's ID (could reduce
+ // complexity here)
+
+ while (
+ !loops.empty() &&
+ std::find(kir_td.rbegin(), kir_td.rend(), loops.back()->iter_domain()) !=
+ kir_td.rend()) {
TORCH_INTERNAL_ASSERT(
- index_map_.find(outer_id) != index_map_.end(), "Outer ID not found");
- TORCH_INTERNAL_ASSERT(
- index_map_.find(inner_id) != index_map_.end(), "Inner ID not found");
+ loop_to_ind_map.find(loops.back()) != loop_to_ind_map.end());
+ initial_index_map[loops.back()->iter_domain()] =
+ loop_to_ind_map.at(loops.back());
+ loops.pop_back();
+ }
+
+ IndexCompute index_compute(
+ tv->domain(),
+ initial_index_map,
+ std::unordered_map<kir::IterDomain*, Val*>(),
+ std::unordered_set<kir::IterDomain*>(),
+ std::vector<bool>(tv->getRootDomain().size(), false));
+
+ p2c_index_maps[tv] = index_compute.indexMap();
+
+ // Go through the tv entire stack
+ while (!p2c_tv_stack.empty()) {
+ // Grab the TV
+ tv = p2c_tv_stack.front();
+ p2c_tv_stack.pop_front();
+ td = tv->domain()->domain();
+ kir_td.clear();
+ std::transform(
+ td.begin(), td.end(), std::back_inserter(kir_td), [](IterDomain* id) {
+ return GpuLower::lowerValue(id)->as<kir::IterDomain>();
+ });
- kir::IrBuilder ir_builder(gpu_lower->kernel());
- index_map_[out_id] = ir_builder.mulExpr(
- index_map_[inner_id],
- ir_builder.mulExpr(index_map_[outer_id], getExtent(inner_id)));
+ // Match loops to this TV if the loop matchis this TV's ID (could reduce
+ // complexity here)
- extent_map_[out_id] =
- ir_builder.mulExpr(getExtent(outer_id), getExtent(inner_id));
- }
+ // Map from all IterDomain's to corresponding index as we process each tv in
+ // the stack
+ std::unordered_map<kir::IterDomain*, Val*> new_indices;
- // return extent_map_[id] if exists, else return id->extent()
- kir::Val* getExtent(kir::IterDomain* id) {
- if (extent_map_.find(id) != extent_map_.end()) {
- return extent_map_.at(id);
- } else {
- return id->extent();
+ while (!loops.empty() &&
+ std::find(
+ kir_td.rbegin(), kir_td.rend(), loops.back()->iter_domain()) !=
+ kir_td.rend()) {
+ TORCH_INTERNAL_ASSERT(
+ loop_to_ind_map.find(loops.back()) != loop_to_ind_map.end());
+ new_indices[loops.back()->iter_domain()] =
+ loop_to_ind_map.at(loops.back());
+ loops.pop_back();
}
- }
- private:
- const TensorDomain* td_;
- std::unordered_map<kir::IterDomain*, kir::Val*> index_map_;
- std::unordered_map<kir::IterDomain*, kir::Val*> extent_map_;
-};
+ if (!p2c_ID_maps.empty()) {
+ index_compute = index_compute.updateIndexCompute(
+ tv->domain(),
+ p2c_ID_maps.front(),
+ new_indices,
+ std::vector<bool>(tv->getRootDomain().size(), false));
-// Returns halo-extended extent if id has halo. Otherwise, just
-// returns id->extent.
-kir::Val* getHaloExtentOfRootAxis(
- IterDomain* id,
- kir::Val* normal_extent = nullptr) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
+ p2c_index_maps[tv] = index_compute.indexMap();
- if (normal_extent == nullptr) {
- normal_extent = gpu_lower->lowerValue(id->extent());
+ p2c_ID_maps.pop_front();
+ }
}
- const auto& halo = gpu_lower->haloInfo().getRootAxisInfo(id);
- if (halo.hasHalo()) {
- auto halo_extent = ir_builder.addExpr(normal_extent, halo.width());
- return halo_extent;
- } else {
- return normal_extent;
+ // PROPAGATE PRODUCER -> CONSUMER END
+
+ // PROPAGATE CONSUMER -> PRODUCER START
+
+ // Setup initial IndexCompute:
+ tv = c2p_tv_stack.front();
+ c2p_tv_stack.pop_front();
+
+ // Map from all IterDomain's to corresponding index as we process each tv in
+ // the stack
+ initial_index_map = p2c_index_maps.at(tv);
+
+ std::unordered_map<kir::IterDomain*, Val*> initial_extent_map;
+ if (!c2p_ID_maps.empty()) {
+ auto first_id_map = c2p_ID_maps.front();
+ for (auto id_entry : first_id_map) {
+ kir::IterDomain* this_id =
+ GpuLower::lowerValue(id_entry.first)->as<kir::IterDomain>();
+ if (initial_extent_map.find(this_id) == initial_extent_map.end()) {
+ initial_extent_map[this_id] = this_id->extent();
+ }
+ }
}
-}
-} // namespace
+ index_compute = IndexCompute(
+ tv->domain(),
+ initial_index_map,
+ initial_extent_map,
+ std::unordered_set<kir::IterDomain*>(),
+ c2p_tv_stack.empty()
+ ? last_tv_root_contiguity
+ : std::vector<bool>(tv->getRootDomain().size(), false));
-IndexSwizzle::IndexSwizzle(
- const TensorView* tv,
- std::unordered_map<kir::IterDomain*, kir::Val*> initial_index_map,
- std::unordered_map<kir::IterDomain*, kir::Val*> extent_map,
- std::unordered_set<kir::IterDomain*> zero_merged_in)
- : IndexCompute(
+ // Go through the tv entire stack
+ while (!c2p_tv_stack.empty()) {
+ // Grab the TV
+ tv = c2p_tv_stack.front();
+ c2p_tv_stack.pop_front();
+
+ if (!c2p_ID_maps.empty()) {
+ index_compute = index_compute.updateIndexCompute(
tv->domain(),
- std::move(initial_index_map),
- std::move(extent_map),
- std::move(zero_merged_in),
- std::vector<bool>(tv->getRootDomain().size(), false)),
- tv_(tv),
- swizzle_type_(tv->swizzleType()),
- ids_to_swizzle_(tv->axesToSwizzle()) {}
-
-void IndexSwizzle::run() {
- TORCH_INTERNAL_ASSERT(
- swizzle_type_ == SwizzleType::NoSwizzle ||
- swizzle_type_ == SwizzleType::Transpose,
- "Invalid swizzle type");
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
- if (swizzle_type_ == SwizzleType::Transpose) {
- // Shifts the second axis by the first axis as ((idx_1 + idx_2) %
- // ext). Alternatively, ((idx_1 - idx_2) & (ext - 1)) would also
- // work if ext is a power of two. Practically, ext should be 32 if
- // the data type of the tensor is float, so the latter approach
- // should also be fine.
- TORCH_INTERNAL_ASSERT(tv_->getMemoryType() == MemoryType::Shared);
- TORCH_INTERNAL_ASSERT(tv_->axesToSwizzle().size() == 2);
-
- UpdateLeafIndices update_leaves(td_, indexMap(), extentMap());
- index_map_ = update_leaves.indexMap();
- extent_map_ = update_leaves.extentMap();
-
- IterDomain* id_to_swizzle_i = ids_to_swizzle_.at(0);
- IterDomain* id_to_swizzle_j = ids_to_swizzle_.at(1);
- kir::IterDomain* id_to_swizzle_i_kir =
- gpu_lower->lowerValue(id_to_swizzle_i)->as<kir::IterDomain>();
- kir::IterDomain* id_to_swizzle_j_kir =
- gpu_lower->lowerValue(id_to_swizzle_j)->as<kir::IterDomain>();
-
- if (indexMap().find(id_to_swizzle_i_kir) != indexMap().end() &&
- indexMap().find(id_to_swizzle_j_kir) != indexMap().end()) {
- auto idx_to_swizzle_i = indexMap().at(id_to_swizzle_i_kir);
- auto idx_to_swizzle_j = indexMap().at(id_to_swizzle_j_kir);
-
- auto swizzled_idx = ir_builder.modExpr(
- ir_builder.addExpr(idx_to_swizzle_i, idx_to_swizzle_j),
- id_to_swizzle_j_kir->extent());
- index_map_[id_to_swizzle_j_kir] = swizzled_idx;
- swizzled_ids_.insert(id_to_swizzle_j);
- IndexCompute::run();
+ c2p_ID_maps.front(),
+ p2c_index_maps.at(tv),
+ c2p_tv_stack.empty()
+ ? last_tv_root_contiguity
+ : std::vector<bool>(tv->getRootDomain().size(), false));
+
+ c2p_ID_maps.pop_front();
}
}
-}
-void IndexSwizzle::handle(Expr* e) {
- auto out_ids = ir_utils::filterByType<IterDomain>(e->outputs());
- bool needs_update =
- std::any_of(out_ids.begin(), out_ids.end(), [this](IterDomain* id) {
- return swizzled_ids_.find(id) != swizzled_ids_.end();
- });
- if (!needs_update) {
- return;
- }
+ // PROPAGATE CONSUMER -> PRODUCER END
+
+ // Fill in extent map as some mapped indices may not have their extent filled
+ // in it, but consumers of this function expect it to be there
- IndexCompute::handle(e);
- for (auto input : ir_utils::filterByType<IterDomain>(e->inputs())) {
- swizzled_ids_.insert(input);
+ std::unordered_map<kir::IterDomain*, Val*> extent_map(
+ index_compute.extentMap());
+ for (auto ind_entry : index_compute.indexMap()) {
+ auto id = ind_entry.first;
+ if (extent_map.find(id) == extent_map.end()) {
+ extent_map[id] = id->extent();
+ }
}
+
+ return std::make_pair(index_compute.indexMap(), extent_map);
}
-std::vector<kir::Val*> Index::getGlobalProducerStridedIndices(
+} // namespace
+
+kir::TensorIndex* Index::getGlobalProducerIndex(
TensorView* producer_tv,
- const TensorView* consumer_tv,
+ TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops) {
- FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalProducerIndex");
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
+ FUSER_PERF_SCOPE("getGlobalProducerIndex");
- // Get a reference tensor replayed as existing loop structure
- auto reference = IndexReferenceReplay::getReference(loops);
- auto reference_domain = reference.domain;
- auto reference_id_map = reference.concrete_to_id;
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
// Replay producer to look like consumer so we can index on producer since our
// loop nests look like consumer
- auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv);
- auto producerAsC =
- TransformReplay::replayPasC(producer_tv, consumer_tv, -1, pairwiseMap)
- .first;
+ auto producerAsC = TransformReplay::replayPasC(
+ producer_tv->domain(), consumer_tv->domain(), -1)
+ .first;
- // Make the producer_tv look like consumer while performing indexing math
+ // Make the actual producer_tv look like consumer while we do the indexing
+ // math in this function
ir_utils::TVDomainGuard domain_guard(producer_tv, producerAsC);
- // Map reference tensor to producer
- std::unordered_map<IterDomain*, IterDomain*> root_ref_to_producer;
- for (auto p_root : producer_tv->getMaybeRFactorDomain()) {
- auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(p_root);
- auto ref_id_it = reference_id_map.find(concrete_id);
- if (ref_id_it != reference_id_map.end()) {
- root_ref_to_producer[ref_id_it->second] = p_root;
- }
- }
-
- // Index into the reference tensor. Reference indexing will handle vectorized
- // dims where index should be set to 0
- auto ref_compute = getReferenceIndexing(loops, reference_domain);
-
- // Replay producer as reference to get reference to producer ID map
- BestEffortReplay replay_producer_as_ref(
- producer_tv->domain()->domain(),
- reference_domain->domain(),
- root_ref_to_producer);
-
- const auto& ref_2_producer = replay_producer_as_ref.getReplay();
-
- // Forward vectorized IDs to index into producer correctly
- // We want p_id to be vectorized like consumer just for the indexing, then we
- // need to switch it back later. Store previous state here when changing. We
- // need to do this as replaying producer as consumer can use replay best
- // effort which means some domains may be the originals.
- std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup;
- for (auto entry : ref_2_producer) {
- auto ref_id = entry.first;
- auto p_id = entry.second;
- if (ref_id->getParallelType() == ParallelType::Vectorize) {
- p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType()));
- p_id->parallelize(ParallelType::Vectorize);
- } else if (ref_id->getParallelType() == ParallelType::MisalignedVectorize) {
- p_id->parallelize(ParallelType::MisalignedVectorize);
- }
- }
-
- const auto reference_halo_extent_map = getReferenceHaloExtentMap(
- reference, consumer_tv, ref_2_producer, ref_compute.extentMap());
+ // grab all tensor views from producer_tv <- computeAtRoot
+ std::deque<TensorView*> tv_stack = getComputeAtTVStackFrom(consumer_tv);
+ tv_stack.push_back(producer_tv);
- // Index into producer using reference indexing
- auto producer_indexing = ref_compute.updateIndexCompute(
- producer_tv->domain(),
- ref_2_producer,
- producer_tv->domain()->contiguity(),
- reference_halo_extent_map);
+ std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
+ std::transform(
+ loops.begin(),
+ loops.end(),
+ std::inserter(loop_to_ind_map, loop_to_ind_map.begin()),
+ [](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); });
- // Revert p_ids
- for (auto entry : p_id_backup) {
- entry.first->parallelize(entry.second);
- }
+ auto index_map = generateIndexAndExtentMap(
+ tv_stack,
+ std::deque<kir::ForLoop*>(loops.begin(), loops.end()),
+ loop_to_ind_map,
+ producer_tv->domain()->contiguity())
+ .first;
// Indices should now be mapped onto IterDomains in producer, so just grab
// and use them.
auto root_dom = producer_tv->getMaybeRFactorDomain();
- // TODO: Abstract stride logic to reuse with consumer indexing
- auto zero = ir_builder.create<kir::Int>(0);
- std::vector<kir::Val*> strides(root_dom.size(), nullptr);
- {
- int stride_i = 0;
- for (size_t i = 0; i < root_dom.size(); i++) {
- if (root_dom[i]->isReduction() ||
- root_dom[i]->getIterType() == IterType::BroadcastWithoutStride) {
- strides[i] = zero;
- continue;
- }
- std::stringstream ss;
- ss << "T" << producer_tv->name() << ".stride[" << stride_i++ << "]";
- strides[i] = ir_builder.create<kir::NamedScalar>(ss.str(), DataType::Int);
- }
- }
-
- kir::Val* cur_contig_stride = ir_builder.create<kir::Int>(1);
- // if we have rfactor we can't simplify the indexing like this, we would need
- // to fix contiguity size to be rfactor size not root size
- if (root_dom.size() == producer_tv->domain()->contiguity().size()) {
- for (size_t i = 0; i < root_dom.size(); i++) {
- auto dim = root_dom.size() - i - 1;
- if (root_dom[dim]->isReduction()) {
- continue;
- }
- if (root_dom[dim]->getIterType() == IterType::BroadcastWithoutStride) {
- continue;
- }
-
- kir::Val* root_ind = nullptr;
- auto kir_root_dom =
- gpu_lower->lowerValue(root_dom[dim])->as<kir::IterDomain>();
- if (producer_indexing.indexMap().find(kir_root_dom) !=
- producer_indexing.indexMap().end()) {
- root_ind = producer_indexing.indexMap().at(kir_root_dom);
- } else if (
- root_dom[dim]->getIterType() == IterType::BroadcastWithStride) {
- root_ind = zero;
- }
-
- TORCH_INTERNAL_ASSERT(
- root_ind != nullptr,
- "Couldn't find root mapping for TV",
- producer_tv->name(),
- " dim: ",
- i,
- " id: ",
- root_dom[dim]);
-
- if (producer_tv->domain()->contiguity()[dim]) {
- // If contig, used the stored stride which may be the previous
- // dimensions stride * previous dimensions size
- strides[dim] = cur_contig_stride;
- // Prepare for the next dimension which may also be contiguous, multiply
- // by extent of this dimension
- auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]);
- cur_contig_stride =
- ir_builder.mulExpr(cur_contig_stride, root_dim_extent);
- } else {
- // If non contiguous dimension, keep local stride information, set cur
- // stride to local stride * local raw extent
- auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]);
- cur_contig_stride = ir_builder.mulExpr(strides[dim], root_dim_extent);
- }
- }
- }
-
- auto vectorize_shift =
- loops.empty() ? nullptr : loops.back()->vectorize_shift();
+ bool inner_most_dim_contig =
+ root_dom[root_dom.size() - 1]->getIterType() == IterType::Iteration &&
+ producer_tv->domain()->contiguity()[root_dom.size() - 1];
// Global striding
- std::vector<kir::Val*> strided_inds(root_dom.size(), ir_builder.zeroVal());
+ int64_t stride_i = 0;
+ std::vector<Val*> strided_inds;
for (const auto i : c10::irange(root_dom.size())) {
- // If the domain is derived from a trivial reduction, no indexing
- // to create.
if (root_dom[i]->isReduction() ||
- root_dom[i]->getIterType() == IterType::BroadcastWithoutStride ||
- root_dom[i]->getIterType() == IterType::BroadcastWithStride ||
- gpu_lower->trivialReductionInfo().isDerived(root_dom[i])) {
+ root_dom[i]->getIterType() == IterType::BroadcastWithoutStride) {
+ continue;
+ } else if (root_dom[i]->getIterType() == IterType::BroadcastWithStride) {
+ stride_i++;
continue;
}
auto kir_root_dom_i =
- gpu_lower->lowerValue(root_dom[i])->as<kir::IterDomain>();
+ GpuLower::lowerValue(root_dom[i])->as<kir::IterDomain>();
TORCH_INTERNAL_ASSERT(
- producer_indexing.indexMap().find(kir_root_dom_i) !=
- producer_indexing.indexMap().end(),
+ index_map.find(kir_root_dom_i) != index_map.end(),
"Couldn't find root mapping for TV",
producer_tv->name(),
" dim: ",
" id: ",
kir::toString(kir_root_dom_i));
- auto root_ind = producer_indexing.indexMap().at(kir_root_dom_i);
-
- root_ind = getProducerIndexWithHalo(producer_tv, i, root_ind, consumer_tv);
+ auto root_ind = index_map.at(kir_root_dom_i);
+ TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(root_ind));
- root_ind = getProducerIndexWithGather(
- i,
- root_ind,
- producer_tv,
- consumer_tv,
- ref_compute.indexMap(),
- reference_id_map);
-
- if (root_ind->isZeroInt()) {
- continue;
+ if (i == root_dom.size() - 1 && inner_most_dim_contig) {
+ strided_inds.push_back(root_ind);
+ } else if (root_ind->isZeroInt()) {
+ stride_i++;
} else {
- auto strided_ind = ir_builder.mulExpr(root_ind, strides[i]);
- if (i == root_dom.size() - 1 && vectorize_shift != nullptr) {
- strided_inds[i] = ir_builder.addExpr(strided_ind, vectorize_shift);
- } else {
- strided_inds[i] = strided_ind;
- }
+ std::stringstream ss;
+ ss << "T" << producer_tv->name() << ".stride[" << stride_i++ << "]";
+ strided_inds.push_back(ir_builder.mulExpr(
+ root_ind,
+ ir_builder.create<kir::NamedScalar>(ss.str(), DataType::Int)));
}
}
- return strided_inds;
+ if (strided_inds.size() == 0)
+ strided_inds.push_back(ir_builder.create<kir::Int>(0));
+
+ return ir_builder.create<kir::TensorIndex>(producer_tv, strided_inds);
}
namespace {
-// Used for local and shared index mapping
-std::unordered_map<kir::ForLoop*, kir::Val*> indexMapFromTV(
- const TensorView* tv,
- const std::vector<kir::ForLoop*>& loops,
- const std::pair<kir::ForLoop*, int64_t>& alloc_point) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
+std::unordered_map<kir::ForLoop*, Val*> indexMapFromTV(
+ TensorView* tv,
+ const std::vector<kir::ForLoop*>& loops) {
+ auto alloc_point = loop_utils::getAllocPoint(tv, loops);
auto alloc_loop = alloc_point.first;
bool within_alloc = false;
within_alloc = true;
}
- const auto zero = ir_builder.create<kir::Int>(0);
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+ Val* zero = ir_builder.create<kir::Int>(0);
- const bool is_global = tv->getMemoryType() == MemoryType::Global;
- const bool is_shared = tv->getMemoryType() == MemoryType::Shared;
- const bool is_local = tv->getMemoryType() == MemoryType::Local;
+ bool is_shared = tv->getMemoryType() == MemoryType::Shared;
+ bool is_local = tv->getMemoryType() == MemoryType::Local;
- std::unordered_map<kir::ForLoop*, kir::Val*> loop_to_ind_map;
+ std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
for (auto loop : loops) {
- kir::Val* idx = nullptr;
- // See also LoopNestGenerator::pushAlloc.
// NOLINTNEXTLINE(bugprone-branch-clone)
if (!within_alloc) {
- if ((loop->iter_domain()->isThreadDim() && is_shared) ||
- (loop->iter_domain()->isThread() && is_global)) {
- idx = loop->index();
- } else {
- idx = zero;
- }
- } else if (
- (loop->iter_domain()->isBlockDim() && is_shared) ||
- (loop->iter_domain()->isThread() && is_local) || loop->vectorize()) {
- idx = zero;
+ loop_to_ind_map[loop] = zero;
+ } else if (loop->iter_domain()->isBlockDim() && is_shared) {
+ loop_to_ind_map[loop] = zero;
+ } else if (loop->iter_domain()->isThread() && is_local) {
+ loop_to_ind_map[loop] = zero;
} else {
- idx = loop->index();
+ loop_to_ind_map[loop] = loop->index();
}
- loop_to_ind_map[loop] = idx;
-
if (!within_alloc && loop == alloc_loop) {
within_alloc = true;
}
} // namespace
// Producer index for either shared or local memory
-std::vector<kir::Val*> Index::getNonGlobalProducerStridedIndices(
+kir::TensorIndex* Index::getProducerIndex_impl(
TensorView* producer_tv,
- const TensorView* consumer_tv,
+ TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- // Get a reference tensor replayed as existing loop structure
- auto reference = IndexReferenceReplay::getReference(loops);
- auto reference_domain = reference.domain;
- auto reference_id_map = reference.concrete_to_id;
-
- // Replay producer to look like consumer so we can index on producer since our
- // loop nests look like consumer
- auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv);
- auto producer_replayed_as_consumer =
- TransformReplay::replayPasC(producer_tv, consumer_tv, -1, pairwise_map)
- .first;
-
- ir_utils::TVDomainGuard domain_guard(
- producer_tv, producer_replayed_as_consumer);
-
- // We want to play producer as consumer instead of the other way around since
- // consumer may have some broadcasted axes producer doesn't have merged into
- // loops producer may use. If we did consumer as producer we wouldn't have
- // this information in the mapping.
- auto replay_PasC =
- BestEffortReplay::replayPasC(producer_tv, consumer_tv, -1, pairwise_map);
-
- auto c2p_map = replay_PasC.getReplay();
-
- // Grab consumer domain entries and reverse replay map. TODO: Maybe
- // TransformReplay::replayPasC could return this map
- decltype(c2p_map) p2c_map;
- for (auto id : consumer_tv->domain()->domain()) {
- auto c2p_it = c2p_map.find(id);
- if (c2p_it != c2p_map.end()) {
- auto c_id = c2p_it->first;
- auto p_id = c2p_it->second;
- p2c_map[p_id] = c_id;
- }
- }
-
- // Find allocation point of producer relative to loop nests. P2C map is
- // required because producer was replayed as consumer, so we can't use the
- // regular compute at maps to line up its iter domains with the for loops.
- auto alloc_point =
- loop_utils::getAllocPoint(producer_tv, loops, p2c_map, true);
- std::unordered_map<kir::ForLoop*, kir::Val*> loop_to_ind_map =
- indexMapFromTV(producer_tv, loops, alloc_point);
-
- // Map loop nests to indicies, zeroing out those not used due to locality of
- // memory
- std::unordered_map<kir::IterDomain*, kir::Val*> ref_id_to_ind_map;
-
- // Due to rfactor/initialization reference_domain may be bigger than loop nest
- // structure, ignore IterDomains that aren't present in the loop nest when
- // indexing reference.
- TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims());
- for (size_t loop_i = 0; loop_i < loops.size(); loop_i++) {
- auto ref_axis = gpu_lower->lowerValue(reference_domain->axis(loop_i))
- ->as<kir::IterDomain>();
- ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]];
- }
-
- // Map reference tensor to producer
- std::unordered_map<IterDomain*, IterDomain*> root_ref_to_producer;
- for (auto p_root : producer_tv->getMaybeRFactorDomain()) {
- auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(p_root);
- auto ref_id_it = reference_id_map.find(concrete_id);
- if (ref_id_it != reference_id_map.end()) {
- root_ref_to_producer[ref_id_it->second] = p_root;
- }
- }
-
- // Grab roots that map into producer and save them into the preferred roots
- // set for references indexing
- std::unordered_set<IterDomain*> preferred_roots;
- for (auto entry : root_ref_to_producer) {
- if (entry.second->isBroadcast() || entry.second->isReduction()) {
- continue;
- }
- preferred_roots.emplace(entry.first);
- }
-
- // Make sure propagation of indexing while mixing with 0 indicies we propagate
- // in a way that the producer will be able to see what's going on (propagating
- // into common roots of reference and producer).
- auto preferred_paths = buildPreferredPaths(reference_domain, preferred_roots);
-
- // Index into the reference tensor
- auto ref_compute = getReferenceIndexing(
- loops, reference_domain, ref_id_to_ind_map, preferred_paths);
-
- // Directly replay the producer as the reference to get the mapping of
- // reference to producer we will use to map the indexing into producer
- BestEffortReplay replay_producer_as_ref(
- producer_tv->domain()->domain(),
- reference_domain->domain(),
- root_ref_to_producer);
-
- const auto& ref_2_producer = replay_producer_as_ref.getReplay();
-
- // Forward vectorized IDs to index into producer correctly
- // We want p_id to be vectorized like consumer just for the indexing, then we
- // need to switch it back later. Store previous state here when changing. We
- // need to do this as replaying producer as consumer can use replay best
- // effort which means some domains may be the originals.
- std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup;
- for (auto entry : ref_2_producer) {
- auto ref_id = entry.first;
- auto p_id = entry.second;
- if (ref_id->getParallelType() == ParallelType::Vectorize) {
- p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType()));
- p_id->parallelize(ParallelType::Vectorize);
- } else if (ref_id->getParallelType() == ParallelType::MisalignedVectorize) {
- p_id->parallelize(ParallelType::MisalignedVectorize);
- }
- }
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- // Index into producer using reference indexing
+ // producer_tv->domain() is not replayed as the loop strucutre we were
+ // provided, so replay it to match consumer_tv which is.
+ auto producerAsC = TransformReplay::replayPasC(
+ producer_tv->domain(), consumer_tv->domain(), -1)
+ .first;
- const auto reference_halo_extent_map = getReferenceHaloExtentMap(
- reference, consumer_tv, ref_2_producer, ref_compute.extentMap());
-
- auto producer_indexing = ref_compute.updateIndexCompute(
- producer_tv->domain(),
- ref_2_producer,
- producer_tv->domain()->contiguity(),
- reference_halo_extent_map);
-
- // Revert p_ids
- for (auto entry : p_id_backup) {
- entry.first->parallelize(entry.second);
- }
+ // Set producer_tv with the domain replayed as consumer to grab the right
+ // indices. The guard will reset the domain when this scope ends.
+ ir_utils::TVDomainGuard domain_guard(producer_tv, producerAsC);
- IndexSwizzle index_swizzle(
- producer_tv,
- producer_indexing.indexMap(),
- producer_indexing.extentMap(),
- producer_indexing.zeroMergedIn());
+ // grab all tensor views from producer_tv <- computeAtRoot
+ std::deque<TensorView*> tv_stack = getComputeAtTVStackFrom(consumer_tv);
+ tv_stack.push_back(producer_tv);
- index_swizzle.run();
+ std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map =
+ indexMapFromTV(producer_tv, loops);
- auto index_map = index_swizzle.indexMap();
- auto extent_map = producer_indexing.extentMap();
+ auto index_and_extent_map = generateIndexAndExtentMap(
+ tv_stack,
+ std::deque<kir::ForLoop*>(loops.begin(), loops.end()),
+ loop_to_ind_map,
+ std::vector<bool>(producer_tv->getRootDomain().size(), false));
+ auto index_map = index_and_extent_map.first;
+ auto extent_map = index_and_extent_map.second;
// Indices should now be mapped onto IterDomains in producer, so just grab
// and use them.
auto root_dom = producer_tv->getMaybeRFactorDomain();
- // Figure out which root axes we don't need to index
- std::unordered_set<IterDomain*> skip_indexing;
-
- for (auto root_id : root_dom) {
- // Already taken care of because we can detect no indexing required
- if (root_id->isBroadcast() || root_id->isReduction() ||
- gpu_lower->trivialReductionInfo().isDerived(root_id)) {
- skip_indexing.insert(root_id);
- continue;
- }
-
- // Already an entry for this root domain, continue
- if (index_map.find(gpu_lower->lowerValue(root_id)->as<kir::IterDomain>()) !=
- index_map.end()) {
- continue;
- }
-
- // Maps to consumers trivial reduction, don't index
- if (p2c_map.find(root_id) != p2c_map.end() &&
- gpu_lower->trivialReductionInfo().isDerived(p2c_map.at(root_id))) {
- skip_indexing.emplace(root_id);
- }
- }
+ std::vector<Val*> strided_inds;
- std::vector<kir::Val*> strided_inds(root_dom.size(), ir_builder.zeroVal());
for (const auto i : c10::irange(root_dom.size())) {
- if (skip_indexing.count(root_dom[i])) {
+ if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast()) {
continue;
}
auto kir_root_dom_i =
- gpu_lower->lowerValue(root_dom[i])->as<kir::IterDomain>();
+ GpuLower::lowerValue(root_dom[i])->as<kir::IterDomain>();
TORCH_INTERNAL_ASSERT(
index_map.find(kir_root_dom_i) != index_map.end(),
kir::toString(kir_root_dom_i));
auto root_ind_i = index_map.at(kir_root_dom_i);
-
- root_ind_i =
- getProducerIndexWithHalo(producer_tv, i, root_ind_i, consumer_tv);
-
- root_ind_i = getProducerIndexWithGather(
- i,
- root_ind_i,
- producer_tv,
- consumer_tv,
- ref_compute.indexMap(),
- reference_id_map);
+ TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(root_ind_i));
if (root_ind_i->isZeroInt()) {
continue;
}
// Compute striding for this index.
- kir::Val* stride = nullptr;
+ Val* stride = nullptr;
for (size_t j = i + 1; j < root_dom.size(); j++) {
- if (skip_indexing.count(root_dom[j])) {
+ if (root_dom[j]->isBroadcast() || root_dom[j]->isReduction()) {
continue;
}
auto kir_root_dom_j =
- gpu_lower->lowerValue(root_dom[j])->as<kir::IterDomain>();
+ GpuLower::lowerValue(root_dom[j])->as<kir::IterDomain>();
TORCH_INTERNAL_ASSERT(
- index_map.find(kir_root_dom_j) != index_map.end(),
+ index_map.find(kir_root_dom_j) != index_map.end() &&
+ extent_map.find(kir_root_dom_j) != extent_map.end(),
"Couldn't find root mapping for TV",
consumer_tv->name(),
" dim: ",
root_dom[i]);
auto root_ind_j = index_map.at(kir_root_dom_j);
- auto root_ext_j = extent_map.find(kir_root_dom_j) == extent_map.end()
- ? kir_root_dom_j->extent()
- : extent_map.at(kir_root_dom_j);
+ auto root_ext_j = extent_map.at(kir_root_dom_j);
- root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j);
+ TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(root_ext_j));
if (!root_ind_j->isZeroInt()) {
if (stride == nullptr) {
}
if (stride != nullptr) {
- strided_inds[i] = ir_builder.mulExpr(root_ind_i, stride);
+ strided_inds.push_back(ir_builder.mulExpr(root_ind_i, stride));
} else {
- strided_inds[i] = root_ind_i;
+ strided_inds.push_back(root_ind_i);
}
}
- return strided_inds;
+ if (strided_inds.size() == 0)
+ strided_inds.push_back(ir_builder.create<kir::Int>(0));
+
+ return ir_builder.create<kir::TensorIndex>(producer_tv, strided_inds);
}
-std::vector<kir::Val*> Index::getGlobalConsumerStridedIndices(
- const TensorView* consumer_tv,
+kir::TensorIndex* Index::getGlobalConsumerIndex(
+ TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops) {
- FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex");
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- // Get a reference tensor replayed as existing loop structure
- auto reference = IndexReferenceReplay::getReference(loops);
- auto reference_domain = reference.domain;
- auto reference_id_map = reference.concrete_to_id;
-
- // Map reference tensor to consumer
- std::unordered_map<IterDomain*, IterDomain*> root_ref_to_consumer;
- for (auto c_root : consumer_tv->getMaybeRFactorDomain()) {
- auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(c_root);
- auto ref_id_it = reference_id_map.find(concrete_id);
- if (ref_id_it != reference_id_map.end()) {
- root_ref_to_consumer[ref_id_it->second] = c_root;
- }
- }
-
- BestEffortReplay replay_consumer_as_ref(
- consumer_tv->domain()->domain(),
- reference_domain->domain(),
- root_ref_to_consumer);
+ FUSER_PERF_SCOPE("getGlobalConsumerIndex");
- const auto& ref_2_consumer = replay_consumer_as_ref.getReplay();
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- // Index into the reference tensor. Reference indexing will handle vectorized
- // dims where index should be set to 0
- auto ref_compute = getReferenceIndexing(loops, reference_domain);
+ // grab all tensor views from producer_tv <- computeAtRoot
+ std::deque<TensorView*> tv_stack = getComputeAtTVStackFrom(consumer_tv);
- // Index into consumer using reference indexing
-
- const auto reference_halo_extent_map = getReferenceHaloExtentMap(
- reference, consumer_tv, ref_2_consumer, ref_compute.extentMap());
+ std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
+ std::transform(
+ loops.begin(),
+ loops.end(),
+ std::inserter(loop_to_ind_map, loop_to_ind_map.begin()),
+ [](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); });
- auto consumer_indexing = ref_compute.updateIndexCompute(
- consumer_tv->domain(),
- ref_2_consumer,
- consumer_tv->domain()->contiguity(),
- reference_halo_extent_map);
+ auto index_map = generateIndexAndExtentMap(
+ tv_stack,
+ std::deque<kir::ForLoop*>(loops.begin(), loops.end()),
+ loop_to_ind_map,
+ consumer_tv->domain()->contiguity())
+ .first;
// Indices should now be mapped onto IterDomains in consumer, so just grab
// and use them.
auto root_dom = consumer_tv->getMaybeRFactorDomain();
- // TODO: Abstract stride logic to reuse with producer indexing
- auto zero = ir_builder.zeroVal();
- std::vector<kir::Val*> strides(root_dom.size(), zero);
- {
- int stride_i = 0;
- for (size_t i = 0; i < root_dom.size(); i++) {
- if (root_dom[i]->isReduction() ||
- root_dom[i]->getIterType() == IterType::BroadcastWithoutStride) {
- strides[i] = zero;
- continue;
- }
- std::stringstream ss;
- ss << "T" << consumer_tv->name() << ".stride[" << stride_i++ << "]";
- strides[i] = ir_builder.create<kir::NamedScalar>(ss.str(), DataType::Int);
- }
- }
-
- kir::Val* cur_contig_stride = ir_builder.oneVal();
- // if we have rfactor we can't simplify the indexing like this, we would need
- // to fix contiguity size to be rfactor size not root size
- if (root_dom.size() == consumer_tv->domain()->contiguity().size()) {
- for (size_t i = 0; i < root_dom.size(); i++) {
- auto dim = root_dom.size() - i - 1;
- if (root_dom[dim]->isReduction()) {
- continue;
- }
- if (root_dom[dim]->getIterType() == IterType::BroadcastWithoutStride) {
- continue;
- }
-
- kir::Val* root_ind = nullptr;
- auto kir_root_dom =
- gpu_lower->lowerValue(root_dom[dim])->as<kir::IterDomain>();
- if (consumer_indexing.indexMap().find(kir_root_dom) !=
- consumer_indexing.indexMap().end()) {
- root_ind = consumer_indexing.indexMap().at(kir_root_dom);
- } else if (
- root_dom[dim]->getIterType() == IterType::BroadcastWithStride) {
- root_ind = zero;
- }
+ bool inner_most_dim_contig =
+ root_dom[root_dom.size() - 1]->getIterType() == IterType::Iteration &&
+ consumer_tv->domain()->contiguity()[root_dom.size() - 1];
- TORCH_INTERNAL_ASSERT(
- root_ind != nullptr,
- "Couldn't find root mapping for TV",
- consumer_tv->name(),
- " dim: ",
- i,
- " id: ",
- root_dom[dim]);
-
- if (consumer_tv->domain()->contiguity()[dim]) {
- // If contig, used the stored stride which may be the previous
- // dimensions stride * previous dimensions size
- strides[dim] = cur_contig_stride;
- // Prepare for the next dimension which may also be contiguous, multiply
- // by extent of this dimension
- auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]);
- cur_contig_stride =
- ir_builder.mulExpr(cur_contig_stride, root_dim_extent);
- } else {
- // If non contiguous dimension, keep local stride information, set cur
- // stride to local stride * local raw extent
- cur_contig_stride = ir_builder.mulExpr(
- strides[dim], getHaloExtentOfRootAxis(root_dom[dim]));
- }
- }
- }
-
- auto vectorize_shift =
- loops.empty() ? nullptr : loops.back()->vectorize_shift();
-
- // Global striding
- std::vector<kir::Val*> strided_inds(root_dom.size(), ir_builder.zeroVal());
+ int64_t stride_i = 0;
+ std::vector<Val*> strided_inds;
for (const auto i : c10::irange(root_dom.size())) {
- // See a comment in indexing to root domains in getGlobalProducerIndex.
if (root_dom[i]->isReduction() ||
- root_dom[i]->getIterType() == IterType::BroadcastWithoutStride ||
- root_dom[i]->getIterType() == IterType::BroadcastWithStride ||
- gpu_lower->trivialReductionInfo().isDerived(root_dom[i])) {
+ root_dom[i]->getIterType() == IterType::BroadcastWithoutStride) {
+ continue;
+ } else if (root_dom[i]->getIterType() == IterType::BroadcastWithStride) {
+ stride_i++;
continue;
}
auto kir_root_dom_i =
- gpu_lower->lowerValue(root_dom[i])->as<kir::IterDomain>();
+ GpuLower::lowerValue(root_dom[i])->as<kir::IterDomain>();
TORCH_INTERNAL_ASSERT(
- consumer_indexing.indexMap().find(kir_root_dom_i) !=
- consumer_indexing.indexMap().end(),
+ index_map.find(kir_root_dom_i) != index_map.end(),
"Couldn't find root mapping for TV",
consumer_tv->name(),
" dim: ",
i,
" id: ",
kir::toString(kir_root_dom_i));
+ auto ind = index_map.at(kir_root_dom_i);
- auto root_ind = consumer_indexing.indexMap().at(kir_root_dom_i);
-
- if (root_ind->isZeroInt()) {
- continue;
+ if (i == root_dom.size() - 1 && inner_most_dim_contig) {
+ strided_inds.push_back(ind);
+ } else if (ind->isZeroInt()) {
+ stride_i++;
} else {
- auto strided_ind = ir_builder.mulExpr(root_ind, strides[i]);
- if (i == root_dom.size() - 1 && vectorize_shift != nullptr) {
- strided_inds[i] = ir_builder.addExpr(strided_ind, vectorize_shift);
- } else {
- strided_inds[i] = strided_ind;
- }
+ std::stringstream ss;
+ ss << "T" << consumer_tv->name() << ".stride[" << stride_i++ << "]";
+ strided_inds.push_back(ir_builder.mulExpr(
+ ind, ir_builder.create<kir::NamedScalar>(ss.str(), DataType::Int)));
}
}
- return strided_inds;
+ if (strided_inds.size() == 0)
+ strided_inds.push_back(ir_builder.create<kir::Int>(0));
+
+ return ir_builder.create<kir::TensorIndex>(consumer_tv, strided_inds);
}
// Consumer index for either shared or local memory
-std::vector<kir::Val*> Index::getNonGlobalConsumerStridedIndices(
- const TensorView* consumer_tv,
+kir::TensorIndex* Index::getConsumerIndex_impl(
+ TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- // Get a reference tensor replayed as existing loop structure
- auto reference = IndexReferenceReplay::getReference(loops);
- auto reference_domain = reference.domain;
- auto reference_id_map = reference.concrete_to_id;
-
- auto alloc_point = loop_utils::getAllocPoint(consumer_tv, loops);
- std::unordered_map<kir::ForLoop*, kir::Val*> loop_to_ind_map =
- indexMapFromTV(consumer_tv, loops, alloc_point);
-
- // Map loop nests to indicies, zeroing out those not used due to locality of
- // memory
- std::unordered_map<kir::IterDomain*, kir::Val*> ref_id_to_ind_map;
-
- // Due to rfactor/initialization reference_domain may be bigger than loop nest
- // structure, ignore IterDomains that aren't present in the loop nest when
- // indexing reference.
- TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims());
- for (size_t loop_i = 0; loop_i < loops.size(); loop_i++) {
- auto ref_axis = gpu_lower->lowerValue(reference_domain->axis(loop_i))
- ->as<kir::IterDomain>();
- ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]];
- }
-
- // Map reference tensor to consumer
- std::unordered_map<IterDomain*, IterDomain*> root_ref_to_consumer;
- for (auto c_root : consumer_tv->getMaybeRFactorDomain()) {
- auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(c_root);
- auto ref_id_it = reference_id_map.find(concrete_id);
- if (ref_id_it != reference_id_map.end()) {
- root_ref_to_consumer[ref_id_it->second] = c_root;
- }
- }
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- // Grab roots that map into consumer and save them into the preferred roots
- // set for references indexing
- std::unordered_set<IterDomain*> preferred_roots;
- for (auto entry : root_ref_to_consumer) {
- if (entry.second->isBroadcast() || entry.second->isReduction()) {
- continue;
- }
- preferred_roots.emplace(entry.first);
- }
-
- // Make sure propagation of indexing while mixing with 0 indicies we propagate
- // in a way that consumer will be able to see what's going on.
- auto preferred_paths = buildPreferredPaths(reference_domain, preferred_roots);
-
- // Index into the reference tensor
- auto ref_compute = getReferenceIndexing(
- loops, reference_domain, ref_id_to_ind_map, preferred_paths);
-
- BestEffortReplay replay_consumer_as_ref(
- consumer_tv->domain()->domain(),
- reference_domain->domain(),
- root_ref_to_consumer);
-
- const auto& ref_2_consumer = replay_consumer_as_ref.getReplay();
+ // grab all tensor views from consumer_tv <- computeAtRoot
+ std::deque<TensorView*> tv_stack = getComputeAtTVStackFrom(consumer_tv);
- const auto reference_halo_extent_map = getReferenceHaloExtentMap(
- reference, consumer_tv, ref_2_consumer, ref_compute.extentMap());
+ std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map =
+ indexMapFromTV(consumer_tv, loops);
- // Index into consumer using reference indexing
- auto consumer_indexing = ref_compute.updateIndexCompute(
- consumer_tv->domain(),
- ref_2_consumer,
- consumer_tv->domain()->contiguity(),
- reference_halo_extent_map);
+ auto index_and_extent_map = generateIndexAndExtentMap(
+ tv_stack,
+ std::deque<kir::ForLoop*>(loops.begin(), loops.end()),
+ loop_to_ind_map,
+ std::vector<bool>(consumer_tv->getRootDomain().size(), false));
- IndexSwizzle index_swizzle(
- consumer_tv,
- consumer_indexing.indexMap(),
- consumer_indexing.extentMap(),
- consumer_indexing.zeroMergedIn());
-
- index_swizzle.run();
-
- auto index_map = index_swizzle.indexMap();
- auto extent_map = consumer_indexing.extentMap();
+ auto index_map = index_and_extent_map.first;
+ auto extent_map = index_and_extent_map.second;
// Indices should now be mapped onto IterDomains in consumer, so just grab
// and use them.
auto root_dom = consumer_tv->getMaybeRFactorDomain();
- std::vector<kir::Val*> strided_inds(root_dom.size(), ir_builder.zeroVal());
+
+ std::vector<Val*> strided_inds;
for (const auto i : c10::irange(root_dom.size())) {
- if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast() ||
- gpu_lower->trivialReductionInfo().isDerived(root_dom[i])) {
+ if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast()) {
continue;
}
auto kir_root_dom_i =
- gpu_lower->lowerValue(root_dom[i])->as<kir::IterDomain>();
+ GpuLower::lowerValue(root_dom[i])->as<kir::IterDomain>();
TORCH_INTERNAL_ASSERT(
index_map.find(kir_root_dom_i) != index_map.end(),
i,
" id: ",
kir::toString(kir_root_dom_i));
+ auto root_ind_i = index_map.at(kir_root_dom_i);
+ TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(root_ind_i));
- const auto root_ind_i = index_map.at(kir_root_dom_i);
if (root_ind_i->isZeroInt()) {
continue;
}
// Compute striding for this index.
- kir::Val* stride = nullptr;
+ Val* stride = nullptr;
for (size_t j = i + 1; j < root_dom.size(); j++) {
- if (root_dom[j]->isBroadcast() || root_dom[j]->isReduction() ||
- gpu_lower->trivialReductionInfo().isDerived(root_dom[j])) {
+ if (root_dom[j]->isBroadcast() || root_dom[j]->isReduction()) {
continue;
}
auto kir_root_dom_j =
- gpu_lower->lowerValue(root_dom[j])->as<kir::IterDomain>();
+ GpuLower::lowerValue(root_dom[j])->as<kir::IterDomain>();
TORCH_INTERNAL_ASSERT(
- index_map.find(kir_root_dom_j) != index_map.end(),
+ index_map.find(kir_root_dom_j) != index_map.end() &&
+ extent_map.find(kir_root_dom_j) != extent_map.end(),
"Couldn't find root mapping for TV",
consumer_tv->name(),
" dim: ",
root_dom[i]);
auto root_ind_j = index_map.at(kir_root_dom_j);
- auto root_ext_j = extent_map.find(kir_root_dom_j) == extent_map.end()
- ? kir_root_dom_j->extent()
- : extent_map.at(kir_root_dom_j);
-
- root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j);
-
+ auto root_ext_j = extent_map.at(kir_root_dom_j);
+ TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(root_ext_j));
if (!root_ind_j->isZeroInt()) {
if (stride == nullptr) {
stride = root_ext_j;
}
if (stride != nullptr) {
- strided_inds[i] = ir_builder.mulExpr(root_ind_i, stride);
+ strided_inds.push_back(ir_builder.mulExpr(root_ind_i, stride));
} else {
- strided_inds[i] = root_ind_i;
+ strided_inds.push_back(root_ind_i);
}
}
- return strided_inds;
+ if (strided_inds.size() == 0)
+ strided_inds.push_back(ir_builder.create<kir::Int>(0));
+
+ return ir_builder.create<kir::TensorIndex>(consumer_tv, strided_inds);
}
-std::vector<kir::Val*> Index::getProducerStridedIndices(
+// Producer is the inputs of an expression
+kir::TensorIndex* Index::getProducerIndex(
TensorView* producer,
- const TensorView* consumer,
+ TensorView* consumer,
const std::vector<kir::ForLoop*>& loops) {
- FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices");
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
+ FUSER_PERF_SCOPE("Index::getProducerIndex");
+
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
if (producer->domain()->noReductions().size() == 0) {
- return std::vector<kir::Val*>(
- producer->getMaybeRFactorDomain().size(), ir_builder.zeroVal());
+ return ir_builder.create<kir::TensorIndex>(producer, std::vector<Val*>{});
}
- std::vector<kir::Val*> strided_indices;
if (producer->getMemoryType() == MemoryType::Global) {
- strided_indices =
- getGlobalProducerStridedIndices(producer, consumer, loops);
- } else {
- strided_indices =
- getNonGlobalProducerStridedIndices(producer, consumer, loops);
+ return getGlobalProducerIndex(producer, consumer, loops);
}
- TORCH_INTERNAL_ASSERT(
- strided_indices.size() == producer->getMaybeRFactorDomain().size());
-
- return strided_indices;
+ return getProducerIndex_impl(producer, consumer, loops);
}
-// Producer is the inputs of an expression
-kir::TensorIndex* Index::getProducerIndex(
- TensorView* producer,
- const TensorView* consumer,
+// Consumer is the output of an expression
+kir::TensorIndex* Index::getConsumerIndex(
+ TensorView* consumer,
const std::vector<kir::ForLoop*>& loops) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- auto strided_indices = getProducerStridedIndices(producer, consumer, loops);
- return ir_builder.create<kir::TensorIndex>(producer, strided_indices);
-}
+ FUSER_PERF_SCOPE("Index::getConsumerIndex");
-std::vector<kir::Val*> Index::getConsumerStridedIndices(
- const TensorView* consumer,
- const std::vector<kir::ForLoop*>& loops) {
- FUSER_PERF_SCOPE("GpuLower::Lower::Index::getConsumerStridedIndices");
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
if (consumer->domain()->noReductions().size() == 0) {
- return std::vector<kir::Val*>(
- consumer->getMaybeRFactorDomain().size(), ir_builder.zeroVal());
+ return ir_builder.create<kir::TensorIndex>(consumer, std::vector<Val*>{});
}
- std::vector<kir::Val*> strided_indices;
if (consumer->getMemoryType() == MemoryType::Global) {
- strided_indices = getGlobalConsumerStridedIndices(consumer, loops);
- } else {
- strided_indices = getNonGlobalConsumerStridedIndices(consumer, loops);
+ return getGlobalConsumerIndex(consumer, loops);
}
- TORCH_INTERNAL_ASSERT(
- strided_indices.size() == consumer->getMaybeRFactorDomain().size());
-
- return strided_indices;
-}
-
-// Consumer is the output of an expression
-kir::TensorIndex* Index::getConsumerIndex(
- const TensorView* consumer,
- const std::vector<kir::ForLoop*>& loops) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- auto strided_indices = getConsumerStridedIndices(consumer, loops);
- return ir_builder.create<kir::TensorIndex>(consumer, strided_indices);
+ return getConsumerIndex_impl(consumer, loops);
}
// Basically just copy getGlobalConsumerIndex, just don't do the striding and
// return std::vector of Vals
-//
-// TODO(kir): replace pair with struct
-//
-std::pair<std::vector<kir::Val*>, bool> Index::getConsumerRootPredIndices(
- const kir::TensorView* kir_consumer_tv,
+std::pair<std::vector<Val*>, bool> Index::getConsumerRootPredIndices(
+ TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops,
const std::vector<bool>& root_contiguity,
- bool unswitch) {
- FUSER_PERF_SCOPE("GpuLower::Lower::Index::getConsumerRootPredIndices");
-
- auto consumer_tv = kir_consumer_tv->fuserTv();
-
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- // Get a reference tensor replayed as existing loop structure
- ReferenceTensor reference = IndexReferenceReplay::getReference(loops);
- auto reference_domain = reference.domain;
- auto reference_id_map = reference.concrete_to_id;
-
- // Map reference tensor to consumer
- std::unordered_map<IterDomain*, IterDomain*> root_ref_to_consumer;
- for (auto c_root : consumer_tv->getMaybeRFactorDomain()) {
- auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(c_root);
- auto ref_id_it = reference_id_map.find(concrete_id);
- if (ref_id_it != reference_id_map.end()) {
- root_ref_to_consumer[ref_id_it->second] = c_root;
- }
- }
+ bool unroll) {
+ FUSER_PERF_SCOPE("Index::getConsumerRootPredIndices");
- BestEffortReplay replay_consumer_as_ref(
- consumer_tv->domain()->domain(),
- reference_domain->domain(),
- root_ref_to_consumer);
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- const auto& ref_2_consumer = replay_consumer_as_ref.getReplay();
+ // grab all tensor views from producer_tv <- computeAtRoot
+ std::deque<TensorView*> tv_stack = getComputeAtTVStackFrom(consumer_tv);
- std::unordered_map<kir::ForLoop*, kir::Val*> loop_to_ind_map;
+ std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
std::transform(
loops.begin(),
std::inserter(loop_to_ind_map, loop_to_ind_map.begin()),
[](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); });
- if (unswitch) {
- bool within_unswitch = false;
- const auto one = ir_builder.create<kir::Int>(1);
+ if (unroll) {
+ bool within_unroll = false;
+ Val* one = ir_builder.create<kir::Int>(1);
for (auto loop : loops) {
- if (loop->iter_domain()->parallelType() == ParallelType::Unroll ||
- loop->iter_domain()->parallelType() == ParallelType::Unswitch ||
- loop->iter_domain()->parallelType() == ParallelType::Vectorize) {
- within_unswitch = true;
+ if (loop->iter_domain()->getParallelType() == ParallelType::Unroll) {
+ within_unroll = true;
}
- if (within_unswitch) {
- if (loop->iter_domain()->isThread()) {
- loop_to_ind_map[loop] = loop->start();
- } else {
- loop_to_ind_map[loop] = ir_builder.subExpr(loop->stop(), one);
- }
+ if (within_unroll && !loop->iter_domain()->isThread()) {
+ loop_to_ind_map[loop] =
+ ir_builder.subExpr(loop->iter_domain()->extent(), one);
}
}
}
- std::unordered_map<kir::IterDomain*, kir::Val*> ref_id_to_ind_map;
- // Due to rfactor/initialization reference_domain may be bigger than loop nest
- // structure
- TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims());
- for (size_t loop_i = 0; loop_i < loops.size(); loop_i++) {
- auto ref_axis = gpu_lower->lowerValue(reference_domain->axis(loop_i))
- ->as<kir::IterDomain>();
- ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]];
- }
-
- // Index into the reference tensor
- auto ref_compute =
- getReferenceIndexing(loops, reference_domain, ref_id_to_ind_map, {});
-
- const auto reference_halo_extent_map = getReferenceHaloExtentMap(
- reference, consumer_tv, ref_2_consumer, ref_compute.extentMap());
-
- // Index into consumer using reference indexing
- auto consumer_indexing = ref_compute.updateIndexCompute(
- consumer_tv->domain(),
- ref_2_consumer,
- root_contiguity,
- reference_halo_extent_map);
+ // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
+ auto index_map = generateIndexAndExtentMap(
+ tv_stack,
+ std::deque<kir::ForLoop*>(loops.begin(), loops.end()),
+ loop_to_ind_map,
+ root_contiguity)
+ .first;
// Indices should now be mapped onto IterDomains in consumer, so just grab
// and use them.
- // If we are generating a predicate for initialization, we should use
- // rfactor instead of root_dom. If we are generating a predicate for
- // actual reduction expr, reduction axes should have their indices
- // mapped to non-zero symbolic vals.
- bool buffer_init = false;
- for (auto consumer_id : kir_consumer_tv->domain()->domain()) {
- if (consumer_id->isReduction()) {
- if (consumer_indexing.indexMap().find(consumer_id) !=
- consumer_indexing.indexMap().end()) {
- if (!consumer_indexing.indexMap().at(consumer_id)->isZeroInt()) {
- buffer_init = false;
- break;
+ // If we are generating a predicate for initialization check if we should use
+ // rfactor instead of root_dom
+ bool use_rfactor = true;
+ if (consumer_tv->hasRFactor()) {
+ auto rfactor_dom = consumer_tv->getMaybeRFactorDomain();
+ for (auto rfactor_id : rfactor_dom) {
+ if (rfactor_id->isReduction()) {
+ auto kir_rfactor_id =
+ GpuLower::lowerValue(rfactor_id)->as<kir::IterDomain>();
+ if (index_map.find(kir_rfactor_id) != index_map.end()) {
+ if (!index_map.at(kir_rfactor_id)->isZeroInt()) {
+ use_rfactor = false;
+ break;
+ }
}
}
- buffer_init = true;
}
}
- // If we are initializing a reduction buffer and the tensor has a
- // rfactor root, the predicate should be based on the rfactor root.
- const auto root_domain =
- (buffer_init && kir_consumer_tv->domain()->hasRFactor())
- ? kir_consumer_tv->domain()->rfactorDomain()
- : kir_consumer_tv->domain()->rootDomain();
-
- const auto zero = ir_builder.create<kir::Int>(0);
- std::vector<kir::Val*> root_inds(root_domain.size(), zero);
+ auto root_dom = use_rfactor ? consumer_tv->getMaybeRFactorDomain()
+ : consumer_tv->getRootDomain();
- for (const auto i : c10::irange(root_domain.size())) {
- if (root_domain[i]->isBroadcast() ||
- gpu_lower->trivialReductionInfo().isDerived(root_domain[i])) {
- continue;
- }
- const auto it = consumer_indexing.indexMap().find(root_domain[i]);
- if (it != consumer_indexing.indexMap().end()) {
- root_inds[i] = it->second;
- }
- }
-
- return {root_inds, buffer_init};
-}
-
-namespace {
-struct PredicateContigInfo {
- public:
- // Iteration domain that is only comprised of merge transformations
- IterDomain* contig_id;
- // The set of root iteration domains that make up the contig_id
- std::unordered_set<IterDomain*> root_ids;
-};
-
-// Find iteration domains in the history of reference comprised only of
-// merge operations. Only return iteration domains that are subsequently fed
-// into a split, or are in the provided domain. In other words, we don't want to
-// return every IterDomain that's contiguous, just the one closest to the
-// leaves. Predicates are not associated with physical memory so we can treat
-// all of them as contiguous merges.
-std::vector<PredicateContigInfo> getPredicateContigIds(
- std::vector<IterDomain*> reference_domain) {
- auto root_vals = IterVisitor::getInputsTo(
- {reference_domain.begin(), reference_domain.end()});
- auto root_ids = ir_utils::filterByType<IterDomain>(root_vals);
-
- // Mark all roots as being originally "contiguous"
- std::vector<IterDomain*> contiguous_ids(root_ids.begin(), root_ids.end());
-
- // Dereference root_vals.begin below, so make sure there's at least one entry
- if (root_vals.empty()) {
- return std::vector<PredicateContigInfo>();
- }
-
- // Run through iteration domain history
- auto exprs = ExprSort::getExprs(
- (*root_vals.begin())->fusion(),
- {reference_domain.begin(), reference_domain.end()});
-
- for (auto expr : exprs) {
- // If not a merge, output is not contiguous
- if (expr->isA<Merge>()) {
- auto merge = expr->as<Merge>();
- auto inner_contig_it = std::find(
- contiguous_ids.begin(), contiguous_ids.end(), merge->inner());
- auto outer_contig_it = std::find(
- contiguous_ids.begin(), contiguous_ids.end(), merge->outer());
-
- if (inner_contig_it != contiguous_ids.end() &&
- outer_contig_it != contiguous_ids.end()) {
- // If inner and outer are contiguous, out must be contiguous. Remove
- // inner and outer, and add out.
- contiguous_ids.erase(outer_contig_it);
- contiguous_ids.erase(std::find(
- contiguous_ids.begin(), contiguous_ids.end(), merge->inner()));
- contiguous_ids.emplace_back(merge->out());
- }
- }
- }
-
- std::vector<PredicateContigInfo> contig_id_infos;
-
- // Create entries and return them
- for (auto contig_id : contiguous_ids) {
- auto contig_root_vals = IterVisitor::getInputsTo({contig_id});
- auto contig_root_ids = ir_utils::filterByType<IterDomain>(contig_root_vals);
- PredicateContigInfo contig_id_info;
- contig_id_info.contig_id = contig_id;
- contig_id_info.root_ids = std::unordered_set<IterDomain*>(
- contig_root_ids.begin(), contig_root_ids.end());
- contig_id_infos.push_back(contig_id_info);
- }
- return contig_id_infos;
-}
-
-} // namespace
-
-// Returns predicates and the concrete (by loop map) root domains they cover
-std::pair<std::vector<kir::Bool*>, std::vector<std::unordered_set<IterDomain*>>>
-Index::getReferenceRootPredicates(
- const kir::TensorView* kir_consumer_tv,
- const std::vector<kir::ForLoop*>& loops,
- bool unswitch) {
- FUSER_PERF_SCOPE("GpuLower::Lower::Index::getReferenceRootPredicates");
-
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- // Get a reference tensor replayed as existing loop structure
- ReferenceTensor reference = IndexReferenceReplay::getReference(loops);
- auto reference_domain = reference.domain;
- auto reference_id_map = reference.concrete_to_id;
-
- std::unordered_map<kir::ForLoop*, kir::Val*> loop_to_ind_map;
-
- std::transform(
- loops.begin(),
- loops.end(),
- std::inserter(loop_to_ind_map, loop_to_ind_map.begin()),
- [](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); });
-
- // If unswitch don't directly use indices from for loop, use for loop extent
- // minus 1
- if (unswitch) {
- TORCH_INTERNAL_ASSERT(
- loops.size() <= reference_domain->nDims(),
- "Invalid reference generated.");
- bool within_unswitch = false;
- const auto one = ir_builder.create<kir::Int>(1);
- for (size_t loop_i = 0; loop_i < loops.size(); loop_i++) {
- auto loop = loops[loop_i];
- auto ref_id = reference_domain->axis(loop_i);
- if (loop->iter_domain()->parallelType() == ParallelType::Unroll ||
- loop->iter_domain()->parallelType() == ParallelType::Unswitch ||
- loop->iter_domain()->parallelType() == ParallelType::Vectorize) {
- within_unswitch = true;
- }
-
- if (within_unswitch) {
- // Rely on the reference to check broadcasting. The for loop could be
- // broadcasted on a constant value from an unroll split. Since reference
- // may convert this to an iter domain, that for loop could be valid to
- // generate predication from.
- if (ref_id->isBroadcast()) {
- // Ignore indexing into broadcasted dimensions.
- continue;
- } else if (loop->iter_domain()->isThread()) {
- loop_to_ind_map[loop] = loop->start();
- } else {
- loop_to_ind_map[loop] = ir_builder.subExpr(loop->stop(), one);
- }
- }
- }
- }
-
- // Add magic zero to a loop pretty far inside in indexing
- kir::IterDomain* magic_zero_loop = nullptr;
- std::unordered_map<kir::IterDomain*, kir::Val*> ref_id_to_ind_map;
- // Due to rfactor/initialization reference_domain may be bigger than loop nest
- // structure
- TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims());
- for (size_t loop_i = 0; loop_i < loops.size(); loop_i++) {
- auto loop = loops[loop_i];
- auto ind = loop_to_ind_map[loops[loop_i]];
- auto ref_axis = reference_domain->axis(loop_i);
- auto kir_ref_axis = gpu_lower->lowerValue(ref_axis)->as<kir::IterDomain>();
-
- if (Index::protectWithMagicZero(loop, ref_axis, ind)) {
- magic_zero_loop = kir_ref_axis;
- }
-
- ref_id_to_ind_map[kir_ref_axis] = loop_to_ind_map[loop];
- }
-
- if (ref_id_to_ind_map.count(magic_zero_loop)) {
- ref_id_to_ind_map[magic_zero_loop] = ir_builder.addExpr(
- ref_id_to_ind_map[magic_zero_loop], ir_builder.magicZeroVal());
- }
-
- auto consumer_tv = kir_consumer_tv->fuserTv();
-
- // Map reference tensor to consumer
- std::unordered_map<IterDomain*, IterDomain*> root_ref_to_consumer;
- for (auto c_root : consumer_tv->getMaybeRFactorDomain()) {
- auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(c_root);
- auto ref_id_it = reference_id_map.find(concrete_id);
- if (ref_id_it != reference_id_map.end()) {
- root_ref_to_consumer[ref_id_it->second] = c_root;
- }
- }
-
- BestEffortReplay replay_consumer_as_ref(
- consumer_tv->domain()->domain(),
- reference_domain->domain(),
- root_ref_to_consumer);
-
- const auto& ref_2_consumer = replay_consumer_as_ref.getReplay();
-
- // Halo information is not currently used as lower_shift will take care of the
- // predicate generation and is still using the older function:
- // getConsumerRootPredIndices
-
- // Generate halo information for reference.
- updateHaloInfoForReference(reference, consumer_tv);
-
- std::unordered_map<kir::IterDomain*, kir::Val*> reference_halo_extent_map;
-
- const auto& halo_info = gpu_lower->haloInfo();
-
- // Generate map from reference iter domains to halo extents
- for (auto entry : ref_2_consumer) {
- auto ref_id = entry.first;
- auto extent = halo_info.getExtent(ref_id);
- if (extent != nullptr) {
- reference_halo_extent_map[gpu_lower->lowerValue(ref_id)
- ->as<kir::IterDomain>()] = extent;
- }
- }
-
- // Index into the reference tensor
- auto ref_indexing = getReferenceIndexing(
- loops,
- reference_domain,
- ref_id_to_ind_map,
- {},
- reference_halo_extent_map);
-
- // If we are initializing a reduction buffer and the tensor has a
- // rfactor root, the predicate should be based on the rfactor root.
- const auto root_domain = reference_domain->getRootDomain();
-
- // Get the contiguous ids we need to generate predicates for
- auto contig_id_infos = getPredicateContigIds(reference_domain->domain());
-
- // Roots in contiguous processing is based on reference roots, want to convert
- // these to concrete roots, flip reference's concrete_to_id map as reference
- // ids are not part of compute at maps.
- decltype(reference_id_map) ref_id_to_concrete;
- std::transform(
- reference_id_map.begin(),
- reference_id_map.end(),
- std::inserter(ref_id_to_concrete, ref_id_to_concrete.begin()),
- [](auto entry) { return std::make_pair(entry.second, entry.first); });
-
- // Track which roots have been handled by the generated predicates
- std::vector<std::unordered_set<IterDomain*>> handeled_roots;
-
- std::vector<kir::Bool*> predicates;
-
- for (auto contig_id_entry : contig_id_infos) {
- auto contig_id = contig_id_entry.contig_id;
- // No predicates needed for braodcasted indices.
- if (contig_id->isBroadcast() ||
- gpu_lower->trivialReductionInfo().isDerived(contig_id)) {
- continue;
- }
-
- auto root_ids = contig_id_entry.root_ids;
- auto kir_contig_id =
- gpu_lower->lowerValue(contig_id)->as<kir::IterDomain>();
-
- const auto it = ref_indexing.indexMap().find(kir_contig_id);
-
- // First condition below is due to broadcasts in consumers of consumer that
- // are not in consumer there can be unresolved indexing in the reference
- // tensor. This can happen when we have something like: TV3[i1o*i2, i1i] and
- // TV1[i2] where tv3 and tv1 share their outer dimension. i1 will be part of
- // reference tensors root domain, but when indexing into TV1 there aren't
- // enough indices to resolve it.
- //
- // The condition also happens with Misaligned predicates, where
- // inner-most vectorized loops are not included in the loops
- // parameter. Predicates involving vectorized loops are separately
- // generated in lower_misaligned_vectorization.
- //
- // Second condition is simply to avoid predication on broadcasting axes as
- // it's not required.
- if (it == ref_indexing.indexMap().end() || it->second->isZeroInt()) {
+ std::vector<Val*> root_inds(root_dom.size(), ir_builder.create<kir::Int>(0));
+ for (const auto i : c10::irange(root_dom.size())) {
+ if (root_dom[i]->isBroadcast()) {
continue;
}
- // Use the iteration domains extent unless there's a halo extent
- auto extent = kir_contig_id->extent();
-
- auto halo_extent_it = reference_halo_extent_map.find(kir_contig_id);
- if (halo_extent_it != reference_halo_extent_map.end()) {
- extent = halo_extent_it->second;
- }
-
- // If the index definition is "simple" and the extent is "simple" then our
- // for loop goes exactly across the iteration domain extent so no predicate
- // needed.
- if (it->second->definition() == nullptr &&
- extent->definition() == nullptr) {
- continue;
+ auto kir_root_dom_i =
+ GpuLower::lowerValue(root_dom[i])->as<kir::IterDomain>();
+ if (index_map.find(kir_root_dom_i) != index_map.end()) {
+ auto ind = index_map.at(kir_root_dom_i);
+ TORCH_INTERNAL_ASSERT(kir::isLoweredScalar(ind))
+ root_inds[i] = ind;
}
-
- predicates.push_back(
- ir_builder.ltExpr(it->second, extent)->as<kir::Bool>());
-
- // Transform roots from reference to concrete roots (based on loop compute
- // at map)
- std::unordered_set<IterDomain*> concrete_root_ids;
- std::transform(
- contig_id_entry.root_ids.begin(),
- contig_id_entry.root_ids.end(),
- std::inserter(concrete_root_ids, concrete_root_ids.begin()),
- [&ref_id_to_concrete](IterDomain* root_id) {
- return ref_id_to_concrete.at(root_id);
- });
- handeled_roots.push_back(concrete_root_ids);
}
- return {predicates, handeled_roots};
-}
-
-bool Index::protectWithMagicZero(
- kir::ForLoop* loop,
- IterDomain* reference_domain,
- kir::Val* ind) {
- bool ref_dom_simple =
- (reference_domain == nullptr ? true
- : reference_domain->definition() != nullptr);
- bool ind_simple =
- (ind == nullptr ? true
- : ind->definition() != nullptr && !ind->isZeroInt());
- return loop->isUnrollable() && (!ref_dom_simple || !ind_simple);
+ return std::make_pair(root_inds, use_rfactor);
}
} // namespace cuda
#pragma once
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <unordered_map>
#include <unordered_set>
namespace cuda {
class IndexCompute : public BackwardVisitor {
- protected:
+ private:
using BackwardVisitor::handle;
-
void handle(Split*) override;
void handle(Merge*) override;
void handle(Expr*) override;
// return extent_map_[id] if exists, else return id->extent()
- kir::Val* getExtent(kir::IterDomain* id);
+ Val* getExtent(kir::IterDomain* id);
- //! True if a domain is not used to index
- bool isZero(kir::IterDomain* id) const;
- //! True if any dependent of a domain is not used to index
- bool hasZeroMerged(kir::IterDomain* id) const;
+ bool hasZeroMerged(kir::IterDomain* id);
// Tensor domain we're mapping back to root
- const TensorDomain* td_; // NOLINT
+ const TensorDomain* td_;
// Map we update as we propagate backward, containing all IDs in the
// propagation. Initial indices are mapped with this map at tv->domain()
// and are back propagated to tv->rootDomain(). This index_map_ keeps the
// indices at intermediate IterDomain's in that back propagation.
- std::unordered_map<kir::IterDomain*, kir::Val*> index_map_; // NOLINT
+ std::unordered_map<kir::IterDomain*, Val*> index_map_;
// Map from IterDomain to their broadcasted extent. If a TV has I0*I1 but its
// producer has B0*I1 this map will contain a mapping from the ID{B0*I1} to
// the extent I0*I1. Also contains updated extents if we merge in a 0 index.
// See zero_merged_in_.
- std::unordered_map<kir::IterDomain*, kir::Val*> extent_map_; // NOLINT
-
- // Keeps track of domains that do not contribute to indexing
- std::unordered_set<kir::IterDomain*> zero_; // NOLINT
+ std::unordered_map<kir::IterDomain*, Val*> extent_map_;
// This set keeps track of IterDomain's that have had a zero index merged into
// them. This happens if we do something like tv->axis(0)->split(4) then
// IDs that are a result of contiguous merges
std::unordered_set<kir::IterDomain*> contig_ids;
- // Mentions if we should propagate an index down a particular IterDomain path
- // if there's an option
- std::unordered_set<kir::IterDomain*> preferred_paths_;
-
- // Map from IterDomains to halo-extended extents in corresponding
- // reference tensor
- std::unordered_map<kir::IterDomain*, kir::Val*> reference_halo_extent_map_;
-
public:
- const std::unordered_map<kir::IterDomain*, kir::Val*>& indexMap() const {
+ const std::unordered_map<kir::IterDomain*, Val*> indexMap() const {
return index_map_;
}
- const std::unordered_map<kir::IterDomain*, kir::Val*>& extentMap() const {
+ const std::unordered_map<kir::IterDomain*, Val*> extentMap() const {
return extent_map_;
}
- const std::unordered_set<kir::IterDomain*>& zeroMergedIn() const {
+ std::unordered_set<kir::IterDomain*> zeroMergedIn() const {
return zero_merged_in_;
}
// Propagate back from _td using initial_index_map
IndexCompute(
const TensorDomain* _td,
- std::unordered_map<kir::IterDomain*, kir::Val*> initial_index_map,
- std::unordered_map<kir::IterDomain*, kir::Val*> _extent_map,
+ std::unordered_map<kir::IterDomain*, Val*> initial_index_map,
+ std::unordered_map<kir::IterDomain*, Val*> _extent_map,
std::unordered_set<kir::IterDomain*> _zero_merged_in,
- const std::vector<bool>& _root_contiguity,
- std::unordered_set<kir::IterDomain*> preferred_paths = {},
- std::unordered_map<kir::IterDomain*, kir::Val*>
- reference_halo_extent_map = {});
+ const std::vector<bool>& _root_contiguity);
// Updates index_map, extent_map, and zero_merged_in based on id_map and
- // returns a new IndexCompute ready to be used.
+ // returns a new IndexCompute ready to be used. new_index_entries are not
+ // mapped, but are added to index_map.
IndexCompute updateIndexCompute(
const TensorDomain* new_td,
const std::unordered_map<IterDomain*, IterDomain*>& id_map,
- const std::vector<bool>& _root_contiguity,
- const std::unordered_map<kir::IterDomain*, kir::Val*>&
- reference_halo_extent_map = {});
-
- virtual void run();
-};
-
-//! Apply swizzle and update root indices accordingly
-class IndexSwizzle : public IndexCompute {
- public:
- IndexSwizzle(
- const TensorView* tv,
- std::unordered_map<kir::IterDomain*, kir::Val*> initial_index_map,
- std::unordered_map<kir::IterDomain*, kir::Val*> extent_map,
- std::unordered_set<kir::IterDomain*> zero_merged_in);
-
- void run() override;
-
- protected:
- using IndexCompute::handle;
-
- void handle(Expr* e) override;
-
- private:
- const TensorView* tv_ = nullptr;
- SwizzleType swizzle_type_ = SwizzleType::NoSwizzle;
- std::vector<IterDomain*> ids_to_swizzle_;
- std::unordered_set<IterDomain*> swizzled_ids_;
+ std::unordered_map<kir::IterDomain*, Val*> new_index_entries,
+ const std::vector<bool>& _root_contiguity);
+
+ // Map producer contiguity information to consumer, if entries don't match
+ // mark as false
+ static std::vector<bool> contiguityPasC(
+ TensorDomain* producer,
+ TensorDomain* consumer);
+
+ static std::vector<bool> contiguityAnd(
+ const std::vector<bool>& contig1,
+ const std::vector<bool>& contig2);
};
// Simple interface for IndexCompute
class Index {
private:
// Producer indexing if it's in shared or local memory
- static std::vector<kir::Val*> getNonGlobalProducerStridedIndices(
+ static kir::TensorIndex* getProducerIndex_impl(
TensorView* producer,
- const TensorView* consumer,
+ TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
// Consumer indexing if it's in shared or local memory
- static std::vector<kir::Val*> getNonGlobalConsumerStridedIndices(
- const TensorView* consumer,
+ static kir::TensorIndex* getConsumerIndex_impl(
+ TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
// Producer if it's in global memory
- static std::vector<kir::Val*> getGlobalProducerStridedIndices(
+ static kir::TensorIndex* getGlobalProducerIndex(
TensorView* producer,
- const TensorView* consumer,
+ TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
// Consumer indexing if it's in global memory
- static std::vector<kir::Val*> getGlobalConsumerStridedIndices(
- const TensorView* consumer,
+ static kir::TensorIndex* getGlobalConsumerIndex(
+ TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
public:
// Producer indexing dispatch
static kir::TensorIndex* getProducerIndex(
TensorView* producer,
- const TensorView* consumer,
+ TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
// Consumer index dispatch
static kir::TensorIndex* getConsumerIndex(
- const TensorView* consumer,
- const std::vector<kir::ForLoop*>& loops);
-
- //! Returns a vector of strided indices mapped onto the (rfactor)
- //! root domain of a producer tensor. The size of the returned
- //! vector is guaranteed to be equal to the number of axes of the
- //! indexing root domain.
- static std::vector<kir::Val*> getProducerStridedIndices(
- TensorView* producer,
- const TensorView* consumer,
- const std::vector<kir::ForLoop*>& loops);
-
- //! Returns a vector of strided indices mapped onto the (rfactor)
- //! root domain of a consumer tensor. The size of the returned
- //! vector is guaranteed to be equal to the number of axes of the
- //! indexing root domain.
- static std::vector<kir::Val*> getConsumerStridedIndices(
- const TensorView* consumer,
+ TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
// Consumer indices for predicates, keep all indices matching in root domain.
// Even those not used for physical addressing. Returns pair <root indices, if
// indices are mapped to rfactor dom>
- static std::pair<std::vector<kir::Val*>, bool> getConsumerRootPredIndices(
- const kir::TensorView* consumer,
+ static std::pair<std::vector<Val*>, bool> getConsumerRootPredIndices(
+ TensorView* consumer,
const std::vector<kir::ForLoop*>& loops,
const std::vector<bool>& root_contiguity,
- bool unswitch = false);
-
- //! Take a consumer tensorview and loop nest and generates predicates
- //! associated with the concrete roots of the loop nest. Returns a list of
- //! predicates, and a list of concrete roots they're associated with. It is
- //! assumed that no predicate is required if index[i] is an index directly
- //! from a for loop. This will not catch all cases if we actually have static
- //! size information for example:
- //!
- //! TV[I].split(4)
- //! would produce the code:
- //! for(i : I/4)
- //! for(j : 4)
- //! if( i * 4 + j < TV.size(0))
- //! TV[i * 4 + j]...
- //!
- //! However if we had TV.size[0] = 16 at "compile time" then we wouldn't need
- //! the predicate. This will be caught by canOmitPredicate in the predicate
- //! lowering
- // TODO: Replace pair of vectors with vector of
- static std::pair<
- std::vector<kir::Bool*>,
- std::vector<std::unordered_set<IterDomain*>>>
- getReferenceRootPredicates(
- const kir::TensorView* kir_consumer_tv,
- const std::vector<kir::ForLoop*>& loops,
- bool unswitch = false);
-
- // Determine if we may run into over reuse of predicates or registers in the
- // compiler. If the loop can be unrolled and the index and domain are not
- // "simple" we likely want the loop protected.
- //
- // Magic zero protection should only be done for global memory and predicates.
- // We should avoid use on registers. Shared memory does not require it, but
- // likely wouldn't hurt.
- static bool protectWithMagicZero(
- kir::ForLoop* loop,
- IterDomain* reference_domain = nullptr,
- kir::Val* ind = nullptr);
+ bool unroll = false);
};
} // namespace cuda
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/index_reference_replay.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-// We're going to replay this split operation on the corresponding ID
-void IndexReferenceReplay::handle(Split* s) {
- auto in = s->in();
-
- auto concrete_in = GpuLower::current()->caIndexMap().getConcreteMappedID(in);
- auto mapped_in_it = concrete_to_id_.find(concrete_in);
- if (mapped_in_it == concrete_to_id_.end()) {
- // If we can't find the concrete IDs in our local map, don't do anything.
- return;
- }
-
- auto mapped_in = mapped_in_it->second;
-
- if (leaf_ids_.find(mapped_in) == leaf_ids_.end()) {
- // If ID has already been replayed, don't do anything.
- return;
- }
-
- auto replayed_outs =
- IterDomain::split(mapped_in, s->factor(), s->innerSplit());
-
- auto concrete_outer =
- GpuLower::current()->caIndexMap().getConcreteMappedID(s->outer());
- auto concrete_inner =
- GpuLower::current()->caIndexMap().getConcreteMappedID(s->inner());
-
- // Update leaf id set and concrete id map
- leaf_ids_.erase(mapped_in);
- leaf_ids_.emplace(replayed_outs.first);
- leaf_ids_.emplace(replayed_outs.second);
- concrete_to_id_[concrete_outer] = replayed_outs.first;
- concrete_to_id_[concrete_inner] = replayed_outs.second;
-}
-
-// We're going to replay this merge operation on the corresponding IDs
-void IndexReferenceReplay::handle(Merge* m) {
- auto in_outer = m->outer();
- auto in_inner = m->inner();
-
- auto concrete_in_outer =
- GpuLower::current()->caIndexMap().getConcreteMappedID(in_outer);
- auto concrete_in_inner =
- GpuLower::current()->caIndexMap().getConcreteMappedID(in_inner);
-
- auto mapped_in_outer_it = concrete_to_id_.find(concrete_in_outer);
- auto mapped_in_inner_it = concrete_to_id_.find(concrete_in_inner);
-
- if (mapped_in_outer_it == concrete_to_id_.end() ||
- mapped_in_inner_it == concrete_to_id_.end()) {
- // If we can't find the concrete IDs in our local map, don't do anything.
- return;
- }
-
- auto mapped_in_outer = mapped_in_outer_it->second;
- auto mapped_in_inner = mapped_in_inner_it->second;
-
- if (leaf_ids_.find(mapped_in_outer) == leaf_ids_.end() &&
- leaf_ids_.find(mapped_in_inner) == leaf_ids_.end()) {
- // If ID has already been replayed, don't do anything.
- return;
- }
- auto replayed = IterDomain::merge(mapped_in_outer, mapped_in_inner);
-
- auto concrete_out =
- GpuLower::current()->caIndexMap().getConcreteMappedID(m->out());
-
- // Update leaf id set and concrete id map
- leaf_ids_.erase(mapped_in_outer);
- leaf_ids_.erase(mapped_in_inner);
- leaf_ids_.emplace(replayed);
- concrete_to_id_[concrete_out] = replayed;
-}
-
-TensorDomain* IndexReferenceReplay::computeReplay() {
- auto gpu_lower = GpuLower::current();
- // Throw an error when two loops are mapped with each other, which
- // violates an assumption that unique mappings between concrete
- // IterDomains and the IterDomains of the loop structure must be
- // established. It should be a reasonable assumption, but fusions
- // like below won't work:
- // tv0 = [I0]
- // tv1 = broadcast(tv0, {true, false});
- // tv2 = broadcast(tv0, {false, true});
- // tv3 = tv1 + tv2
- // Notice that the two axes of each of tv1, tv2 and tv3 are mapped
- // with each other. We believe it is unlikely this limitation
- // becomes a real concern in practice.
- for (auto it_i = loop_structure_.begin(); it_i != loop_structure_.end();
- ++it_i) {
- for (auto it_j = it_i + 1; it_j != loop_structure_.end(); ++it_j) {
- TORCH_INTERNAL_ASSERT(
- !gpu_lower->caIndexMap().areMapped(
- (*it_i)->iter_domain(), (*it_j)->iter_domain()),
- "Unsupported loop structure. Two loops are mapped together.");
- }
- }
-
- // Grab the iter domain's from the loop structure
- std::vector<IterDomain*> fusion_loop_structure;
-
- std::transform(
- loop_structure_.begin(),
- loop_structure_.end(),
- std::back_inserter(fusion_loop_structure),
- [&](kir::ForLoop* fl) {
- auto fid = gpu_lower->caIndexMap().toFusion(fl->iter_domain());
- return fid;
- });
-
- // Get any and all inputs that generated the provided loop structure, some
- // root inputs may be mapped to eachother but not identical
- auto all_inputs = InputsOf::outputs(
- FusionGuard::getCurFusion(),
- std::vector<Val*>(
- fusion_loop_structure.begin(), fusion_loop_structure.end()));
-
- // Make sure all inputs are iter domains, ignoring anything like split factor
- // inputs
- auto all_iter_inputs = ir_utils::filterByType<IterDomain>(all_inputs);
-
- // Sort out the inputs as there could be entires that map to eachother, and
- // they can be a combiantion of iteration, reduction, and broadcast. Order as
- // iter, reduction, then broadcast for iterating and removing duplicate mapped
- // entries. Since these are input IterDomains we mainly want to prioritize
- // non-broadcast "versions" of the iter domain if it shows up more than once.
- // We could get both if we have a compute at structure where a consumer has a
- // concrete iter domain but it's producer has a broadcast domain, and the
- // compute at axis is across a split on this domain. The producer would give a
- // broadcast input, consumer would have iter domain input.
- // Additionally, we prefer non-reduction iter domains over reduciton
- // domains, but this is just optional and not necessary for correctness.
- std::vector<IterDomain*> sorted_inputs;
- std::copy_if(
- all_iter_inputs.begin(),
- all_iter_inputs.end(),
- std::back_inserter(sorted_inputs),
- [](IterDomain* id) { return !id->isBroadcast() && !id->isReduction(); });
- std::copy_if(
- all_iter_inputs.begin(),
- all_iter_inputs.end(),
- std::back_inserter(sorted_inputs),
- [](IterDomain* id) { return id->isReduction(); });
- std::copy_if(
- all_iter_inputs.begin(),
- all_iter_inputs.end(),
- std::back_inserter(sorted_inputs),
- [](IterDomain* id) { return id->isBroadcast(); });
-
- // Produce a non repetitive set of inputs. Remove "duplicate" IterDomains that
- // map to eachother.
- std::vector<IterDomain*> root_axes;
- for (auto root_id : sorted_inputs) {
- auto concrete_id = gpu_lower->caIndexMap().getConcreteMappedID(root_id);
- if (concrete_to_id_.find(concrete_id) != concrete_to_id_.end()) {
- continue;
- }
-
- // Make a copy of the root_id for the reference to "own"
- IterDomain* root_id_copy = root_id->clone();
-
- // Initialize root axes, concrete map, and leaf map for replay.
- root_axes.push_back(root_id_copy);
- concrete_to_id_[concrete_id] = root_id_copy;
- leaf_ids_.emplace(root_id_copy);
- }
-
- // Order is important here, replay expressions from loops outside to inside.
- auto replay_exprs = ExprSort::getExprs(
- FusionGuard::getCurFusion(),
- {fusion_loop_structure.begin(), fusion_loop_structure.end()});
-
- // Run the reference replay
- for (auto expr : replay_exprs) {
- OptInDispatch::handle(expr);
- }
-
- // Construct a tensor that's representitive of the replayed loop structure.
- std::vector<IterDomain*> loops_replayed_domain;
-
- // Grab a set of concrete leaf ids to make it easier to search which for loop
- // matches the leaf id from the replay.
- std::unordered_set<IterDomain*> concrete_leaf_ids;
- for (auto entry : concrete_to_id_) {
- if (leaf_ids_.find(entry.second) != leaf_ids_.end()) {
- concrete_leaf_ids.emplace(entry.first);
- }
- }
-
- // Figure out which ID's that were replayed correspond to the respective loops
- // that were replayed.
- std::transform(
- fusion_loop_structure.begin(),
- fusion_loop_structure.end(),
- std::back_inserter(loops_replayed_domain),
- [&](IterDomain* loop_id) {
- for (auto id : concrete_leaf_ids) {
- // Matching has to be done on loop map, though replay was done in ID
- // map, so we need to manually check that things are mapped in the
- // loop map. Cannot simply look up concrete IDs to match them as index
- // map and loop map do not have the same concrete id mapping. We also
- // allow matching explicitly through the index map. Index map is not
- // gauranteed to be contained in loop map, therefore if we generate
- // mappings to conrete id's through the index map, the mapping from
- // those ID's to the ID's we replay are not gauranteed to be in loop
- // map. The reverse is also true, so for validation make sure one of
- // the mappings exist. For reference check the difference between:
- // AdvancedLowering5 test and AdvancedIndexing1.
- if (gpu_lower->caLoopMap().areMapped(id, loop_id) ||
- gpu_lower->caIndexMap().areMapped(id, loop_id)) {
- concrete_leaf_ids.erase(id);
- auto replayed_id = concrete_to_id_.at(id);
- // Propagate parallelization and vectorization. Necessary
- // for indexing. IndexCompute::getExtent depends on the
- // propagated parallelization.
- if (isParallelTypeVectorize(loop_id->getParallelType()) ||
- isParallelTypeThread(loop_id->getParallelType())) {
- replayed_id->parallelize(loop_id->getParallelType());
- }
- return replayed_id;
- }
- }
-
- TORCH_INTERNAL_ASSERT(
- false,
- "Could not find required iter domain in reference replay: ",
- loop_id);
- });
-
- // Add any remaining leaf iter domains, this can happen from rfactor patterns.
- for (auto entry : concrete_leaf_ids) {
- loops_replayed_domain.push_back(concrete_to_id_.at(entry));
- }
- if (replay_exprs.empty()) {
- auto domain = new TensorDomain(
- // If there was no replay only return a domain with a root domain.
- loops_replayed_domain);
- return domain;
- } else {
- auto domain = new TensorDomain(root_axes, loops_replayed_domain);
- return domain;
- }
-}
-
-IndexCompute getReferenceIndexing(
- const std::vector<kir::ForLoop*>& loop_structure,
- TensorDomain* reference_tensor) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- // Create a simple index mapping from loop iter domains to their local index.
- // This is only applicable to global memory buffers.
- std::unordered_map<kir::IterDomain*, kir::Val*> initial_index_map;
-
- TORCH_INTERNAL_ASSERT(loop_structure.size() <= reference_tensor->nDims());
- int magic_zero_loop = -1;
- for (size_t loop_i = 0; loop_i < loop_structure.size(); loop_i++) {
- auto ref_axis = reference_tensor->axis(loop_i);
- auto kir_ref_axis = gpu_lower->lowerValue(ref_axis)->as<kir::IterDomain>();
- auto loop = loop_structure[loop_i];
- auto ind = loop->index();
- ;
-
- initial_index_map[kir_ref_axis] = ind;
- if (loop->vectorize()) {
- initial_index_map[kir_ref_axis] = ir_builder.create<kir::Int>(0);
- }
-
- if (Index::protectWithMagicZero(loop, ref_axis, ind)) {
- magic_zero_loop = (int)loop_i;
- }
- }
-
- // Add magic zero to a fairly inner most index
- if (magic_zero_loop >= 0) {
- auto ref_id = gpu_lower->lowerValue(reference_tensor->axis(magic_zero_loop))
- ->as<kir::IterDomain>();
- initial_index_map[ref_id] = ir_builder.addExpr(
- initial_index_map[ref_id], ir_builder.magicZeroVal());
- }
-
- // Send to the other version of reference indexing that directly takes the
- // index map
- return getReferenceIndexing(
- loop_structure, reference_tensor, initial_index_map, {});
-}
-
-IndexCompute getReferenceIndexing(
- const std::vector<kir::ForLoop*>& loop_structure,
- TensorDomain* reference_tensor,
- std::unordered_map<kir::IterDomain*, kir::Val*> index_map,
- std::unordered_set<IterDomain*> preferred_paths,
- std::unordered_map<kir::IterDomain*, kir::Val*> halo_extent_map) {
- auto gpu_lower = GpuLower::current();
-
- // I thought this might be necesasry, but turns out it's not. I think it's
- // because of the root ordering above, however leaving it in case we find
- // out it is necessary in some cases. At the time of commiting, cuda-memcheck
- // passed without this.
- //
- // std::unordered_map<kir::IterDomain*,
- // kir::Val*> reference_extent_map; for (auto loop : loop_structure) {
- // // If there's a broadcast merged in the for loop ID we want to track its
- // // extent
- // auto inputs = InputsOf::outputs(
- // FusionGuard::getCurFusion(),
- // {gpu_lower->caIndexMap().toFusion(loop->iter_domain())});
-
- // auto iter_inputs = ir_utils::filterByType<IterDomain>(inputs);
-
- // // If any of the inputs are a broadcast, explicitly mark the loop id's
- // // extent
- // if (std::any_of(iter_inputs.begin(), iter_inputs.end(), [](IterDomain*
- // id) {
- // return id->isBroadcast();
- // })) {
- // reference_extent_map[loop->iter_domain()] =
- // loop->iter_domain()->extent();
- // }
- // }
-
- // Convert to preferred_path to kir::IterDomain for IndexCompute
- std::unordered_set<kir::IterDomain*> kir_preferred_path;
- std::transform(
- preferred_paths.begin(),
- preferred_paths.end(),
- std::inserter(kir_preferred_path, kir_preferred_path.begin()),
- [&gpu_lower](IterDomain* id) {
- return gpu_lower->lowerValue(id)->as<kir::IterDomain>();
- });
-
- IndexCompute compute(
- reference_tensor,
- index_map, // NOLINT
- // reference_extent_map, // Seems this is not necessary, see comment above
- // in this function
- {},
- std::unordered_set<kir::IterDomain*>(),
- reference_tensor->contiguity(),
- kir_preferred_path,
- halo_extent_map);
-
- compute.run();
-
- return compute;
-}
-
-namespace {
-
-// Class to track through the reference what path to take for zero merged in
-// indices if we're indexing shared memory or local memory. Use marked root
-// domains and traverse through the replay to mark paths to get to them during a
-// backward replay.
-class PreferredPathCompute : public IterVisitor {
- private:
- void handle(Expr* e) override {
- // If an input ID is marked, propagate the marking to outputs of the
- // expression
- auto all_iter_inputs = ir_utils::filterByType<IterDomain>(e->inputs());
- if (std::any_of(
- all_iter_inputs.begin(),
- all_iter_inputs.end(),
- [&](IterDomain* inp_id) {
- return this->preferred_path.find(inp_id) !=
- this->preferred_path.end();
- })) {
- auto all_iter_outputs = ir_utils::filterByType<IterDomain>(e->outputs());
- preferred_path.insert(all_iter_outputs.begin(), all_iter_outputs.end());
- }
- }
-
- private:
- // If making a choice these are the iter domains to prefer when traversing
- // backward.
- std::unordered_set<IterDomain*> preferred_path;
-
- public:
- static std::unordered_set<IterDomain*> compute(
- TensorDomain* reference_domain,
- const std::unordered_set<IterDomain*>& preferred_roots) {
- // TODO: assert all provided preferred roots are in the history of reference
- // domain.
-
- PreferredPathCompute compute;
- // Init preferred path
- compute.preferred_path = preferred_roots;
-
- // Propagate
- compute.traverseFrom(
- FusionGuard::getCurFusion(),
- std::vector<Val*>(
- reference_domain->domain().begin(),
- reference_domain->domain().end()));
-
- return compute.preferred_path;
- }
-};
-} // namespace
-
-// External interface for preferred path propagation.
-std::unordered_set<IterDomain*> buildPreferredPaths(
- TensorDomain* reference_tensor,
- const std::unordered_set<IterDomain*>& preferred_roots) {
- return PreferredPathCompute::compute(reference_tensor, preferred_roots);
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-struct ReferenceTensor {
- TensorDomain* domain = nullptr;
-
- // Map from concrete iteration domains in ComputeAtMaps to iter domains
- // including those used to construct domain.
- std::unordered_map<IterDomain*, IterDomain*> concrete_to_id;
-};
-
-class IndexReferenceReplay : public OptInDispatch {
- private:
- IndexReferenceReplay(const std::vector<kir::ForLoop*>& loop_structure)
- : loop_structure_(loop_structure) {}
-
- // We're going to replay this split operation on the corresponding ID
- void handle(Split* s) override;
-
- // We're going to replay this merge operation on the corresponding IDs
- void handle(Merge* m) override;
-
- TensorDomain* computeReplay();
-
- using OptInDispatch::handle;
-
- private:
- const std::vector<kir::ForLoop*>& loop_structure_;
-
- // Replay map
- std::unordered_map<IterDomain*, IterDomain*> concrete_to_id_;
-
- // Replay map
- std::unordered_set<IterDomain*> leaf_ids_;
-
- public:
- static ReferenceTensor getReference(
- const std::vector<kir::ForLoop*>& loop_structure) {
- auto replay = IndexReferenceReplay(loop_structure);
- ReferenceTensor ref;
- ref.domain = replay.computeReplay();
- ref.concrete_to_id = replay.concrete_to_id_;
- return ref;
- }
-};
-
-// Index into the reference based on the provided index map.
-IndexCompute getReferenceIndexing(
- const std::vector<kir::ForLoop*>& loop_structure,
- TensorDomain* reference_domain,
- std::unordered_map<kir::IterDomain*, kir::Val*> index_map,
- std::unordered_set<IterDomain*> preferred_path,
- std::unordered_map<kir::IterDomain*, kir::Val*> halo_extent_map = {});
-
-// Short cut for global TVs. Index into the reference based on all loop indicies
-// in the loop structure.
-IndexCompute getReferenceIndexing(
- const std::vector<kir::ForLoop*>& loop_structure,
- TensorDomain* reference_domain);
-
-// When indexing there are sometimes an option to propagate an index down
-// multiple paths. This will return the IterDomains in the history of the
-// reference domain and mark which paths should be taken (if there's a
-// preference) to reach the roots provided in preferred_roots.
-std::unordered_set<IterDomain*> buildPreferredPaths(
- TensorDomain* reference_domain,
- const std::unordered_set<IterDomain*>& preferred_roots);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
namespace inst {
Trace::Trace() {
- const char* trace_filename = getenv("PYTORCH_NVFUSER_TRACE");
+ const char* trace_filename = getenv("PYTORCH_CUDA_FUSER_TRACE");
if (trace_filename != nullptr) {
log_file_ = fopen(trace_filename, "w");
TORCH_CHECK(log_file_ != nullptr, "Can't open trace file");
#include <torch/csrc/jit/codegen/cuda/utils.h>
-#include <nvToolsExt.h>
-
-// NOLINTNEXTLINE(modernize-deprecated-headers)
-#include <stdio.h>
#include <chrono>
#include <cstdio>
//! This class is not intended to be used directly. Instead, the operations
//! to be traced are marked (for example using the FUSER_PERF_SCOPE macro)
//!
-//! In order to enable tracing, the `PYTORCH_NVFUSER_TRACE` environment
+//! In order to enable tracing, the `PYTORCH_CUDA_FUSER_TRACE` environment
//! variable is set to point to a trace file (ex `test.trace`). The file name
//! may be a relative or an absolute path.
//!
if (log_file_ != nullptr) {
logEvent('B', name);
}
- nvtxRangePushA(name);
}
void endEvent(const char* name) {
- nvtxRangePop();
if (log_file_ != nullptr) {
logEvent('E', name);
}
void compileFusionGroup(Node* fusion_node) {
TORCH_CHECK(
- getFuserInterface()->fn_compile_n != nullptr,
+ getFuserInterface()->fn_compile_n_ != nullptr,
"Running the CUDA fuser requires a CUDA build.");
- getFuserInterface()->fn_compile_n(fusion_node);
+ getFuserInterface()->fn_compile_n_(fusion_node);
}
void runFusionGroup(const Node* fusion_node, Stack& stack) {
TORCH_CHECK(
- getFuserInterface()->fn_run_n_s != nullptr,
+ getFuserInterface()->fn_run_n_s_ != nullptr,
"Running the CUDA fuser requires a CUDA build.");
- getFuserInterface()->fn_run_n_s(fusion_node, stack);
+ getFuserInterface()->fn_run_n_s_(fusion_node, stack);
}
void fuseGraph(std::shared_ptr<Graph>& graph) {
TORCH_CHECK(
- getFuserInterface()->fn_fuse_graph != nullptr,
+ getFuserInterface()->fn_fuse_graph_ != nullptr,
"Running the CUDA fuser requires a CUDA build.");
- getFuserInterface()->fn_fuse_graph(graph);
+ getFuserInterface()->fn_fuse_graph_(graph);
}
bool canFuseNode(const Node* node) {
- return getFuserInterface()->fn_can_fuse_n != nullptr &&
- getFuserInterface()->fn_can_fuse_n(node);
-}
-
-void InsertProfileNodesForCUDAFuser(ProfilingRecord* pr) {
- if (getFuserInterface()->fn_insert_profile_inodes) {
- getFuserInterface()->fn_insert_profile_inodes(pr);
- }
+ return getFuserInterface()->fn_can_fuse_n_ != nullptr &&
+ getFuserInterface()->fn_can_fuse_n_(node);
}
//! [ Note -- type guard logic in CudaFusionGuard ]
// check a. if num_dimension check fails or scalar type check fails
if (*guard_tensor_type->dim() != static_cast<size_t>(tensor.ndimension()) ||
(guard_tensor_type->scalarType().has_value() &&
- (guard_tensor_type->scalarType().value() != tensor.scalar_type()))) {
+ (guard_tensor_type->scalarType().value() != tensor.scalar_type())) ||
+ tensor.requires_grad()) {
return false;
}
if (j != 0) {
// we use contiguity to collapse dimension, if size == 1, it is
// always collapsible
- // computeStrideProps also default to contiguous when stride == 1
- if (t_sizes[sorted_index] != 1 && t_strides[sorted_index] != 1) {
+ if (t_sizes[sorted_index] != 1) {
TORCH_INTERNAL_ASSERT(
stride_properties[j - 1]->stride_index_.has_value(),
"Counknown index is meaningless");
namespace {
-// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
-RegisterOperators size_eq_guard({
- Operator(
- //"prim::CudaFusionSizeEq(int[] size, int[] ref) -> bool",
- "prim::CudaFusionSizeEq(...) -> bool",
- // prim::CudaFusionGuard returns a fresh Boolean type without aliasing.
- // if we would ever return refined tensor, which would change aliasing
- // analysis, we should update aliasdb pass.
- [](const Node* node) -> Operation {
- return [](Stack* stack) {
- at::ArrayRef<IValue> inputs = last(stack, 2);
- drop(stack, 2);
-
- if (!fuser::cuda::getCudaFusionGuardMode()) {
- push(stack, IValue(true));
- return;
- }
-
- // auto inp = inputs[0].toIntList();
- TORCH_INTERNAL_ASSERT(
- inputs[1].isIntList(), "reference needs to be of int list");
- auto ref = inputs[1].toIntList();
-
- auto ret = true;
- if (ref.empty()) {
- ret = inputs[0].isNone();
- } else {
- if (inputs[0].isIntList()) {
- auto inp = inputs[0].toIntList();
- if (inp.size() != ref.size()) {
- push(stack, IValue(false));
- return;
- }
-
- for (size_t i = 0; i < inp.size(); i++) {
- if (((inp[i] == 1) != (ref[i] == 1))) {
- ret = false;
- break;
- }
- }
- } else {
- ret = false;
- }
- }
-
- push(stack, IValue(ret));
- return;
- };
- },
- aliasAnalysisFromSchema()),
-});
-
-// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_fusion({
Operator(
prim::CudaFusionGroup,
},
aliasAnalysisFromSchema()),
});
-
-// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
-RegisterOperators reg_add_optional({
- Operator(
- "prim::add_optional(Tensor(a) input, Tensor? bias) -> Tensor(a)",
- [](const Node* node) -> Operation {
- return [](Stack* stack) {
- IValue input, bias;
- pop(stack, input, bias);
- if (bias.isNone()) {
- push(stack, std::move(input));
- } else {
- push(stack, at::add(input.toTensor(), bias.toTensor(), 1.0));
- }
- };
- },
- aliasAnalysisFromSchema()),
-});
} // namespace
} // namespace jit
// dummy struct to allow API registration
struct CudaFuserInterface {
- void (*fn_compile_n)(Node*) = nullptr;
- void (*fn_run_n_s)(const Node*, Stack&) = nullptr;
- void (*fn_fuse_graph)(std::shared_ptr<Graph>&) = nullptr;
- bool (*fn_can_fuse_n)(const Node*) = nullptr;
- void (*fn_insert_profile_inodes)(ProfilingRecord* pr) = nullptr;
+ void (*fn_compile_n_)(Node*) = nullptr;
+ void (*fn_run_n_s_)(const Node*, Stack&) = nullptr;
+ void (*fn_fuse_graph_)(std::shared_ptr<Graph>&) = nullptr;
+ bool (*fn_can_fuse_n_)(const Node*) = nullptr;
};
// Get interface, this is used by registration and user facing API internally
C10_EXPORT void runFusionGroup(const Node* fusion_node, Stack& stack);
C10_EXPORT void fuseGraph(std::shared_ptr<Graph>&);
C10_EXPORT bool canFuseNode(const Node* node);
-C10_EXPORT void InsertProfileNodesForCUDAFuser(ProfilingRecord* pr);
C10_EXPORT bool complyWith(
const at::Tensor& tensor,
}
// When we create a Val we immediately register them with the active fusion.
-Val::Val(ValType _vtype, DataType _dtype, bool register_val)
+Val::Val(ValType _vtype, DataType _dtype, bool register_val, bool lowered)
: vtype_(_vtype), dtype_(_dtype) {
Fusion* fusion = FusionGuard::getCurFusion();
TORCH_CHECK(
fusion != nullptr, "No active fusion group found when creating a Val.");
fusion_ = fusion;
if (register_val) {
- name_ = fusion_->registerVal(this);
+ if (lowered) {
+ name_ = fusion_->registerLoweredVal(this);
+ } else {
+ name_ = fusion_->registerVal(this);
+ }
}
}
-// NOTE: we don't clone the definition_ and uses_ here
-// since they may introduce cloning cycles. Instead, we copy
-// the original pointers and we'll fix them up later part of the
-// Fusion copy. Neither definition_ nor uses_ are copied through
-// this constructor now leaving them to be resolved by later stages
-//
-Val::Val(const Val* src, IrCloner* ir_cloner)
- : Statement(src, ir_cloner),
- vtype_(src->vtype_),
- dtype_(src->dtype_),
- is_fusion_input_(src->is_fusion_input_),
- is_fusion_output_(src->is_fusion_output_) {}
-
-const std::vector<Expr*>& Val::uses() const {
- if (vtype_ == ValType::TensorView) {
- if (!fusion()->isTVUseInfoValid() && !fusion()->isUpdatingTVUseInfo()) {
- fusion()->resetTvUses();
- }
+namespace {
+
+// TODO(kir): remove this
+ValType lowerValType(ValType vtype) {
+ switch (vtype) {
+ case ValType::Scalar:
+ return ValType::KirScalar;
+ case ValType::NamedScalar:
+ return ValType::KirNamedScalar;
+ case ValType::TensorDomain:
+ return ValType::KirTensorDomain;
+ case ValType::IterDomain:
+ return ValType::KirIterDomain;
+ case ValType::TensorView:
+ return ValType::KirTensorView;
+ default:
+ TORCH_CHECK(false, "Unexpected");
}
- return uses_;
}
-namespace {
+} // namespace
+
+// TODO(kir): remove this
+Val::Val(const Val* fusion_ir_node)
+ : vtype_(lowerValType(fusion_ir_node->vtype_)),
+ dtype_(fusion_ir_node->dtype_) {
+ // The lowered nodes preserve the names from the fusion IR counterparts
+ name_ = fusion_ir_node->name_;
+ fusion_ = fusion_ir_node->fusion_;
+ fusion_->registerLoweredVal(this);
+}
-// Traverse definition of all values involved in constructing the provided val.
+Val::Val(const Val* src, IrCloner* ir_cloner)
+ : Statement(src, ir_cloner), vtype_(src->vtype_), dtype_(src->dtype_) {}
+
+// Traverse origin of all values involved in constructing the provided val.
// Check if all values involved are constant values, meaning the provided
// val is also a constant value.
+namespace {
+
class ConstCheck : OptOutConstDispatch {
private:
bool is_const_ = true;
is_const_ = is_const_ && b->isConst();
}
- void handle(const Double* d) override {
- is_const_ = is_const_ && d->isConst();
+ void handle(const Float* f) override {
+ is_const_ = is_const_ && f->isConst();
+ }
+
+ void handle(const Half* h) override {
+ is_const_ = is_const_ && h->isConst();
}
void handle(const Int* i) override {
is_const_ = is_const_ && false;
}
+ void handle(const kir::Bool* b) override {
+ is_const_ = is_const_ && b->isConst();
+ }
+
+ void handle(const kir::Float* f) override {
+ is_const_ = is_const_ && f->isConst();
+ }
+
+ void handle(const kir::Half* h) override {
+ is_const_ = is_const_ && h->isConst();
+ }
+
+ void handle(const kir::Int* i) override {
+ is_const_ = is_const_ && i->isConst();
+ }
+
+ void handle(const kir::NamedScalar* ns) override {
+ is_const_ = is_const_ && false;
+ }
+
void handle(const Expr* expr) override {
for (auto inp : expr->inputs()) {
handle(inp);
}
void handle(const Val* val) override {
- if (val->definition() != nullptr) {
- handle(val->definition());
- } else {
+ const Expr* orig = FusionGuard::getCurFusion()->origin(val);
+ if (orig != nullptr)
+ handle(orig);
+ else
OptOutConstDispatch::handle(val);
- }
}
public:
if (isConstScalar() && isAnInt()) {
if (this->getValType() == ValType::Scalar) {
return this->as<Int>()->value();
+ } else if (this->getValType() == ValType::KirScalar) {
+ return this->as<kir::Int>()->value();
}
}
return c10::optional<int64_t>();
return dtype_;
}
-bool Val::isProducerOf(const Val* other) const {
- TORCH_INTERNAL_ASSERT(other != nullptr);
- TORCH_INTERNAL_ASSERT(fusion() == other->fusion());
-
- if (definition() == nullptr) {
- return false;
- }
- return std::any_of(
- definition()->inputs().begin(),
- definition()->inputs().end(),
- [other](const Val* input) { return input == other; });
+Expr* Val::getOrigin() {
+ return fusion_->origin(this);
}
-bool Val::isConsumerOf(const Val* other) const {
- return other->isProducerOf(this);
+const Expr* Val::getOrigin() const {
+ return fusion_->origin(this);
}
// We don't register with the active fusion in Expr as this needs to be done
// after inputs and outputs are registered with the Expr
-Expr::Expr(ExprType type) : type_{type} {
+Expr::Expr(ExprType _type) : type_{_type} {
Fusion* fusion = FusionGuard::getCurFusion();
if (fusion == nullptr)
TORCH_CHECK(false, "No active fusion group found when creating an Expr.");
inputs_(ir_cloner->clone(src->inputs_)),
outputs_(ir_cloner->clone(src->outputs_)) {}
-bool Expr::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<Expr>()) {
+bool Expr::sameAs(const Expr* const other) const {
+ if (getExprType() != other->getExprType())
return false;
- }
- const Expr* other_expr = other->as<Expr>();
- if (getExprType() != other_expr->getExprType()) {
- return false;
- }
- if (inputs().size() != other_expr->inputs().size() ||
- outputs().size() != other_expr->outputs().size()) {
+ if (inputs().size() != other->inputs().size() ||
+ outputs().size() != other->outputs().size())
return false;
- }
for (const auto i : c10::irange(inputs().size())) {
- if (!input(i)->sameAs(other_expr->input(i))) {
+ if (!input(i)->sameAs(other->input(i)))
return false;
- }
}
return true;
}
#include <torch/csrc/jit/codegen/cuda/utils.h>
#include <cstdint>
+#include <deque>
#include <iostream>
#include <limits>
#include <memory>
using StmtNameType = unsigned int;
-constexpr StmtNameType kInvalidStmName =
+constexpr StmtNameType UNINITIALIZED_STMTNAMETYPE =
std::numeric_limits<unsigned int>::max();
class Fusion;
class IterDomain;
class IrCloner;
-//! Statement is the highest level node representation. Everything that is
-//! considered "IR" will be derived from this class at some point. Both Values
-//! and Expr's are a Statement. If there will ever be any more fundamental
-//! types, they will also derive from Statement.
-//!
-//! We use Statements to pass around nodes of unknown compile type. Therefore it
-//! is also important for the design to have a dispatch system for a Statment.
-//! Basically beinng able to succienctly traverse down the inhereitance stack of
-//! a Statment at runtime. This is currently implemented in dispatch.h
-//!
+/*
+ * Statement is the highest level node representation. Everything that is
+ * considered "IR" will be derived from this class at some point. Both Values
+ * and Expr's are a Statement. If there will ever be any more fundamental types,
+ * they will also derive from Statement.
+ *
+ * We use Statements to pass around nodes of unknown compile type. Therefore it
+ * is also important for the design to have a dispatch system for a Statment.
+ * Basically beinng able to succienctly traverse down the inhereitance stack of
+ * a Statment at runtime. This is currently implemented in dispatch.h
+ */
class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase {
friend void swap(Fusion&, Fusion&) noexcept;
// Return if this statement is the same as another statement
// TODO: should this run through dispatch on this and other?
- virtual bool sameAs(const Statement* other) const {
+ bool sameAs(const Statement* const other) const {
return this == other;
}
protected:
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
- StmtNameType name_ = kInvalidStmName;
+ StmtNameType name_ = UNINITIALIZED_STMTNAMETYPE;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
Fusion* fusion_ = nullptr;
};
-//! A Val represents a "value." These are objects, like tensors, scalars, and
-//! memory locations, that are inputs and outputs of computations (represented
-//! by Exprs, below)
-//!
-//! Vals are constant and unique and should always be passed
-//! around as a pointer. Val can generally be thought of as representing any
-//! type of data. Some examples: a constant size like convolution filter width a
-//! runtime constant like batch normalizations momentum a "symbolic" tensor like
-//! one passed down from the JIT a memory buffer used in device code
-//!
-//! Adding a Val:
-//! Right now adding a Val is quite involved. Val's can be defined in ir.h or in
-//! their own header file. The following is what is currently needed to add a
-//! new Val:
-//!
-//! 1) Definition inheriting from Val
-//! - Members must be private or protected
-//! - Accessor functions for members
-//! - Must call Val constructor, Val constructor registers with fusion
-//! - Implementation of bool sameAs(...)
-//! - Must implement a "cloning" constructor, ex.
-//! Int::Int(const Int* src, IrCloner* ir_cloner)
-//! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
-//! 3) Default mutator function should be added to mutator.cpp
-//! 4a) Printing functions should be added to ir_iostream.h/.cpp
-//! 4b) Graphviz generation must be added to ir_graphviz.h/.cpp
-//! 5) An enum value must be added to ValType in type.h
-//! 6) A string entry must be added in val_type_string_map
-//!
+/*
+ * A Val represents a "value." These are objects, like tensors, scalars, and
+ * memory locations, that are inputs and outputs of computations (represented
+ * by Exprs, below). Vals are constant and unique and should always be passed
+ * around as a pointer. Val can generally be thought of as representing any type
+ * of data. Some examples: a constant size like convolution filter width a
+ * runtime constant like batch normalizations momentum a "symbolic" tensor like
+ * one passed down from the JIT a memory buffer used in device code
+ *
+ * Adding a Val:
+ * Right now adding a Val is quite involved. Val's can be defined in ir.h or in
+ * their own header file. The following is what is currently needed to add a new
+ * Val:
+ * 1) Definition inheriting from Val
+ * - Members must be private or protected
+ * - Accessor functions for members
+ * - Must call Val constructor, Val constructor registers with fusion
+ * - Implementation of bool sameAs(...)
+ * - Must implement a "cloning" constructor, ex.
+ * Int::Int(const Int* src, IrCloner* ir_cloner)
+ * 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
+ * 3) Default mutator function should be added to mutator.cpp
+ * 4a) Printing functions should be added to ir_iostream.h/.cpp
+ * 4b) Graphviz generation must be added to ir_graphviz.h/.cpp
+ * 5) An enum value must be added to ValType in type.h
+ * 6) A string entry must be added in val_type_string_map
+ */
class TORCH_CUDA_CU_API Val : public Statement {
public:
+ ~Val() override = default;
+
+ Val() = delete;
+
// We may not want to register this value during Val's constructor. The reason
// for this is that if we register the val, then ina derived constructor try
// to throw, fusion's destructor will get called, but the pointer to this Val
explicit Val(
ValType _vtype,
DataType _dtype = DataType::Null,
- bool register_val = true);
+ bool register_val = true,
+ bool lowered = false);
+
+ // Lowers an existing Fusion IR node into a Kernel IR counterpart
+ explicit Val(const Val* fusion_ir_node);
Val(const Val* src, IrCloner* ir_cloner);
- // TODO: why is this optional?
- //
+ // TODO: Values are unique and not copyable
+ Val(const Val& other) = delete;
+ Val& operator=(const Val& other) = delete;
+
+ Val(Val&& other) = delete;
+ Val& operator=(Val&& other) = delete;
+
c10::optional<ValType> getValType() const override {
return vtype_;
}
// Throws if no DataType is found. Vals must have a DataType
- //
- // TODO: why is this optional?
- //
c10::optional<DataType> getDataType() const override;
bool isScalar() const {
- return vtype_ == ValType::Scalar || vtype_ == ValType::NamedScalar;
+ return vtype_ == ValType::Scalar || vtype_ == ValType::NamedScalar ||
+ vtype_ == ValType::KirScalar || vtype_ == ValType::KirNamedScalar;
}
bool isConstScalar() const;
// Returns the Expr that this value is an output of, returns nullptr if none
// was found
- Expr* definition() const {
- if (is_fusion_input_) {
- return nullptr;
- }
- return definition_;
- }
-
- const std::vector<Expr*>& uses() const;
-
- bool isFusionInput() const {
- return is_fusion_input_;
- }
-
- bool isFusionOutput() const {
- return is_fusion_output_;
- }
-
- //! Returns true when other is a producer of this
- bool isProducerOf(const Val* other) const;
-
- //! Returns true when other is a consumer of this
- bool isConsumerOf(const Val* other) const;
+ Expr* getOrigin();
+ const Expr* getOrigin() const;
bool sameType(const Statement* other) override {
return Statement::sameType(other) &&
// TODO: Make this more sophisticated. A value being the same as another value
// should be evaluated based on the DAG that created it, and that DAGs leaf
// nodes
- bool sameAs(const Statement* other) const override {
+ bool sameAs(const Val* const other) const {
return this == other;
}
static Statement* mutatorDispatch(T mutator, Val*);
protected:
- friend Fusion;
-
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const ValType vtype_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const DataType dtype_;
-
- // Following is managed by Fusion and can change.
- void setDefinition(Expr* expr) {
- definition_ = expr;
- }
-
- void setIsFusionInput(bool is_fusion_input) {
- is_fusion_input_ = is_fusion_input;
- }
-
- void setIsFusionOutput(bool is_fusion_output) {
- is_fusion_output_ = is_fusion_output;
- }
-
- void setUses(const std::vector<Expr*>& uses) {
- uses_ = uses;
- }
-
- private:
- // Following is managed by Fusion and can change.
- bool is_fusion_input_ = false;
- bool is_fusion_output_ = false;
-
- Expr* definition_ = nullptr;
- std::vector<Expr*> uses_;
};
-//! A Expr represents a "computation." These are functions that takes inputs
-//! and produce outputs, inputs and outputs all being Vals. There are
-//! specializations of BinaryOp which takes 2 inputs and produces 1 output, and
-//! UnaryOp which takes 1 input and produces 1 output. Exprs are unique and
-//! immutable. Conceptually, Exprs could always be manipulated using unique
-//! pointers, and we could add this later. However, for now Exprs can be
-//! replaced in a fusion, but they cannot be modified in place.
-//!
-//! The IR is static single assignment (SSA). Values can only be defined as an
-//! output of an Expr once. If they are re-defined the original definition is
-//! deleted from the program, as opposed to an ordered redefinition of the
-//! value in the program.
-//!
-//! Note: Registering an Expr with a Fusion is actually 2 parts, one part is
-//! done in the Expr constructor, so that should be called on anything that
-//! inherits Expr. The issue with having registration in Expr's constructor, is
-//! that the constructor of an Expr will set ouputs and inputs. This
-//! information is important for registration with Fuser, so it can track the
-//! dependency chain.
-//!
-//! Adding an Expr:
-//! Right now adding an Expr is quite involved. Expr's can be defined in ir.h
-//! or in their own header file. The following is what is currently needed for
-//! Expr definitions:
-//!
-//! 1) Definition inheriting from Expr.
-//! - Members must be private or protected
-//! - Accessor functions for members
-//! - Constructors need to register with the Fusion after inputs/outputs
-//! are defined
-//! - Implementation of bool sameAs(...)
-//! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
-//! 3) Default mutator function should be added to mutator.h/.cpp
-//! 4) Printing functions should be added to ir_iostream.h/.cpp
-//! 5) Lower case convenience functions should be added to arith.h/.cpp (If
-//! user facing)
-//! 6) An enum value must be added to ExprType in type.h
-//! 7) A string entry must be added in expr_type_string_map
-//! 8) Entry added to ir_graphviz .cpp/.h
-//!
+// A Expr represents a "computation." These are functions that takes inputs
+// and produce outputs, inputs and outputs all being Vals. There are
+// specializations of BinaryOp which takes 2 inputs and produces 1 output, and
+// UnaryOp which takes 1 input and produces 1 output. Exprs are unique and
+// immutable. Conceptually, Exprs could always be manipulated using unique
+// pointers, and we could add this later. However, for now Exprs can be
+// replaced in a fusion, but they cannot be modified in place.
+
+// The IR is static single assignment (SSA). Values can only be defined as an
+// output of an Expr once. If they are re-defined the original definition is
+// deleted from the program, as opposed to an ordered redefinition of the value
+// in the program.
+
+// Note: Registering an Expr with a Fusion is actually 2 parts, one part is
+// done in the Expr constructor, so that should be called on anything that
+// inherits Expr. The issue with having registration in Expr's constructor, is
+// that the constructor of an Expr will set ouputs and inputs. This information
+// is important for registration with Fuser, so it can track the dependency
+// chain.
+
+// Adding an Expr:
+// Right now adding an Expr is quite involved. Expr's can be defined in ir.h or
+// in their own header file. The following is what is currently needed for Expr
+// definitions:
+// 1) Definition inheriting from Expr.
+// - Members must be private or protected
+// - Accessor functions for members
+// - Constructors need to register with the Fusion after inputs/outputs are
+// defined
+// - Implementation of bool sameAs(...)
+// 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
+// 3) Default mutator function should be added to mutator.h/.cpp
+// 4) Printing functions should be added to ir_iostream.h/.cpp
+// 5) Lower case convenience functions should be added to arith.h/.cpp (If user
+// facing)
+// 6) An enum value must be added to ExprType in type.h
+// 7) A string entry must be added in expr_type_string_map
+// 8) Entry added to ir_graphviz .cpp/.h
+
class TORCH_CUDA_CU_API Expr : public Statement {
public:
- explicit Expr(ExprType type);
+ Expr() = delete;
+ explicit Expr(ExprType _type);
Expr(const Expr* src, IrCloner* ir_cloner);
+ ~Expr() override = default;
+
+ Expr(const Expr& other) = delete;
+ Expr& operator=(const Expr& other) = delete;
+
+ Expr(Expr&& other) = delete;
+ Expr& operator=(Expr&& other) = delete;
c10::optional<ExprType> getExprType() const override {
return type_;
return type_;
}
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const Expr* const other) const;
// Input/output accessors
const auto& inputs() const {
clone_ = new Bool(b, this);
}
-void IrCloner::handle(const Double* d) {
- clone_ = new Double(d, this);
+void IrCloner::handle(const Float* f) {
+ clone_ = new Float(f, this);
+}
+
+void IrCloner::handle(const Half* h) {
+ clone_ = new Half(h, this);
}
void IrCloner::handle(const Int* i) {
clone_ = new ReductionOp(op, this);
}
-void IrCloner::handle(const WelfordOp* op) {
- clone_ = new WelfordOp(op, this);
-}
-
-void IrCloner::handle(const TransposeOp* op) {
- clone_ = new TransposeOp(op, this);
-}
-
-void IrCloner::handle(const ShiftOp* op) {
- clone_ = new ShiftOp(op, this);
-}
-
-void IrCloner::handle(const GatherOp* op) {
- clone_ = new GatherOp(op, this);
-}
-
void IrCloner::handle(const Split* split) {
clone_ = new Split(split, this);
}
class Fusion;
-//! Clones nodes from an exiting Fusion
-//!
-//! \warning IrCloner machinery is a specialized helper for implementing
-//! Fusion copy operations and it's not intended for any other uses
-//!
+// Clones nodes from an exiting Fusion
class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch {
friend class Statement;
void handle(const IterDomain*) override;
void handle(const Bool*) override;
- void handle(const Double*) override;
+ void handle(const Float*) override;
+ void handle(const Half*) override;
void handle(const Int*) override;
void handle(const NamedScalar*) override;
void handle(const TernaryOp*) override;
void handle(const BroadcastOp*) override;
void handle(const ReductionOp*) override;
- void handle(const WelfordOp*) override;
- void handle(const TransposeOp*) override;
- void handle(const ShiftOp*) override;
- void handle(const GatherOp*) override;
void handle(const Split*) override;
void handle(const Merge*) override;
}
}
- void handle(const Double* d) override {
- if (d->isSymbolic()) {
- label_ << "d" << d->name();
+ void handle(const Float* f) override {
+ if (f->isSymbolic()) {
+ label_ << "f" << f->name();
} else {
if (detail_level_ >= DetailLevel::Explicit) {
- label_ << "d" << d->name() << "=";
+ label_ << "f" << f->name() << "=";
}
- label_ << *d->value();
+ label_ << std::fixed << std::setprecision(2) << *f->value();
+ }
+ }
+
+ void handle(const Half* h) override {
+ if (h->isSymbolic()) {
+ label_ << "h" << h->name();
+ } else {
+ if (detail_level_ >= DetailLevel::Explicit) {
+ label_ << "h" << h->name() << "=";
+ }
+ label_ << *h->value();
}
}
label_ << IrNodeLabel::gen(id->start()) << " : ";
}
label_ << IrNodeLabel::gen(id->extent());
+ if (id->rawExtent() != id->extent()) {
+ label_ << "\\<" << IrNodeLabel::gen(id->rawExtent()) << "\\>";
+ }
label_ << ")";
}
void handle(const Split* split) override {
- label_ << "Split(inner=" << (split->innerSplit() ? "true" : "false")
- << ", factor=" << IrNodeLabel::gen(split->factor()) << ")";
+ label_ << "Split(factor=" << IrNodeLabel::gen(split->factor()) << ")";
}
void handle(const Merge* merge) override {
const DetailLevel detail_level_;
};
-// Small color palette from the X11 theme
-static const char* getColorFromIndex(size_t index) {
- const size_t number_of_colors = 10;
- index = index % number_of_colors;
- switch (index) {
- case 0: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- return "azure";
- case 1: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- return "pink";
- case 2: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- return "green";
- case 3: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- return "grey";
- case 4: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- return "yellow";
- case 5: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- return "lavender";
- case 6: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- return "cyan";
- case 7: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- return "white";
- case 8: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- return "magenta";
- case 9: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- return "red";
- default:
- break;
- }
- return "";
-}
-
} // anonymous namespace
void IrGraphGenerator::print(
const Fusion* fusion,
const char* filename,
- DetailLevel detail_level,
- ExprColorMap* expr_color_map) {
+ DetailLevel detail_level) {
std::ofstream dot_file(filename);
TORCH_CHECK(dot_file.good(), "Failed to open the IR graph file");
- dot_file << toGraphviz(fusion, detail_level, expr_color_map);
+ dot_file << toGraphviz(fusion, detail_level);
}
std::string IrGraphGenerator::toGraphviz(
const Fusion* fusion,
- DetailLevel detail_level,
- ExprColorMap* expr_color_map) {
- IrGraphGenerator ir_graph(fusion, detail_level, expr_color_map);
+ DetailLevel detail_level) {
+ IrGraphGenerator ir_graph(fusion, detail_level);
return ir_graph.generate();
}
IrGraphGenerator::IrGraphGenerator(
const Fusion* fusion,
- DetailLevel detail_level,
- ExprColorMap* expr_color_map)
- : detail_level_(detail_level),
- fusion_(fusion),
- expr_color_map_(expr_color_map) {
+ DetailLevel detail_level)
+ : detail_level_(detail_level), fusion_(fusion) {
// setup inputs & outputs
// (indexes used to quickly check if a value is fusion input or output)
for (const auto* input : fusion->inputs()) {
void IrGraphGenerator::printExpr(const Expr* expr, const std::string& label) {
graph_def_ << " " << getid(expr) << " "
<< "[label=\"" << label << "\", shape=oval, color=blue, "
- << "style=filled, fillcolor=";
- if (expr_color_map_ != nullptr && expr_color_map_->count(expr)) {
- graph_def_ << getColorFromIndex(expr_color_map_->at(expr));
- } else {
- graph_def_ << "azure";
- }
- graph_def_ << "];\n";
+ << "style=filled, fillcolor=azure];\n";
}
void IrGraphGenerator::printValue(const Val* val, const std::string& label) {
void IrGraphGenerator::handle(const Val* v) {
if (!visited(v)) {
visited_.insert(v);
- if (const auto* def = v->definition()) {
+ if (const auto* def = fusion_->origin(v)) {
handle(def);
}
OptInConstDispatch::handle(v);
addArc(id->start(), id, "[color=gray]");
}
- addArc(id->extent(), id, "[color=gray]");
+ addArc(id->rawExtent(), id, "[color=gray]");
+
+ if (detail_level_ >= DetailLevel::Explicit &&
+ id->rawExtent() != id->extent()) {
+ addArc(id->extent(), id, "[color=gray, style=dashed]");
+ }
}
void IrGraphGenerator::handle(const Bool* b) {
printValue(b, IrNodeLabel::gen(b, detail_level_));
}
-void IrGraphGenerator::handle(const Double* d) {
- printValue(d, IrNodeLabel::gen(d, detail_level_));
+void IrGraphGenerator::handle(const Float* f) {
+ printValue(f, IrNodeLabel::gen(f, detail_level_));
+}
+
+void IrGraphGenerator::handle(const Half* h) {
+ printValue(h, IrNodeLabel::gen(h, detail_level_));
}
void IrGraphGenerator::handle(const Int* i) {
graph_def_ << " " << getid(tv) << " [label=\"" << label.str()
<< "\", shape=Mrecord, color=brown, " << style << "];\n";
+ if (const auto* compute_at_view = tv->getComputeAtView()) {
+ std::stringstream arc_style;
+ arc_style << "[color=red, style=dashed, label=\""
+ << "ComputeAt(" << tv->getRelativeComputeAtAxis() << ")\"]";
+ addArc(tv, compute_at_view, arc_style.str());
+ }
+
tensor_views_.push_back(tv);
}
Verbose, // Includes all values and dead definitions
};
- using ExprColorMap = std::unordered_map<const Expr*, size_t>;
-
public:
static void print(
const Fusion* fusion,
const char* filename,
- DetailLevel detail_level = DetailLevel::Basic,
- ExprColorMap* expr_color_map = nullptr);
+ DetailLevel detail_level = DetailLevel::Basic);
- static std::string toGraphviz(
- const Fusion* fusion,
- DetailLevel detail_level,
- ExprColorMap* expr_color_map = nullptr);
+ static std::string toGraphviz(const Fusion* fusion, DetailLevel detail_level);
private:
- IrGraphGenerator(
- const Fusion* fusion,
- DetailLevel detail_level,
- ExprColorMap* expr_color_map = nullptr);
+ IrGraphGenerator(const Fusion* fusion, DetailLevel detail_level);
~IrGraphGenerator() override = default;
std::string generate();
void handle(const IterDomain*) override;
void handle(const Bool*) override;
- void handle(const Double*) override;
+ void handle(const Float*) override;
+ void handle(const Half*) override;
void handle(const Int*) override;
void handle(const NamedScalar*) override;
std::vector<const TensorView*> tensor_views_;
std::vector<std::string> arcs_;
int next_id_ = 1;
- ExprColorMap* expr_color_map_ = nullptr;
};
} // namespace cuda
#include <torch/csrc/jit/ir/ir.h>
-//! Nodes in here are intended to be "user facing" users in this sense being
-//! those that want to be able to generate CUDA code.
+/*
+ * Nodes in here are intended to be "user facing" users in this sense being
+ * those that want to be able to generate CUDA code.
+ */
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
-class WelfordResult;
-
-//! A Bool value
-//!
-//! This value can be a symbolic value (defined after the kernel
-//! is compiled) or a constant value (inlined into the kernel definition).
-//!
+/*
+ * A Bool value.
+ * This value can be a symbolic value (defined after the kernel
+ * is compiled) or a constant value (inlined into the kernel definition).
+ */
class TORCH_CUDA_CU_API Bool : public Val {
public:
+ ~Bool() override = default;
+
Bool() : Val(ValType::Scalar, DataType::Bool), maybe_value_{c10::nullopt} {}
- explicit Bool(bool value)
- : Val(ValType::Scalar, DataType::Bool), maybe_value_{value} {}
+ explicit Bool(bool _value)
+ : Val(ValType::Scalar, DataType::Bool), maybe_value_{_value} {}
Bool(const Bool* src, IrCloner* ir_cloner);
+ Bool(const Bool& other) = delete;
+ Bool& operator=(const Bool& other) = delete;
+
+ Bool(Bool&& other) = delete;
+ Bool& operator=(Bool&& other) = delete;
+
bool isSymbolic() const {
return !(maybe_value_.has_value());
}
return maybe_value_;
}
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const Bool* const other) const;
private:
const c10::optional<bool> maybe_value_;
};
-//! A Float64 value. For now we don't have any other type besides
-//! Float64. This value can be a symbolic value (defined after the kernel
-//! is compiled) or a constant value (inlined into the kernel definition).
-class TORCH_CUDA_CU_API Double : public Val {
+/*
+ * A Float32 value. For now we don't have any other type besides
+ * Float32. This value can be a symbolic value (defined after the kernel
+ * is compiled) or a constant value (inlined into the kernel definition).
+ */
+class TORCH_CUDA_CU_API Float : public Val {
public:
using ScalarType = double;
- Double()
- : Val(ValType::Scalar, DataType::Double), maybe_value_{c10::nullopt} {}
+ ~Float() override = default;
- explicit Double(ScalarType value)
- : Val(ValType::Scalar, DataType::Double), maybe_value_{value} {}
+ Float() : Val(ValType::Scalar, DataType::Float), maybe_value_{c10::nullopt} {}
- Double(const Double* src, IrCloner* ir_cloner);
+ explicit Float(ScalarType _value)
+ : Val(ValType::Scalar, DataType::Float), maybe_value_{_value} {}
+
+ Float(const Float* src, IrCloner* ir_cloner);
+
+ Float(const Float& other) = delete;
+ Float& operator=(const Float& other) = delete;
+
+ Float(Float&& other) = delete;
+ Float& operator=(Float&& other) = delete;
bool isSymbolic() const {
return !(maybe_value_.has_value());
return maybe_value_;
}
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const Float* const other) const;
private:
const c10::optional<ScalarType> maybe_value_;
};
-//! An Int64 value. If used for indexing it's set as size_t. Otherwise it's an
-//! inlined literal in the kernel.
+/*
+ * An IEEE 754 Float16 value.
+ * This value can be a symbolic value (defined after the kernel
+ * is compiled) or a constant value (inlined into the kernel definition).
+ */
+class TORCH_CUDA_CU_API Half : public Val {
+ public:
+ ~Half() override = default;
+
+ Half() : Val(ValType::Scalar, DataType::Half), maybe_value_{c10::nullopt} {}
+
+ explicit Half(float _value)
+ : Val(ValType::Scalar, DataType::Half), maybe_value_{_value} {}
+
+ Half(const Half* src, IrCloner* ir_cloner);
+
+ Half(const Half& other) = delete;
+ Half& operator=(const Half& other) = delete;
+
+ Half(Half&& other) = delete;
+ Half& operator=(Half&& other) = delete;
+
+ bool isSymbolic() const {
+ return !(maybe_value_.has_value());
+ }
+ bool isConst() const {
+ return maybe_value_.has_value();
+ }
+ c10::optional<float> value() const {
+ return maybe_value_;
+ }
+
+ bool sameAs(const Half* const other) const;
+
+ private:
+ const c10::optional<float> maybe_value_;
+};
+
+// An Int64 value. If used for indexing it's set as size_t. Otherwise it's an
+// inlined literal in the kernel.
class TORCH_CUDA_CU_API Int : public Val {
public:
using ScalarType = int64_t;
+ ~Int() override = default;
+
Int() : Val(ValType::Scalar, DataType::Int), maybe_value_{c10::nullopt} {}
- explicit Int(ScalarType value)
- : Val(ValType::Scalar, DataType::Int), maybe_value_{value} {}
+ explicit Int(ScalarType _value)
+ : Val(ValType::Scalar, DataType::Int), maybe_value_{_value} {}
Int(const Int* src, IrCloner* ir_cloner);
+ Int(const Int& other) = delete;
+ Int& operator=(const Int& other) = delete;
+
+ Int(Int&& other) = delete;
+ Int& operator=(Int&& other) = delete;
+
bool isSymbolic() const {
return !(maybe_value_.has_value());
}
return maybe_value_;
}
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const Int* const other) const;
private:
const c10::optional<ScalarType> maybe_value_;
};
-//! Mode during propagation of computeAt, standard will throw an error if
-//! computeAt position provided can't be satisfied, best effort will lower the
-//! computeAt position as needed during traversal, most inlined will increase
-//! the compute at position to maximum possible through traversal.
-enum class ComputeAtMode { Standard, BestEffort, MostInlined };
-
class ComputeAt;
-class TransformPropagator;
-class TransformIter;
class TransformReplay;
+class TransformIter;
class OptOutMutator;
+class LoopNestGenerator;
namespace ir_utils {
class TVDomainGuard;
}
-//! TensorView is our primitive Tensor Type used in code generation. It can be
-//! thought of as representing physical memory, however, its dimensionality is
-//! modifed as split/merge/computeAt functions are called. The history of
-//! these transformations are kept and used for generating actual code
-//! referncing physical memory. Generally when users are thinking of code
-//! generation in reference to a Tensor, this is the class they should be
-//! interacting with.
-//!
-//! The reason we need both TensorView and TensorDomain is that we need to have
-//! a record of both what is being computed and how it is being computed. For
-//! example we may have the operation:
-//!
-//! TV3[I, J, K] = TV2[I, J, K] + TV1[I, J, K]
-//!
-//! The mathematical operations here are on the tensor views TV1, TV2, and
-//! TV3. This operation is a pointwise operation. To compute this pointwise
-//! operation we iterate over the 3D TensorDomain [I, J, K], where K is the
-//! fastest changing dimension.
-//!
-//! \todo Need to work on the const model for TensorView, making all functions
-//! that should be const, const. Gave this a try but expanded really quickly.
-//! getComputeAtAxis not being const because it can return a TV that some expect
-//! to be non-const is the biggest headache.
-//!
+// TensorView is our primitive Tensor Type used in code generation. It can be
+// thought of as representing physical memory, however, its dimensionality is
+// modifed as split/merge/computeAt functions are called. The history of
+// these transformations are kept and used for generating actual code referncing
+// physical memory. Generally when users are thinking of code generation in
+// reference to a Tensor, this is the class they should be interacting with.
+//
+// The reason we need both TensorView and TensorDomain is that we need to have a
+// record of both what is being computed and how it is being computed. For
+// example we may have the operation: TV3[I, J, K] = TV2[I, J, K] + TV1[I, J, K]
+// The mathematical operations here are on the tensor views TV1, TV2, and TV3.
+// This operation is a pointwise operation. To compute this pointwise operation
+// we iterate over the 3D TensorDomain [I, J, K], where K is the fastest
+// changing dimension.
+//
+// TODO: Need to work on the const model for TensorView, making all functions
+// that should be const, const. Gave this a try but expanded really quickly.
+// getComputeAtAxis not being const because it can return a TV that some expect
+// to be non-const is the biggest headache.
class TORCH_CUDA_CU_API TensorView : public Val {
public:
+ ~TensorView() override = default;
+
+ TensorView(const TensorView& other) = delete;
+ TensorView& operator=(const TensorView& other) = delete;
+
+ TensorView(TensorView&& other) = delete;
+ TensorView& operator=(TensorView&& other) = delete;
+
TensorView(
- TensorDomain* domain,
+ TensorDomain* _domain,
DataType dtype,
MemoryType mtype = MemoryType::Local);
- explicit TensorView(const std::shared_ptr<c10::TensorType>& tensor_type);
+ TensorView(const std::shared_ptr<c10::TensorType>& tensor_type);
- explicit TensorView(const std::shared_ptr<Value>& jit_value)
+ TensorView(const std::shared_ptr<Value>& jit_value)
: TensorView(jit_value->type()->cast<c10::TensorType>()) {}
TensorView(const TensorView* src, IrCloner* ir_cloner);
bool hasReduction() const;
bool hasBlockReduction() const;
bool hasGridReduction() const;
+ bool hasBlockBroadcast() const;
bool hasBroadcast() const;
bool hasRFactor() const;
- //! This is the previous hasReduction logic,
- //! kept here exclusively for lower loop pass will
- //! deprecate when Fusion IR pass can convert
- //! trivial reductions
- bool hasAnyReduction() const;
-
c10::optional<unsigned int> getReductionAxis() const;
const std::vector<IterDomain*>& getRootDomain() const;
IterDomain* axis(int pos) const;
- // Does it share outer axes with other tensors?
+ // Is there an active computeAt TensorView/Axis
bool hasComputeAt() const {
- return compute_at_pos_ > 0;
+ return compute_at_view_ != nullptr;
}
- bool hasMaxProducerPosition() const {
- return max_producer_pos_ > 0;
+ // Return the TensorView we're computing at
+ TensorView* getComputeAtView() const {
+ return compute_at_view_;
}
size_t nDims() const;
- // Returns the position that this tensor is produced at relative to its axes.
- unsigned int getComputeAtPosition() const {
- return compute_at_pos_;
+ // Return compute at axis relative to this domain
+ unsigned int getThisComputeAtAxis() const {
+ return this_compute_at_axis_;
+ }
+
+ // Return compute at axis relative to compute at view
+ unsigned int getRelativeComputeAtAxis() const {
+ return relative_compute_at_axis_;
+ }
+
+ // Return position in compute_at_view that lines up with this->axis(pos)?
+ int getComputeAtRelPos(int pos);
+
+ // Will check if an axis is inside computeAtAxis and will fetch the reference
+ // to be used in code generation.
+ std::pair<int, TensorView*> getComputeAtPos(int pos) {
+ pos = normalizeAxisPos(pos);
+ TORCH_INTERNAL_ASSERT(
+ nDims() > 0, "Tried to access a computeAt axis in a 0-dim TensorView");
+ if (!hasComputeAt() || getThisComputeAtAxis() <= (unsigned int)pos)
+ return std::make_pair(pos, this);
+ return compute_at_view_->getComputeAtPos(getComputeAtRelPos(pos));
+ }
+
+ std::pair<IterDomain*, TensorView*> getComputeAtAxis(int pos) {
+ const auto computeAtPos = getComputeAtPos(pos);
+ return std::make_pair(
+ computeAtPos.second->axis(computeAtPos.first), computeAtPos.second);
}
- // Returns the maximum position of producers are being computed at relative to
- // this tensor. This position dictates the clear expectations of producers.
- unsigned int getMaxProducerPosition() const {
- return max_producer_pos_;
+ // Compute this TensorView relative to another tensor at axis
+ TensorView* computeAt(TensorView* consumer, int axis);
+
+ void clearComputeAt() {
+ this_compute_at_axis_ = 0;
+ relative_compute_at_axis_ = 0;
+ compute_at_view_ = nullptr;
}
- //! This is used when we disconnect a tensorview from a reduction
- //! operation and connect it to a non-reduction operator. We need
- //! to remove the reduction ids on the tv in this case.
- //! Currently only used in translate welford, and this function may
- //! be refactored or extended if any more use cases appear.
- void clearReductionIterDomains();
-
- //! Compute this TensorView relative to a consumer position, -1 will
- //! compute tensors inline with each other, 0 doesn't share
- //! any loop nests between the tensors. It's an error when the given
- //! position is not legally viable. Alternatively, when the mode
- //! parameter is ComputeAtMode::BestEffort, the position is lowered
- //! one by one until a valid position is found. When
- //! ComputeAtMode::MostInlined is given, the position parameter is
- //! ignored, and the deepest possible position is searched.
- TensorView* computeAt(
- TensorView* consumer,
- int position,
- ComputeAtMode mode = ComputeAtMode::Standard);
-
- //! Compute this tensor to consumer, at local position, -1 will compute
- //! tensors inline with eachother, 0 doesn't share any loop nests between the
- //! tensors. The mode parameter can be used in the same manner as computeAt.
- TensorView* computeWith(
- TensorView* consumer,
- int position,
- ComputeAtMode mode = ComputeAtMode::Standard);
-
- // Split "axis" into 2 axes
- //! inner_split dictates if the factor section of the split should be inside
- //! the
- //! remainer or outside.
- //! e.g. split(0, 4, inner_split = true) will result in:
- //! tv[id{extent}] -> tv[id{ceilDiv(extent, factor)}, id{factor}]
- //! e.g. split(0, 4, inner_split = false) will result in:
- //! tv[id{extent}] -> tv[id{factor}, id{ceilDiv(extent, factor)}]
- TensorView* split(int axis, unsigned int factor, bool inner_split = true);
+ // Split "axis" into 2 axes where the inner axes is size of "factor"
+ // and outer axis is size axis.size() / factor
+ TensorView* split(int axis, unsigned int factor);
// Split "axis" into 2 axes where the inner axes is size of "factor"
// and outer axis is size axis.size() / factor. Factor can be a symbolic
// value instead of constant. This requires setting the symbolic value as an
// input, or using a parallel dim from NamedScalar::getParallelDim
- TensorView* split(int axis, Val* factor, bool inner_split = true);
+ TensorView* split(int axis, Val* factor);
// Merge axis_o and axis_i into 1 IterDomain
TensorView* merge(int axis_o, int axis_i);
// Reorder axes according to old2new[old_pos] = new_pos
TensorView* reorder(const std::unordered_map<int, int>& old2new);
- //! Swizzle indices to improve memory access efficiency.
- //!
- //! Swizzle::Transpose is a pattern commonly used to avoid bank
- //! conflicts in shared memory. It takes two axes and shifts the
- //! second axis by the first axis as ((axis1 + axis2) % extent). The
- //! memory type must be Shared.
- //!
- //! \input type Swizzle pattern such as transpose.
- //! \input axes Axes to swizzle
- TensorView* swizzle(SwizzleType type, const std::vector<int>& axes);
-
// WARNING: rFactor does not return this TensorView, ir returns a new
// tensorview consumed by this!
//
//
TensorView* rFactor(const std::vector<int>& axes);
- //! Welford Version of rFactor, semantically similar with
- //! the reduction version except that the rfactor is done
- //! in a multi-output scan pattern
- WelfordResult rFactor(
- const std::vector<int>& axes,
- TensorView* avg,
- TensorView* var,
- TensorView* n);
-
// Create a TensorView before the original tensor. A common use case is to
// write results into shared memory or registers before moving to global
// memory. Analogous to TVM Cache_Write
// read tensor into shared memory or registers. Analogous to TVM Cache_Read
TensorView* cache_after();
- // For a fusion output with other uses, we want to avoid writing to global
- // memory and then reading the output again. We write to global memory
- // separately after an operation. We replace this fusion output with the
- // direct write TensorView.
- TensorView* cache_fork();
-
MemoryType getMemoryType() const {
return memory_type_;
}
void setMemoryType(MemoryType mt);
- SwizzleType swizzleType() const {
- return swizzle_type_;
- }
-
- const std::vector<IterDomain*>& axesToSwizzle() const {
- return axes_to_swizzle_;
- }
-
- friend TORCH_CUDA_CU_API TransformPropagator;
friend TORCH_CUDA_CU_API TransformReplay;
friend TORCH_CUDA_CU_API OptOutMutator;
+ friend TORCH_CUDA_CU_API LoopNestGenerator;
friend ComputeAt;
+ friend void IrFixComputeAt(Fusion*);
friend void adjustMemoryTypes(Fusion* fusion);
friend class ir_utils::TVDomainGuard;
protected:
+ // Make an exact copy of this tensor (similar to clone()), however, also grabs
+ // the same name. Current use of this is for initialization of reductions.
+ // This will break our dependency chain as it is a literal clone of a
+ // TensorView but it has a different dependency chain. We need to improve our
+ // dependency model to allow for initailziation of reduction buffers. The only
+ // reason we can get away with this for now is because we don't use dependency
+ // analysis for the IR after we call this.
+ TensorView* unsafeClone() const;
+
void setDomain(TensorDomain* td) {
domain_ = td;
}
- void setComputeAt(unsigned int this_pos, bool decrease = false);
+ void setComputeAt(TensorView* computeAtView, int axis);
- void setMaxProducer(unsigned int this_pos, bool decrease = false);
+ // Set all computeAt members without checking any correctness. Useful for
+ // computeAt with outputs relative to eachother
+ void setComputeAt(TensorView* computeAtView, int thisPos, int relPos);
private:
int normalizeAxisPos(int pos) const {
return pos;
}
- //! A helper function to maintain the consistency of welford output
- //! schedules when doing rfactor on welford ops.
- TensorView* welfordRfactorHelper(
- TensorView* tv,
- const std::vector<int>& axes);
-
- private:
- TensorDomain* domain_ = nullptr;
- unsigned int compute_at_pos_ = 0;
- unsigned int max_producer_pos_ = 0;
- MemoryType memory_type_ = MemoryType::Local;
- SwizzleType swizzle_type_ = SwizzleType::NoSwizzle;
- std::vector<IterDomain*> axes_to_swizzle_;
-};
-
-//! A simple TensorView builder
-//!
-//! Example usage:
-//!
-//! auto tv = TensorViewBuilder()
-//! .ndims(ndims)
-//! .dtype(dtype)
-//! .contiguity(contiguity)
-//! .build();
-//!
-class TORCH_CUDA_CU_API TensorViewBuilder {
- public:
- //! Set the number of dimensions of the tensor (default 0, meaning scalar)
- TensorViewBuilder& ndims(size_t ndims);
-
- //! Set the data type of the tensor (default DataType::Float)
- TensorViewBuilder& dtype(DataType dtype);
-
- //! Set the contiguity information (default non-contiguous)
- TensorViewBuilder& contiguity(std::vector<bool> contiguity);
+ // In Cache Before, for the origin expr of the original tensor,
+ // we create a new operation where the original tensor is replaced
+ // with the new cache tensor. This function creates a new expr
+ // given the consumer, the output of the expression.
+ void createExprConsumer(Expr* expr, TensorView* consumer);
- //! Set the shape (default 0 dimensional, ie. scalar)
- TensorViewBuilder& shape(std::vector<int64_t> shape);
+ // In Cache After, for all the uses of the original tensor, we create
+ // a new operation where the original tensor is replaced with the new
+ // cache tensor. This function creates a new expr given a producer,
+ // an input for the expression.
+ void createExprProducer(
+ Expr* expr,
+ TensorView* current,
+ TensorView* producer);
- //! Creates a new TensorView with the specified options
- TensorView* build() const;
+ void setThisComputeAtAxis();
private:
- size_t ndims_ = 0;
- DataType dtype_ = DataType::Float;
- std::vector<bool> contiguity_;
- std::vector<int64_t> shape_;
+ TensorDomain* domain_ = nullptr;
+ TensorView* compute_at_view_ = nullptr;
+ // compute at axis in compute at view
+ unsigned int relative_compute_at_axis_ = 0;
+ unsigned int this_compute_at_axis_ = 0;
+ MemoryType memory_type_ = MemoryType::Local;
};
} // namespace cuda
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
-//! Nodes in here should generally not be used by users. They should be behind
-//! the scenes and users shouldn't have to be aware of what they do to use the
-//! code generator
-//!
-//! \todo improve implementation bool IterDomain::sameAs(const IterDomain*)
-//! \todo Add testing of sameAs functions for these nodes
-//!
+/*
+ * Nodes in here should generally not be used by users. They should be behind
+ * the scenes and users shouldn't have to be aware of what they do to use the
+ * code generator.
+ */
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
-//! Returns true if both v1 and v2 are scalars, are the same type of scalars,
-//! and dispatches to the inherited Val type's `->sameAs` call. e.g. if both
-//! vals are `Int` will dispatch to v1->as<Int>()->sameAs(v2.as<Int>())
+// Returns true if both v1 and v2 are scalars, are the same type of scalars, and
+// dispatches to the inherited Val type's `->sameAs` call. e.g. if both vals are
+// `Int` will dispatch to v1->as<Int>()->sameAs(v2.as<Int>())
bool areEqualScalars(Val* v1, Val* v2);
-//! A specialization for Unary operations. Unary operations take in a single
-//! input and produce a single output. Examples include:
-//! 1) Casting operation i.e. float(a_val)
-//! 2) Negation i.e. val * -1
-//! 3) Reduction across a dimension i.e. val.sum(axis=2)
-//! 4) split/merge
+/*
+ * TODO: improve implementation bool IterDomain::sameAs(const IterDomain*) const
+ * TODO: Add testing of sameAs functions for these nodes
+ */
+
+/*
+ * A specialization for Unary operations. Unary operations take in a single
+ * input and produce a single output. Examples include:
+ * 1) Casting operation i.e. float(a_val)
+ * 2) Negation i.e. val * -1
+ * 3) Reduction across a dimension i.e. val.sum(axis=2)
+ * 4) split/merge
+ */
class TORCH_CUDA_CU_API UnaryOp : public Expr {
public:
- UnaryOp(UnaryOpType type, Val* out, Val* in);
+ ~UnaryOp() override = default;
+ UnaryOp(UnaryOpType _type, Val* _out, Val* _in);
UnaryOp(const UnaryOp* src, IrCloner* ir_cloner);
+ UnaryOp(const UnaryOp& other) = delete;
+ UnaryOp& operator=(const UnaryOp& other) = delete;
+
+ UnaryOp(UnaryOp&& other) = delete;
+ UnaryOp& operator=(UnaryOp&& other) = delete;
+
Val* out() const {
return out_;
}
return unary_op_type_;
}
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const UnaryOp* const other) const;
private:
const UnaryOpType unary_op_type_;
Val* const in_ = nullptr;
};
-//! A specialization for Binary operations. Binary operations take in two inputs
-//! and produce a single output. Examples include:
-//! 1) Add/mul/div/mod/sub (A * B)
-//! 2) LT (A < B)
+/*
+ * A specialization for Binary operations. Binary operations take in two inputs
+ * and produce a single output. Examples include:
+ * 1) Add/mul/div/mod/sub (A * B)
+ * 2) LT (A < B)
+ */
class TORCH_CUDA_CU_API BinaryOp : public Expr {
public:
- BinaryOp(BinaryOpType type, Val* out, Val* lhs, Val* rhs);
+ ~BinaryOp() override = default;
+ BinaryOp(BinaryOpType _type, Val* _out, Val* _lhs, Val* _rhs);
BinaryOp(const BinaryOp* src, IrCloner* ir_cloner);
+ BinaryOp(const BinaryOp& other) = delete;
+ BinaryOp& operator=(const BinaryOp& other) = delete;
+
+ BinaryOp(BinaryOp&& other) = delete;
+ BinaryOp& operator=(BinaryOp&& other) = delete;
+
Val* out() const {
return out_;
}
return binary_op_type_;
}
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const BinaryOp* other) const;
private:
const BinaryOpType binary_op_type_;
Val* const rhs_ = nullptr;
};
-//! Broadcast in to match out. is_broadcast_dims are relative to out. Where
-//! is_broadcast_dims.size() == out->nDims().
+/*
+ * Broadcast _in to match _out. broadcast_dims are relative to out. Where
+ * broadcast_dims.size() + _in->nDims() == _out->nDims().
+ */
class TORCH_CUDA_CU_API BroadcastOp : public Expr {
public:
- //! \param out The output tensor
- //! \param in The input tensor
- //! \param is_broadcast_dims True when output dim is a new broadcast domain
- BroadcastOp(Val* out, Val* in, std::vector<bool> is_broadcast_dims);
+ ~BroadcastOp() override = default;
+ BroadcastOp(Val* _out, Val* _in);
BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner);
+ BroadcastOp(const BroadcastOp& other) = delete;
+ BroadcastOp& operator=(const BroadcastOp& other) = delete;
+
+ BroadcastOp(BroadcastOp&& other) = delete;
+ BroadcastOp& operator=(BroadcastOp&& other) = delete;
+
Val* out() const {
return out_;
}
return in_;
}
- bool isBroadcastDim(size_t dim) const {
- return is_broadcast_dims_.at(dim);
- }
-
- const std::vector<bool>& getBroadcastDimFlags() const {
- return is_broadcast_dims_;
- }
-
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const BroadcastOp* const other) const;
private:
Val* const out_ = nullptr;
Val* const in_ = nullptr;
-
- //! The same list passed to the broadcast arithmetic op. Each
- //! element corresponds to an IterDomain of the output tensor and is
- //! true when the IterDomain is a new broadcast domain. Note
- //! that the output tensor may have other broadcast domains whose
- //! flags are false because the input tensor may already have
- //! broadcast domains.
- const std::vector<bool> is_broadcast_dims_;
};
-//! Reduction operation. Out is first initialized to _init. Then
-//! reduction_op_type is used to update out as out = reductionOp(out, in).
-//! Output's axes marked as reduction will be reduced to produce an output
-//! tensor. The output tensors size will be the size of all
-//! non-reduction/non-broadcast dimensions.
+/*
+ * Reduction operation. Out is first initialized to _init. Then
+ * _reduction_op_type is used to update out as out = reductionOp(out, in).
+ * Output's axes marked as reduction will be reduced to produce an output
+ * tensor. The output tensors size will be the size of all
+ * non-reduction/non-broadcast dimensions.
+ */
class TORCH_CUDA_CU_API ReductionOp : public Expr {
public:
- ReductionOp(BinaryOpType reduction_op_type, Val* init, Val* out, Val* in);
+ ~ReductionOp() override = default;
+ ReductionOp(BinaryOpType _reduction_op_type, Val* _init, Val* _out, Val* _in);
ReductionOp(const ReductionOp* src, IrCloner* ir_cloner);
+ ReductionOp(const ReductionOp& other) = delete;
+ ReductionOp& operator=(const ReductionOp& other) = delete;
+
+ ReductionOp(ReductionOp&& other) = delete;
+ ReductionOp& operator=(ReductionOp&& other) = delete;
+
Val* out() const {
return out_;
}
return reduction_op_type_;
}
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const ReductionOp* const other) const;
private:
const BinaryOpType reduction_op_type_;
Val* const in_ = nullptr;
};
-//! Welford Scan operation.
-class TORCH_CUDA_CU_API WelfordOp : public Expr {
- public:
- WelfordOp(
- Val* out_avg,
- Val* out_var,
- Val* out_N,
- Val* init_avg,
- Val* init_var,
- Val* init_N,
- Val* in_avg,
- Val* in_var,
- Val* in_N);
-
- WelfordOp(const WelfordOp* src, IrCloner* ir_cloner);
-
- Val* out() const {
- return out_avg_;
- }
-
- Val* in() const {
- return in_avg_;
- }
-
- Val* init() const {
- return init_avg_;
- }
-
- bool sameAs(const Statement* const other) const override;
-
- // Welford Accessors
- // TODO clean up
- Val* outAvg() const {
- return out_avg_;
- }
-
- Val* outVar() const {
- return out_var_;
- }
-
- Val* outN() const {
- return out_N_;
- }
-
- Val* inAvg() const {
- return in_avg_;
- }
-
- Val* inVar() const {
- return in_var_;
- }
-
- Val* inN() const {
- return in_N_;
- }
-
- Val* initAvg() const {
- return init_avg_;
- }
-
- Val* initVar() const {
- return init_var_;
- }
-
- Val* initN() const {
- return init_N_;
- }
-
- bool singleValue() const {
- return in_N_->isOneInt();
- }
-
- bool hasInit() const {
- return !init_N_->isZeroInt();
- }
-
- private:
- Val* const out_avg_;
- Val* const out_var_;
- Val* const out_N_;
- Val* const init_avg_;
- Val* const init_var_;
- Val* const init_N_;
- Val* const in_avg_;
- Val* const in_var_;
- Val* const in_N_;
-};
-
-class TORCH_CUDA_CU_API TransposeOp : public Expr {
- public:
- TransposeOp(TensorView* out, TensorView* in, std::vector<int> new2old);
-
- TransposeOp(const TransposeOp* src, IrCloner* ir_cloner);
-
- TensorView* out() const {
- return out_;
- }
-
- TensorView* in() const {
- return in_;
- }
-
- const std::vector<int>& new2old() const {
- return new2old_;
- }
-
- private:
- TensorView* const out_ = nullptr;
- TensorView* const in_ = nullptr;
- const std::vector<int> new2old_;
-};
-
class TORCH_CUDA_CU_API TernaryOp : public Expr {
public:
- TernaryOp(TernaryOpType type, Val* out, Val* in1, Val* in2, Val* in3);
+ ~TernaryOp() override = default;
+ TernaryOp(TernaryOpType _type, Val* _out, Val* _in1, Val* _in2, Val* _in3);
TernaryOp(const TernaryOp* src, IrCloner* ir_cloner);
+ TernaryOp(const TernaryOp& other) = delete;
+ TernaryOp& operator=(const TernaryOp& other) = delete;
+
+ TernaryOp(TernaryOp&& other) = delete;
+ TernaryOp& operator=(TernaryOp&& other) = delete;
+
Val* out() const {
return out_;
}
return ternary_op_type_;
}
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const TernaryOp* other) const;
private:
const TernaryOpType ternary_op_type_;
Val* const in3_ = nullptr;
};
-//! Shift
-class TORCH_CUDA_CU_API ShiftOp : public Expr {
- public:
- //! \param out
- //! \param in
- //! \param offsets
- ShiftOp(Val* out, Val* in, std::vector<int> offsets);
-
- ShiftOp(const ShiftOp* src, IrCloner* ir_cloner);
-
- Val* out() const {
- return out_;
- }
- Val* in() const {
- return in_;
- }
-
- int offset(size_t dim) const {
- return offsets_.at(dim);
- }
-
- const std::vector<int>& offsets() const {
- return offsets_;
- }
-
- bool sameAs(const Statement* other) const override;
-
- private:
- Val* const out_ = nullptr;
- Val* const in_ = nullptr;
- //! Each of the root axes is shifted by the corresponding value of
- //! offsets_. The sign of each value indicates the direction of
- //! shifting.
- const std::vector<int> offsets_;
-};
-
-//! Gather a window around each element.
-class TORCH_CUDA_CU_API GatherOp : public Expr {
- public:
- GatherOp(
- Val* out,
- Val* in,
- std::vector<Int*> window_shape,
- std::vector<std::vector<Int*>> pad_width);
-
- GatherOp(const GatherOp* src, IrCloner* ir_cloner);
-
- Val* out() const {
- return out_;
- }
- Val* in() const {
- return in_;
- }
-
- const auto& windowShape() const {
- return window_shape_;
- }
-
- //! Returns the gather axis that corresponds to an input axis
- int gatherAxis(int axis) const;
-
- const auto& padWidth() const {
- return pad_width_;
- }
-
- bool sameAs(const Statement* other) const override;
-
- private:
- Val* const out_ = nullptr;
- Val* const in_ = nullptr;
- //! Shape of a window gathered for each element.
- std::vector<Int*> window_shape_;
- //! The size of zero-padding of each axis.
- std::vector<std::vector<Int*>> pad_width_;
-};
-
-// Friends for direct access to split
-class TensorDomain;
-class ReplayTransformations;
-class IndexReferenceReplay;
-//! Simply a representation of an annotated 1D iterable from start to extent.
-//! TensorDomains which represent how to iterate over a tensor is made up of
-//! IterDomains to form an ND iterable. We directly set parallization strategies
-//! on IterDomains.
+// Simply a representation of an annotated 1D iterable from start to extent.
+// TensorDomains which represent how to iterate over a tensor is made up of
+// IterDomains to form an ND iterable. We directly set parallization strategies
+// on IterDomains.
class TORCH_CUDA_CU_API IterDomain : public Val {
public:
IterDomain(
- Val* start,
- Val* extent,
- ParallelType parallel_type = ParallelType::Serial,
- IterType iter_type = IterType::Iteration,
- bool is_rfactor_domain = false);
+ Val* _start,
+ Val* _extent,
+ ParallelType _parallel_type = ParallelType::Serial,
+ IterType _iter_type = IterType::Iteration,
+ bool _is_rfactor_domain = false);
IterDomain(const IterDomain* src, IrCloner* ir_cloner);
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const IterDomain* const other) const;
// Returns a new IterDomain matching properties of this
// TODO: parallel_method->getParallelType
isRFactorProduct());
}
- //! Clone a vector domains
- static std::vector<IterDomain*> clone(
- const std::vector<IterDomain*>& domains);
-
static IterDomain* merge(IterDomain* outer, IterDomain* inner);
- //! Run concretization pass and return the concretized domain of broadcast id
+ // TODO: Make protected and friend TensorDomain so only it can call into this
+ // directly, users should not be able to use this call
+ static std::pair<IterDomain*, IterDomain*> split(IterDomain* in, Val* factor);
+
+ // Run concretization pass and return the concretized domain of broadcast id
static const IterDomain* concretizeDomain(IterDomain* bcast_dom);
+ // Attempt to prove 2 IterDomains are equal in start and rawExtent
+ static bool proveEquivalent(IterDomain* a, IterDomain* b);
+
bool isReduction() const {
return getIterType() == IterType::Reduction;
}
getIterType() == IterType::BroadcastWithoutStride;
}
- bool isGather() const {
- return getIterType() == IterType::Gather;
- }
-
bool isParallelized() const {
return getParallelType() != ParallelType::Serial;
}
- //! Return if this iter domain is mapped to a grid dimension
+ // Return if this iter domain is mapped to a grid dimension
bool isBlockDim() const {
- return isParallelTypeBlockDim(getParallelType());
+ return (
+ getParallelType() == ParallelType::BIDz ||
+ getParallelType() == ParallelType::BIDy ||
+ getParallelType() == ParallelType::BIDx);
}
- //! Return if this iter domain is mapped to a block dimension
+ // Return if this iter domain is mapped to a block dimension
bool isThreadDim() const {
- return isParallelTypeThreadDim(getParallelType());
+ return (
+ getParallelType() == ParallelType::TIDz ||
+ getParallelType() == ParallelType::TIDy ||
+ getParallelType() == ParallelType::TIDx);
}
- //! Return if this iter domain is either mapped to a block or grid dimension
+ // Return if this iter domain is either mapped to a block or grid dimension
bool isThread() const {
return (isBlockDim() || isThreadDim());
}
- //! Convert to strided broadcast, used for supporting broadcast on output
- void toStridedBroadcast() {
- TORCH_INTERNAL_ASSERT(
- isBroadcast(),
- "toStridedBroadCast: converting an non-broadcast iterdomain",
- this);
- iter_type_ = IterType::BroadcastWithStride;
- }
+ void parallelize(ParallelType t) {
+ parallel_type_ = t;
- // Convert a serial iterdomain to broadcast, used for implicit broadcast
- void convertToBroadcast() {
- TORCH_INTERNAL_ASSERT(
- !isBroadcast() && !isReduction(),
- "convertToBroadcast: converting an non-serial iterdomain",
- this);
+ TORCH_CHECK(
+ t != ParallelType::Vectorize, "Vectorization not yet supported.");
- iter_type_ = IterType::BroadcastWithStride;
+ if (t == ParallelType::Unroll)
+ TORCH_CHECK(
+ start()->isZeroInt() && extent()->isConstScalar(),
+ "Unrolling only supported with start = 0 and extent as a const int, but got ",
+ "a start of ",
+ start(),
+ " and extent ",
+ extent(),
+ " .");
}
- void parallelize(ParallelType t);
-
ParallelType getParallelType() const {
return parallel_type_;
}
Val* start() const {
return start_;
}
+ Val* extent() const;
- Val* extent() const {
- TORCH_INTERNAL_ASSERT(extent_ != nullptr);
+ Val* rawExtent() const {
return extent_;
}
- //! Check if IterDomain is a broadcast axis with compile-time
- //! known extent. This is the case with all size-1 IterDomains on
- //! a TensorView's root domain when the TensorView is created.
- bool isImplicitBroadcast() const {
- return isBroadcast() && extent()->isOneInt();
- }
-
- //! Check if IterDomain is a reduction axis with size of 1, i.e.
- //! a "squeeze" operator.
- //!
- //! NOTE: Detection of trivial reduction here is not
- //! comprehensive. See detectTrivialReductionDerivedDomains for more
- //! comprehensive analysis. We typically use this for root domain trivial
- //! reduction checks. So we ship to the correct scheduler. It may
- //! not be incredibly robust, but it makes sense to keep it for now.
- bool isTrivialReduction() const {
- return isReduction() && extent()->isOneInt();
- }
-
- protected:
- friend TensorDomain;
- friend ReplayTransformations;
- friend IndexReferenceReplay;
+ IterDomain(const IterDomain& other) = delete;
+ IterDomain& operator=(const IterDomain& other) = delete;
- static std::pair<IterDomain*, IterDomain*> split(
- IterDomain* in,
- Val* factor,
- bool inner_split);
+ IterDomain(IterDomain&& other) = delete;
+ IterDomain& operator=(IterDomain&& other) = delete;
private:
Val* const start_ = nullptr;
bool is_rfactor_domain_ = false;
};
-//! TensorDomain holds a vector of IterDomains. It holds an IterDomain for every
-//! logical axis in its associated tensor. TensorDomain does not directly hold
-//! the Tensor it is associated with, and in theory could be associated with
-//! multiple tensors. TensorDomain's primary responsibility is to provide a
-//! mechanism to access history of transformations that were used to generate
-//! it. This is done through the normal interaction of Expr/Val in Fusion. i.e.
-//! if we want to know the previous operation generating a particular
-//! TensorDomain we can simply call:
-//!
-//! FusionGuard::getCurFusion()->definition(a_tensor_domain)
-//!
-//! which should give us an operation in the list [split, merge] or similar
-//! operations that take in a TensorDomain, applies a transformation and outputs
-//! a tensor domain.
+/*
+ * TensorDomain holds a vector of IterDomains. It holds an IterDomain for every
+ * logical axis in its associated tensor. TensorDomain does not directly hold
+ * the Tensor it is associated with, and in theory could be associated with
+ * multiple tensors. TensorDomain's primary responsibility is to provide a
+ * mechanism to access history of transformations that were used to generate it.
+ * This is done through the normal interaction of Expr/Val in Fusion. i.e. if we
+ * want to know the previous operation generating a particular TensorDomain we
+ * can simply call FusionGuard::getCurFusion()->origin(a_tensor_domain) which
+ * should give us an operation in the list [split, merge] or similar
+ * operations that take in a TensorDomain, applies a transformation and outputs
+ * a tensor domain.
+ */
class TORCH_CUDA_CU_API TensorDomain : public Val {
public:
+ TensorDomain() = delete;
+ ~TensorDomain() override = default;
+
+ TensorDomain(const TensorDomain& other) = delete;
+ TensorDomain& operator=(const TensorDomain& other) = delete;
+
+ TensorDomain(TensorDomain&& other) = delete;
+ TensorDomain& operator=(TensorDomain&& other) = delete;
+
explicit TensorDomain(
- std::vector<IterDomain*> root_domain,
- std::vector<bool> contiguity = std::vector<bool>());
+ std::vector<IterDomain*> _domain,
+ std::vector<bool> _contiguity = std::vector<bool>());
TensorDomain(
- std::vector<IterDomain*> root_domain,
- std::vector<IterDomain*> domain,
- std::vector<bool> contiguity = std::vector<bool>());
+ std::vector<IterDomain*> _root_domain,
+ std::vector<IterDomain*> _domain,
+ std::vector<bool> _contiguity = std::vector<bool>());
TensorDomain(
- std::vector<IterDomain*> root_domain,
- std::vector<IterDomain*> rfactor_domain,
- std::vector<IterDomain*> domain,
- std::vector<bool> contiguity = std::vector<bool>());
+ std::vector<IterDomain*> _root_domain,
+ std::vector<IterDomain*> _rfactor_domain,
+ std::vector<IterDomain*> _domain,
+ std::vector<bool> _contiguity = std::vector<bool>());
TensorDomain(const TensorDomain* src, IrCloner* ir_cloner);
return domain_.size();
}
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const TensorDomain* const other) const;
static bool sameAs(
const std::vector<IterDomain*>& lhs,
bool hasReduction() const;
bool hasBlockReduction() const;
bool hasGridReduction() const;
+ bool hasBlockBroadcast() const;
bool hasBroadcast() const;
bool hasRFactor() const;
- bool hasVectorize() const;
c10::optional<unsigned int> getReductionAxis() const;
void resetDomains() {
no_reduction_domain_ = noReductions(domain_);
no_bcast_domain_ = noBroadcasts(domain_);
- has_nontrivial_reduction_ = hasNontrivialReduction(domain_);
}
// i here is int, as we want to accept negative value and ::size_type can be a
size_t posOf(IterDomain* id) const;
- // Split "axis" into 2 axes
- //! inner_split dictates if the factor section of the split should be inside
- //! the
- //! remainer or outside.
- //! e.g. split(0, 4, inner_split = true) will result in:
- //! tv[id{extent}] -> tv[id{ceilDiv(extent, factor)}, id{factor}]
- //! e.g. split(0, 4, inner_split = false) will result in:
- //! tv[id{extent}] -> tv[id{factor}, id{ceilDiv(extent, factor)}]
- void split(int axis_, Val* factor, bool inner_split);
+ // Split "axis" into 2 axes where the inner axes is size of "factor"
+ // and outer axis is size axis.size() / factor. Allow factor to be symbolic
+ // value instead of constant.
+ // TODO: Make protected and friend TensorDomain so only it can call into this
+ // directly, users should not be able to use this call
+ void split(int axis_, Val* factor);
// Merge axis_o and axis_i. axis_i is the fast changing dimension. Resulting
// axis is by default placed at original position axis_o
static bool hasBroadcast(const std::vector<IterDomain*>&);
static bool hasReduction(const std::vector<IterDomain*>&);
- static bool hasNontrivialReduction(const std::vector<IterDomain*>&);
+
+ // return std::pair<producer_id, consumer_id> representing
+ // the mapping between corresponding axes. Not all axes have
+ // corresponding mapping, e.g., broadcast axis in consumer
+ // does not have any corresponding axis in producer.
+ static std::vector<std::pair<int, int>> mapDomainPandC(
+ const std::vector<IterDomain*>& producer,
+ const std::vector<IterDomain*>& consumer);
+
+ // Create a map between producer root IterDomains and consumer root
+ // IterDomains.
+ static std::vector<std::pair<IterDomain*, IterDomain*>> mapRootPandC(
+ const TensorDomain* producer,
+ const TensorDomain* consumer);
+
+ // Create a map from consumer root IterDomains -> producer root IterDomains.
+ // Only those root consumer IDs present in consumer_root_dims_to_map
+ // will be attempted to map to their corresponding producer IDs.
+ static std::unordered_map<IterDomain*, IterDomain*> mapRootCtoP(
+ const TensorDomain* consumer,
+ const TensorDomain* producer,
+ const std::unordered_set<IterDomain*>& consumer_root_dims_to_map);
+
+ static std::unordered_map<IterDomain*, IterDomain*> mapRootCtoP(
+ const TensorDomain* consumer,
+ const TensorDomain* producer) {
+ return mapRootCtoP(
+ consumer,
+ producer,
+ std::unordered_set<IterDomain*>(
+ consumer->getRootDomain().begin(),
+ consumer->getRootDomain().end()));
+ }
+
+ // Create a map from producer root IterDomains -> consumer root IterDomains.
+ // Only those root producer IDs present in producer_maybe_rfactor_dims_to_map
+ // will be attempted to map to their corresponding consumer IDs.
+ static std::unordered_map<IterDomain*, IterDomain*> mapRootPtoC(
+ const TensorDomain* producer,
+ const TensorDomain* consumer,
+ const std::unordered_set<IterDomain*>&
+ producer_maybe_rfactor_dims_to_map);
+
+ static std::unordered_map<IterDomain*, IterDomain*> mapRootPtoC(
+ const TensorDomain* producer,
+ const TensorDomain* consumer) {
+ auto p_root = producer->getMaybeRFactorDomain();
+ return mapRootPtoC(
+ producer,
+ consumer,
+ std::unordered_set<IterDomain*>(p_root.begin(), p_root.end()));
+ }
// pair is in order where second is the consumer of first
std::pair<TensorDomain*, TensorDomain*> rFactor(const std::vector<int>& axes);
std::vector<IterDomain*> no_reduction_domain_;
const std::vector<IterDomain*> rfactor_domain_;
const std::vector<bool> contiguity_;
- bool has_nontrivial_reduction_;
};
-//! Representation a split on an IterDomain by "factor"
-//! inner_split dictates if the factor section of the split should be inside the
-//! remainer or outside.
+/*
+ * Representation a split on an IterDomain by "factor"
+ * TODO: Implement split by nparts
+ */
class TORCH_CUDA_CU_API Split : public Expr {
public:
- Split(
- IterDomain* outer,
- IterDomain* inner,
- IterDomain* in,
- Val* factor,
- bool inner_split = true);
+ ~Split() override = default;
+
+ Split(const Split& other) = delete;
+ Split& operator=(const Split& other) = delete;
+
+ Split(Split&& other) = delete;
+ Split& operator=(Split&& other) = delete;
+
+ Split(IterDomain* _outer, IterDomain* _inner, IterDomain* _in, Val* _factor);
Split(const Split* src, IrCloner* ir_cloner);
Val* factor() const {
return factor_;
}
-
- bool innerSplit() const {
- return inner_split_;
- }
-
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const Split* const other) const;
private:
IterDomain* const outer_ = nullptr;
IterDomain* const inner_ = nullptr;
IterDomain* const in_ = nullptr;
Val* const factor_ = nullptr;
- bool inner_split_ = true;
};
-//! Merge the IterDomains outer and inner into one domain, outer and inner
-//! dictate which will be traversed first (inner). Both IterDomains must be of
-//! the same iter or reduction type, as well as the same parallelization
-//! strategy if there is one
-//!
-//! \todo Should this be a unary op type?
-//!
+/*
+ * Merge the IterDomains outer and inner into one domain, outer and inner
+ * dictate which will be traversed first (inner). Both IterDomains must be of
+ * the same iter or reduction type, as well as the same parallelization strategy
+ * if there is one.
+ * TODO: Should this be a unary op type?
+ */
class TORCH_CUDA_CU_API Merge : public Expr {
public:
- Merge(IterDomain* out, IterDomain* outer, IterDomain* inner);
+ ~Merge() override = default;
+ Merge(IterDomain* _out, IterDomain* _outer, IterDomain* _inner);
Merge(const Merge* src, IrCloner* ir_cloner);
+ Merge(const Merge& other) = delete;
+ Merge& operator=(const Merge& other) = delete;
+
+ Merge(Merge&& other) = delete;
+ Merge& operator=(Merge&& other) = delete;
+
IterDomain* out() const {
return out_;
}
return inner_;
}
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const Merge* const other) const;
private:
IterDomain* const out_ = nullptr;
IterDomain* const inner_ = nullptr;
};
-//! Integer value which has a special name
-//!
-//! These could be:
-//! - threadIdx.x
-//! - blockIdx.y
-//! - blockDim.z
-//! - T3.stride[2]
-//!
+/*
+ * Integer value which has a special name. These could be:
+ * - threadIdx.x
+ * - blockIdx.y
+ * - blockDim.z
+ * - T3.stride[2]
+ */
class TORCH_CUDA_CU_API NamedScalar : public Val {
public:
+ ~NamedScalar() override = default;
+ NamedScalar() = delete;
+
// NOLINTNEXTLINE(modernize-pass-by-value)
- NamedScalar(std::string name, DataType dtype)
- : Val(ValType::NamedScalar, dtype), name_(name) {}
+ NamedScalar(std::string _name, DataType dtype)
+ : Val(ValType::NamedScalar, dtype), name_(_name) {}
NamedScalar(const NamedScalar* src, IrCloner* ir_cloner);
+ NamedScalar(const NamedScalar& other) = delete;
+ NamedScalar& operator=(const NamedScalar& other) = delete;
+
+ NamedScalar(NamedScalar&& other) = delete;
+ NamedScalar& operator=(NamedScalar&& other) = delete;
+
const std::string& name() const {
return name_;
}
- bool sameAs(const Statement* other) const override;
+ bool sameAs(const NamedScalar* const other) const {
+ return other->name().compare(name()) == 0;
+ }
- //! Return the named scalar extent of a parallel dimension (e.g. blockDim.x)
+ // Return the named scalar extent of a parallel dimension (e.g. blockDim.x)
static NamedScalar* getParallelDim(ParallelType p_type);
- //! Return the named scalar index of a parallel dimension (e.g. threadIdx.x)
+ // Return the named scalar index of a parallel dimension (e.g. threadIdx.x)
static NamedScalar* getParallelIndex(ParallelType p_type);
- //! Return the parallel type of this NamedScalar if it is an extent of a
- //! parallel dimension
+ // Return the parallel type of this NamedScalar if it is an extent of a
+ // parallel dimension
c10::optional<ParallelType> getParallelDim() const;
- //! Return the parallel type of this NamedScalar if it is an index of a
- //! parallel dimension
+ // Return the parallel type of this NamedScalar if it is an index of a
+ // parallel dimension
c10::optional<ParallelType> getParallelIndex() const;
private:
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <c10/util/irange.h>
void IrPrinter::handle(const TensorView* tv) {
if (tv->nDims() == 0) {
- os_ << typePrefix(tv->getDataType().value()) << tv->name();
- } else {
- os_ << "T" << tv->name();
- switch (tv->getMemoryType()) {
- case MemoryType::Global:
- os_ << "_g";
+ switch (tv->getDataType().value()) {
+ case DataType::Bool:
+ os_ << "b";
+ break;
+ case DataType::Float:
+ os_ << "f";
break;
- case MemoryType::Shared:
- os_ << "_s";
+ case DataType::Half:
+ os_ << "h";
break;
- case MemoryType::Local:
- os_ << "_l";
+ case DataType::Int:
+ os_ << "i";
break;
+ default:
+ TORCH_INTERNAL_ASSERT(
+ false, "Did not recognize type ", tv->getDataType().value());
}
+ os_ << tv->name();
+
+ } else {
+ os_ << "T" << tv->name();
handle(tv->domain());
- if (tv->getComputeAtPosition() > 0) {
- os_ << " ca_pos( ";
- os_ << tv->getComputeAtPosition();
- os_ << " )";
- }
- if (tv->getMaxProducerPosition() > 0) {
- os_ << " produce_pos( ";
- os_ << tv->getMaxProducerPosition();
- os_ << ")";
+ if (tv->getComputeAtView() != nullptr) {
+ os_ << " compute_at( ";
+ os_ << "T" << tv->getComputeAtView()->name();
+ os_ << ", " << tv->getRelativeComputeAtAxis() << " )";
}
}
}
}
void IrPrinter::handle(const Bool* b) {
- if (print_inline_ && b->definition() != nullptr) {
+ if (print_inline_ && FusionGuard::getCurFusion()->origin(b) != nullptr) {
os_ << "( ";
- handle(b->definition());
+ handle(FusionGuard::getCurFusion()->origin(b));
os_ << " )";
return;
}
}
}
-void IrPrinter::handle(const Double* d) {
- if (print_inline_ && d->definition() != nullptr) {
+void IrPrinter::handle(const Float* f) {
+ if (print_inline_ && FusionGuard::getCurFusion()->origin(f) != nullptr) {
os_ << "( ";
- handle(d->definition());
+ handle(FusionGuard::getCurFusion()->origin(f));
os_ << " )";
return;
}
- if (d->isSymbolic()) {
- os_ << "d" << d->name();
+ if (f->isSymbolic()) {
+ os_ << "f" << f->name();
} else {
- os_ << "double("
+ os_ << "float("
<< std::setprecision(
- std::numeric_limits<Double::ScalarType>::max_digits10)
- << *(d->value()) << ")";
+ std::numeric_limits<Float::ScalarType>::max_digits10)
+ << *(f->value()) << ")";
+ }
+}
+
+void IrPrinter::handle(const Half* h) {
+ if (print_inline_ && FusionGuard::getCurFusion()->origin(h) != nullptr) {
+ os_ << "( ";
+ handle(FusionGuard::getCurFusion()->origin(h));
+ os_ << " )";
+ return;
+ }
+
+ if (h->isSymbolic()) {
+ os_ << "h" << h->name();
+ } else {
+ os_ << "__float2half(" << *(h->value()) << ")";
}
}
void IrPrinter::handle(const Int* i) {
if (print_inline_) {
- if (auto def = i->definition()) {
+ if (auto def = FusionGuard::getCurFusion()->origin(i)) {
os_ << "( ";
handle(def);
os_ << " )";
os_ << i->name();
}
+void IrPrinter::handle(const kir::Bool* b) {
+ os_ << "kir::Bool (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::Float* f) {
+ os_ << "kir::Float (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::Half* h) {
+ os_ << "kir::Half (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::Int* i) {
+ os_ << "kir::Int (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::NamedScalar*) {
+ os_ << "kir::NamedScalar (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::TensorIndex*) {
+ os_ << "kir::TensorIndex (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::IterDomain*) {
+ os_ << "kir::IterDomain (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::TensorDomain*) {
+ os_ << "kir::TensorDomain (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::TensorView*) {
+ os_ << "kir::TensorView (use kir::toString() to print Kernel IR nodes)";
+}
+
static bool isTV(const Val* val) {
- return val->getValType().value() == ValType::TensorView;
+ return val->getValType().value() == ValType::TensorView ||
+ val->getValType().value() == ValType::TensorIndex;
}
// Check if we're a TensorView op that we can generate code for.
checkInlineable(uop);
}
- auto op_type = uop->getUnaryOpType();
-
- if (auto inline_uop = inline_op_str(op_type)) {
+ if (auto inline_uop = inline_op_str(uop->getUnaryOpType())) {
os_ << inline_uop.value();
handle(uop->in());
} else {
- if (op_type == UnaryOpType::Cast) {
+ if (uop->getUnaryOpType() == UnaryOpType::Cast) {
c10::optional<std::string> cast_str = cast_func_str(std::make_pair(
uop->in()->getDataType().value(), uop->out()->getDataType().value()));
TORCH_INTERNAL_ASSERT(cast_str != c10::nullopt, "Unsupported Cast");
os_ << cast_str.value();
} else {
- if (alsoBooleanOperator(op_type) &&
- uop->out()->getDataType().value() == DataType::Bool) {
- os_ << stringifyBooleanOp(op_type);
- } else {
- os_ << op_type;
- }
- if (uop->out()->getDataType().value() == DataType::Float &&
- needFloatSuffix(op_type)) {
- os_ << "f";
- }
+ os_ << uop->getUnaryOpType();
}
- if (op_type == UnaryOpType::RandLike) {
- os_ << "(";
- handle(uop->in());
- } else {
- os_ << "(";
+ os_ << "(";
+ if (uop->getUnaryOpType() == UnaryOpType::RandLike)
+ os_ << "rnd";
+ else
handle(uop->in());
- }
os_ << ")";
}
checkInlineable(bop);
}
- auto op_type = bop->getBinaryOpType();
- if (auto inline_bop = inline_op_str(op_type)) {
+ if (auto inline_bop = inline_op_str(bop->getBinaryOpType())) {
handle(bop->lhs());
if (istvop) {
os_ << "\n";
os_ << " " << inline_bop.value() << " ";
handle(bop->rhs());
} else {
- if (alsoBooleanOperator(op_type) &&
- bop->out()->getDataType().value() == DataType::Bool) {
- os_ << stringifyBooleanOp(op_type);
- } else {
- os_ << op_type;
- }
- if (bop->out()->getDataType().value() == DataType::Float &&
- needFloatSuffix(op_type)) {
- os_ << "f";
- }
- os_ << "(";
+ os_ << bop->getBinaryOpType() << "(";
handle(bop->lhs());
if (istvop) {
os_ << "\n";
os_ << ";\n";
}
+void IrPrinter::handle(const kir::UnaryOp* uop) {
+ os_ << "kir::UnaryOp (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::BinaryOp* bop) {
+ os_ << "kir::BinaryOp (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::TernaryOp* top) {
+ os_ << "kir::TernaryOp (use kir::toString() to print Kernel IR nodes)";
+}
+
void IrPrinter::handle(const ReductionOp* rop) {
+ TORCH_CHECK(rop->out()->getValType() != ValType::TensorIndex);
indent();
os_ << rop->out() << " = reduction( " << rop->in()
<< ", op = " << rop->getReductionOpType()
<< ", initial value = " << rop->init() << " )\n";
}
-void IrPrinter::handle(const WelfordOp* wop) {
- indent();
- os_ << wop->outAvg() << "(Avg),\n"
- << wop->outVar() << "(Var),\n"
- << wop->outN() << "(Count)"
- << "\n = Welford ( ";
- if (wop->singleValue()) {
- os_ << wop->inAvg() << "(Avg), ";
- } else {
- os_ << wop->inAvg() << "(Avg)\n " << wop->inVar() << "(Var)\n "
- << wop->inN() << "(Count)";
- }
- if (wop->hasInit()) {
- os_ << "\n initial value = " << wop->initAvg() << "(Avg)\n "
- << wop->initVar() << "(Var)\n " << wop->initN() << "(N)";
- }
- os_ << " )\n";
+void IrPrinter::handle(const kir::ReductionOp* rop) {
+ os_ << "kir::ReductionOp (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::GridReduction* gr) {
+ os_ << "kir::GridReduction (use kir::toString() to print Kernel IR nodes)";
}
void IrPrinter::handle(const BroadcastOp* bop) {
+ TORCH_CHECK(bop->out()->getValType() != ValType::TensorIndex);
indent();
os_ << bop->out() << " = broadcast( " << bop->in() << " )\n";
}
-void IrPrinter::handle(const TransposeOp* top) {
- indent();
- os_ << top->out() << " = transpose( " << top->in() << " )\n";
+void IrPrinter::handle(const kir::BroadcastOp*) {
+ os_ << "kir::BroadcastOp (use kir::toString() to print Kernel IR nodes)";
}
-void IrPrinter::handle(const ShiftOp* sop) {
- indent();
- os_ << sop->out() << " = shift( " << sop->in() << ", {" << sop->offsets()
- << "} )\n";
+void IrPrinter::handle(const kir::ForLoop* fl) {
+ os_ << "kir::ForLoop (use kir::toString() to print Kernel IR nodes)";
}
-void IrPrinter::handle(const GatherOp* op) {
- indent();
- os_ << op->out() << " = gather( " << op->in() << ", {";
- bool no_comma = true;
- for (const auto& s : op->windowShape()) {
- if (!no_comma) {
- os_ << ", ";
- }
- os_ << s;
- no_comma = false;
- }
- os_ << "}, {";
- no_comma = true;
- for (const auto& pad : op->padWidth()) {
- if (!no_comma) {
- os_ << ", ";
- }
- os_ << "{" << pad[0] << ", " << pad[1] << "}";
- no_comma = false;
- }
- os_ << "} )\n";
+void IrPrinter::handle(const kir::IfThenElse* ite) {
+ os_ << "kir::IfThenElse (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::Allocate* a) {
+ os_ << "kir::Allocate (use kir::toString() to print Kernel IR nodes)";
+}
+
+void IrPrinter::handle(const kir::Sync* a) {
+ os_ << "kir::Sync (use kir::toString() to print Kernel IR nodes)";
}
void IrPrinter::handle(const Split* s) {
- os_ << (s->innerSplit() ? "Split: " : "Outer split: ");
+ os_ << "Split: ";
handle(s->in());
os_ << " by factor " << s->factor() << " -> ";
handle(s->outer());
os_ << "\n";
}
-void IrTransformPrinter::handle(Fusion* f) {
- auto all_vals = f->usedMathVals();
-
- for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
- IrPrinter::handle(tv);
- os() << "\n";
- printTransforms(tv);
- }
-}
-
-void IrTransformPrinter::printTransforms(TensorView* tv) {
- auto root_domain = tv->getMaybeRFactorDomain();
- auto all_exp = DependencyCheck::getAllExprsBetween(
- {root_domain.begin(), root_domain.end()},
- {tv->domain()->domain().begin(), tv->domain()->domain().end()});
-
- os() << " root domain : (";
- for (size_t root_idx = 0; root_idx < root_domain.size(); root_idx++) {
- IrPrinter::handle(root_domain[root_idx]);
- if (root_idx + 1 < root_domain.size()) {
- os() << ",";
- }
- }
- os() << ")\n";
-
- for (auto exp : all_exp) {
- os() << " ";
- IrPrinter::handle(exp);
- }
-}
-
std::ostream& operator<<(std::ostream& os, const Statement* stmt) {
IrPrinter p(os);
p.handle(stmt);
void handle(const IterDomain*) override;
void handle(const Bool*) override;
- void handle(const Double*) override;
+ void handle(const Float*) override;
+ void handle(const Half*) override;
void handle(const Int*) override;
void handle(const NamedScalar*) override;
void handle(const BinaryOp*) override;
void handle(const TernaryOp*) override;
void handle(const ReductionOp*) override;
- void handle(const WelfordOp*) override;
void handle(const BroadcastOp*) override;
- void handle(const TransposeOp*) override;
- void handle(const ShiftOp*) override;
- void handle(const GatherOp*) override;
+
+ void handle(const kir::Bool*) override;
+ void handle(const kir::Float*) override;
+ void handle(const kir::Half*) override;
+ void handle(const kir::Int*) override;
+ void handle(const kir::NamedScalar*) override;
+
+ void handle(const kir::TensorIndex*) override;
+ void handle(const kir::IterDomain*) override;
+ void handle(const kir::TensorDomain*) override;
+ void handle(const kir::TensorView*) override;
+
+ void handle(const kir::UnaryOp*) override;
+ void handle(const kir::BinaryOp*) override;
+ void handle(const kir::TernaryOp*) override;
+ void handle(const kir::ReductionOp*) override;
+ void handle(const kir::BroadcastOp*) override;
+
+ void handle(const kir::GridReduction*) override;
+ void handle(const kir::ForLoop*) override;
+ void handle(const kir::IfThenElse*) override;
+ void handle(const kir::Allocate*) override;
+ void handle(const kir::Sync*) override;
void handle(const Split*) override;
void handle(const Merge*) override;
print_inline_ = prev;
}
- protected:
- std::ostream& os() {
- return os_;
- }
-
private:
std::ostream& os_;
bool print_inline_ = false;
TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream& os, Fusion* f);
TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream& os, Fusion& f);
+// TODO(kir): catch accidental << printing of Kernel IR nodes
+// (use kir::toString(node) instead)
+std::ostream& operator<<(std::ostream& os, const kir::Bool*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::Float*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::Half*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::Int*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::NamedScalar*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::TensorIndex*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::IterDomain*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::TensorDomain*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::TensorView*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::UnaryOp*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::BinaryOp*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::TernaryOp*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::ReductionOp*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::BroadcastOp*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::GridReduction*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::ForLoop*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::IfThenElse*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::Allocate*) = delete;
+std::ostream& operator<<(std::ostream& os, const kir::Sync*) = delete;
+
} // namespace cuda
} // namespace fuser
} // namespace jit
#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
namespace {
-class ScalarCheck : OptInConstDispatch {
+class ScalarCheck : OptInDispatch {
public:
- static bool sameAs(const Val* v1, const Val* v2) {
+ static bool sameAs(Val* v1, Val* v2) {
if (v1 == v2)
return true;
}
private:
- void handle(const Bool* b) override {
+ void handle(Bool* b) override {
same_ = v1_->as<Bool>()->sameAs(v2_->as<Bool>());
}
- void handle(const Double* d) override {
- same_ = v1_->as<Double>()->sameAs(v2_->as<Double>());
+ void handle(Float* f) override {
+ same_ = v1_->as<Float>()->sameAs(v2_->as<Float>());
}
- void handle(const Int* i) override {
+ void handle(Half* h) override {
+ same_ = v1_->as<Half>()->sameAs(v2_->as<Half>());
+ }
+
+ void handle(Int* i) override {
same_ = v1_->as<Int>()->sameAs(v2_->as<Int>());
}
- void handle(const NamedScalar* ns) override {
+ void handle(NamedScalar* ns) override {
same_ = v1_->as<NamedScalar>()->sameAs(v2_->as<NamedScalar>());
}
- ScalarCheck(const Val* _v1, const Val* _v2) : v1_(_v1), v2_(_v2) {
- OptInConstDispatch::handle(v1_);
+ ScalarCheck(Val* _v1, Val* _v2) : v1_(_v1), v2_(_v2) {
+ OptInDispatch::handle(v1_);
}
private:
- const Val* v1_ = nullptr;
- const Val* v2_ = nullptr;
+ Val* v1_ = nullptr;
+ Val* v2_ = nullptr;
bool same_ = false;
};
Bool::Bool(const Bool* src, IrCloner* ir_cloner)
: Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
-bool Bool::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<Bool>()) {
- return false;
- }
- const auto other_bool = other->as<Bool>();
- if (isConst() && other_bool->isConst()) {
- return *value() == *(other_bool->value());
- }
- return false;
+bool Bool::sameAs(const Bool* const other) const {
+ if (isConst() && other->isConst())
+ return *value() == *(other->value());
+ return this == other;
}
-Double::Double(const Double* src, IrCloner* ir_cloner)
+Float::Float(const Float* src, IrCloner* ir_cloner)
: Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
-bool Double::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<Double>()) {
- return false;
- }
- const auto other_double = other->as<Double>();
- if (isConst() && other_double->isConst())
- return *value() == *(other_double->value());
- return false;
+bool Float::sameAs(const Float* const other) const {
+ if (isConst() && other->isConst())
+ return *value() == *(other->value());
+ return this == other;
+}
+
+Half::Half(const Half* src, IrCloner* ir_cloner)
+ : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
+
+bool Half::sameAs(const Half* const other) const {
+ if (isConst() && other->isConst())
+ return *value() == *(other->value());
+ return this == other;
}
Int::Int(const Int* src, IrCloner* ir_cloner)
: Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
-bool Int::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<Int>()) {
- return false;
- }
- const auto other_int = other->as<Int>();
- if (isConst() && other_int->isConst()) {
- return *value() == *(other_int->value());
- }
- return false;
+bool Int::sameAs(const Int* const other) const {
+ if (isConst() && other->isConst())
+ return *value() == *(other->value());
+ return this == other;
}
-UnaryOp::UnaryOp(UnaryOpType type, Val* out, Val* in)
- : Expr(ExprType::UnaryOp), unary_op_type_{type}, out_{out}, in_{in} {
- addOutput(out);
- addInput(in);
+UnaryOp::UnaryOp(UnaryOpType _type, Val* _out, Val* _in)
+ : Expr(ExprType::UnaryOp), unary_op_type_{_type}, out_{_out}, in_{_in} {
+ addOutput(_out);
+ addInput(_in);
name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
out_(ir_cloner->clone(src->out_)),
in_(ir_cloner->clone(src->in_)) {}
-bool UnaryOp::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<UnaryOp>()) {
+bool UnaryOp::sameAs(const UnaryOp* const other) const {
+ if (type() != other->type())
return false;
- }
- const auto other_op = other->as<UnaryOp>();
- if (getUnaryOpType() != other_op->getUnaryOpType())
- return false;
- return Expr::sameAs(other);
+ return as<Expr>()->sameAs(other);
}
-BinaryOp::BinaryOp(BinaryOpType type, Val* out, Val* lhs, Val* rhs)
+BinaryOp::BinaryOp(BinaryOpType _type, Val* _out, Val* _lhs, Val* _rhs)
: Expr(ExprType::BinaryOp),
- binary_op_type_{type},
- out_{out},
- lhs_{lhs},
- rhs_{rhs} {
- addOutput(out);
- addInput(lhs);
- addInput(rhs);
+ binary_op_type_{_type},
+ out_{_out},
+ lhs_{_lhs},
+ rhs_{_rhs} {
+ addOutput(_out);
+ addInput(_lhs);
+ addInput(_rhs);
name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
lhs_(ir_cloner->clone(src->lhs_)),
rhs_(ir_cloner->clone(src->rhs_)) {}
-bool BinaryOp::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<BinaryOp>()) {
+bool BinaryOp::sameAs(const BinaryOp* other) const {
+ if (getBinaryOpType() != other->getBinaryOpType())
return false;
- }
- const auto other_op = other->as<BinaryOp>();
- if (getBinaryOpType() != other_op->getBinaryOpType())
+ if (!(lhs()->sameAs(other->lhs()) && rhs()->sameAs(other->rhs())))
return false;
- return Expr::sameAs(other);
+ return true;
}
-TernaryOp::TernaryOp(TernaryOpType type, Val* out, Val* in1, Val* in2, Val* in3)
+TernaryOp::TernaryOp(
+ TernaryOpType _type,
+ Val* _out,
+ Val* _in1,
+ Val* _in2,
+ Val* _in3)
: Expr(ExprType::TernaryOp),
- ternary_op_type_{type},
- out_{out},
- in1_{in1},
- in2_{in2},
- in3_{in3} {
- addOutput(out);
- addInput(in1);
- addInput(in2);
- addInput(in3);
+ ternary_op_type_{_type},
+ out_{_out},
+ in1_{_in1},
+ in2_{_in2},
+ in3_{_in3} {
+ addOutput(_out);
+ addInput(_in1);
+ addInput(_in2);
+ addInput(_in3);
name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
in2_(ir_cloner->clone(src->in2_)),
in3_(ir_cloner->clone(src->in3_)) {}
-bool TernaryOp::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<TernaryOp>()) {
+bool TernaryOp::sameAs(const TernaryOp* other) const {
+ if (getTernaryOpType() != other->getTernaryOpType())
return false;
- }
- const auto other_op = other->as<TernaryOp>();
- if (getTernaryOpType() != other_op->getTernaryOpType())
+ if (!(in1()->sameAs(other->in1()) && in2()->sameAs(other->in2()) &&
+ in3()->sameAs(other->in3())))
return false;
- return Expr::sameAs(other);
+ return true;
}
-BroadcastOp::BroadcastOp(Val* out, Val* in, std::vector<bool> is_broadcast_dims)
- : Expr(ExprType::BroadcastOp),
- out_(out),
- in_(in),
- is_broadcast_dims_(std::move(is_broadcast_dims)) {
- // clang-tidy complains about out_ that it may be null.
- TORCH_INTERNAL_ASSERT(out_ != nullptr);
- TORCH_INTERNAL_ASSERT(in_ != nullptr);
-
- auto out_type = out->getValType().value();
- auto in_type = in->getValType().value();
+BroadcastOp::BroadcastOp(Val* _out, Val* _in)
+ : Expr(ExprType::BroadcastOp), out_(_out), in_(_in) {
+ auto out_type = _out->getValType().value();
+ auto in_type = _in->getValType().value();
TORCH_INTERNAL_ASSERT(
out_type == ValType::TensorView && in_type == ValType::TensorView,
"Cannot braodcast a non-tensor object.");
- addOutput(out);
- addInput(in);
- name_ = FusionGuard::getCurFusion()->registerExpr(this);
-
// This is a generic check that root dims of a consumer and producer match.
// Maybe we shouldn't relegate it to this constructor.
- const auto c_tv = out_->as<TensorView>();
- const auto p_tv = in_->as<TensorView>();
+ const auto c_tv = out()->as<TensorView>();
+ const auto p_tv = in()->as<TensorView>();
const auto& c_root = c_tv->getRootDomain();
const auto& p_root = p_tv->getMaybeRFactorDomain();
- const auto root_p2c =
- PairwiseRootDomainMap(p_tv, c_tv)
- .mapProducerToConsumer(p_tv->domain(), c_tv->domain());
-
- for (auto id : p_root) {
- if (root_p2c.find(id) == root_p2c.end()) {
- TORCH_INTERNAL_ASSERT(
- id->isReduction(),
- "Invalid broadcast op: ",
- id,
- ". Non-reduction input dim does't match to output.");
- }
- }
+ const auto root_p2c = TensorDomain::mapDomainPandC(p_root, c_root);
+
+ std::vector<bool> c_mapped(c_root.size(), false);
+ std::vector<bool> p_mapped(p_root.size(), false);
- std::unordered_set<IterDomain*> c_mapped;
for (auto pair_entry : root_p2c) {
- c_mapped.insert(pair_entry.second);
+ auto p_i = pair_entry.first;
+ p_mapped[p_i] = true;
+ auto c_i = pair_entry.second;
+ c_mapped[c_i] = true;
}
- for (size_t i = 0; i < c_root.size(); ++i) {
- const auto c_id = c_root[i];
- if (c_mapped.find(c_id) != c_mapped.end()) {
- continue;
+ bool bad_mismatch = false;
+
+ for (const auto i : c10::irange(c_root.size())) {
+ if (!c_mapped[i]) {
+ if (!c_root[i]->isBroadcast()) {
+ bad_mismatch = true;
+ }
+ }
+ }
+
+ for (const auto i : c10::irange(p_root.size())) {
+ if (!p_mapped[i]) {
+ if (!p_root[i]->isReduction()) {
+ bad_mismatch = true;
+ }
}
- TORCH_INTERNAL_ASSERT(
- c_id->isBroadcast() && is_broadcast_dims_[i],
- "Invalid broadcast op: ",
- c_id,
- ". Non-broadcasted output dim isn't matched from input.");
}
+
+ TORCH_INTERNAL_ASSERT(
+ !bad_mismatch,
+ "Invalid broadcast op. Non-broadcasted dims don't match from input to output.");
+
+ addOutput(_out);
+ addInput(_in);
+ name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
out_(ir_cloner->clone(src->out_)),
- in_(ir_cloner->clone(src->in_)),
- is_broadcast_dims_(src->is_broadcast_dims_) {}
+ in_(ir_cloner->clone(src->in_)) {}
-bool BroadcastOp::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<BroadcastOp>()) {
- return false;
- }
- const auto other_op = other->as<BroadcastOp>();
- if (getBroadcastDimFlags() != other_op->getBroadcastDimFlags()) {
- return false;
- }
- return Expr::sameAs(other);
+bool BroadcastOp::sameAs(const BroadcastOp* const other) const {
+ return other->in() == in() && other->out() == out();
}
ReductionOp::ReductionOp(
- BinaryOpType reduction_op_type,
- Val* init,
- Val* out,
- Val* in)
+ BinaryOpType _reduction_op_type,
+ Val* _init,
+ Val* _out,
+ Val* _in)
: Expr(ExprType::ReductionOp),
- reduction_op_type_(reduction_op_type),
- init_(init),
- out_(out),
- in_(in) {
- TORCH_CHECK(out->getValType().value() == ValType::TensorView);
-
- TORCH_INTERNAL_ASSERT(
- in->getValType() == ValType::TensorView &&
- out->getValType() == ValType::TensorView,
- "Reduction operation was created that does not have tensor inputs and outputs.");
-
- TORCH_INTERNAL_ASSERT(
- TensorDomain::noReductions(in->as<TensorView>()->getMaybeRFactorDomain())
- .size() == out->as<TensorView>()->getRootDomain().size(),
- "Reduction operation created with mismatched domains.");
-
- TORCH_INTERNAL_ASSERT(
- init->isConstScalar(),
- "Tried to create a reduction operation whith an initial value that isn't a constant.");
-
- addOutput(out);
- addInput(in);
- name_ = FusionGuard::getCurFusion()->registerExpr(this);
-}
-
-WelfordOp::WelfordOp(
- Val* out_avg,
- Val* out_var,
- Val* out_N,
- Val* init_avg,
- Val* init_var,
- Val* init_N,
- Val* in_avg,
- Val* in_var,
- Val* in_N)
- : Expr(ExprType::WelfordOp),
- out_avg_(out_avg),
- out_var_(out_var),
- out_N_(out_N),
- init_avg_(init_avg),
- init_var_(init_var),
- init_N_(init_N),
- in_avg_(in_avg),
- in_var_(in_var),
- in_N_(in_N) {
- // Check output type
- TORCH_INTERNAL_ASSERT(out_avg->getValType().value() == ValType::TensorView);
- TORCH_INTERNAL_ASSERT(out_var->getValType().value() == ValType::TensorView);
- TORCH_INTERNAL_ASSERT(out_N->getValType().value() == ValType::TensorView);
-
- // check initial value
- TORCH_INTERNAL_ASSERT(init_N->getValType().value() == ValType::Scalar);
- if (!init_N->isZeroInt()) {
- // when initial count is zero, no initial variance or average is needed
- // initial value with a count of 1 is un-common enough that I'll push
- // the responsibility of creating all-zero var tensors to the user
- TORCH_INTERNAL_ASSERT(
- init_avg && init_avg->getValType().value() == ValType::TensorView);
+ reduction_op_type_(_reduction_op_type),
+ init_(_init),
+ out_(_out),
+ in_(_in) {
+ if (_out->getValType().value() == ValType::TensorView) {
TORCH_INTERNAL_ASSERT(
- init_var && init_var->getValType().value() == ValType::TensorView);
- }
+ _in->getValType() == ValType::TensorView &&
+ _out->getValType() == ValType::TensorView,
+ "Reduction operation was created that does not have tensor inputs and outputs.");
- TORCH_INTERNAL_ASSERT(
- in_avg && in_avg->getValType().value() == ValType::TensorView);
- // check input
- TORCH_INTERNAL_ASSERT(
- in_N->getValType().value() == ValType::Scalar ||
- in_N->getValType().value() == ValType::TensorView);
- if (!in_N->isOneInt()) {
- // when input is only one value, only the value is required through avg
- // input the var part is implicitly 0 and codegen will handle that.
TORCH_INTERNAL_ASSERT(
- in_var && in_var->getValType().value() == ValType::TensorView);
- }
-
- addOutput(out_avg);
- addOutput(out_var);
- addOutput(out_N);
+ TensorDomain::noReductions(
+ _in->as<TensorView>()->getMaybeRFactorDomain())
+ .size() == _out->as<TensorView>()->getRootDomain().size(),
+ "Reduction operation created with mismatched domains.");
- addInput(in_avg);
- // Conditionally adding this input?
- if (!in_N->isOneInt()) {
- addInput(in_var);
+ } else {
+ TORCH_INTERNAL_ASSERT(
+ _in->getValType() == ValType::TensorIndex &&
+ _out->getValType() == ValType::TensorIndex,
+ "Reduction operation was created that does not have tensor inputs and outputs.");
}
- addInput(in_N);
+ TORCH_INTERNAL_ASSERT(
+ _init->isConstScalar(),
+ "Tried to create a reduction operation whith an initial value that isn't a constant.");
+ addOutput(_out);
+ addInput(_in);
name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
-WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner)
- : Expr(src, ir_cloner),
- out_avg_(ir_cloner->clone(src->out_avg_)),
- out_var_(ir_cloner->clone(src->out_var_)),
- out_N_(ir_cloner->clone(src->out_N_)),
- init_avg_(src->init_avg_ ? ir_cloner->clone(src->init_avg_) : nullptr),
- init_var_(src->init_var_ ? ir_cloner->clone(src->init_var_) : nullptr),
- init_N_(ir_cloner->clone(src->init_N_)),
- in_avg_(ir_cloner->clone(src->in_avg_)),
- in_var_(src->in_var_ ? ir_cloner->clone(src->in_var_) : nullptr),
- in_N_(ir_cloner->clone(src->in_N_)) {}
-
-namespace {
-inline bool sameOptionalVal(Val* a, Val* b) {
- return ((a == nullptr && b == nullptr)) || ((a && b) && (a->sameAs(b)));
-}
-} // namespace
-
-bool WelfordOp::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (auto other_wop = dynamic_cast<const WelfordOp*>(other)) {
- return in_avg_->sameAs(other_wop->in_avg_) &&
- sameOptionalVal(in_var_, other_wop->in_var_) &&
- in_N_->sameAs(other_wop->in_N_) &&
- sameOptionalVal(init_avg_, other_wop->init_avg_) &&
- sameOptionalVal(init_var_, other_wop->init_var_) &&
- init_N_->sameAs(other_wop->init_N_);
- }
- return false;
-}
-
ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
reduction_op_type_(src->reduction_op_type_),
out_(ir_cloner->clone(src->out_)),
in_(ir_cloner->clone(src->in_)) {}
-bool ReductionOp::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<ReductionOp>()) {
- return false;
- }
- const auto other_op = other->as<ReductionOp>();
- // Note that init is not part of input vals, so it must be checked separately.
+bool ReductionOp::sameAs(const ReductionOp* other) const {
return (
- Expr::sameAs(other) &&
- getReductionOpType() == other_op->getReductionOpType() &&
- init()->sameAs(other_op->init()));
-}
-
-TransposeOp::TransposeOp(
- TensorView* out,
- TensorView* in,
- std::vector<int> new2old)
- : Expr(ExprType::TransposeOp),
- out_(out),
- in_(in),
- new2old_(std::move(new2old)) {
- // Sanity check of the input parameters. Maybe not necessary as they
- // should be checked at function transpose.
-
- TORCH_INTERNAL_ASSERT(
- !in->hasRFactor(), "Transposing rFactor tensors is not supported.");
-
- TORCH_INTERNAL_ASSERT(
- TensorDomain::noReductions(in->getRootDomain()).size() ==
- out->getRootDomain().size());
-
- TORCH_INTERNAL_ASSERT(new2old_.size() == out->getRootDomain().size());
-
- // Make sure the entries of new2old are unique and range from 0 to
- // N-1, where N == new2old.size().
- std::set<int> old_positions(new2old_.begin(), new2old_.end());
- TORCH_INTERNAL_ASSERT(old_positions.size() == new2old_.size());
- // old_positions is sorted, so the first entry must be 0.
- TORCH_INTERNAL_ASSERT(
- *(old_positions.begin()) == 0,
- "Invalid new2old vector detected: ",
- new2old_);
- // The last entry must be N-1, since old_positions is sorted, starts
- // with 0, and its length is N.
- TORCH_INTERNAL_ASSERT(
- *(old_positions.rbegin()) == (int)(new2old_.size() - 1),
- "Invalid new2old vector detected: ",
- new2old_);
-
- addOutput(out);
- addInput(in);
- name_ = FusionGuard::getCurFusion()->registerExpr(this);
-}
-
-TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner)
- : Expr(src, ir_cloner),
- out_(ir_cloner->clone(src->out_)),
- in_(ir_cloner->clone(src->in_)),
- new2old_(src->new2old_) {}
-
-ShiftOp::ShiftOp(Val* out, Val* in, std::vector<int> offsets)
- : Expr(ExprType::ShiftOp),
- out_(out),
- in_(in),
- offsets_(std::move(offsets)) {
- // clang-tidy complains about out_ that it may be null.
- TORCH_INTERNAL_ASSERT(out_ != nullptr);
- TORCH_INTERNAL_ASSERT(in_ != nullptr);
-
- auto out_type = out->getValType().value();
- auto in_type = in->getValType().value();
-
- TORCH_INTERNAL_ASSERT(
- out_type == ValType::TensorView && in_type == ValType::TensorView,
- "Cannot shift a non-tensor object.");
-
- TORCH_INTERNAL_ASSERT(
- offsets_.size() ==
- TensorDomain::noReductions(in_->as<TensorView>()->getRootDomain())
- .size(),
- "Invalid offset vector: ",
- offsets_);
-
- addOutput(out);
- addInput(in);
- name_ = FusionGuard::getCurFusion()->registerExpr(this);
-}
-
-ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner)
- : Expr(src, ir_cloner),
- out_(ir_cloner->clone(src->out_)),
- in_(ir_cloner->clone(src->in_)),
- offsets_(src->offsets_) {}
-
-bool ShiftOp::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<ShiftOp>()) {
- return false;
- }
- const auto other_op = other->as<ShiftOp>();
- if (offsets() != other_op->offsets()) {
- return false;
- }
- return Expr::sameAs(other);
-}
-
-GatherOp::GatherOp(
- Val* out,
- Val* in,
- std::vector<Int*> window_shape,
- std::vector<std::vector<Int*>> pad_width)
- : Expr(ExprType::GatherOp),
- out_(out),
- in_(in),
- window_shape_(std::move(window_shape)),
- pad_width_(std::move(pad_width)) {
- // clang-tidy complains about out_ that it may be null.
- TORCH_INTERNAL_ASSERT(out_ != nullptr);
- TORCH_INTERNAL_ASSERT(in_ != nullptr);
-
- auto out_type = out->getValType().value();
- auto in_type = in->getValType().value();
-
- TORCH_INTERNAL_ASSERT(
- out_type == ValType::TensorView && in_type == ValType::TensorView,
- "Cannot shift a non-tensor object.");
-
- const auto ndims =
- TensorDomain::noReductions(in_->as<TensorView>()->getRootDomain()).size();
-
- TORCH_INTERNAL_ASSERT(
- window_shape_.size() == ndims,
- "Invalid window_shape vector: ",
- window_shape_);
- TORCH_INTERNAL_ASSERT(
- pad_width_.size() == ndims, "Invalid pad_width vector: ", pad_width_);
-
- for (const auto& pad : pad_width_) {
- TORCH_INTERNAL_ASSERT(
- pad.size() == 2, "Padding size for each axis must have two Int vals.");
- }
-
- addOutput(out);
- addInput(in);
- name_ = FusionGuard::getCurFusion()->registerExpr(this);
-}
-
-GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner)
- : Expr(src, ir_cloner),
- out_(ir_cloner->clone(src->out_)),
- in_(ir_cloner->clone(src->in_)) {
- std::transform(
- src->window_shape_.begin(),
- src->window_shape_.end(),
- std::back_inserter(window_shape_),
- [&ir_cloner](const auto& x) { return ir_cloner->clone(x); });
- for (const auto& pad : src->pad_width_) {
- std::vector<Int*> pad_clone;
- std::transform(
- pad.begin(),
- pad.end(),
- std::back_inserter(pad_clone),
- [&ir_cloner](const auto& x) { return ir_cloner->clone(x); });
- pad_width_.push_back(pad_clone);
- }
-}
-
-bool GatherOp::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<GatherOp>()) {
- return false;
- }
- const auto other_op = other->as<GatherOp>();
- if (windowShape().size() != other_op->windowShape().size()) {
- return false;
- }
- for (size_t i = 0; i < windowShape().size(); ++i) {
- if (!windowShape()[i]->sameAs(other_op->windowShape()[i])) {
- return false;
- }
- }
- if (padWidth().size() != other_op->padWidth().size()) {
- return false;
- }
- for (size_t i = 0; padWidth().size(); ++i) {
- if (!padWidth()[i][0]->sameAs(other_op->padWidth()[i][0]) ||
- !padWidth()[i][1]->sameAs(other_op->padWidth()[i][1])) {
- return false;
- }
- }
- return Expr::sameAs(other);
-}
-
-int GatherOp::gatherAxis(int axis) const {
- if (axis < 0) {
- axis += out()->as<TensorView>()->nDims();
- }
- TORCH_INTERNAL_ASSERT(
- axis >= 0 && axis < (int)windowShape().size(), "Invalid axis: ", axis);
- return int(windowShape().size()) + axis;
+ in()->sameAs(other->in()) &&
+ getReductionOpType() == other->getReductionOpType() &&
+ init()->sameAs(other->init()));
}
IterDomain::IterDomain(
- Val* start,
- Val* extent,
- ParallelType parallel_type,
- IterType iter_type,
- bool is_rfactor_domain)
+ Val* _start,
+ Val* _extent,
+ ParallelType _parallel_type,
+ IterType _iter_type,
+ bool _is_rfactor_domain)
: Val(ValType::IterDomain, DataType::Int, false),
- start_(start),
- extent_(extent),
- parallel_type_(parallel_type),
- iter_type_(iter_type),
- is_rfactor_domain_(is_rfactor_domain) {
+ start_(_start),
+ extent_(_extent),
+ parallel_type_(_parallel_type),
+ iter_type_(_iter_type),
+ is_rfactor_domain_(_is_rfactor_domain) {
TORCH_CHECK(
!(isRFactorProduct() && isBroadcast()),
"IterDomain cannot be both a broadcast and rfactor domain.");
TORCH_INTERNAL_ASSERT(
- extent->isAnInt(),
+ _extent->isAnInt(),
"Cannot create an iter domain over an extent that is not an int but received ",
- extent,
+ _extent,
" .");
TORCH_INTERNAL_ASSERT(
- start->isAnInt(),
+ _start->isAnInt(),
"Cannot create an iter domain with a start that is not an int but received ",
- start,
+ _extent,
" .");
// Check that all for-loops iterate from zero to some positive integer
// lower_insert_syncs uses this assumption for correctness.
TORCH_INTERNAL_ASSERT(
- start->isZeroInt(),
+ _start->isZeroInt(),
"Cannot create an iter domain with a start that is non-zero but received ",
- start,
+ _extent,
" .");
+ TORCH_INTERNAL_ASSERT(
+ !_extent->isZeroInt(),
+ "Cannot create an iter domain with a extent that is zero but received ",
+ _extent,
+ " .");
+
+ // TORCH_INTERNAL_ASSERT(!kir::isLoweredVal(_extent));
+
name_ = fusion_->registerVal(this);
}
iter_type_(src->iter_type_),
is_rfactor_domain_(src->is_rfactor_domain_) {}
-bool IterDomain::sameAs(const Statement* other) const {
- if (other == this) {
+bool IterDomain::sameAs(const IterDomain* const other) const {
+ if (other == this)
return true;
- }
-
- if (!other->isA<IterDomain>()) {
- return false;
- }
-
- const IterDomain* other_id = other->as<IterDomain>();
- bool is_same = isReduction() == other_id->isReduction() &&
- getParallelType() == other_id->getParallelType();
- is_same = is_same && ScalarCheck::sameAs(extent(), other_id->extent());
- is_same = is_same && ScalarCheck::sameAs(start(), other_id->start());
+ bool is_same = isReduction() == other->isReduction() &&
+ getParallelType() == other->getParallelType();
+ is_same = is_same && ScalarCheck::sameAs(extent(), other->extent());
+ is_same = is_same && ScalarCheck::sameAs(start(), other->start());
return is_same;
}
-std::vector<IterDomain*> IterDomain::clone(
- const std::vector<IterDomain*>& domains) {
- std::vector<IterDomain*> cloned_domains;
- std::transform(
- domains.begin(),
- domains.end(),
- std::back_inserter(cloned_domains),
- [](auto id) { return id->clone(); });
- return cloned_domains;
-}
-
IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
TORCH_CHECK(
outer->start()->isZeroInt() && inner->start()->isZeroInt(),
"Merging IterDomains with starting values that aren't 0 is not supported at this time.");
TORCH_CHECK(
- !outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(),
- "Merging IterDomains with ending values that are 0 is not supported at this time.");
- TORCH_CHECK(
- outer->isReduction() == inner->isReduction() ||
- (!outer->isReduction() && inner->extent()->isOneInt()) ||
- (outer->extent()->isOneInt() && !inner->isReduction()),
+ outer->isReduction() == inner->isReduction(),
"Merging IterDomains requires that their iteration types match.");
TORCH_CHECK(
- (outer->isGather() && inner->isGather()) ||
- (!outer->isGather() && !inner->isGather()),
- "Merging gather and non-gather domains is not supported.");
+ outer->getParallelType() == inner->getParallelType(),
+ "Merging IterDomains requires that their parallel types match.");
Val* merged_id_size = mul(outer->extent(), inner->extent());
itype = IterType::Iteration;
}
- // Merging trivial reduction with iter domain, that's fine, just make it an
- // iter domain.
- if ((outer->isReduction() || inner->isReduction()) &&
- (!outer->isReduction() || !inner->isReduction())) {
- itype = IterType::Iteration;
- }
-
IterDomain* merged_id = new IterDomain(
new Int(0),
merged_id_size->as<Int>(),
std::pair<IterDomain*, IterDomain*> IterDomain::split(
IterDomain* in,
- Val* factor,
- bool inner_split) {
+ Val* factor) {
TORCH_CHECK(
in->start()->isZeroInt(),
"Splitting IterDomains with starting values that aren't 0 is not supported at this time.");
- TORCH_CHECK(
- !in->extent()->isZeroInt(),
- "Splitting IterDomains with ending values that are 0 is not supported at this time.");
+ if (in->getParallelType() != ParallelType::Serial)
+ TORCH_CHECK(
+ false,
+ "Splitting an axis of non-Serial iteration is not supported at this time."
+ " Parallelization strategy must be set after calling split.");
TORCH_CHECK(factor->isAnInt(), "Cannot split by non-integer value ", factor);
}
// outer loop size
- Val* remainder = ceilDiv(in->extent(), factor);
+ Val* vo = ceilDiv(in->extent(), factor);
// outer loop IterDomain
IterDomain* ido = new IterDomain(
new Int(0),
- inner_split ? remainder->as<Int>() : factor,
+ vo->as<Int>(),
in->getParallelType(),
in->getIterType(),
in->isRFactorProduct());
// inner loop IterDomain
IterDomain* idi = new IterDomain(
new Int(0),
- inner_split ? factor : remainder->as<Int>(),
+ factor,
in->getParallelType(),
in->getIterType(),
in->isRFactorProduct());
- new Split(ido, idi, in, factor, inner_split);
+ new Split(ido, idi, in, factor);
return {ido, idi};
}
-// TODO: We should change parallelize interface to be on tensorview or at least
-// vectorize should be done on tensorview. This would let us check that we don't
-// vectorize to the left of the computeAt domain, and could allow us to do some
-// simple validation of vectorize as it's inputs are right most and contiguous.
-void IterDomain::parallelize(ParallelType t) {
- parallel_type_ = t;
- if (t == ParallelType::Unroll || isParallelTypeVectorize(t)) {
- TORCH_CHECK(
- start()->isZeroInt() && extent()->isConstScalar(),
- "Vectorization, unrolling, and unswitching are only supported with start = 0 and extent as a const int, but got ",
- "a start of ",
- start(),
- " and extent ",
- extent(),
- " .");
+// TODO(kir): review if this is still needed in the Fusion IR
+Val* IterDomain::extent() const {
+ if (isThread()) {
+ // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
+ if (extent_->getValType() == ValType::Scalar)
+ if (extent_->as<Int>()->isConst())
+ return extent_;
+
+ return NamedScalar::getParallelDim(getParallelType());
}
+ return extent_;
}
TensorDomain::TensorDomain(
- std::vector<IterDomain*> root_domain,
- std::vector<bool> contiguity)
- : Val(ValType::TensorDomain, DataType::Null, false),
- root_domain_(std::move(root_domain)),
+ std::vector<IterDomain*> _domain,
+ std::vector<bool> _contiguity)
+ : Val(ValType::TensorDomain),
+ root_domain_(std::move(_domain)),
contiguity_(
- contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
- : std::move(contiguity)) {
+ _contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
+ : std::move(_contiguity)) {
TORCH_CHECK(
contiguity_.size() == root_domain_.size(),
"Invalid contiguity information provided, incorrect size. Recieved vector of size ",
" but needed one of size ",
root_domain_.size());
- // Just due to clang-tidy, correct value set in resetDomains
- has_nontrivial_reduction_ = false;
domain_ = root_domain_;
resetDomains();
- name_ = fusion_->registerVal(this);
}
TensorDomain::TensorDomain(
- std::vector<IterDomain*> root_domain,
- std::vector<IterDomain*> domain,
- std::vector<bool> contiguity)
+ std::vector<IterDomain*> _root_domain,
+ std::vector<IterDomain*> _domain,
+ std::vector<bool> _contiguity)
: Val(ValType::TensorDomain, DataType::Null, false),
- root_domain_(std::move(root_domain)),
- domain_(std::move(domain)),
+ root_domain_(std::move(_root_domain)),
+ domain_(std::move(_domain)),
contiguity_(
- contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
- : std::move(contiguity)) {
+ _contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
+ : std::move(_contiguity)) {
TORCH_CHECK(
contiguity_.size() == root_domain_.size(),
"Invalid contiguity information provided, incorrect size. Recieved vector of size ",
std::vector<Val*> domain_vals(domain_.begin(), domain_.end());
auto inps = IterVisitor::getInputsTo(domain_vals);
- // Validate that the root domain consists of all inputs to domain
+ // Validate that the root domain consists of all inputs to _domain
// Uncertain if this will hold for RFactor
std::unordered_set<Val*> root_vals(root_domain_.begin(), root_domain_.end());
" is an input of domain, but it is not found in the root domain.");
});
- // Just due to clang-tidy, correct value set in resetDomains
- has_nontrivial_reduction_ = false;
resetDomains();
+
name_ = fusion_->registerVal(this);
}
TensorDomain::TensorDomain(
- std::vector<IterDomain*> root_domain,
- std::vector<IterDomain*> rfactor_domain,
- std::vector<IterDomain*> domain,
- std::vector<bool> contiguity)
+ std::vector<IterDomain*> _root_domain,
+ std::vector<IterDomain*> _rfactor_domain,
+ std::vector<IterDomain*> _domain,
+ std::vector<bool> _contiguity)
: Val(ValType::TensorDomain, DataType::Null, false),
- root_domain_(std::move(root_domain)),
- domain_(std::move(domain)),
- rfactor_domain_(std::move(rfactor_domain)),
+ root_domain_(std::move(_root_domain)),
+ domain_(std::move(_domain)),
+ rfactor_domain_(std::move(_rfactor_domain)),
contiguity_(
- contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
- : std::move(contiguity)) {
+ _contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
+ : std::move(_contiguity)) {
TORCH_CHECK(
contiguity_.size() == root_domain_.size(),
"Invalid contiguity information provided, incorrect size. Recieved vector of size ",
auto inps = IterVisitor::getInputsTo(
std::vector<Val*>(domain_.begin(), domain_.end()));
- // Validate that the root domain consists of all inputs to domain
+ // Validate that the root domain consists of all inputs to _domain
// Uncertain if this will hold for RFactor
std::unordered_set<Val*> root_vals(root_domain_.begin(), root_domain_.end());
" is an input of the rfactor domain, but it is not found in the root domain.");
});
- // Just due to clang-tidy, correct value set in resetDomains
- has_nontrivial_reduction_ = false;
resetDomains();
name_ = fusion_->registerVal(this);
}
no_bcast_domain_(ir_cloner->clone(src->no_bcast_domain_)),
no_reduction_domain_(ir_cloner->clone(src->no_reduction_domain_)),
rfactor_domain_(ir_cloner->clone(src->rfactor_domain_)),
- contiguity_(src->contiguity()),
- has_nontrivial_reduction_(src->has_nontrivial_reduction_) {}
+ contiguity_(src->contiguity()) {}
bool TensorDomain::operator==(const TensorDomain& other) const {
// Checks equality of each class field. Should not be necessary to
contiguity_ == other.contiguity_;
}
-bool TensorDomain::sameAs(const Statement* const other) const {
- if (this == other) {
- return true;
- }
-
- if (!other->isA<TensorDomain>()) {
+bool TensorDomain::sameAs(const TensorDomain* const other) const {
+ if (nDims() != other->nDims())
return false;
- }
-
- const TensorDomain* other_td = other->as<TensorDomain>();
-
- if (nDims() != other_td->nDims()) {
+ if (getRootDomain().size() != other->getRootDomain().size())
return false;
- }
- if (getRootDomain().size() != other_td->getRootDomain().size()) {
- return false;
- }
- if (getRFactorDomain().size() != other_td->getRFactorDomain().size()) {
+ if (getRFactorDomain().size() != other->getRFactorDomain().size())
return false;
- }
- for (size_t i = 0; i < nDims(); i++) {
- if (!(axis(i)->sameAs(other_td->axis(i)))) {
+ for (const auto i : c10::irange(nDims()))
+ if (!(axis(i)->sameAs(other->axis(i))))
return false;
- }
- }
- for (size_t i = 0; i < getRootDomain().size(); i++) {
- if (!(getRootDomain()[i]->sameAs(other_td->getRootDomain()[i]))) {
+ for (const auto i : c10::irange(getRootDomain().size()))
+ if (!(getRootDomain()[i]->sameAs(other->getRootDomain()[i])))
return false;
- }
- }
- for (size_t i = 0; i < getRFactorDomain().size(); i++) {
- if (!(getRFactorDomain()[i]->sameAs(other_td->getRFactorDomain()[i]))) {
+ for (const auto i : c10::irange(getRFactorDomain().size()))
+ if (!(getRFactorDomain()[i]->sameAs(other->getRFactorDomain()[i])))
return false;
- }
- }
return true;
}
}
bool TensorDomain::hasReduction() const {
- return has_nontrivial_reduction_;
+ return no_reduction_domain_.size() != domain_.size();
}
bool TensorDomain::hasBlockReduction() const {
});
}
+bool TensorDomain::hasBlockBroadcast() const {
+ return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
+ return id->isBroadcast() && id->isThreadDim();
+ });
+}
+
bool TensorDomain::hasBroadcast() const {
return no_bcast_domain_.size() != domain_.size();
}
return !rfactor_domain_.empty();
}
-bool TensorDomain::hasVectorize() const {
- return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
- return id->getParallelType() == ParallelType::Vectorize ||
- id->getParallelType() == ParallelType::MisalignedVectorize;
- });
-}
-
c10::optional<unsigned int> TensorDomain::getReductionAxis() const {
auto it = std::find_if(domain_.begin(), domain_.end(), [](const auto& id) {
return id->isReduction();
TORCH_CHECK(false, "Provided id is not part of this domain.");
}
-void TensorDomain::split(int axis_, Val* factor, bool inner_split) {
+void TensorDomain::split(int axis_, Val* factor) {
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim domain");
if (axis_ < 0)
axis_ += nDims();
"Tried to split on axis outside TensorDomain's range.");
IterDomain* id = axis(axis_);
- auto split_ids = IterDomain::split(id, factor, inner_split);
+ auto split_ids = IterDomain::split(id, factor);
domain_.erase(domain_.begin() + axis_);
domain_.insert(domain_.begin() + axis_, split_ids.second);
domain_.insert(domain_.begin() + axis_, split_ids.first);
// Eventhough these checks are already in TensorView, we want to redo them as
// we can enter this function from other places, not through TensorView
- auto new2old = ir_utils::normalizeOld2New(old2new_, dom.size());
+ // adjust based on negative values (any negative values gets nDims added to
+ // it)
+ std::unordered_map<int, int> old2new;
+ auto ndims = dom.size();
+ std::transform(
+ old2new_.begin(),
+ old2new_.end(),
+ std::inserter(old2new, old2new.begin()),
+ [ndims](std::unordered_map<int, int>::value_type entry) {
+ return std::unordered_map<int, int>::value_type({
+ entry.first < 0 ? entry.first + ndims : entry.first,
+ entry.second < 0 ? entry.second + ndims : entry.second,
+ });
+ });
+
+ // Check if any adjusted values are < 0, or >= nDims, which are invalid
+
+ TORCH_CHECK(
+ std::none_of(
+ old2new.begin(),
+ old2new.end(),
+ [ndims](std::unordered_map<int, int>::value_type entry) {
+ return entry.first < 0 || (unsigned int)entry.first >= ndims ||
+ entry.second < 0 || (unsigned int)entry.second >= ndims;
+ }),
+ "Reorder axes are not within the number of dimensions of the provided domain.");
+
+ // Going to use sets, to see if any duplicate values are in the map.
+
+ std::set<int> old_pos_set;
+ std::transform(
+ old2new.begin(),
+ old2new.end(),
+ std::inserter(old_pos_set, old_pos_set.begin()),
+ [](std::unordered_map<int, int>::value_type entry) {
+ return entry.first;
+ });
+
+ std::set<int> new_pos_set;
+ std::transform(
+ old2new.begin(),
+ old2new.end(),
+ std::inserter(new_pos_set, new_pos_set.begin()),
+ [](std::unordered_map<int, int>::value_type entry) {
+ return entry.second;
+ });
+
+ // Error out if duplicate values are found.
+ TORCH_CHECK(
+ old_pos_set.size() == old2new.size() &&
+ new_pos_set.size() == old2new.size(),
+ "Duplicate entries in transformation map sent to TensorView reorder.");
+
+ // END VALIDATION CHECKS
+
+ std::vector<int> new2old(ndims, -1);
+
+ // Go through each old and new position, make sure they're within [0, ndims)
+ for (std::pair<int, int> elem : old2new) {
+ int old_pos = elem.first;
+ int new_pos = elem.second;
+ new2old[new_pos] = old_pos;
+ }
+
+ // old_positions that already have a new position
+ std::set<int> old_positions(new2old.begin(), new2old.end());
+ old_positions.erase(-1);
+
+ // All available new positions
+ std::set<int> all_positions;
+ for (const auto i : c10::irange(ndims))
+ all_positions.insert(i);
+
+ // Check what positions haven't been specified.
+ std::set<int> positions_left;
+ std::set_difference(
+ all_positions.begin(),
+ all_positions.end(),
+ old_positions.begin(),
+ old_positions.end(),
+ std::inserter(positions_left, positions_left.end()));
+
+ // Fill in positions that weren't specified, in relative order,
+ // in empty spots in the set of new positions.
+ // new2old[new_position] = old_position
+ auto it = positions_left.begin(); // old positions left
+ std::transform(
+ new2old.begin(), new2old.end(), new2old.begin(), [&it](int i) -> int {
+ return i == -1 ? *it++ : i;
+ });
std::vector<IterDomain*> reordered_domain;
std::transform(
std::vector<IterDomain*> TensorDomain::noReductions(
const std::vector<IterDomain*>& td) {
size_t size_out = 0;
- for (auto id : td)
+ for (const auto& id : td)
if (!id->isReduction())
size_out++;
std::vector<IterDomain*> noReductionDomain(size_out);
int it = 0;
- for (auto id : td)
+ for (const auto& id : td)
if (!id->isReduction())
noReductionDomain[it++] = id;
std::vector<IterDomain*> TensorDomain::noBroadcasts(
const std::vector<IterDomain*>& td) {
size_t size_out = 0;
- for (auto id : td)
+ for (const auto& id : td)
if (!id->isBroadcast())
size_out++;
std::vector<IterDomain*> noBroadcastDomain(size_out);
int it = 0;
- for (auto id : td)
+ for (const auto& id : td)
if (!id->isBroadcast())
noBroadcastDomain[it++] = id;
}
bool TensorDomain::hasBroadcast(const std::vector<IterDomain*>& td) {
- for (auto id : td)
+ for (const auto& id : td)
if (id->isBroadcast())
return true;
return false;
}
-
bool TensorDomain::hasReduction(const std::vector<IterDomain*>& td) {
- for (auto id : td)
+ for (const auto& id : td)
if (id->isReduction())
return true;
return false;
}
-bool TensorDomain::hasNontrivialReduction(const std::vector<IterDomain*>& td) {
- for (auto id : td) {
- if (id->isReduction() && !id->isTrivialReduction()) {
- return true;
+std::vector<std::pair<int, int>> TensorDomain::mapDomainPandC(
+ const std::vector<IterDomain*>& producer,
+ const std::vector<IterDomain*>& consumer) {
+ std::vector<std::pair<int, int>> dom_map;
+
+ size_t itc = 0, itp = 0;
+ while (itc < consumer.size() && itp < producer.size()) {
+ if (consumer[itc]->isBroadcast() && !producer[itp]->isBroadcast()) {
+ itc++;
+ continue;
+ }
+ if (producer[itp]->isReduction()) {
+ itp++;
+ continue;
+ }
+
+ dom_map.emplace_back(std::make_pair(itp, itc));
+ itc++;
+ itp++;
+ }
+ return dom_map;
+}
+
+std::vector<std::pair<IterDomain*, IterDomain*>> TensorDomain::mapRootPandC(
+ const TensorDomain* producer,
+ const TensorDomain* consumer) {
+ auto consumer_root = consumer->getRootDomain();
+ auto producer_root = producer->getMaybeRFactorDomain();
+ std::vector<std::pair<IterDomain*, IterDomain*>> root_id_map;
+ for (const auto& m : mapDomainPandC(producer_root, consumer_root)) {
+ auto producer_axis = producer_root[m.first];
+ auto consumer_axis = consumer_root[m.second];
+ root_id_map.emplace_back(std::make_pair(producer_axis, consumer_axis));
+ }
+ return root_id_map;
+}
+
+std::unordered_map<IterDomain*, IterDomain*> TensorDomain::mapRootCtoP(
+ const TensorDomain* consumer,
+ const TensorDomain* producer,
+ const std::unordered_set<IterDomain*>& consumer_root_dims_to_map) {
+ std::unordered_map<IterDomain*, IterDomain*> root_id_map;
+ for (const auto& kv : mapRootPandC(producer, consumer)) {
+ auto producer_axis = kv.first;
+ auto consumer_axis = kv.second;
+ if (consumer_root_dims_to_map.find(consumer_axis) !=
+ consumer_root_dims_to_map.end()) {
+ root_id_map[consumer_axis] = producer_axis;
}
}
- return false;
+ return root_id_map;
}
-// TODO: Rfactor a Welford
+std::unordered_map<IterDomain*, IterDomain*> TensorDomain::mapRootPtoC(
+ const TensorDomain* producer,
+ const TensorDomain* consumer,
+ const std::unordered_set<IterDomain*>& producer_maybe_rfactor_dims_to_map) {
+ std::unordered_map<IterDomain*, IterDomain*> root_id_map;
+ for (const auto& kv : mapRootPandC(producer, consumer)) {
+ auto producer_axis = kv.first;
+ auto consumer_axis = kv.second;
+ if (producer_maybe_rfactor_dims_to_map.find(producer_axis) !=
+ producer_maybe_rfactor_dims_to_map.end()) {
+ root_id_map[producer_axis] = consumer_axis;
+ }
+ }
+ return root_id_map;
+}
// pair is in order where second is the consumer of first
std::pair<TensorDomain*, TensorDomain*> TensorDomain::rFactor(
std::vector<int> axes(axes_.size());
- auto ndims = nDims();
+ const auto ndims = nDims();
std::transform(axes_.begin(), axes_.end(), axes.begin(), [ndims](int i) {
return i < 0 ? i + ndims : i;
});
bool rfactor_found = false;
bool reduction_found = false;
- for (decltype(nDims()) i{0}; i < nDims(); i++) {
+ for (const auto i : c10::irange(nDims())) {
if (axis(i)->isReduction()) {
if (axes_set.find(i) != axes_set.end()) {
rfactor_found = true;
namespace {
+//! Container class DisjointSet models equivalence relationships
+//!
+//! Each instance of this class keeps a set of equivalent classes
+//! DisjointSet::join(a,b) makes the full class of a and b equivalent
+//! DisjointSet::areEqual(a,b) checks if a and b belong same class
+//!
+//! \note The template type T is assumed to be hashable
+template <typename T>
+class DisjointSet {
+ public:
+ DisjointSet() = default;
+
+ //! Joins the equivalent class that a and b belong to
+ //! areEqual(a',b') will be true for each a'=a and b'=b
+ //!
+ //! \param a An element from a equivalent class
+ //! will create a new equivalent class if a does
+ //! not belong to any
+ //! \param b An element from another equivalent class
+ //! will create a new equivalent class if b does
+ //! not belong to any
+ void join(T a, T b) {
+ // cases where either of the quiv class doesn't exist
+ if (!entry_map.count(a) && !entry_map.count(b)) {
+ createPoint(a);
+ entry_map[b] = fixedPoint(a);
+ } else if (!entry_map.count(a)) {
+ entry_map[a] = fixedPoint(b);
+ } else if (!entry_map.count(b)) {
+ entry_map[b] = fixedPoint(a);
+ } else {
+ // case where both equiv classes exist and need to join
+ const int i0 = fixedPoint(a);
+ const int i1 = fixedPoint(b);
+ int new_parent = 0;
+ int new_child = 0;
+
+ // Either order here is correct but joining larger class to smaller class
+ // tend to be faster
+ std::tie(new_parent, new_child) = (weights[i0] < weights[i1])
+ ? std::make_pair(i0, i1)
+ : std::make_pair(i1, i0);
+ weights[new_parent] += weights[new_child];
+ set_map[new_child] = new_parent;
+ }
+ }
+
+ //! Checks if a and b belong to the same equivalent class
+ //!
+ //! \param a An element from a equivalent class
+ //! \param b An element from another equivalent class
+ //! \returns Boolean value representing if a and b are
+ //! recorded to be in the same equivalent class
+ //! will return false if any of a or b doesn't
+ //! have an equivalent class recorded
+ bool areEquivalent(T a, T b) const {
+ if (!entry_map.count(a) || !entry_map.count(b)) {
+ return false;
+ }
+ return fixedPoint(a) == fixedPoint(b);
+ }
+
+ private:
+ // Internal fixed point implementation:
+ // Returns the equivalent class that e belongs to
+ int fixedPoint(int e) const {
+ TORCH_INTERNAL_ASSERT(static_cast<int>(set_map.size()) > e);
+ while (set_map[e] != e) {
+ // Chasing to fixed point
+ e = set_map[e];
+ }
+ return e;
+ }
+
+ //! Utility to check the class i belongs to:
+ //!
+ //! Will create a new class if no match seen
+ //! \param e element e to find the equiv class for
+ //! \returns the equivalent class that e belongs to
+ //!
+ int fixedPoint(T e) const {
+ // Handles case when i doesn't have an equivalence class
+ TORCH_INTERNAL_ASSERT(entry_map.count(e));
+
+ // Use fixed point as a representation for the equiv class
+ return fixedPoint(entry_map.at(e));
+ }
+
+ //! Utility to create a new equiv class for i
+ //
+ //! \param i Element i to create the equiv class for
+ void createPoint(T i) {
+ entry_map[i] = next_index_;
+ set_map.push_back(next_index_++);
+ weights.push_back(1);
+ }
+
+ private:
+ // Internal representation of the equivalence class as integers
+ // set_map implements the "parent" relationship
+ std::vector<int> set_map;
+ // Weights is used for preliminary perf optimization
+ std::vector<int> weights;
+
+ // Map the input of type T to its equivalence class
+ std::unordered_map<T, int> entry_map;
+
+ // Running counter for generating new index when
+ // Creating new equiv classes
+ int next_index_ = 0;
+};
+
//! Concretize broadcast axes, i.e. identifying a non-broadcast
//! IterDomain that the broadcast IterDomain can map to.
//!
}
}
+//! Models equivalence provable by the graph
+//!
+//! This traversal processes root domains only,
+//! equalities , e.g. :
+//! T2 [i0,i1] = T1[i2,i3] + T0[i4,i5]
+//! will prove that i2 and i4 are equal in the sense that
+//! i2.start = i4.start, i2.extent = i4.extent
+//! Depends on ConcretizeDomain, and equalities involving
+//! broadcast domains are defined based on the concretized version
+class ProveValEqual : private IterVisitor {
+ public:
+ explicit ProveValEqual(Fusion* fusion) : cd_(fusion) {
+ traverseFrom(fusion, fusion->outputs(), false);
+ }
+
+ //! Checks if two scalars are equal
+ //!
+ //! First checks if ScalarCheck has them equal,
+ //! next try to prove them equal from
+ //! the graph_traversal result
+ //!
+ //! \param a A symbolic value
+ //! \param b Another value from the same fusion
+ //! \returns Boolean representing if they are proven to be
+ //! equal based on scalar check and graph traversal
+ bool areEqual(Val* a, Val* b) const {
+ if (ScalarCheck::sameAs(a, b)) {
+ return true;
+ }
+ if (eq_set_.areEquivalent(a, b)) {
+ return true;
+ }
+ return false;
+ }
+
+ //! Checks if two iterdomains are equal
+ //!
+ //! Equality defined as equal start and equal extent
+ //! true means a and b are equal
+ //! false only means that they cannot be proven equal based
+ //! on scalar check and graph traversal
+ //!
+ //! \param a An iterdomain
+ //! \param b Another iterdomain from the same fusion
+ //! \returns Boolean representing if they are proven to be
+ //! equivalent in the sense that they have equal
+ //! start and extent
+ bool areEquivalent(IterDomain* a, IterDomain* b) const {
+ if (a->sameAs(b)) {
+ return true;
+ }
+
+ // Abort on un-concretized domains, this can appear once we
+ // allow broadcast on fusion output
+ if (!cd_.canConcretize(a) || !cd_.canConcretize(b)) {
+ return false;
+ }
+
+ auto ac = cd_.concretized(a);
+ auto bc = cd_.concretized(b);
+ return areEqual(ac->start(), bc->start()) &&
+ areEqual(ac->rawExtent(), bc->rawExtent());
+ }
+
+ private:
+ // Utility class to record new equality found
+ void proveId(IterDomain* a, IterDomain* b) {
+ if (!a->sameAs(b)) {
+ eq_set_.join(a->start(), b->start());
+ eq_set_.join(a->rawExtent(), b->rawExtent());
+ }
+ }
+
+ // Inspect a pointwise op and record the identified equality
+ void provePwOp(Expr* e) {
+ if (e->output(0)->getValType() != ValType::TensorView) {
+ return;
+ }
+
+ TORCH_INTERNAL_ASSERT(e->outputs().size() == 1);
+ TensorView* tv = e->output(0)->as<TensorView>();
+ const std::vector<IterDomain*>& io = tv->getRootDomain();
+
+ // Record equalities from output to all the inputs
+ // ignores un-concretizable broadcasts
+ for (auto* i : ir_utils::filterByType<TensorView>(e->inputs())) {
+ std::vector<IterDomain*> ii =
+ TensorDomain::noReductions(i->getMaybeRFactorDomain());
+
+ for (const auto it : c10::irange(ii.size()))
+ if (cd_.canConcretize(ii[it]) && cd_.canConcretize(io[it]))
+ proveId(cd_.concretized(ii[it]), cd_.concretized(io[it]));
+ }
+ }
+
+ using IterVisitor::handle;
+
+ void handle(ReductionOp* rop) override {
+ provePwOp(rop);
+ }
+
+ void handle(UnaryOp* uop) override {
+ provePwOp(uop);
+ }
+
+ void handle(BinaryOp* bop) override {
+ provePwOp(bop);
+ }
+
+ void handle(TernaryOp* top) override {
+ provePwOp(top);
+ }
+
+ private:
+ ConcretizeDomain cd_;
+ DisjointSet<const Val*> eq_set_;
+};
+
} // namespace
// API call to return the concretized axis of a broadcast axis
return ConcretizeDomain::getConcreteDomain(bcast_dom);
}
+// API call to check if two IterDomains are equal
+// checks start and extent, contains both scalar check and graph traversal
+// broadcast domains are concretized before comparing
+bool IterDomain::proveEquivalent(IterDomain* a, IterDomain* b) {
+ TORCH_INTERNAL_ASSERT(a->fusion() == b->fusion());
+ ProveValEqual pve(a->fusion());
+ return pve.areEquivalent(a, b);
+}
+
Split::Split(
- IterDomain* outer,
- IterDomain* inner,
- IterDomain* in,
- Val* factor,
- bool inner_split)
+ IterDomain* _outer,
+ IterDomain* _inner,
+ IterDomain* _in,
+ Val* _factor)
: Expr(ExprType::Split),
- outer_{outer},
- inner_{inner},
- in_{in},
- factor_{factor},
- inner_split_{inner_split} {
+ outer_{_outer},
+ inner_{_inner},
+ in_{_in},
+ factor_{_factor} {
TORCH_INTERNAL_ASSERT(
factor_->isAnInt(),
"Attempted to create a Split node with a non-integer factor.");
- addOutput(outer);
- addOutput(inner);
- addInput(in);
- // TODO add factor as an input, need to check Split::Split during validation
- // and need to check BestEffortReplay::findFirstMismatchedID addInput(factor);
+ addOutput(_outer);
+ addOutput(_inner);
+ addInput(_in);
name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
outer_(ir_cloner->clone(src->outer_)),
inner_(ir_cloner->clone(src->inner_)),
in_(ir_cloner->clone(src->in_)),
- factor_(ir_cloner->clone(src->factor_)),
- inner_split_(src->inner_split_) {}
+ factor_(ir_cloner->clone(src->factor_)) {}
-bool Split::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<Split>()) {
- return false;
- }
- return Expr::sameAs(other) &&
- factor()->sameAs(other->as<Split>()->factor()) &&
- innerSplit() == other->as<Split>()->innerSplit();
+bool Split::sameAs(const Split* const other) const {
+ return (
+ outer()->sameAs(other->outer()) && inner()->sameAs(other->inner()) &&
+ in()->sameAs(other->in()) && factor()->sameAs(other->factor()));
}
-Merge::Merge(IterDomain* out, IterDomain* outer, IterDomain* inner)
- : Expr(ExprType::Merge), out_{out}, outer_{outer}, inner_{inner} {
- addOutput(out);
- addInput(outer);
- addInput(inner);
+Merge::Merge(IterDomain* _out, IterDomain* _outer, IterDomain* _inner)
+ : Expr(ExprType::Merge), out_{_out}, outer_{_outer}, inner_{_inner} {
+ addOutput(_out);
+ addInput(_outer);
+ addInput(_inner);
name_ = FusionGuard::getCurFusion()->registerExpr(this);
}
outer_(ir_cloner->clone(src->outer_)),
inner_(ir_cloner->clone(src->inner_)) {}
-bool Merge::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<Merge>()) {
- return false;
- }
- return Expr::sameAs(other);
+bool Merge::sameAs(const Merge* const other) const {
+ return (
+ out()->sameAs(other->out()) && outer()->sameAs(other->outer()) &&
+ inner()->sameAs(other->inner()));
}
NamedScalar::NamedScalar(const NamedScalar* src, IrCloner* ir_cloner)
: Val(src, ir_cloner), name_(src->name_) {}
-bool NamedScalar::sameAs(const Statement* other) const {
- if (this == other) {
- return true;
- }
- if (!other->isA<NamedScalar>()) {
- return false;
- }
- return other->as<NamedScalar>()->name().compare(name()) == 0;
-}
-
NamedScalar* NamedScalar::getParallelDim(ParallelType p_type) {
std::string parallel_dim = stringifyThreadSize(p_type);
return new NamedScalar(parallel_dim, DataType::Int);
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <iostream>
public:
IrTransformPrinter(std::ostream& os) : IrPrinter(os) {}
- void handle(Fusion* f) override;
+ void handle(const UnaryOp* const uop) override {
+ if (printInline()) {
+ IrPrinter::handle(uop);
+ }
+ }
+
+ void handle(const BinaryOp* const bop) override {
+ if (printInline()) {
+ IrPrinter::handle(bop);
+ }
+ }
- private:
- void printTransforms(TensorView* tv);
+ void handle(Fusion* f) override {
+ IrPrinter::handle(f);
+ }
};
} // namespace cuda
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-
-#include <set>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-namespace ir_utils {
-
-std::vector<int> normalizeOld2New(
- const std::unordered_map<int, int>& old2new_in,
- size_t ndims) {
- // adjust based on negative values (any negative values gets nDims added to
- // it)
- std::unordered_map<int, int> old2new;
- std::transform(
- old2new_in.begin(),
- old2new_in.end(),
- std::inserter(old2new, old2new.begin()),
- [ndims](std::unordered_map<int, int>::value_type entry) {
- return std::unordered_map<int, int>::value_type({
- entry.first < 0 ? entry.first + ndims : entry.first,
- entry.second < 0 ? entry.second + ndims : entry.second,
- });
- });
-
- // Check if any adjusted values are < 0, or >= nDims, which are invalid
-
- TORCH_CHECK(
- std::none_of(
- old2new.begin(),
- old2new.end(),
- [ndims](std::unordered_map<int, int>::value_type entry) {
- return entry.first < 0 || (unsigned int)entry.first >= ndims ||
- entry.second < 0 || (unsigned int)entry.second >= ndims;
- }),
- "Reorder axes are not within the number of dimensions of the provided domain.");
-
- // Going to use sets, to see if any duplicate values are in the map.
-
- std::set<int> old_pos_set;
- std::transform(
- old2new.begin(),
- old2new.end(),
- std::inserter(old_pos_set, old_pos_set.begin()),
- [](std::unordered_map<int, int>::value_type entry) {
- return entry.first;
- });
-
- std::set<int> new_pos_set;
- std::transform(
- old2new.begin(),
- old2new.end(),
- std::inserter(new_pos_set, new_pos_set.begin()),
- [](std::unordered_map<int, int>::value_type entry) {
- return entry.second;
- });
-
- // Error out if duplicate values are found.
- TORCH_CHECK(
- old_pos_set.size() == old2new.size() &&
- new_pos_set.size() == old2new.size(),
- "Duplicate entries in transformation map sent to TensorView reorder.");
-
- // END VALIDATION CHECKS
-
- std::vector<int> new2old(ndims, -1);
-
- // Go through each old and new position, make sure they're within [0, ndims)
- for (std::pair<int, int> elem : old2new) {
- int old_pos = elem.first;
- int new_pos = elem.second;
- new2old[new_pos] = old_pos;
- }
-
- // old_positions that already have a new position
- std::set<int> old_positions(new2old.begin(), new2old.end());
- old_positions.erase(-1);
-
- // All available new positions
- std::set<int> all_positions;
- for (decltype(ndims) i{0}; i < ndims; i++)
- all_positions.insert(i);
-
- // Check what positions haven't been specified.
- std::set<int> positions_left;
- std::set_difference(
- all_positions.begin(),
- all_positions.end(),
- old_positions.begin(),
- old_positions.end(),
- std::inserter(positions_left, positions_left.end()));
-
- // Fill in positions that weren't specified, in relative order,
- // in empty spots in the set of new positions.
- // new2old[new_position] = old_position
- auto it = positions_left.begin(); // old positions left
- std::transform(
- new2old.begin(), new2old.end(), new2old.begin(), [&it](int i) -> int {
- return i == -1 ? *it++ : i;
- });
-
- return new2old;
-}
-
-namespace ValReplacement {
-// Create New Expr given producer - [an input for the expression]
-// Creates a new Expr substituting current with producer
-struct SubstituteInExpr : public OptInDispatch {
- public:
- static Expr* subsitute(Expr* expr, Val* reference, Val* substitute) {
- TORCH_INTERNAL_ASSERT(
- expr != nullptr && reference != nullptr && substitute != nullptr,
- "Nullptr arg found.");
- SubstituteInExpr sie(reference, substitute);
- sie.handle(expr);
- TORCH_INTERNAL_ASSERT(
- sie.expr_ != nullptr,
- "Substitution failed of ",
- reference,
- " with ",
- substitute);
- return sie.expr_;
- }
-
- private:
- explicit SubstituteInExpr(Val* reference, Val* substitute)
- : reference_(reference), substitute_(substitute) {}
-
- void handle(Expr* expr) final {
- OptInDispatch::handle(expr);
- }
-
- void handle(UnaryOp* unary_expr) final {
- auto in =
- reference_->sameAs(unary_expr->in()) ? substitute_ : unary_expr->in();
- auto out =
- reference_->sameAs(unary_expr->out()) ? substitute_ : unary_expr->out();
- expr_ = new UnaryOp(unary_expr->getUnaryOpType(), out, in);
- }
-
- void handle(BinaryOp* binary_expr) final {
- auto lhs = reference_->sameAs(binary_expr->lhs()) ? substitute_
- : binary_expr->lhs();
- auto rhs = reference_->sameAs(binary_expr->rhs()) ? substitute_
- : binary_expr->rhs();
- auto out = reference_->sameAs(binary_expr->out()) ? substitute_
- : binary_expr->out();
-
- expr_ = new BinaryOp(binary_expr->getBinaryOpType(), out, lhs, rhs);
- }
-
- void handle(TernaryOp* ternary_expr) final {
- auto in1 = reference_->sameAs(ternary_expr->in1()) ? substitute_
- : ternary_expr->in1();
- auto in2 = reference_->sameAs(ternary_expr->in2()) ? substitute_
- : ternary_expr->in2();
- auto in3 = reference_->sameAs(ternary_expr->in3()) ? substitute_
- : ternary_expr->in3();
- auto out = reference_->sameAs(ternary_expr->out()) ? substitute_
- : ternary_expr->out();
- expr_ = new TernaryOp(ternary_expr->getTernaryOpType(), out, in1, in2, in3);
- }
-
- void handle(ReductionOp* reduction_expr) final {
- auto init = reference_->sameAs(reduction_expr->init())
- ? substitute_
- : reduction_expr->init();
- auto out = reference_->sameAs(reduction_expr->out())
- ? substitute_
- : reduction_expr->out();
- auto in = reference_->sameAs(reduction_expr->in()) ? substitute_
- : reduction_expr->in();
-
- expr_ =
- new ReductionOp(reduction_expr->getReductionOpType(), init, out, in);
- }
-
- void handle(BroadcastOp* broadcast_expr) final {
- auto out = reference_->sameAs(broadcast_expr->out())
- ? substitute_
- : broadcast_expr->out();
- auto in = reference_->sameAs(broadcast_expr->in()) ? substitute_
- : broadcast_expr->in();
-
- expr_ = new BroadcastOp(out, in, broadcast_expr->getBroadcastDimFlags());
- }
-
- void handle(TransposeOp* transpose_expr) final {
- TORCH_INTERNAL_ASSERT(
- substitute_->isA<TensorView>(),
- "All args to transpose must be tensor view, but received a non-TensorView for replacement: ",
- substitute_);
- auto out = reference_->sameAs(transpose_expr->out())
- ? substitute_->as<TensorView>()
- : transpose_expr->out();
- auto in = reference_->sameAs(transpose_expr->in())
- ? substitute_->as<TensorView>()
- : transpose_expr->in();
- expr_ = new TransposeOp(out, in, transpose_expr->new2old());
- }
-
- void handle(ShiftOp* shift_expr) final {
- auto out =
- reference_->sameAs(shift_expr->out()) ? substitute_ : shift_expr->out();
- auto in =
- reference_->sameAs(shift_expr->in()) ? substitute_ : shift_expr->in();
-
- expr_ = new ShiftOp(out, in, shift_expr->offsets());
- }
-
- void handle(GatherOp* gather_expr) final {
- auto out = reference_->sameAs(gather_expr->out()) ? substitute_
- : gather_expr->out();
- auto in =
- reference_->sameAs(gather_expr->in()) ? substitute_ : gather_expr->in();
-
- expr_ = new GatherOp(
- out, in, gather_expr->windowShape(), gather_expr->padWidth());
- }
-
- void handle(WelfordOp* welford_expr) final {
- auto out_avg = reference_->sameAs(welford_expr->outAvg())
- ? substitute_->as<TensorView>()
- : welford_expr->outAvg();
- auto out_var = reference_->sameAs(welford_expr->outVar())
- ? substitute_->as<TensorView>()
- : welford_expr->outVar();
- auto out_N = reference_->sameAs(welford_expr->outN())
- ? substitute_->as<TensorView>()
- : welford_expr->outN();
- auto in_avg = reference_->sameAs(welford_expr->inAvg())
- ? substitute_->as<TensorView>()
- : welford_expr->inAvg();
- auto in_var =
- welford_expr->inVar() && reference_->sameAs(welford_expr->inVar())
- ? substitute_->as<TensorView>()
- : welford_expr->inVar();
- auto in_N = reference_->sameAs(welford_expr->inN()) ? substitute_
- : welford_expr->inN();
- auto init_avg =
- welford_expr->initAvg() && reference_->sameAs(welford_expr->initAvg())
- ? substitute_->as<TensorView>()
- : welford_expr->initAvg();
- auto init_var =
- welford_expr->initVar() && reference_->sameAs(welford_expr->initVar())
- ? substitute_->as<TensorView>()
- : welford_expr->initVar();
- auto init_N =
- welford_expr->initN() && reference_->sameAs(welford_expr->initN())
- ? substitute_
- : welford_expr->initN();
- expr_ = new WelfordOp(
- out_avg,
- out_var,
- out_N,
- init_avg,
- init_var,
- init_N,
- in_avg,
- in_var,
- in_N);
- }
-
- private:
- Val* reference_ = nullptr;
- Val* substitute_ = nullptr;
- Expr* expr_ = nullptr;
-};
-
-} // namespace ValReplacement
-
-Expr* replaceValInExpr(Expr* expr, Val* reference, Val* substitute) {
- FusionGuard fg(expr->fusion());
- return ValReplacement::SubstituteInExpr::subsitute(
- expr, reference, substitute);
-}
-
-TensorView* rfactorHelper(
- TensorView* reduction_tv,
- const std::vector<int>& axes) {
- TORCH_INTERNAL_ASSERT(reduction_tv->definition() != nullptr);
- const bool is_welford = reduction_tv->definition()->isA<WelfordOp>();
- if (!is_welford) {
- return reduction_tv->rFactor(axes);
- }
- auto welford = reduction_tv->definition()->as<WelfordOp>();
- auto w_avg = welford->outAvg()->as<TensorView>();
- auto w_var = welford->outVar()->as<TensorView>();
- auto w_n = welford->outN()->as<TensorView>();
-
- WelfordResult rtvs = reduction_tv->rFactor(axes, w_avg, w_var, w_n);
-
- // TODO: this can be more generic, using avg because
- // WelfordOp::out() returns the avg
- return rtvs.avg;
-}
-
-namespace {
-
-std::vector<TensorView*> uniqueEntries(
- const std::vector<TensorView*>& tv_deuqe) {
- std::vector<TensorView*> unique_entries;
- std::unordered_set<TensorView*> inserted;
- for (auto tv_entry : tv_deuqe) {
- if (inserted.emplace(tv_entry).second) {
- unique_entries.emplace_back(tv_entry);
- }
- }
- return unique_entries;
-}
-
-} // namespace
-
-std::vector<TensorView*> producerTvsOf(TensorView* tv) {
- if (tv->definition() == nullptr) {
- return {};
- }
- auto producer_vals =
- ir_utils::filterByType<TensorView>(tv->definition()->inputs());
- return uniqueEntries({producer_vals.begin(), producer_vals.end()});
-}
-
-std::vector<TensorView*> consumerTvsOf(TensorView* tv) {
- std::vector<TensorView*> consumer_tvs;
- for (auto use_expr : tv->uses()) {
- auto outputs = ir_utils::filterByType<TensorView>(use_expr->outputs());
- consumer_tvs.insert(consumer_tvs.end(), outputs.begin(), outputs.end());
- }
- return uniqueEntries(consumer_tvs);
-}
-
-std::vector<TensorView*> producerTvsOf(const std::vector<TensorView*>& tvs) {
- std::vector<TensorView*> all_producer_tvs;
- for (auto tv : tvs) {
- auto producer_tvs = producerTvsOf(tv);
- all_producer_tvs.insert(
- all_producer_tvs.end(), producer_tvs.begin(), producer_tvs.end());
- }
-
- return uniqueEntries(all_producer_tvs);
-}
-
-std::vector<TensorView*> consumerTvsOf(const std::vector<TensorView*>& tvs) {
- std::vector<TensorView*> all_consumer_tvs;
- for (auto tv : tvs) {
- auto consumer_tvs = consumerTvsOf(tv);
- all_consumer_tvs.insert(
- all_consumer_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
- }
-
- return uniqueEntries(all_consumer_tvs);
-}
-
-std::vector<TensorView*> inputTvsOf(TensorView* tv) {
- return inputTvsOf(std::vector<TensorView*>{tv});
-}
-
-std::vector<TensorView*> outputTvsOf(TensorView* tv) {
- return outputTvsOf(std::vector<TensorView*>{tv});
-}
-
-std::vector<TensorView*> inputTvsOf(std::vector<TensorView*> tvs) {
- auto inp_vals = IterVisitor::getInputsTo({tvs.begin(), tvs.end()});
- auto filtered = ir_utils::filterByType<TensorView>(inp_vals);
- std::vector<TensorView*> inp_tvs(filtered.begin(), filtered.end());
- return uniqueEntries(inp_tvs);
-}
-
-std::vector<TensorView*> outputTvsOf(std::vector<TensorView*> tvs) {
- auto out_vals = DependencyCheck::getAllOutputsOf({tvs.begin(), tvs.end()});
- auto filtered = ir_utils::filterByType<TensorView>(out_vals);
- std::vector<TensorView*> out_tvs(filtered.begin(), filtered.end());
- return uniqueEntries(out_tvs);
-}
-
-std::vector<TensorView*> allTvs(Fusion* fusion) {
- auto used_vals = fusion->usedMathVals();
- auto used_tvs = ir_utils::filterByType<TensorView>(used_vals);
- return uniqueEntries({used_tvs.begin(), used_tvs.end()});
-}
-
-} // namespace ir_utils
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
#pragma once
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/type.h>
#include <iterator>
-#include <unordered_map>
namespace torch {
namespace jit {
private:
Iterator current_;
- Iterator end_;
+ const Iterator end_;
};
// An iterable view to a given container of Val pointers. Only returns
return filterByType<FilterType>(inputs.cbegin(), inputs.cend());
}
-//! Returns a list of new-to-old mappings.
-//!
-//! The input map does not need to be complete. Missing axes are
-//! assumed not to be affected.
-//!
-//! This is used to preprocess broadcast and transpose arguments.
-//!
-//! Example: (N := ndims)
-//! {{0, 1}} -> [1, 0, ...., N-1]
-//! Transposes the first two axes with no other change.
-//!
-//! {{0, -1}} -> [N-1, ...., 0]
-//! Swaps the first and last axes.
-std::vector<int> normalizeOld2New(
- const std::unordered_map<int, int>& old2new_in,
- size_t ndims);
-
-// Replace all uses of reference with substitute in expr. Return the Expr.
-// Warning: Invalidates provided Expr.
-// Warning: Removes connection of reference through provided Expr.
-// Warning: Creates new Expr connecting substitue.
-// Reference is found through direct pointer comparison.
-Expr* replaceValInExpr(Expr* expr, Val* reference, Val* substitute);
-
-// Makes rfactor generic with reduction ops and Welford
-TensorView* rfactorHelper(TensorView* red_tv, const std::vector<int>& axes);
-
-// Return immediate producers of tv
-std::vector<TensorView*> producerTvsOf(TensorView* tv);
-
-// Return immediate consumers of tv
-std::vector<TensorView*> consumerTvsOf(TensorView* tv);
-
-// Return immediate producers of tvs (can return tvs input)
-std::vector<TensorView*> producerTvsOf(const std::vector<TensorView*>& tvs);
-
-// Return immediate consumers of tvs (can return tvs input)
-std::vector<TensorView*> consumerTvsOf(const std::vector<TensorView*>& tvs);
-
-// Returns producers of tv that are inputs of fusion
-std::vector<TensorView*> inputTvsOf(TensorView* tv);
-
-// Returns consumers of tv that are outputs of fusion
-std::vector<TensorView*> outputTvsOf(TensorView* tv);
-
-// Returns producers of tvs that are inputs of fusion
-std::vector<TensorView*> inputTvsOf(std::vector<TensorView*> tvs);
-
-// Returns consumers of tvs that are outputs of fusion
-std::vector<TensorView*> outputTvsOf(std::vector<TensorView*> tvs);
-
-// returns all tensor views in fusion that are used between outputs and inputs.
-TORCH_CUDA_CU_API std::vector<TensorView*> allTvs(Fusion* fusion);
-
} // namespace ir_utils
} // namespace cuda
} // namespace fuser
} // namespace
-std::vector<Statement*> IterVisitor::next(Statement* stmt) {
- if (stmt->isVal()) {
- return next(stmt->as<Val>());
- } else if (stmt->isExpr()) {
- return next(stmt->as<Expr>());
- } else {
- TORCH_INTERNAL_ASSERT(
- false, "IterVisitor could not detect type in next_dispatch.");
- }
-}
-
-std::vector<Statement*> IterVisitor::next(Val* v) {
- FusionGuard::getCurFusion()->assertInFusion(v, "Cannot traverse val, ");
- if (v->definition() != nullptr) {
- return {v->definition()};
- }
- return {};
-}
-
-std::vector<Statement*> IterVisitor::next(Expr* expr) {
- FusionGuard::getCurFusion()->assertInFusion(expr, "Cannot traverse expr, ");
- std::vector<Statement*> next_stmts{
- expr->inputs().begin(), expr->inputs().end()};
- return next_stmts;
-}
-
-// This handle functions is called on every Statement* in topological order,
-// starting from outputs to inputs.
-void IterVisitor::handle(Statement* s) {
- OptOutDispatch::handle(s);
-}
-
-// This handle functions is called on every Expr* in topological order,
-// starting from outputs to inputs.
-void IterVisitor::handle(Expr* e) {
- OptOutDispatch::handle(e);
-}
-
-// This handle functions is called on every Val* in topological order,
-// starting from outputs to inputs.
-void IterVisitor::handle(Val* v) {
- OptOutDispatch::handle(v);
-}
-
// Implementation details:
// We start with an entry in stmt_stack that is the outputs we want to
// process. We cannot process these outputs untill all Stmts in their history
// If we just poped a stmt_stack level, we can finally visit it!
if (all_inputs_visited) {
- // stmt may have be already visited.
- if (traverseAllPaths || visited.find(stmt) == visited.end()) {
- // Mark visited
- visited.insert(stmt);
+ // Mark visited
+ visited.insert(stmt);
- // Actually visit stmt
- handle(stmt);
- }
+ // Actually visit stmt
+ handle(stmt);
// Remove last value just visited
current_inputs.pop_back();
}
}
-void IterVisitor::traverseHelper(Fusion* fusion, bool traverse_all_paths) {
+void IterVisitor::traverse_(
+ Fusion* fusion,
+ bool from_outputs_only,
+ bool traverse_all_paths) {
FusionGuard fg(fusion);
- auto term_val_outs = fusion->getTerminatingOutputs();
- if (!term_val_outs.empty()) {
- traverseFrom(fusion, term_val_outs, traverse_all_paths);
+ if (from_outputs_only) {
+ auto term_val_outs = fusion->getTerminatingOutputs();
+ if (!term_val_outs.empty()) {
+ traverseFrom(fusion, term_val_outs, traverse_all_paths);
+ }
+ return;
+ }
+
+ std::vector<Val*> leaves;
+ // Search for Vals with no uses (output edges)
+ for (Val* val : fusion->deterministic_vals())
+ if (!fusion->used(val)) {
+ leaves.push_back(val);
+ }
+
+ if (!leaves.empty()) {
+ traverseFrom(fusion, leaves, traverse_all_paths);
}
}
-void IterVisitor::traverse(Fusion* fusion) {
- traverseHelper(fusion, false);
+void IterVisitor::traverse(Fusion* fusion, bool from_outputs_only) {
+ traverse_(fusion, from_outputs_only, false);
}
-void IterVisitor::traverseAllPaths(Fusion* fusion) {
- traverseHelper(fusion, true);
+void IterVisitor::traverseAllPaths(Fusion* fusion, bool from_outputs_only) {
+ traverse_(fusion, from_outputs_only, true);
}
namespace {
// expressions.
class Inputs : public IterVisitor {
private:
- std::vector<Val*> inputs_;
+ std::unordered_set<Val*> inputs;
void handle(Val* val) override {
- if (val->definition() == nullptr) {
- if (std::find(inputs_.begin(), inputs_.end(), val) == inputs_.end()) {
- inputs_.push_back(val);
- }
+ if (val->getOrigin() == nullptr) {
+ inputs.emplace(val);
}
}
public:
- static std::vector<Val*> getInputs(const std::vector<Val*>& of) {
+ static std::unordered_set<Val*> getInputs(const std::vector<Val*>& of) {
if (of.empty()) {
- return {};
+ return std::unordered_set<Val*>();
}
Inputs inps;
inps.traverseFrom(of[0]->fusion(), of);
- return inps.inputs_;
+ return inps.inputs;
}
};
} // namespace
-std::vector<Val*> IterVisitor::getInputsTo(const std::vector<Val*>& vals) {
+std::unordered_set<Val*> IterVisitor::getInputsTo(
+ const std::vector<Val*>& vals) {
return Inputs::getInputs(vals);
}
return next_stmts;
}
-void BackwardVisitor::handle(Statement* stmt) {
- OptOutDispatch::handle(stmt);
-}
-
-void BackwardVisitor::handle(Expr* expr) {
- OptOutDispatch::handle(expr);
-}
-
-void BackwardVisitor::handle(Val* val) {
- OptOutDispatch::handle(val);
-}
-
void BackwardVisitor::traverseFrom(
Fusion* fusion,
const std::vector<Val*>& from,
// All stmts we've called handle on
std::unordered_set<Statement*> visited_stmts_;
- if (must_cover_all_expr_outputs_) {
- for (auto traversal_pair : traversal_exprs_) {
- for (auto out : traversal_pair.first->outputs()) {
- TORCH_INTERNAL_ASSERT(
- vals.find(out) != vals.end(),
- "Invalid backward traversal found. Some output paths were not provided:",
- out);
- }
+ for (auto traversal_pair : traversal_exprs_) {
+ for (auto out : traversal_pair.first->outputs()) {
+ TORCH_INTERNAL_ASSERT(
+ vals.find(out) != vals.end(),
+ "Invalid backward traversal found. Some output paths were not provided.");
}
}
// Looks for and returns all values in between dependencies and vals, including
// them.
struct Dependencies : public IterVisitor {
- private:
- //! A given set of dependency Vals
- const std::unordered_set<Val*> dependencies_;
- //! Vals that are found between dependencies_ and of. Topologically
- //! ordered.
- std::vector<Val*> vals_;
- //! Exprs that are found between dependencies_ and of. Topologically
- //! ordered.
- std::vector<Expr*> exprs_;
- //! A set version of vals_
- std::unordered_set<Val*> dependent_vals_;
- //! A set version of exprs_
- std::unordered_set<Expr*> dependent_exprs_;
+ std::unordered_set<Val*> dependencies_;
+ std::unordered_set<Val*> vals_;
- private:
std::vector<Statement*> next(Val* v) override {
- if (dependencies_.find(v) != dependencies_.end()) {
+ if (dependencies_.find(v) != dependencies_.end())
return std::vector<Statement*>();
- }
return IterVisitor::next(v);
}
void handle(Val* val) override {
- // val is included if:
- // 1. it is one of the dependencies, or
- // 2. its defining expression is included in the dependent expr set
- if (dependencies_.find(val) != dependencies_.end()) {
- TORCH_INTERNAL_ASSERT(
- dependent_vals_.find(val) == dependent_vals_.end(),
- "Trying to add already added val: ",
- val);
- vals_.push_back(val);
- dependent_vals_.insert(val);
- } else {
- auto def = val->definition();
- if (def != nullptr &&
- dependent_exprs_.find(def) != dependent_exprs_.end()) {
- TORCH_INTERNAL_ASSERT(
- dependent_vals_.find(val) == dependent_vals_.end(),
- "Trying to add already added val: ",
- val);
- vals_.push_back(val);
- dependent_vals_.insert(val);
- }
- }
- }
-
- void handle(Expr* expr) override {
- // Track which expr is depedent on the dependencies_ exprs.
- if (std::any_of(
- expr->inputs().begin(), expr->inputs().end(), [&](Val* input_val) {
- return dependent_vals_.find(input_val) != dependent_vals_.end();
- })) {
- if (!dependent_exprs_.count(expr)) {
- exprs_.push_back(expr);
- dependent_exprs_.insert(expr);
- }
- }
+ vals_.emplace(val);
}
Dependencies(
};
public:
- static std::vector<Val*> getAllVals(
+ static std::unordered_set<Val*> getAllVals(
const std::unordered_set<Val*>& dependencies,
const std::vector<Val*>& of) {
if (of.empty()) {
- return {};
+ return std::unordered_set<Val*>();
}
Dependencies deps(dependencies, of);
return deps.vals_;
}
-
- static std::vector<Expr*> getAllExprs(
- const std::unordered_set<Val*>& dependencies,
- const std::vector<Val*>& of) {
- if (of.empty()) {
- return {};
- }
-
- Dependencies deps(dependencies, of);
- return deps.exprs_;
- }
};
// Looks for and returns all output values with dependencies on `of`.
void handle(Val* val) override {
if (of_.find(val) != of_.end()) {
Statement* out_stmt = stmt_stack.front().back();
- TORCH_INTERNAL_ASSERT(out_stmt->isVal());
- auto out_val = out_stmt->as<Val>();
- if (of_.find(out_val) == of_.end()) {
- outs_.emplace(out_val);
+ if (out_stmt->isVal()) {
+ auto out_val = out_stmt->as<Val>();
+ if (of_.find(out_val) == of_.end()) {
+ outs_.emplace(out_val);
+ }
}
}
}
- // TODO: Simply traverse through uses from of. Would be a lot faster than
- // tracing all paths like this.
FindOutputs(const std::unordered_set<Val*>& _of) : of_(_of) {
auto fusion = (*of_.begin())->fusion();
- traverseFrom(fusion, fusion->outputs(), true);
+ traverseFrom(fusion, fusion->outputs(), false);
};
static std::unordered_set<Val*> getAllOutputsOf(
}
};
-// Looks for and returns all values that depends on `of`.
-class DependentVals : public IterVisitor {
- private:
- // Which nodes to find dependencies of
- const std::unordered_set<Val*>& of_;
-
- // Dependencies we have so far
- std::unordered_set<Val*> outs_;
-
- // Boundary where we want to stop searching beyond
- std::unordered_set<Val*> boundary_;
-
- std::vector<Statement*> next(Val* v) override {
- if (boundary_.find(v) != boundary_.end())
- return std::vector<Statement*>();
- return IterVisitor::next(v);
- }
-
- void handle(Val* val) override {
- if (val->isFusionInput() || val->definition() == nullptr ||
- of_.count(val) || outs_.count(val)) {
- return;
- }
-
- for (auto v : val->definition()->inputs()) {
- if (of_.count(v) || outs_.count(v)) {
- outs_.emplace(val);
- return;
- }
- }
- }
-
- // optimization to limit search path
- void createBoundary() {
- for (auto v_of : of_) {
- for (auto v_expr : v_of->uses()) {
- for (auto v_in : v_expr->inputs()) {
- boundary_.emplace(v_in);
- }
- }
- }
- }
-
- DependentVals(const std::unordered_set<Val*>& _of) : of_(_of) {
- createBoundary();
- auto fusion = (*of_.begin())->fusion();
- traverseFrom(fusion, fusion->outputs(), false);
- };
-
- public:
- static std::unordered_set<Val*> getAllDependentVals(
- const std::unordered_set<Val*>& of) {
- if (of.empty()) {
- return std::unordered_set<Val*>();
- }
- DependentVals dependencies(of);
- return dependencies.outs_;
- }
-};
-
class DependencyChains : public IterVisitor {
public:
std::deque<std::deque<Val*>> dep_chains;
DependencyChains(Val* _dependency, bool all_chains_ = false)
: dependencies_({_dependency}) {
if (all_chains_) {
- traverseAllPaths(_dependency->fusion());
+ traverseAllPaths(_dependency->fusion(), false);
} else {
- traverse(_dependency->fusion());
+ traverse(_dependency->fusion(), false);
}
}
}
if (all_chains_) {
- traverseAllPaths((*dependencies_.begin())->fusion());
+ traverseAllPaths((*dependencies_.begin())->fusion(), false);
} else {
- traverse((*dependencies_.begin())->fusion());
+ traverse((*dependencies_.begin())->fusion(), false);
}
}
return DependencyChains::getAllUseChains(producer);
}
-std::vector<Val*> DependencyCheck::getAllValsBetween(
+std::unordered_set<Val*> DependencyCheck::getAllValsBetween(
const std::unordered_set<Val*>& dependencies,
const std::vector<Val*>& of) {
return Dependencies::getAllVals(dependencies, of);
}
-std::vector<Expr*> DependencyCheck::getAllExprsBetween(
- const std::unordered_set<Val*>& dependencies,
- const std::vector<Val*>& of) {
- return Dependencies::getAllExprs(dependencies, of);
-}
-
std::unordered_set<Val*> DependencyCheck::getAllOutputsOf(
const std::unordered_set<Val*>& of) {
if (of.empty()) {
return FindOutputs::getAllOutputsOf(of);
}
-std::unordered_set<Val*> DependencyCheck::getAllDependentVals(
- const std::unordered_set<Val*>& of) {
- if (of.empty()) {
- return std::unordered_set<Val*>();
- }
- FusionGuard fg((*of.begin())->fusion());
- return DependentVals::getAllDependentVals(of);
-}
-
void ExprSort::handle(Expr* expr) {
exprs.push_back(expr);
}
-std::vector<Expr*> ExprSort::getExprs(Fusion* fusion) {
+std::vector<Expr*> ExprSort::getExprs(Fusion* fusion, bool from_outputs_only) {
ExprSort es;
- es.traverse(fusion);
+ es.traverse(fusion, from_outputs_only);
return es.exprs;
}
}
void InputsOf::handle(Val* v) {
- if (v->definition() == nullptr) {
- if (grabbed_inputs.emplace(v).second) {
- ordered_inputs.push_back(v);
- }
- }
+ if (FusionGuard::getCurFusion()->origin(v) == nullptr)
+ inputs.emplace(v);
}
-std::vector<Val*> InputsOf::output(Fusion* fusion, Val* output_) {
- return outputs(fusion, {output_});
-}
-
-std::vector<Val*> InputsOf::outputs(
- Fusion* fusion,
- const std::vector<Val*>& outputs_) {
+std::unordered_set<Val*> InputsOf::output(Fusion* fusion, Val* output_) {
InputsOf io;
- io.traverseFrom(fusion, outputs_, false);
- return io.ordered_inputs;
+ io.traverseFrom(FusionGuard::getCurFusion(), {output_}, false);
+ return io.inputs;
}
} // namespace cuda
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+
+#include <torch/csrc/jit/codegen/cuda/fusion.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/type.h>
#include <deque>
namespace fuser {
namespace cuda {
-class Fusion;
-class Statement;
-class Expr;
-class Val;
-
/*
* IterVisitor starts from leaf nodes, fusion outputs, or the provided values.
* It walks the DAG bacwkards from the starting nodes, to roots. Each node in
// These functions will start at outputs and propagate up through the DAG
// to inputs based on depth first traversal. Next could be called on a node
// multiple times.
- virtual std::vector<Statement*> next(Statement* stmt);
-
- virtual std::vector<Statement*> next(Val* v);
-
- virtual std::vector<Statement*> next(Expr* expr);
+ virtual std::vector<Statement*> next(Statement* stmt) {
+ // NOLINTNEXTLINE(bugprone-branch-clone)
+ if (stmt->isVal()) {
+ return next(stmt->as<Val>());
+ } else if (stmt->isExpr()) {
+ return next(stmt->as<Expr>());
+ } else {
+ TORCH_INTERNAL_ASSERT(
+ false, "IterVisitor could not detect type in next_dispatch.");
+ }
+ }
+
+ virtual std::vector<Statement*> next(Val* v) {
+ FusionGuard::getCurFusion()->assertInFusion(v, "Cannot traverse val, ");
+ if (FusionGuard::getCurFusion()->origin(v) != nullptr) {
+ return {FusionGuard::getCurFusion()->origin(v)};
+ }
+ return {};
+ }
+
+ virtual std::vector<Statement*> next(Expr* expr) {
+ FusionGuard::getCurFusion()->assertInFusion(expr, "Cannot traverse expr, ");
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ std::vector<Statement*> next_stmts{
+ expr->inputs().begin(), expr->inputs().end()};
+ return next_stmts;
+ }
// This handle functions is called on every Statement* in topological order,
// starting from outputs to inputs.
- void handle(Statement* s) override;
-
+ void handle(Statement* s) override {
+ OptOutDispatch::handle(s);
+ }
// This handle functions is called on every Expr* in topological order,
// starting from outputs to inputs.
- void handle(Expr* e) override;
-
+ void handle(Expr* e) override {
+ OptOutDispatch::handle(e);
+ }
// This handle functions is called on every Val* in topological order,
// starting from outputs to inputs.
- void handle(Val* v) override;
+ void handle(Val* v) override {
+ OptOutDispatch::handle(v);
+ }
// The entire stack during traversal. stmt_stack.back().back() is the node
// that is being called in handle(). stmt_stack.back() contains siblings (not
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::unordered_set<Statement*> termination_stmts;
- void traverseHelper(Fusion* fusion, bool traverse_all_paths = false);
+ void traverse_(
+ Fusion* fusion,
+ bool from_outputs_only = false,
+ bool traverse_all_paths = false);
public:
// Starts at nodes provided in from, traverses from these nodes to inputs.
const std::vector<Val*>& from,
bool traverseAllPaths = false);
- // Iterates from terminating outputs registered with the fusion. Terminating
- // means value is not used to generate any other value used in producing
- // registered outputs.
- void traverse(Fusion* fusion);
+ // from_outputs_only = true start from outputs registered with fusion,
+ // from_outputs_only = false start from all leaf nodes. Calls into
+ // traverseFrom.
+ void traverse(Fusion* fusion, bool from_outputs_only = false);
- // Same as traverse put it traverses every edge, meaning it will traverse
- // values more than once.
- void traverseAllPaths(Fusion* fusion);
+ // from_outputs_only = true start from outputs registered with fusion,
+ // from_outputs_only = false start from all leaf nodes. Calls into
+ // traverseFrom.
+ void traverseAllPaths(Fusion* fusion, bool from_outputs_only = false);
- static std::vector<Val*> getInputsTo(const std::vector<Val*>& vals);
+ static std::unordered_set<Val*> getInputsTo(const std::vector<Val*>& vals);
};
/*
*
* The first step of BackwardVisitor is to make sure we've specified enough
* outputs to guarentee that we will traverse all outputs of all exprs during
- * the backward traversal. In the case where we don't require visiting all
- * outputs of some exprs, example being the `N` output of welford ops.
- * `must_cover_all_expr_outputs` is added to disable the check, and in
- * this case the visitor pass need be aware
- * 1. Exprs with any output that has a use chain that ends with a final
- * consumer in the `from` list `will be` visited.
- * 2. Vals that doesn't have a use chain that ends with a final
- * consumer in the `from` list `will not be` visited, even though its
- * definition expr might be visited. An example is if the `N` output
- * of an welford op is unused, but other outputs are, the welford op
- * will be visited but the `N` output will not.
- *
+ * the backward traversal.
*/
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class TORCH_CUDA_CU_API BackwardVisitor : public OptOutDispatch {
- protected:
- // NOLINTNEXTLINE(modernize-use-override)
- virtual ~BackwardVisitor() = default;
+ public:
+ ~BackwardVisitor() override = default;
- BackwardVisitor(bool must_cover_all_expr_outputs = true)
- : must_cover_all_expr_outputs_(must_cover_all_expr_outputs) {}
+ BackwardVisitor() = default;
BackwardVisitor(const BackwardVisitor& other) = default;
BackwardVisitor& operator=(const BackwardVisitor& other) = default;
// This handle functions is called on every Statement* in topological order,
// starting from outputs to inputs.
- // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
- virtual void handle(Statement* stmt) override;
-
+ void handle(Statement* stmt) override {
+ OptOutDispatch::handle(stmt);
+ }
// This handle functions is called on every Expr* in topological order,
// starting from outputs to inputs.
- // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
- virtual void handle(Expr* expr) override;
-
+ void handle(Expr* expr) override {
+ OptOutDispatch::handle(expr);
+ }
// This handle functions is called on every Val* in topological order,
// starting from outputs to inputs.
- // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
- virtual void handle(Val* val) override;
+ void handle(Val* val) override {
+ OptOutDispatch::handle(val);
+ }
// All exprs that need to be visited in this traversal. Labeled in topological
// order (size_t).
Fusion* fusion,
const std::vector<Val*>& from,
bool traverseAllPaths = false);
-
- bool must_cover_all_expr_outputs_ = true;
};
class TORCH_CUDA_CU_API DependencyCheck {
// Returns an empty deque if there are no uses of dependency found.
static std::deque<std::deque<Val*>> getAllUseChains(Val* dependency);
- // Grab all values that exist between and including provided
- // vals. Returned values are topologicaly ordered.
- static std::vector<Val*> getAllValsBetween(
- const std::unordered_set<Val*>& dependencies,
- const std::vector<Val*>& of);
-
- // Returns all dependent exprs that exist between
- // the provided vals
- static std::vector<Expr*> getAllExprsBetween(
+ // Grab all values that exist between and including provided vals
+ static std::unordered_set<Val*> getAllValsBetween(
const std::unordered_set<Val*>& dependencies,
const std::vector<Val*>& of);
// Return registered outputs of the fusion that are a dependency of any val of
static std::unordered_set<Val*> getAllOutputsOf(
const std::unordered_set<Val*>& of);
-
- // Return all Vals that depend on the given Vals
- static std::unordered_set<Val*> getAllDependentVals(
- const std::unordered_set<Val*>& of);
};
// Expr sort will take a fusion and return a topologically sorted list of
// expressions.
class ExprSort : public IterVisitor {
- protected:
+ private:
std::vector<Expr*> exprs;
void handle(Expr* expr) override;
public:
- static std::vector<Expr*> getExprs(Fusion* fusion);
+ static std::vector<Expr*> getExprs(Fusion* fusion, bool from_outputs_only);
static std::vector<Expr*> getExprs(
Fusion* fusion,
class InputsOf : public IterVisitor {
private:
- std::unordered_set<Val*> grabbed_inputs;
- std::vector<Val*> ordered_inputs;
+ std::unordered_set<Val*> inputs;
void handle(Val* v) final;
public:
- static std::vector<Val*> output(Fusion* fusion, Val* output_);
- static std::vector<Val*> outputs(
- Fusion* fusion,
- const std::vector<Val*>& outputs_);
+ static std::unordered_set<Val*> output(Fusion* fusion, Val* output_);
};
} // namespace cuda
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
+
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <iostream>
#include <unordered_set>
namespace jit {
namespace fuser {
namespace cuda {
-namespace kir {
namespace {
//! Scan all primary expressions in the Kernel IR and build
-//! lists of specialized nodes and other interesting information
-class KernelIrScanner : private kir::IrVisitor {
+//! list of specialized nodes
+//!
+//! \note primary expressions are expressions which are not subexpressions
+//! in a larger expression (things like ForLoop or IfThenElse are not
+//! real expressions)
+//!
+class KernelIrScanner : private OptOutDispatch {
public:
- explicit KernelIrScanner(const Kernel* kernel) {
- for (const auto& ir_node : kernel->irNodes()) {
- ir_node->accept(this);
+ // Use expression count to uniquely identify each expression
+ size_t all_expression_count = 0;
+
+ // Map expression id to war hazard sync
+ std::unordered_map<size_t, kir::Sync*> war_hazard_syncs;
+
+ std::vector<kir::Allocate*> global_allocations;
+ std::vector<kir::Allocate*> dynamic_allocations;
+ std::vector<kir::Allocate*> static_allocations;
+ std::unordered_set<Expr*> primary_expressions;
+
+ public:
+ explicit KernelIrScanner(const std::vector<Expr*>& exprs) {
+ TORCH_INTERNAL_ASSERT(!exprs.empty());
+ for (auto expr : exprs) {
+ handle(expr);
}
}
- const auto& summary() const {
- return summary_;
+ private:
+ void handle(Expr* expr) final {
+ TORCH_CHECK(primary_expressions.insert(expr).second);
+ ++all_expression_count;
+ OptOutDispatch::handle(expr);
}
- private:
- void visit(const kir::Sync* sync) final {
+ void handle(kir::Sync* sync) final {
// TODO: Move to a dedicated validation pass
// which is not on the common execution/compilation path
if (sync->isWarHazardSync()) {
- ++summary_.war_hazard_syncs_count;
+ war_hazard_syncs[all_expression_count] = sync;
}
}
- void visit(const kir::Allocate* allocate) final {
- switch (allocate->memoryType()) {
+ void handle(kir::ForLoop* fl) final {
+ for (auto expr : fl->body().exprs()) {
+ handle(expr);
+ }
+ }
+
+ void handle(kir::IfThenElse* ite) final {
+ for (auto expr : ite->thenBody().exprs()) {
+ handle(expr);
+ }
+ for (auto expr : ite->elseBody().exprs()) {
+ handle(expr);
+ }
+ }
+
+ void handle(kir::Allocate* a) final {
+ switch (a->getMemoryType()) {
case MemoryType::Global:
- summary_.global_allocations.push_back(allocate);
+ global_allocations.push_back(a);
break;
case MemoryType::Shared:
- if (ExpressionEvaluator::isConst(allocate->size())) {
- summary_.static_smem_allocations.push_back(allocate);
+ if (a->size()->isConstScalar()) {
+ static_allocations.push_back(a);
} else {
- summary_.dynamic_smem_allocations.push_back(allocate);
+ dynamic_allocations.push_back(a);
}
break;
case MemoryType::Local:
- if (!ExpressionEvaluator::isConst(allocate->size())) {
- summary_.has_dynamic_local_memory_allocations = true;
- summary_.dynamic_lmem_allocations.emplace_back(allocate);
- }
break;
}
}
+};
- void visit(const kir::UnaryOp* unary_op) final {
- if (unary_op->operation() == UnaryOpType::RandLike) {
- // This kernel is using random numbers
- summary_.is_stochastic = true;
- }
- }
-
- void visit(const kir::TensorIndex* tensor_index) final {
- const auto tv = tensor_index->view();
- const auto domain = tv->domain();
-
- // Do we have any reductions?
- summary_.has_block_reductions =
- summary_.has_block_reductions || domain->hasBlockReduction();
-
- // Do we have block broadcasts?
- summary_.has_block_broadcasts =
- summary_.has_block_broadcasts || domain->hasBlockBroadcast();
-
- // Update the largest smem data type
- if (domain->hasBlockReduction() || domain->hasGridReduction() ||
- tv->memoryType() == MemoryType::Shared) {
- const auto data_type = tv->dtype();
- const size_t type_size = dataTypeSize(data_type);
- if (type_size > max_smem_type_size_) {
- max_smem_type_size_ = type_size;
- summary_.largest_smem_data_type = data_type;
- }
- }
-
- // Update Welford
- if (tensor_index->definition() != nullptr &&
- tensor_index->definition()->isA<kir::WelfordOp>()) {
- summary_.has_welford = true;
- summary_.has_block_welford =
- summary_.has_block_welford || domain->hasBlockReduction();
- summary_.has_grid_welford =
- summary_.has_grid_welford || domain->hasGridReduction();
- }
- }
-
- void visit(const kir::GridWelford* grid_welford) final {
- const auto dom = grid_welford->welford_op()
- ->out()
- ->as<kir::TensorIndex>()
- ->view()
- ->domain();
- updateGridReductionInLoop(dom);
- }
-
- void visit(const kir::GridReduction* grid_reduction) final {
- const auto dom = grid_reduction->reduction_op()
- ->out()
- ->as<kir::TensorIndex>()
- ->view()
- ->domain();
- updateGridReductionInLoop(dom);
- }
+} // namespace
- private:
- size_t max_smem_type_size_ = 0;
- KernelSummary summary_;
+// TODO(kir): Kernel IR validation
+void Kernel::finalize(
+ std::vector<Expr*> top_level_exprs,
+ ThreadPredicateMap predicate_map) {
+ TORCH_CHECK(top_level_exprs_.empty());
+ TORCH_CHECK(!predicate_map_);
+ top_level_exprs_ = std::move(top_level_exprs);
+ predicate_map_ =
+ std::make_unique<ThreadPredicateMap>(std::move(predicate_map));
+ analyze();
+}
- private:
- void updateGridReductionInLoop(TensorDomain* dom) {
- ++summary_.number_of_grid_reductions;
+void Kernel::analyze() {
+ FUSER_PERF_SCOPE("Kernel::analyze");
- const auto gpu_lower = GpuLower::current();
- for (size_t i = 0; i < dom->nDims(); ++i) {
- const auto id =
- gpu_lower->caParallelMap().getConcreteMappedID(dom->domain()[i]);
- summary_.has_grid_reduction_in_loop =
- summary_.has_grid_reduction_in_loop ||
- !(id->isThread() || id->extent()->isOneInt());
- }
- }
-};
+ const KernelIrScanner ir_scanner(top_level_exprs_);
-//! Make sure tensors have valid allocations even when parallelized
-//! loops potentially have larger iteration counts than the number of
-//! threads.
-//!
-//! When an IterDomain of a tensor is parallelized, the IterDomain
-//! may not contribute to the allocation of the tensor. For example,
-//! it is assumed that an allocation of a local-memory tensor does not
-//! need to be accounted for an parallelied IterDomain. This is true
-//! when it is guaranteed that each thread only needs to execute the
-//! loop body once. However, if not, the allocation is invalid as it
-//! only has a space for one value per thread.
-//!
-//! ValidateAllocation checks all tensor allocations and sees if any
-//! tensor may have a parallelized loop whose iteration count may
-//! be larger than the number of threads. If so, an error is thrown if
-//! the tensor is not allocated on thread-shared memories. Note that
-//! when allocated on a shared memory (i.e., MemoryType::Shared or
-//! MemoryType::Global for tensors parallelized with threadIdx, or
-//! MemoryType::Global for tensors parallelized with blockIdx), it is
-//! assumed that allocation is properly extended for the iteration
-//! count.
-class ValidateAllocation : private kir::IrVisitor {
- public:
- static void validate(const Kernel* kernel) {
- ValidateAllocation validate_allocation(kernel);
- }
+ // Cache the list of buffers used within the kernel
+ summary_.war_hazard_syncs = ir_scanner.war_hazard_syncs;
+ summary_.global_allocations = ir_scanner.global_allocations;
+ summary_.dynamic_smem_allocations = ir_scanner.dynamic_allocations;
+ summary_.static_smem_allocations = ir_scanner.static_allocations;
- private:
- explicit ValidateAllocation(const Kernel* kernel) {
- live_allocations_.emplace_back(std::vector<const Allocate*>());
- for (const auto& ir_node : kernel->topLevelExprs()) {
- ir_node->accept(this);
+ // Figure out if the kernel uses random numbers
+ for (auto expr : ir_scanner.primary_expressions) {
+ if (expr->getExprType() == ExprType::KirUnaryOp) {
+ if (expr->as<kir::UnaryOp>()->getUnaryOpType() == UnaryOpType::RandLike) {
+ summary_.is_stochastic = true;
+ break;
+ }
}
- live_allocations_.pop_back();
- TORCH_INTERNAL_ASSERT(live_allocations_.empty());
- }
-
- void visit(const kir::Allocate* allocate) final {
- TORCH_INTERNAL_ASSERT(!live_allocations_.empty());
- live_allocations_.back().push_back(allocate);
}
- // for_loop is parallelized and its stop value is not guaranteed to
- // be <= the number of threads, which breaks an assumption made
- // during in the allocation lowering if it's thread-parallel and not
- // allocated on shared or global memories, or if it's block-parallel
- // ando not allocated on global memory.
- void validate(const kir::ForLoop* for_loop) {
- const auto loop_id = for_loop->iter_domain();
- const auto gpu_lower = GpuLower::current();
- for (const auto& allocations : live_allocations_) {
- for (const auto& allocate : allocations) {
- const auto tv = allocate->buffer()->as<kir::TensorView>();
- for (const auto& axis : tv->domain()->domain()) {
- if (!gpu_lower->caParallelMap().areMapped(loop_id, axis)) {
- continue;
- }
- if (isParallelTypeThreadDim(loop_id->parallelType())) {
- TORCH_INTERNAL_ASSERT(
- tv->memoryType() == MemoryType::Shared ||
- tv->memoryType() == MemoryType::Global);
- } else if (isParallelTypeBlockDim(loop_id->parallelType())) {
- TORCH_INTERNAL_ASSERT(tv->memoryType() == MemoryType::Global);
+ // Look for reductions and shared memory buffers
+ size_t max_smem_type_size = 0;
+ for (auto expr : ir_scanner.primary_expressions) {
+ for (auto out : expr->outputs()) {
+ if (out->getValType() == ValType::TensorIndex) {
+ const auto tv = out->as<kir::TensorIndex>()->view();
+ const auto domain = tv->domain();
+
+ // Do we have any reductions?
+ summary_.has_block_reductions |= domain->hasBlockReduction();
+ summary_.has_grid_reductions |= domain->hasGridReduction();
+
+ // Do we have block broadcasts?
+ summary_.has_block_broadcasts |= domain->hasBlockBroadcast();
+
+ // Update the largest smem data type
+ if (domain->hasBlockReduction() || domain->hasGridReduction() ||
+ tv->memoryType() == MemoryType::Shared) {
+ const auto data_type = tv->getDataType().value();
+ const size_t type_size = dataTypeSize(data_type);
+ if (type_size > max_smem_type_size) {
+ max_smem_type_size = type_size;
+ summary_.largest_smem_data_type = data_type;
}
}
}
}
}
-
- void visit(const kir::ForLoop* for_loop) final {
- if (for_loop->stop() != for_loop->iter_domain()->extent() &&
- isParallelTypeThread(for_loop->iter_domain()->parallelType())) {
- validate(for_loop);
- }
-
- live_allocations_.emplace_back(std::vector<const Allocate*>());
- for (const auto& expr : for_loop->body().exprs()) {
- expr->accept(this);
- }
- live_allocations_.pop_back();
- }
-
- void visit(const kir::IfThenElse* ite) final {
- for (const auto& expr : ite->thenBody().exprs()) {
- expr->accept(this);
- }
- for (const auto& expr : ite->elseBody().exprs()) {
- expr->accept(this);
- }
- }
-
- private:
- std::vector<std::vector<const Allocate*>> live_allocations_;
-};
-
-} // namespace
-
-// TODO(kir): Kernel IR validation
-void Kernel::finalize(std::vector<kir::Expr*> top_level_exprs) {
- TORCH_CHECK(top_level_exprs_.empty());
- top_level_exprs_ = std::move(top_level_exprs);
- predicate_map_ = std::make_unique<ThreadPredicateMap>(
- GpuLower::current()->threadPredMap());
- ValidateAllocation::validate(this);
- analyze();
-}
-
-void Kernel::analyze() {
- FUSER_PERF_SCOPE("Kernel::analyze");
-
- const KernelIrScanner ir_scanner(this);
- summary_ = ir_scanner.summary();
}
void Kernel::print() const {
ir_printer.printKernel(this);
}
-} // namespace kir
} // namespace cuda
} // namespace fuser
} // namespace jit
namespace jit {
namespace fuser {
namespace cuda {
-namespace kir {
//! Summary of interesting facts about the kernel
+//!
+//! TODO(kir): const node ptrs
+//!
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct KernelSummary {
- //! Count of WAR (write-after-read) hazard barriers
- int war_hazard_syncs_count = 0;
+ //! List of Write-After-Read (WAR) synchronization barriers
+ std::unordered_map<size_t, kir::Sync*> war_hazard_syncs;
//! List of global buffers
- std::vector<const kir::Allocate*> global_allocations;
+ std::vector<kir::Allocate*> global_allocations;
//! List of dynamic shared memory buffers
- std::vector<const kir::Allocate*> dynamic_smem_allocations;
+ std::vector<kir::Allocate*> dynamic_smem_allocations;
//! List of static shared memory buffers
- std::vector<const kir::Allocate*> static_smem_allocations;
+ std::vector<kir::Allocate*> static_smem_allocations;
//! Indicate the need to generate random numbers
bool is_stochastic = false;
//! Do we have any block reductions?
bool has_block_reductions = false;
- //! Number of static grid reductions
- int number_of_grid_reductions = 0;
-
- //! Do we have any grid reduction in a loop?
- bool has_grid_reduction_in_loop = false;
+ //! Do we have any grid reductions?
+ bool has_grid_reductions = false;
//! Do we have any block broadcasts?
bool has_block_broadcasts = false;
- //! Do we have any welford op?
- bool has_welford = false;
-
- //! Do we have any welford op?
- bool has_block_welford = false;
-
- //! Do we have any welford op?
- bool has_grid_welford = false;
-
//! Largest shared memory buffer base type
DataType largest_smem_data_type = DataType::Null;
-
- //! Do we have allocations of dynamic local memory?
- bool has_dynamic_local_memory_allocations = false;
-
- //! List of dynamic local memory buffers.
- //! Only used for debugging.
- std::vector<const kir::Allocate*> dynamic_lmem_allocations;
};
//! Container for a lowered Kernel IR
//! At this point we have a complete kernel definition and we can
//! run analysis passes to build a KernelSummary
//!
- void finalize(std::vector<kir::Expr*> top_level_exprs);
+ void finalize(
+ std::vector<Expr*> top_level_exprs,
+ ThreadPredicateMap predicate_map);
//! Register input as an input of the kernel
void addInput(Val* input) {
inputs_.push_back(input);
- input_set_.insert(input);
}
//! Register output as an output of the kernel
void addOutput(Val* output) {
outputs_.push_back(output);
- output_set_.insert(output);
}
const auto& inputs() const {
return outputs_;
}
- bool isInput(Val* val) const {
- return input_set_.find(val) != input_set_.end();
- }
-
- bool isOutput(Val* val) const {
- return output_set_.find(val) != output_set_.end();
- }
-
const auto& topLevelExprs() const {
return top_level_exprs_;
}
- const auto& irNodes() const {
- return ir_nodes_;
- }
-
const KernelSummary& summary() const {
return summary_;
}
//! \note This is a specialized helper for kir::IrBuilder, not
//! intendted for general use
//!
- void registerIrNode(kir::Passkey passkey, std::unique_ptr<kir::Node> node) {
- TORCH_CHECK(passkey.kernel == this);
+ void registerIrNode(std::unique_ptr<Statement> node) {
ir_nodes_.push_back(std::move(node));
}
- //! Allocates a new value identifier
- kir::ValueId newValueId(kir::Passkey passkey) {
- TORCH_CHECK(passkey.kernel == this);
- return next_value_id_++;
- }
-
//! Debug dump of the Kernel IR
void print() const;
private:
// Kernel IR nodes
- std::vector<std::unique_ptr<kir::Node>> ir_nodes_;
+ std::vector<std::unique_ptr<Statement>> ir_nodes_;
- // Top level statements
- std::vector<kir::Expr*> top_level_exprs_;
+ // Map from value to its definition expression
+ std::unordered_map<const Val*, Expr*> definitions_;
+
+ // Top level expressions
+ std::vector<Expr*> top_level_exprs_;
// Kernel inputs and outputs
std::vector<Val*> inputs_;
std::vector<Val*> outputs_;
- std::unordered_set<Val*> input_set_;
- std::unordered_set<Val*> output_set_;
-
- // Used to allocate unique value IDs
- kir::ValueId next_value_id_ = 1;
// Summary of interesting kernel data
KernelSummary summary_;
std::unique_ptr<ThreadPredicateMap> predicate_map_;
};
-} // namespace kir
} // namespace cuda
} // namespace fuser
} // namespace jit
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/parser.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
+#include <torch/csrc/jit/codegen/cuda/scheduler.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <c10/util/irange.h>
if (index != -1 && index != cur_index) {
return -1;
}
- index = (int)cur_index; // NOLINT
+ // NOLINTNEXTLINE(bugprone-signed-char-misuse)
+ index = cur_index;
}
return index;
}
}
#pragma clang diagnostic pop
-at::DimVector graphReductionAxes(
- const std::shared_ptr<Graph>& graph,
- bool& simple_reduction) {
+at::DimVector graphReductionAxes(const std::shared_ptr<Graph>& graph) {
FUSER_PERF_SCOPE("graphReductionAxes");
- simple_reduction = true;
at::DimVector reduction_axes;
// TODO: let check that we have only single reduction node in the graph.
- int reduction_count = 0;
for (const auto& n : graph->nodes()) {
- if (isReductionToSizeNode(n)) {
- // TODO: we don't support permutation with ReductionToSize;
- simple_reduction = false;
- reduction_axes.clear();
- return reduction_axes;
- } else if (isReductionNode(n)) {
+ if (isReductionNode(n)) {
// TODO: we should return empty when `keepdim` is True?
auto dims_list = constant_as<c10::List<int64_t>>(n->input(1));
TORCH_INTERNAL_ASSERT(
for (const auto dim : dims_list->vec()) {
reduction_axes.emplace_back(static_cast<int>(dim));
}
- ++reduction_count;
// we should return here, but we don't!
// We continue the traversal and check for other reduction node. Because
- // our permutation doesn't really support intermediate reduction, hence we
- // mark simple_reduction as false;
- if (reduction_count != 1) {
- simple_reduction = false;
- return reduction_axes;
- }
+ // our permutation doesn't really support intermediate reduction; Continue
+ // traversal would trigger the `TORCH_INTERNAL_ASSERT`, it's not ideal but
+ // at least it's not silent error.
}
- // TODO: this doesn't apply any more, clean it up
}
return reduction_axes;
}
}
}
-void encodeBuffer(size_t value, std::string& buffer) {
- const char* v = reinterpret_cast<char*>(&value);
- for (size_t i = 0; i < sizeof(size_t); i++) {
- buffer.push_back(*(v++));
- }
-}
-
} // namespace
InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId(
- const at::ArrayRef<IValue>& inputs,
- const SchedulerRuntimeInfo* additional_info) {
+ const at::ArrayRef<IValue>& inputs) {
IdLookupReturn ret;
-
- // lock mutex_ because we are touching encoding_
- std::lock_guard<std::mutex> guard(mutex_);
- encoding_.clear();
+ std::stringstream encoded_inputs;
for (const auto& input : inputs) {
if (input.isTensor()) {
auto& input_tensor = input.toTensor();
+ encoded_inputs << ";";
+ auto sep = "";
for (auto size : input_tensor.sizes()) {
- encodeBuffer(size, encoding_);
- encoding_.push_back(' ');
+ encoded_inputs << sep << size;
+ sep = ",";
}
- encoding_.push_back('X');
- encoding_.push_back(' ');
+ encoded_inputs << "@";
+ sep = "";
for (auto stride : input_tensor.strides()) {
- encodeBuffer(stride, encoding_);
- encoding_.push_back(' ');
+ encoded_inputs << sep << stride;
+ sep = ",";
}
- encoding_.push_back('d');
- encodeBuffer(input_tensor.device().index(), encoding_);
+ encoded_inputs << "@" << input_tensor.device().str();
} else {
// encode s for scalar;
- encoding_.push_back('s');
+ encoded_inputs << ";s";
}
- encoding_.push_back(';');
- }
- if (additional_info) {
- encodeBuffer(additional_info->getCommonAlignmentSize(), encoding_);
}
+ auto& id_iter_pair = encoding_lookup_[encoded_inputs.str()];
- auto& entry = encoding_lookup_[encoding_];
+ // short-cut to leave LRU entry as is;
+ if (id_iter_pair.lru_iter == used_entry_.begin()) {
+ ret.id = id_iter_pair.id;
+ return ret;
+ }
- if (entry.id == 0) {
+ if (id_iter_pair.id == 0) {
// no entry existed for given input set, set id for given entry
- entry.id = current_id_++;
+ id_iter_pair.id = current_id_++;
if (used_entry_.size() == max_cache_size_) {
// pop least recently used cache;
const auto& remove_iter = encoding_lookup_.find(used_entry_.back());
encoding_lookup_.erase(remove_iter);
}
} else {
- // short-cut to leave LRU entry as is
- if (entry.lru_iter == used_entry_.begin()) {
- ret.id = entry.id;
- return ret;
- }
-
- used_entry_.erase(entry.lru_iter);
+ used_entry_.erase(id_iter_pair.lru_iter);
}
- ret.id = entry.id;
- entry.lru_iter = used_entry_.insert(used_entry_.begin(), encoding_);
+ ret.id = id_iter_pair.id;
+ id_iter_pair.lru_iter =
+ used_entry_.insert(used_entry_.begin(), encoded_inputs.str());
return ret;
}
-FusionExecutorCache::FusionExecutorCache(std::unique_ptr<Fusion> fusion)
- : fusion_(std::move(fusion)) {}
+// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+FusionExecutorCache::FusionExecutorCache(std::unique_ptr<Fusion>&& fusion)
+ : fusion_(std::move(fusion)) {
+ FUSER_PERF_SCOPE("FusionExecutorCache::FusionExecutorCache");
+ // avoid putting `has_reduction_` in the initializer list
+ has_reduction_ = fusion_->hasReduction();
+}
std::vector<at::Tensor> FusionExecutorCache::runFusionWithInputs(
const at::ArrayRef<IValue>& inputs) {
- FUSER_PERF_SCOPE("FusionExecutorCache::runFusionWithInputs");
+ FUSER_PERF_SCOPE("runFusionWithInputs");
- SchedulerRuntimeInfo runtime_info(fusion(), inputs);
-
- auto id_lookup_ret = inputs_id_lookup_.lookupId(inputs, &runtime_info);
+ // get unique id `unique_id` for given input set `inputs`;
+ auto id_lookup_ret = inputs_id_lookup_.lookupId(inputs);
if (id_lookup_ret.eviction) {
evictCache(id_lookup_ret.evict_id);
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const size_t unique_id = id_lookup_ret.id;
- auto kernel_runtime = getKernelRuntimeFor(inputs, unique_id);
- most_recent_runtime_ = kernel_runtime;
- return kernel_runtime->runWithInput(inputs, unique_id);
-}
-
-void FusionExecutorCache::evictCache(size_t cache_id) {
- auto it = id_to_kernel_runtime_.find(cache_id);
- TORCH_INTERNAL_ASSERT(it != id_to_kernel_runtime_.end());
- it->second->evictCache(cache_id);
- id_to_kernel_runtime_.erase(it);
-}
-
-FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor(
- const at::ArrayRef<IValue>& inputs,
- size_t unique_id) {
- // Check for id hit case
- auto id_it = id_to_kernel_runtime_.find(unique_id);
- if (id_it != id_to_kernel_runtime_.end()) {
- return id_it->second;
- }
-
- // Access kernels associated with the common device id
- auto dev_id = getCommonDeviceCUDA(inputs);
- TORCH_INTERNAL_ASSERT(dev_id >= 0);
- auto& kernel_runtimes = kernel_runtimes_[dev_id];
-
- // Check for re-use hit case
- // a kernel runtime is re-usable if all the compiled
- // kernels have the same heuristic parameters
- std::unique_ptr<FusionHeuristics> new_heuristics;
-
- auto reuse_it = std::find_if(
- kernel_runtimes.begin(),
- kernel_runtimes.end(),
- [&inputs, &new_heuristics](auto& kernel_runtime) {
- auto maybe_heuristics = kernel_runtime->getMaybeHeuristicsFor(inputs);
- if (!maybe_heuristics.has_value()) {
- return false;
- }
- new_heuristics = std::move(maybe_heuristics.value());
- return true;
- });
-
- FusionKernelRuntime* kernel_runtime;
- if (reuse_it != kernel_runtimes.end()) {
- kernel_runtime = reuse_it->get();
- kernel_runtime->updateHeuristicsLaunchParams(new_heuristics.get());
- } else {
- // graph miss, need to re-build an optimized graph for this case
- kernel_runtimes.emplace_back(
- std::make_unique<FusionKernelRuntime>(fusion_.get(), inputs));
- kernel_runtime = kernel_runtimes.back().get();
- if (profiling_) {
- kernel_runtime->profile(true);
- }
- }
-
- id_to_kernel_runtime_[unique_id] = kernel_runtime;
- return kernel_runtime;
-}
-
-FusionKernelRuntime::FusionKernelRuntime(
- Fusion* fusion,
- const at::ArrayRef<IValue>& inputs) {
- FUSER_PERF_SCOPE("FusionKernelRuntime::FusionKernelRuntime");
-
- // Make a copy of fusion and do segmentation and translation
- // on this copy
- auto fusion_copy = std::make_unique<Fusion>(*fusion);
-
- // Run segmentation on the copied fusion
- SchedulerRuntimeInfo runtime_info(fusion_copy.get(), inputs, true);
-
- //! Try to schedule the complete fusion
- const auto maybe_complete_fusion_heuristic =
- SchedulerEntry::proposeHeuristics(fusion_copy.get(), runtime_info);
-
- //! Decide if this fusion is segmented or not
- const bool segmented = !maybe_complete_fusion_heuristic.has_value();
-
- if (segmented) {
- // Take ownership and segment transformed fusion
- segmented_fusion_ =
- SegmentCandidateFinder::segment(std::move(fusion_copy), inputs);
- heuristics_ = segmented_fusion_->makeInitialHeuristics(inputs);
- executors_ =
- std::vector<FusionExecutor>(segmented_fusion_->groups().size());
- if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) {
- segmented_fusion_->print();
- }
- } else {
- auto complete_fusion_heuristic = maybe_complete_fusion_heuristic.value();
-
- // Translate welfords if apply
- if (fusion_copy->hasWelford()) {
- bool translated = SegmentCandidateFinder::TranslateWelfordInFusion(
- fusion_copy.get(), inputs);
- if (translated) {
- complete_fusion_heuristic = ScheduleHeuristic::Normalization;
- }
- }
- // Take ownership of the transformed fusion
- single_kernel_fusion_ = std::move(fusion_copy);
-
- single_kernel_fusion_data_cache_ = std::make_unique<HeuristicSummary>(
- single_kernel_fusion_.get(), complete_fusion_heuristic, runtime_info);
-
- heuristics_ = std::make_unique<FusionHeuristics>(
- complete_fusion_heuristic,
- runtime_info,
- single_kernel_fusion_data_cache_.get());
-
- executors_ = std::vector<FusionExecutor>(1);
- // In the case that the fusion isn't segmented but user
- // wants segmented fusion in the debug print. Will
- // print math of the composite fusion as placeholder
- if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) {
- single_kernel_fusion_->printMath();
- }
- }
-
- is_segmented_ = segmented;
-}
-
-std::vector<at::Tensor> FusionKernelRuntime::runKernelWithInput(
- const at::ArrayRef<IValue>& inputs,
- size_t input_id,
- SegmentedGroup* sg) {
- FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput");
- // This function will be called once on un-segmented fusion,
- // for segmented fusion, this function will be called on each segment
- // In the case of segmented fusion, segmented group needs to be given so
- // a kernel is compiled and run for a segmented group
- // In the case of complete fusion, sg = nullptr, and the original fusion
- // is complied and run
- auto group_id = sg ? sg->groupId() : 0;
const int device_index = getCommonDeviceCUDA(inputs);
TORCH_CHECK(device_index >= 0, "device is not coherent for fusion inputs");
LaunchParams launch_params;
+ if (code_to_fe_lookup_.count(unique_id) == 0) {
+ // enter when we get a new input set. We need to search for compatible
+ // entries in cached `FusionExecutor` or compile new one as needed.
+
+ // caching strategy is different for pw-fusion and reduction-fusion.
+ if (has_reduction_) {
+ // Grab the fusion to analyze for heuristics
+ FusionGuard fg(fusion_.get());
+
+ TensorView* reduction_tv = nullptr;
+ // Use dependency check to find the reduction tv as it returns used values
+ // instead of exprs.
+
+ // The call is relatively heavy weight, consider caching
+ auto used_vals = DependencyCheck::getAllValsBetween(
+ {fusion_->inputs().begin(), fusion_->inputs().end()},
+ fusion_->outputs());
+
+ // Find the reduction tensor view, make sure there's only one
+ for (auto val : used_vals) {
+ if (val->getValType().value() == ValType::TensorView) {
+ auto tv = val->as<TensorView>();
+ if (tv->hasReduction()) {
+ TORCH_INTERNAL_ASSERT(
+ reduction_tv == nullptr,
+ "Already found a reduction tensorview, cannot handle fusion of multiple reductions.");
+ reduction_tv = tv;
+ }
+ }
+ }
- auto scheduler_entry = schedulers()[group_id].get();
-
- // Check that the heuristics are matched, in the case of segmented fusion
- TORCH_INTERNAL_ASSERT(!sg || scheduler_entry->heuristc() == sg->heuristic());
-
- if (!executors_[group_id].compiled()) {
- FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput::Compile");
- std::unique_ptr<Fusion> fusion_to_run;
- if (sg) {
- // Running a segment group as a single kernel,
- // make a fusion to run from segmented fusion
- fusion_to_run = segmented_fusion_->makeFusion(sg);
- } else {
- // Without a segmented group defaults to compiling the
- // complete fusion
- fusion_to_run = std::make_unique<Fusion>(*single_kernel_fusion_);
- }
- CompileOptions options;
- options.device = c10::Device(DeviceType::CUDA, device_index);
- options.index_mode = scheduler_entry->indexMode();
- FusionGuard fg(fusion_to_run.get());
- scheduler_entry->schedule(fusion_to_run.get());
- // Load launch params for reduction and normalization kernels
- if (scheduler_entry->hasReductionParam()) {
- launch_params = scheduler_entry->reductionParams().lparams;
- } else {
- launch_params = scheduler_entry->pointwiseParams().lparams;
- }
- executors_[group_id].compileFusion(
- fusion_to_run.get(), options, inputs, launch_params);
- } else {
- FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput::FetchFromCache");
- // Load launch params for reduction and normalization kernels
- if (scheduler_entry->hasReductionParam()) {
- launch_params = scheduler_entry->reductionParams().lparams;
- } else {
- launch_params = scheduler_entry->pointwiseParams().lparams;
- }
- }
-
- if (profiling_) {
- FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput::profiling_");
- most_recent_executor_log_.fusion_executor = &executors_[group_id];
- most_recent_executor_log_.launch_constraints = launch_params;
- if (scheduler_entry->hasReductionParam()) {
- most_recent_executor_log_.reduction_params =
- scheduler_entry->reductionParams();
- } else {
- most_recent_executor_log_.pointwise_params =
- scheduler_entry->pointwiseParams();
- }
- }
-
- return executors_[group_id].runFusion(inputs, launch_params, input_id);
-}
+ TORCH_INTERNAL_ASSERT(
+ reduction_tv != nullptr,
+ "Could not find the reduction tensor view in the fusion.");
-std::vector<at::Tensor> FusionKernelRuntime::runMultiKernelWithInput(
- const at::ArrayRef<IValue>& inputs,
- size_t input_id) {
- FUSER_PERF_SCOPE("FusionKernelRuntime::runMultiKernelWithInput");
+ // Generate the reduction parameters
+ auto reduction_params =
+ getReductionHeuristics(fusion_.get(), inputs, reduction_tv);
- TORCH_INTERNAL_ASSERT(
- inputs.size() == segmented_fusion_->inputs().size(),
- "Inputs were not set up correctly, recieved ",
- inputs.size(),
- " inputs but expecting ",
- segmented_fusion_->inputs().size());
-
- // Map to keep track of currently available tensors
- std::unordered_map<Val*, IValue> tensor_map;
-
- // Bind input in the tensor_map
- for (size_t i = 0; i < inputs.size(); i++) {
- tensor_map.emplace(segmented_fusion_->inputs()[i], inputs[i]);
-
- // Bind tensorview inputs values in case some segmented group
- // needs it down the road.
- // TODO: we probably have done this already up to this point
- // should consider caching the expression evaluators, both
- // more convenient and safer than replication
- if (inputs[i].isTensor()) {
- auto aten_tensor = inputs[i].toTensor();
TORCH_INTERNAL_ASSERT(
- segmented_fusion_->inputs()[i]->getValType() == ValType::TensorView);
- auto input_tv = segmented_fusion_->inputs()[i]->as<TensorView>();
- auto root_dom = TensorDomain::noReductions(input_tv->getRootDomain());
- for (size_t dim = 0; dim < root_dom.size(); dim++) {
- const auto extent = root_dom[dim]->extent();
- const auto value = aten_tensor.sizes()[dim];
- tensor_map.emplace(extent, value);
- }
- }
- }
+ reduction_params.has_value(),
+ "Error getting reduction heuristics for scheduling.");
- // Keep track of groups that has run
- std::vector<bool> group_ran(segmented_fusion_->groups().size(), false);
+ launch_params = reduction_params.value().lparams;
- while (!std::all_of(
- group_ran.begin(), group_ran.end(), [](bool b) { return b; })) {
- bool one_ran = false;
+ auto fusion_executor =
+ &red_fusion_executor_cache_[device_index][reduction_params.value()];
- // Find the first segment with all inputs available to run
- for (size_t group_i = 0; group_i < segmented_fusion_->groups().size();
- group_i++) {
- auto& group = segmented_fusion_->groups()[group_i];
- if (group_ran[group_i]) {
- continue;
- }
- const auto& group_inputs = group->inputs();
- bool ready_to_run = std::all_of(
- group_inputs.begin(), group_inputs.end(), [&tensor_map](Val* val) {
- return tensor_map.find(val) != tensor_map.end();
- });
-
- if (ready_to_run) {
- std::vector<IValue> group_runtime_inputs;
- group_runtime_inputs.reserve(group_inputs.size());
-
- // Prepare input vector
- for (auto input : group_inputs) {
- group_runtime_inputs.push_back(tensor_map.at(input));
- }
+ if (!fusion_executor->compiled()) {
+ // HEURISTIC NOT COMPILED, COMPILE A KERNEL
+ Fusion fusion = *fusion_;
- // Run graph segment
- auto group_runtime_outputs =
- runKernelWithInput(group_runtime_inputs, input_id, group);
+ FusionGuard fg(&fusion);
- const auto& group_outputs = group->outputs();
+ // Heavy weight call
+ auto used_vals = DependencyCheck::getAllValsBetween(
+ {fusion.inputs().begin(), fusion.inputs().end()}, fusion.outputs());
- // Insert graph segment output to tensor map
- for (size_t group_out_i = 0; group_out_i < group_outputs.size();
- group_out_i++) {
- tensor_map.emplace(
- group_outputs[group_out_i], group_runtime_outputs[group_out_i]);
- }
- group_ran[group_i] = true;
- one_ran = true;
- }
- }
- TORCH_INTERNAL_ASSERT(
- one_ran,
- "Couldn't run all groups, something must have gone wrong in segmentation.");
- }
+ TensorView* reduction_tv = nullptr;
- // Produce final global output
- std::vector<IValue> fusion_outputs;
- for (auto output : segmented_fusion_->outputs()) {
- const auto iter = tensor_map.find(output);
- if (iter != tensor_map.end()) {
- fusion_outputs.push_back(iter->second);
- } else {
- // This is the check for an empty tensor;
- TORCH_INTERNAL_ASSERT(
- output->as<TensorView>()->nDims() == 0 &&
- output->getDataType().has_value() &&
- output->getDataType().value() == DataType::Float,
- "Non empty tensor cannot be found at tensor_map in ",
- __FUNCTION__);
- fusion_outputs.emplace_back(at::Tensor());
- }
- }
+ for (auto val : used_vals) {
+ if (val->getValType().value() == ValType::TensorView) {
+ auto tv = val->as<TensorView>();
+ if (tv->hasReduction()) {
+ TORCH_INTERNAL_ASSERT(
+ reduction_tv == nullptr,
+ "Already found a reduction tensorview, cannot handle fusion of multiple reductions.");
+ reduction_tv = tv;
+ }
+ }
+ }
- std::vector<at::Tensor> fusion_output_tensors;
- std::transform(
- fusion_outputs.begin(),
- fusion_outputs.end(),
- std::back_inserter(fusion_output_tensors),
- [](IValue ival) {
TORCH_INTERNAL_ASSERT(
- ival.isTensor(), "Cannot output non-tensor objects from a fusion.");
- return ival.toTensor();
- });
-
- return fusion_output_tensors;
-}
-
-const std::vector<FusionKernelRuntime::SchedulerEntryPtr>& FusionKernelRuntime::
- schedulers() {
- return heuristics_->heuristicsList();
-}
+ reduction_tv != nullptr,
+ "Could not find the reduction tensor view in the fusion.");
+
+ // Heavy weight call
+ auto outputsOfReduction =
+ DependencyCheck::getAllOutputsOf({reduction_tv});
+
+ auto tv_entries =
+ ir_utils::filterByType<TensorView>(outputsOfReduction);
+
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ std::vector<TensorView*> tvOutputsOfReduction(
+ tv_entries.begin(), tv_entries.end());
+
+ scheduleReduction(
+ &fusion,
+ reduction_params.value(),
+ reduction_tv,
+ tvOutputsOfReduction);
+
+ // This means we have not found a previously generated kernel that's
+ // compatible with the new reduction params. We need to finish codegen.
+ CompileOptions options;
+ options.device = c10::Device(DeviceType::CUDA, device_index);
+ fusion_executor->compileFusion(&fusion, options);
+ }
+ // record new short cut to `FusionExecutor`
+ code_to_fe_lookup_[unique_id] = fusion_executor;
-void FusionKernelRuntime::updateHeuristicsLaunchParams(
- FusionHeuristics* update_heuristics) {
- FUSER_PERF_SCOPE("FusionKernelRuntime::updateHeuristicsLaunchParams");
- auto scheduler_list_length = heuristics_->heuristicsList().size();
- TORCH_INTERNAL_ASSERT(
- update_heuristics->heuristicsList().size() == scheduler_list_length);
- for (size_t i = 0; i < scheduler_list_length; i++) {
- auto& schedulerPtr = heuristics_->heuristicsList()[i];
- if (schedulerPtr->hasReductionParam()) {
- schedulerPtr->updateLaunchConstraint(
- update_heuristics->heuristicsList()[i]->reductionParams().lparams);
} else {
- schedulerPtr->updateLaunchConstraint(
- update_heuristics->heuristicsList()[i]->pointwiseParams().lparams);
- }
- }
-}
-
-c10::optional<FusionKernelRuntime::HeuristicsPtr> FusionKernelRuntime::
- getMaybeHeuristicsFor(const at::ArrayRef<IValue>& inputs) {
- FUSER_PERF_SCOPE("FusionKernelRuntime::getMaybeHeuristicsFor");
- auto complete_fusion = is_segmented_ ? segmented_fusion_->completeFusion()
- : single_kernel_fusion_.get();
- SchedulerRuntimeInfo runtime_info(complete_fusion, inputs, true);
-
- c10::optional<FusionKernelRuntime::HeuristicsPtr> ret;
- // Segmented case, need to iterate over all segmented groups
- if (is_segmented_) {
- ret = std::make_unique<FusionHeuristics>();
- size_t total_groups = segmented_fusion_->groups().size();
- for (size_t group_index = 0; group_index < total_groups; group_index++) {
- auto group = segmented_fusion_->groups()[group_index];
-
- auto maybe_scheduler_entry = group->getMaybeSchedulerEntry(runtime_info);
- if (!maybe_scheduler_entry.has_value()) {
- return c10::nullopt;
+ // Handle pointwise operations
+ if (pw_fusion_executor_cache_.count(device_index) == 0) {
+ pw_fusion_executor_cache_[device_index] =
+ std::make_unique<FusionExecutor>();
+ CompileOptions options;
+ options.device = c10::Device(DeviceType::CUDA, device_index);
+ // no need to copy fusion_, as we are not generating more than 1 kernel
+ // for PW.
+ scheduleFusion(fusion_.get(), inputs);
+ pw_fusion_executor_cache_[device_index]->compileFusion(
+ fusion_.get(), options);
}
- auto scheduler_entry = std::move(maybe_scheduler_entry.value());
- if (!scheduler_entry->sameAs(
- heuristics_->heuristicsList()[group_index].get())) {
- return c10::nullopt;
- }
- ret.value()->emplaceBack(std::move(scheduler_entry));
+ // record new short cut to `FusionExecutor`
+ code_to_fe_lookup_[unique_id] =
+ pw_fusion_executor_cache_[device_index].get();
}
-
- return ret;
}
- // Un-segmented case, just check the complete fusion
- auto& complete_fusion_scheduler = schedulers()[0];
- auto complete_fusion_heuristic = complete_fusion_scheduler->heuristc();
- if (!SchedulerEntry::canSchedule(
- complete_fusion_heuristic,
- complete_fusion,
- runtime_info,
- single_kernel_fusion_data_cache_.get())) {
- return c10::nullopt;
- }
-
- ret = std::make_unique<FusionHeuristics>(
- complete_fusion_heuristic,
- runtime_info,
- single_kernel_fusion_data_cache_.get());
- if (!complete_fusion_scheduler->sameAs(
- ret.value()->heuristicsList()[0].get())) {
- return c10::nullopt;
- }
-
- return ret;
+ return code_to_fe_lookup_[unique_id]->runFusion(
+ inputs, launch_params, unique_id);
}
bool GraphCache::requiresPermutation() {
- if (!support_permutation_) {
- return false;
- }
-
const size_t input_rank = input_permutation_.size();
for (const auto i : c10::irange(input_rank)) {
if (input_permutation_[i] != (long)i) {
permuted_vec_optional_stride,
type->requires_grad());
}; // closing lambda
+
for (auto input : graph->inputs()) {
if (auto input_type = input->type()->cast<TensorType>()) {
input->setType(type_permute_fn(input_type));
// 2. adjust reduction axes for the permutation;
// permute changes the semantics of axes, we need to update the reduction
// axes in the graph in order to match the behavior;
- reduction_axes_ = graphReductionAxes(graph, support_permutation_);
-
- // TODO: reduction with permutation is tricky now as we might support complex
- // topology in graph with segmented fusion.
- if (support_permutation_) {
- // run over inputs to extract common types;
- TensorTypePtr acc_type = TensorType::get();
- for (const auto& input : graph->inputs()) {
- // only check tensor types;
- if (auto input_type = input->type()->cast<TensorType>()) {
- if (acc_type->dim().has_value()) {
- // TODO: I think merge cannot handle broadcast - Go verify it later;
- // TODO: Since we are only handling permutation here, we should just
- // merge the stride_index_;
- acc_type = acc_type->merge(*input_type);
- } else {
- acc_type = input_type;
- }
+ reduction_axes_ = graphReductionAxes(graph);
+
+ // run over inputs to extract common types;
+ TensorTypePtr acc_type = TensorType::get();
+ for (const auto& input : graph->inputs()) {
+ // only check tensor types;
+ if (auto input_type = input->type()->cast<TensorType>()) {
+ if (acc_type->dim().has_value()) {
+ // TODO: I think merge cannot handle broadcast - Go verify it later;
+ // TODO: Since we are only handling permutation here, we should just
+ // merge the stride_index_;
+ acc_type = acc_type->merge(*input_type);
+ } else {
+ acc_type = input_type;
}
}
- extractPermutation(acc_type);
}
+ extractPermutation(acc_type);
createFusion(graph);
}
std::vector<at::Tensor> GraphCache::runGraphWithInputs(
const at::ArrayRef<IValue>& inputs) {
- FUSER_PERF_SCOPE("GraphCache::runGraphWithInputs");
+ FUSER_PERF_SCOPE("runGraphWithInputs");
// GraphCache need to permute inputs/outputs to accommodate dimension
// coalescing
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
+#include <torch/csrc/jit/codegen/cuda/scheduler.h>
#include <c10/util/ArrayRef.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <mutex>
#include <type_traits>
#include <unordered_map>
namespace fuser {
namespace cuda {
-class SegmentedGroup;
-class FusionHeuristics;
-class SchedulerRuntimeInfo;
-
-// Utilities for benchmarking and profiling
-struct ExecutorLog {
- c10::optional<ReductionParams> reduction_params = c10::nullopt;
- c10::optional<PointwiseParams> pointwise_params = c10::nullopt;
- c10::optional<LaunchParams> launch_constraints = c10::nullopt;
- FusionExecutor* fusion_executor = nullptr;
-};
-
-//! FusionKernelRuntime is the unified interface from fusion graphs into
-//! caching, compilation into kernels, and kernel launches.
-//!
-//! Each instance is also a cache entry tracked by FusionKernelRuntimeCache.
-//!
-//! Two types of instance can be created, one for complete/single-kernel fusion
-//! and one for segmented/multi-kernel fusion.
-//! Conceptually this is a generalization of FusionExecutor that supports both
-//! single-kernel and multi-kernel caching/compiling/launching
-class TORCH_CUDA_CU_API FusionKernelRuntime {
- public:
- explicit FusionKernelRuntime(
- Fusion* fusion,
- const at::ArrayRef<IValue>& inputs);
-
- //! Type notations within FusionKernelRuntime Context
- using HashType = size_t;
- using SchedulerEntryPtr = std::unique_ptr<SchedulerEntry>;
-
- //! Evicts internally cached parameters based on input sizes.
- //! An interface used by runtime caches.
- void evictCache(size_t input_id) {
- for (auto& fe : executors_) {
- fe.evictCache(input_id);
- }
- }
-
- //! Unified interface to run the managed kernels with given input
- std::vector<at::Tensor> runWithInput(
- const at::ArrayRef<IValue>& inputs,
- size_t input_id) {
- if (is_segmented_) {
- return runMultiKernelWithInput(inputs, input_id);
- } else {
- return runKernelWithInput(inputs, input_id);
- }
- }
-
- //! Turn On/Off profiling
- void profile(bool to_profile = true) {
- profiling_ = to_profile;
- }
-
- //! Returns if this runtime is segmented
- bool isSegmented() {
- return is_segmented_;
- }
-
- //! Returns the fusion segments if applicable
- SegmentedFusion* fusionSegments() {
- TORCH_INTERNAL_ASSERT(is_segmented_);
- return segmented_fusion_.get();
- }
-
- //! Returns the single kernel fusion if applicable
- Fusion* singleKernelFusion() {
- TORCH_INTERNAL_ASSERT(!is_segmented_);
- return single_kernel_fusion_.get();
- }
-
- //! Returns the list of heuristics in this runtime
- FusionHeuristics* schedulerHeuristics() {
- return heuristics_.get();
- }
-
- //! Return the most recently used executor, corresponding to the
- //! most recent kernel launch.
- //! TODO: have a interface for grabbing all recent logs. Need to put a buffer
- //! space for recent logs
- ExecutorLog getMostRecentExecutorLog() {
- TORCH_INTERNAL_ASSERT(
- profiling_, "Executor log is only produced in profiling mode");
- return most_recent_executor_log_;
- }
-
- // Try to compute heuristics based on the SegmentedFusion managed
- // in this kernel runtime, and will return a nullopt if either
- // any segment cannot be scheduled or the parameters don't match
- using HeuristicsPtr = std::unique_ptr<FusionHeuristics>;
- c10::optional<HeuristicsPtr> getMaybeHeuristicsFor(
- const at::ArrayRef<IValue>& inputs);
-
- //! Copy the launch params given in the parameter heuristics to prepare
- //! for kernel launch for a new input dimension but same heuristics
- void updateHeuristicsLaunchParams(FusionHeuristics* update_heuristics);
-
- private:
- //! Interface to run a single kernel, either one kernel for single-kernel
- //! fusions,
- //! or a kernel for a segmentedGrouup in a segmented fusion. Returns the
- //! kernel outputs.
- std::vector<at::Tensor> runKernelWithInput(
- const at::ArrayRef<IValue>& inputs,
- size_t input_id,
- SegmentedGroup* sg = nullptr);
-
- //! Interface to run a the whole graph in a segmented fusion and return the
- //! complete
- //! fusion outputs.
- std::vector<at::Tensor> runMultiKernelWithInput(
- const at::ArrayRef<IValue>& inputs,
- size_t input_id);
-
- //! Access the list of schedulers maintained in this runtime instance
- const std::vector<SchedulerEntryPtr>& schedulers();
-
- private:
- //! Entries indexed by groupID:
- //! Executors holding compiled kernels
- std::vector<FusionExecutor> executors_;
-
- //! Heuristics object holding scheduler entries for all segments
- std::unique_ptr<FusionHeuristics> heuristics_;
-
- // Checks if this runtime instance is for a single-kernel fusion (false) or a
- // segmented fusion (true).
- bool is_segmented_ = true;
-
- //! Multi-Kernel fusion segment when applies
- std::unique_ptr<SegmentedFusion> segmented_fusion_ = nullptr;
-
- //! Single-Kernel fusion when applies
- //! TODO: unify the segmented and un-segmented code-path
- std::unique_ptr<Fusion> single_kernel_fusion_ = nullptr;
-
- //! Graph traversal datacache for the single kernel fusion
- //! TODO: unify the segmented and un-segmented code-path
- std::unique_ptr<HeuristicSummary> single_kernel_fusion_data_cache_ = nullptr;
-
- // States for profiling support
- bool profiling_ = false;
-
- // The heuristics and executor for most recent kernel launch
- ExecutorLog most_recent_executor_log_;
-};
-
//! Encoding an input set to unique id, which is used to short-cut cache entry
//! selection in our nested cache implementation to cut off overhead.
//!
//! \note the uniqueness of the ide generated for a given input set is only
//! local to the instance of `InputsIdLookup`.
//!
-class TORCH_CUDA_CU_API InputsIdLookup : public NonCopyable {
+class TORCH_CUDA_CU_API InputsIdLookup {
public:
//! constructor where maximum cache size is fixed during init
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
- explicit InputsIdLookup(size_t max_cache_size = 100)
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+ explicit InputsIdLookup(size_t max_cache_size = 10)
: max_cache_size_(max_cache_size){};
//! struct to hold return value for lookupId.
//! within the lookup cache. This is needed because lookup shortcut is also
//! cached in nested `GraphCache`, `FusionExecutorCache` and `FusionExecutor`.
//! see [ Note -- 2 level cache implementation ]
- IdLookupReturn lookupId(
- const at::ArrayRef<IValue>& inputs,
- const SchedulerRuntimeInfo* additional_info = nullptr);
+ IdLookupReturn lookupId(const at::ArrayRef<IValue>& inputs);
//! debugging API that returns the size of lookup table
size_t size() const {
}
private:
- // string to store encoded input meta information. Reuse the buffer instead of
- // stringtream gives few us perf gain.
- std::string encoding_; // Note: shared state, guarded by mutex_
-
- // mutex_ used to guard reused encoding_
- std::mutex mutex_;
-
//! entry stored in `encoding_lookup_` to implement LRU
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct EncodingEntry {
- size_t id = 0;
+ size_t id;
std::list<std::string>::iterator lru_iter;
};
//! c) broadcasting semantics (size-1 or not);
//! d) rank;
//! e) scalar type;
-//!
-//!
-//! [ Note -- Segmented Fusion Tentative Design ]
-//! Segmentation adds an extra dimension in caching. Initial implementation,
-//! assumed graph partition strategy is independent of input pattern, which we
-//! can revisit once we have more advanced graph segmentation logic Each
-//! FusionExecutorCache corresponds to one graph and one graph segmentation.
-//!
-//!
-class TORCH_CUDA_CU_API FusionExecutorCache {
+
+class FusionExecutorCache {
public:
//! create new fusion executor cache at a given device to handle kernel
//! generation of dynamic sizes;
//! fusion executor is taking the ownership of `fusion`;
- explicit FusionExecutorCache(std::unique_ptr<Fusion> fusion);
+ explicit FusionExecutorCache(std::unique_ptr<Fusion>&& fusion);
//! Execute fusion graph with given inputs, create `FusionExecutor` as needed;
std::vector<at::Tensor> runFusionWithInputs(
const at::ArrayRef<IValue>& inputs);
- Fusion* fusion() {
- return fusion_.get();
- }
-
- void printFusion() {
- fusion_->printMath();
- }
-
- FusionKernelRuntime* getMostRecentKernelRuntime() {
- return most_recent_runtime_;
- }
-
- // TODO: in a follow up we need a global logging structure
- // to capture runtime profiling info. We also need to define
- // a suitable profiling window / buffer size.
- ExecutorLog getMostRecentExecutorInfo() {
- TORCH_INTERNAL_ASSERT(most_recent_runtime_ != nullptr);
- return most_recent_runtime_->getMostRecentExecutorLog();
- }
-
- void profile(bool to_profile) {
- profiling_ = to_profile;
- for (auto& it : kernel_runtimes_) {
- for (auto& kernel_runtime : it.second) {
- kernel_runtime->profile(to_profile);
- }
- }
- }
-
private:
//! evict cached short cut entry in `code_to_fe_lookup_` as well as cached
//! entry in `FusionExecutor`
- void evictCache(size_t cache_id);
-
- FusionKernelRuntime* getKernelRuntimeFor(
- const at::ArrayRef<IValue>& inputs,
- size_t id);
+ void evictCache(size_t cache_id) {
+ auto iter = code_to_fe_lookup_.find(cache_id);
+ TORCH_INTERNAL_ASSERT(
+ iter != code_to_fe_lookup_.end(),
+ "evict cache failed to find an entry");
+ // evict nested lookup entry in nested `FusionExecutor`
+ (iter->second)->evictCache(cache_id);
+ code_to_fe_lookup_.erase(iter);
+ };
private:
//! original un-scheduled `Fusion`;
std::unique_ptr<Fusion> fusion_;
+ // I'm trading the const model in favor of assigning `has_reduction_` in the
+ // body of constructor, instead of the initializer list;
+ // Because of the move statement used in the constructor, it's tricky to
+ // maintain the code if we have `has_reduction_` as a const member and
+ // initizlize it in the initializer list, where the order of initialization
+ // is controled by the order of declaration instead of their order in the list
+ //
+ //! cache fusion->hasReduction() because it's expensive;
+ bool has_reduction_;
+
+ //! TODO: ugly logic for now. We should integrate the hashing of cache for
+ //! different kernels. (alternatively we could do so in scheduler).
+ //! ugly bits now:
+ //! The fact that we have heuristics only for reduction, but use a general
+ //! kernel for all point-wise fusion ended up with this:
+ //! 1. For point-wise fusion, we have a single `FusionExecutor` in
+ //! `pw_fusion_executor_cache_`
+ //! 2. For reduction fusion we have a hash table with ReductionParams as entry
+ //! pointing to the actual `FusionExecutor` in `red_fusion_executor_cache_`
+ //!
+ //! Both cache_ key on device_index, because `FusionExecutor` is designated to
+ //! a single device
+ std::unordered_map<int, std::unique_ptr<FusionExecutor>>
+ pw_fusion_executor_cache_;
+ std::unordered_map<
+ int,
+ std::unordered_map<ReductionParams, FusionExecutor, ReductionParamsHash>>
+ red_fusion_executor_cache_;
+
+ //! short cut to FusionExecutor for input set encoded with id;
+ std::unordered_map<size_t, FusionExecutor*> code_to_fe_lookup_;
+
//! inputs to unique_id lookup table;
InputsIdLookup inputs_id_lookup_;
-
- //! Graphs after input dependent transfoms
- std::unordered_map<size_t, std::vector<std::unique_ptr<FusionKernelRuntime>>>
- kernel_runtimes_;
-
- //! Logging state for most recent compilation
- bool profiling_ = false;
-
- //! Logging state for most recent compilation
- ExecutorLog most_recent_executor_log_;
-
- //! short-cut for cache hit
- std::unordered_map<size_t, FusionKernelRuntime*> id_to_kernel_runtime_;
-
- //! Profiling info:
- //! TODO: this can be largely expanded to look at complete
- //! caching profiles. Currently it just makes it easier to test
- FusionKernelRuntime* most_recent_runtime_ = nullptr;
};
class GraphCache {
std::shared_ptr<Graph> graph_;
//! TODO: poor name, we should use `eliminated_axes_` instead;
at::DimVector reduction_axes_;
- bool support_permutation_;
//! helper function used at run-time to check whether a common permutation is
//! present, this is used to take the short-cut to skip permutation logic.
+++ /dev/null
-
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-
-#include <iostream>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-namespace kir {
-
-void ExpressionEvaluator::bind(
- const Val* value,
- Int::ScalarType concrete_value) {
- TORCH_CHECK(value->isScalar());
- TORCH_CHECK(value->dtype() == DataType::Int);
- TORCH_CHECK(!value->isConst(), "Tried to bind to a constant value");
- TORCH_CHECK(
- value->definition() == nullptr,
- "Tried to bind to a value that is computed in the kernel IR");
- known_values_[value] = concrete_value;
-}
-
-c10::optional<Int::ScalarType> ExpressionEvaluator::evaluate(const Val* value) {
- FUSER_PERF_SCOPE("kir::ExpressionEvaluator::evaluate");
-
- TORCH_CHECK(value->isScalar());
- TORCH_CHECK(value->dtype() == DataType::Int);
-
- // Const scalar?
- if (value->isScalar() && value->isConst()) {
- return value->as<Int>()->value();
- }
-
- // Is the value known (either explicit binding or memoized)?
- const auto pre_eval_it = known_values_.find(value);
- if (pre_eval_it != known_values_.end()) {
- return pre_eval_it->second;
- }
-
- value->accept(this);
-
- const auto post_eval_it = known_values_.find(value);
- return post_eval_it != known_values_.end()
- ? c10::optional<Int::ScalarType>(post_eval_it->second)
- : c10::nullopt;
-}
-
-bool ExpressionEvaluator::isConst(const Val* value) {
- return ExpressionEvaluator().evaluate(value).has_value();
-}
-
-void ExpressionEvaluator::print() const {
- std::cout << "\nEvaluation context\n";
- std::cout << "--------------------\n";
- for (const auto& kv : known_values_) {
- std::cout << toString(kv.first) << " = " << kv.second << "\n";
- }
- std::cout << "--------------------\n\n";
-}
-
-void ExpressionEvaluator::unhandled(const void*) {
- TORCH_INTERNAL_ASSERT(
- false, "Kernel IR expression evaluation reached an unsupported node");
-}
-
-void ExpressionEvaluator::visit(const Int* value) {
- TORCH_INTERNAL_ASSERT(!value->isConst());
- if (auto def = value->definition()) {
- def->accept(this);
- }
-}
-
-void ExpressionEvaluator::visit(const NamedScalar* named_scalar) {
- // It's a legal expresison node so we must handle it
-}
-
-void ExpressionEvaluator::visit(const UnaryOp* unary_op) {
- const auto in = evaluate(unary_op->in());
- if (in.has_value()) {
- switch (unary_op->operation()) {
- case UnaryOpType::Neg:
- known_values_[unary_op->out()] = -*in;
- break;
- case UnaryOpType::Cast:
- known_values_[unary_op->out()] = *in;
- break;
- default:
- TORCH_CHECK(!"Unexpected operator type");
- }
- }
-}
-
-void ExpressionEvaluator::visit(const BinaryOp* binary_op) {
- const auto lhs = evaluate(binary_op->lhs());
- const auto rhs = evaluate(binary_op->rhs());
- if (lhs.has_value() && rhs.has_value()) {
- switch (binary_op->operation()) {
- case BinaryOpType::Add:
- known_values_[binary_op->out()] = *lhs + *rhs;
- break;
- case BinaryOpType::Sub:
- known_values_[binary_op->out()] = *lhs - *rhs;
- break;
- case BinaryOpType::Mul:
- known_values_[binary_op->out()] = *lhs * *rhs;
- break;
- case BinaryOpType::Div:
- TORCH_CHECK(*rhs != 0);
- known_values_[binary_op->out()] = *lhs / *rhs;
- break;
- case BinaryOpType::Mod:
- TORCH_CHECK(*rhs != 0);
- known_values_[binary_op->out()] = *lhs % *rhs;
- break;
- case BinaryOpType::CeilDiv:
- TORCH_CHECK(*rhs != 0);
- known_values_[binary_op->out()] = (*lhs + *rhs - 1) / *rhs;
- break;
- case BinaryOpType::And:
- known_values_[binary_op->out()] = Int::ScalarType(*lhs && *rhs);
- break;
- default:
- TORCH_CHECK(!"Unexpected operator type");
- }
- }
-}
-
-} // namespace kir
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <c10/util/Optional.h>
-
-#include <unordered_map>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-namespace kir {
-
-//! Calculate Kernel IR expressions
-//!
-//! How to evaluate Kernel IR expressions:
-//!
-//! ```cpp
-//! kir::ExpressionEvaluator eval;
-//! eval.bind(symbolic_value, concrete_value);
-//! ... bind more values ...
-//! const auto result = eval.evaluate(interesting_value);
-//! if (result.has_value()) {
-//! ... we have successfully calculated the result ...
-//! } else {
-//! ... expression can't be evaluated ...
-//! }
-//! ```
-//!
-class TORCH_CUDA_CU_API ExpressionEvaluator : private IrVisitor {
- public:
- //! Set a concrete value for a symbolic value
- void bind(const Val* value, Int::ScalarType concrete_value);
-
- //! Try to evaluate a Kernel IR value
- c10::optional<Int::ScalarType> evaluate(const Val* value);
-
- //! Returns true if `value` is known before binding kernel inputs
- static bool isConst(const Val* value);
-
- //! Debugging helper, prints all the currently known values
- void print() const;
-
- private:
- void unhandled(const void*) final;
- void visit(const Int* value) final;
- void visit(const NamedScalar* named_scalar) final;
- void visit(const UnaryOp* unary_op) final;
- void visit(const BinaryOp* binary_op) final;
-
- private:
- std::unordered_map<const Val*, Int::ScalarType> known_values_;
-};
-
-} // namespace kir
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
-#include <torch/csrc/jit/codegen/cuda/kernel.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/type.h>
-#include <iostream>
-
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace kir {
-void Node::print() const {
- std::cout << "\n";
- IrPrinter(std::cout).printNode(this);
- std::cout << "\n";
-}
-
-Val::Val(Passkey passkey, DataType dtype) : Node(passkey), dtype_(dtype) {
- // NOLINTNEXTLINE: https://bugs.llvm.org/show_bug.cgi?id=48534
- id_ = passkey.kernel->newValueId(passkey);
-}
-
-namespace {
-
-// Traverse definition of all values involved in constructing the provided val.
-// Check if all values involved are constant values, meaning the provided
-// val is also a constant value.
-class ConstCheck : IrVisitor {
- private:
- bool is_const_ = true;
-
- using IrVisitor::visit;
-
- void visit(const Bool* b) {
- is_const_ = is_const_ && b->isConst();
- }
-
- void visit(const Double* d) {
- is_const_ = is_const_ && d->isConst();
- }
-
- void visit(const Int* i) {
- is_const_ = is_const_ && i->isConst();
- }
-
- void visit(const NamedScalar* ns) {
- is_const_ = is_const_ && false;
- }
-
- void visit(const Expr* expr) {
- for (auto inp : expr->inputs()) {
- visit(inp);
- }
- }
-
- void visit(const Val* val) {
- if (val->definition() != nullptr) {
- visit(val->definition());
- } else {
- val->accept(this);
- }
- }
-
- public:
- static bool isConst(const Val* val) {
- ConstCheck cc;
- cc.visit(val);
- return cc.is_const_;
- }
-};
-
-} // namespace
-
-bool Val::isConstScalar() const {
- if (!isScalar())
- return false;
- return ConstCheck::isConst(this);
-}
-
-Expr* Expr::parentScope() const {
- if (scope()) {
- return scope()->owner();
- } else {
- return nullptr;
- }
-}
-
NamedScalar* NamedScalar::getParallelDim(ParallelType p_type) {
std::string parallel_dim = stringifyThreadSize(p_type);
kir::IrBuilder ir_builder(GpuLower::current()->kernel());
return c10::nullopt;
}
-IterDomain::IterDomain(Passkey passkey, Val* start, Val* extent)
- : Val(passkey, DataType::Int), start_(start), extent_(extent) {}
+IterDomain::IterDomain(Passkey, Val* start, Val* extent)
+ : Val(ValType::KirIterDomain, DataType::Int, true, true),
+ start_(start),
+ extent_(extent) {}
-IterDomain::IterDomain(
- Passkey passkey,
- const fuser::cuda::IterDomain* iter_domain)
- : Val(passkey, iter_domain->getDataType().value()),
- start_(GpuLower::current()->lowerValue(iter_domain->start())),
- extent_(GpuLower::current()->lowerValue(iter_domain->extent())),
+IterDomain::IterDomain(Passkey, const fuser::cuda::IterDomain* iter_domain)
+ : Val(iter_domain),
+ start_(GpuLower::lowerValue(iter_domain->start())),
+ extent_(GpuLower::lowerValue(iter_domain->rawExtent())),
parallel_type_(iter_domain->getParallelType()),
iter_type_(iter_domain->getIterType()),
- is_rfactor_domain_(iter_domain->isRFactorProduct()),
- is_simple_(iter_domain->definition() == nullptr) {
- // preserve the fusion node's name
- setName(iter_domain->name());
-}
+ is_rfactor_domain_(iter_domain->isRFactorProduct()) {}
-//! Note that the parallel dimension, if available, may be different
-//! from the actual extent of this IterDomain as the parallel
-//! dimension is determined by the largest extent of IterDomains
-//! sharing the same loop.
Val* IterDomain::extent() const {
- TORCH_INTERNAL_ASSERT(extent_ != nullptr);
+ TORCH_CHECK(isLoweredVal(extent_));
+ if (isThread()) {
+ // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
+ if (extent_->getValType() == ValType::KirScalar) {
+ if (extent_->as<kir::Int>()->isConst()) {
+ return extent_;
+ }
+ }
+ return NamedScalar::getParallelDim(getParallelType());
+ }
return extent_;
}
-TensorDomain::TensorDomain(Passkey passkey, std::vector<IterDomain*> domain)
- : Val(passkey, DataType::Null), root_domain_(std::move(domain)) {
+TensorDomain::TensorDomain(Passkey, std::vector<IterDomain*> domain)
+ : Val(ValType::KirTensorDomain), root_domain_(std::move(domain)) {
domain_ = root_domain_;
resetDomains();
}
TensorDomain::TensorDomain(
- Passkey passkey,
+ Passkey,
const fuser::cuda::TensorDomain* tensor_domain)
- : Val(passkey, DataType::Null), contiguity_(tensor_domain->contiguity()) {
- // preserve the fusion node's name
- setName(tensor_domain->name());
-
+ : Val(tensor_domain), contiguity_(tensor_domain->contiguity()) {
const auto lowerIterDomains =
[](const std::vector<fuser::cuda::IterDomain*>& domains) {
std::vector<IterDomain*> lowered_domains;
lowered_domains.reserve(domains.size());
for (const auto iter_domain : domains) {
lowered_domains.push_back(
- GpuLower::current()->lowerValue(iter_domain)->as<IterDomain>());
+ GpuLower::lowerValue(iter_domain)->as<IterDomain>());
}
return lowered_domains;
};
return !rfactor_domain_.empty();
}
-bool TensorDomain::hasVectorize() const {
- return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
- return id->parallelType() == ParallelType::Vectorize ||
- id->parallelType() == ParallelType::MisalignedVectorize;
- });
-}
-
IterDomain* TensorDomain::axis(int i) const {
TORCH_INTERNAL_ASSERT(i >= 0 && i < int(domain_.size()));
return domain_[i];
return no_broadcast_domains;
}
-TensorView::TensorView(Passkey passkey, const fuser::cuda::TensorView* tv)
- : Val(passkey, tv->getDataType().value()), fuser_tv_(tv) {
- setName(tv->name());
- domain_ = GpuLower::current()->lowerValue(tv->domain())->as<TensorDomain>();
+TensorView::TensorView(Passkey, const fuser::cuda::TensorView* tv)
+ : Val(tv), fuser_tv_(tv) {
+ domain_ = GpuLower::lowerValue(tv->domain())->as<TensorDomain>();
memory_type_ = tv->getMemoryType();
}
-TensorView::TensorView(
- Passkey passkey,
- DataType dtype,
- TensorDomain* domain,
- MemoryType memory_type)
- : Val(passkey, dtype), domain_(domain), memory_type_(memory_type) {}
-
-UnaryOp::UnaryOp(Passkey passkey, UnaryOpType operation, Val* out, Val* in)
- : Expr(passkey), operation_(operation), out_(out), in_(in) {
+UnaryOp::UnaryOp(Passkey, UnaryOpType type, Val* out, Val* in)
+ : Expr(ExprType::KirUnaryOp), unary_op_type_{type}, out_{out}, in_{in} {
addOutput(out);
addInput(in);
+ name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
}
-BinaryOp::BinaryOp(
- Passkey passkey,
- BinaryOpType operation,
- Val* out,
- Val* lhs,
- Val* rhs)
- : Expr(passkey), operation_(operation), out_(out), lhs_(lhs), rhs_(rhs) {
+BinaryOp::BinaryOp(Passkey, BinaryOpType type, Val* out, Val* lhs, Val* rhs)
+ : Expr(ExprType::KirBinaryOp),
+ binary_op_type_{type},
+ out_{out},
+ lhs_{lhs},
+ rhs_{rhs} {
addOutput(out);
addInput(lhs);
addInput(rhs);
+ name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
}
TernaryOp::TernaryOp(
- Passkey passkey,
- TernaryOpType operation,
+ Passkey,
+ TernaryOpType type,
Val* out,
Val* in1,
Val* in2,
Val* in3)
- : Expr(passkey),
- operation_(operation),
- out_(out),
- in1_(in1),
- in2_(in2),
- in3_(in3) {
+ : Expr(ExprType::KirTernaryOp),
+ ternary_op_type_{type},
+ out_{out},
+ in1_{in1},
+ in2_{in2},
+ in3_{in3} {
addOutput(out);
addInput(in1);
addInput(in2);
addInput(in3);
+ name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
}
ReductionOp::ReductionOp(
- Passkey passkey,
- BinaryOpType operation,
+ Passkey,
+ BinaryOpType reduction_op_type,
Val* init,
Val* out,
- Val* in)
- : Expr(passkey), operation_(operation), init_(init), out_(out), in_(in) {
+ Val* in,
+ Bool* pred)
+ : Expr(ExprType::KirReductionOp),
+ reduction_op_type_(reduction_op_type),
+ init_(init),
+ out_(out),
+ in_(in),
+ pred_(pred) {
addOutput(out);
addInput(in);
-}
-
-WelfordOp::WelfordOp(
- Passkey passkey,
- Val* out_var,
- Val* out_avg,
- Val* out_N,
- Val* init_var,
- Val* init_avg,
- Val* init_N,
- Val* in_var,
- Val* in_avg,
- Val* in_N)
- : Expr(passkey),
- out_var_(out_var),
- out_avg_(out_avg),
- out_N_(out_N),
- init_var_(init_var),
- init_avg_(init_avg),
- init_N_(init_N),
- in_var_(in_var),
- in_avg_(in_avg),
- in_N_(in_N) {
- addOutput(out_avg);
- addOutput(out_var);
- addOutput(out_N);
-
- if (!in_N->isOneInt()) {
- addInput(in_var);
- }
- addInput(in_avg);
- addInput(in_N);
-}
-
-std::vector<IterDomain*> WelfordOp::getReductionDomains() const {
- // out is a TensorIndex after lowering
- const auto out_val = out()->as<kir::TensorIndex>()->view();
-
- auto vec_domain = out_val->as<TensorView>()->domain()->domain();
-
- vec_domain.erase(
- std::remove_if(
- vec_domain.begin(),
- vec_domain.end(),
- [](IterDomain* id) { return !id->isReduction(); }),
- vec_domain.end());
- return vec_domain;
-}
-
-std::unordered_map<ParallelType, IterDomain*, TypeHash> WelfordOp::
- getParallelReductionDomains() const {
- std::unordered_map<ParallelType, IterDomain*, TypeHash> parallel_domains;
- for (auto d : getReductionDomains()) {
- if (d->isThread()) {
- parallel_domains.insert(std::make_pair(d->parallelType(), d));
- }
- }
- return parallel_domains;
+ name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
}
std::vector<IterDomain*> ReductionOp::getReductionDomains() const {
std::unordered_map<ParallelType, IterDomain*, TypeHash> parallel_domains;
for (auto d : getReductionDomains()) {
if (d->isThread()) {
- parallel_domains.insert(std::make_pair(d->parallelType(), d));
+ parallel_domains.insert(std::make_pair(d->getParallelType(), d));
}
}
return parallel_domains;
}
-BroadcastOp::BroadcastOp(Passkey passkey, Val* out, Val* in)
- : Expr(passkey), out_(out), in_(in) {
- TORCH_CHECK(in->isA<TensorIndex>() || in->isA<TensorView>());
- TORCH_CHECK(out->isA<TensorIndex>() || out->isA<TensorView>());
+BroadcastOp::BroadcastOp(Passkey, Val* out, Val* in)
+ : Expr(ExprType::KirBroadcastOp), out_(out), in_(in) {
+ TORCH_CHECK(in->getValType().value() == ValType::TensorIndex);
+ TORCH_CHECK(out->getValType().value() == ValType::TensorIndex);
addOutput(out);
addInput(in);
+ name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
}
TensorIndex::TensorIndex(
- Passkey passkey,
+ Passkey,
const fuser::cuda::TensorView* view,
std::vector<Val*> indices)
- : Val(passkey, view->getDataType().value()),
- view_(GpuLower::current()->lowerValue(view)->as<TensorView>()),
+ : Val(ValType::TensorIndex, view->getDataType().value(), true, true),
+ view_(GpuLower::lowerValue(view)->as<TensorView>()),
indices_(indices) {
TORCH_INTERNAL_ASSERT(
std::all_of(
indices.begin(),
indices.end(),
- [](Val* v) { return v->dtype() == DataType::Int; }),
+ [](Val* v) {
+ return (v->getValType() == ValType::KirScalar ||
+ v->getValType() == ValType::KirNamedScalar) &&
+ v->getDataType() == DataType::Int;
+ }),
"Cannot index with a value other than an int.");
- indices_.erase(
- std::remove_if(
- indices_.begin(),
- indices_.end(),
- [](Val* index) { return index->isZeroInt(); }),
- indices_.end());
- // If indices becomes empty, just put one ZeroInt
- if (indices_.empty()) {
- indices_.push_back(kir::IrBuilder(GpuLower::current()->kernel()).zeroVal());
- }
}
-Sync::Sync(Passkey passkey, bool war_sync)
- : Expr(passkey), war_sync_(war_sync) {}
-
-InitMagicZero::InitMagicZero(Passkey passkey) : Expr(passkey) {}
-
-UpdateMagicZero::UpdateMagicZero(Passkey passkey) : Expr(passkey) {}
-
-void Scope::insert(std::vector<Expr*>::const_iterator pos, Expr* expr) {
- exprs_.insert(pos, expr);
- expr->setScope(this);
+Sync::Sync(Passkey, bool war_sync) : Expr(ExprType::Sync), war_sync_(war_sync) {
+ name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
}
void Scope::insert_before(Expr* ref, Expr* expr) {
- const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
- TORCH_INTERNAL_ASSERT(
- it != exprs_.end(),
- "Tried to insert ",
- expr,
- " before the reference: ",
- ref,
- " however the reference was not found in this scope.");
- insert(it, expr);
+ auto it = exprs_.begin();
+ while (it != exprs_.end()) {
+ if ((*it)->sameAs(ref))
+ break;
+ it++;
+ }
+ if (it != exprs_.end())
+ exprs_.insert(it, expr);
}
void Scope::insert_after(Expr* ref, Expr* expr) {
- const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
- TORCH_INTERNAL_ASSERT(
- it != exprs_.end(),
- "Tried to insert ",
- expr,
- " after the reference: ",
- ref,
- " however the reference was not found in this scope.");
- insert(it + 1, expr);
-}
-
-void Scope::insert(size_t pos, Expr* expr) {
- const auto it = exprs_.begin() + pos;
- insert(it, expr);
-}
-
-void Scope::erase(std::vector<Expr*>::const_iterator pos) {
- // Remove the scope of the expr if this is the scope
- auto expr = *pos;
- TORCH_INTERNAL_ASSERT(
- expr->scope() == this,
- "Inconsistent scoping of expression detected: ",
- kir::toString(expr));
- expr->setScope(nullptr);
- exprs_.erase(pos);
+ auto it = exprs_.begin();
+ while (it != exprs_.end()) {
+ if (*it == ref)
+ break;
+ it++;
+ }
+ if (it != exprs_.end())
+ exprs_.insert(++it, expr);
}
void Scope::erase(Expr* ref) {
- const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
- if (it != exprs_.end()) {
- erase(it);
+ auto it = exprs_.begin();
+ while (it != exprs_.end()) {
+ if (*it == ref)
+ break;
+ it++;
}
-}
-
-void Scope::erase(size_t pos) {
- TORCH_INTERNAL_ASSERT(pos < size());
- erase(exprs_.begin() + pos);
+ if (it != exprs_.end())
+ exprs_.erase(it);
}
bool Scope::contains(Expr* expr) const {
- const auto it = std::find(exprs_.begin(), exprs_.end(), expr);
- return it != exprs_.end();
+ for (auto e : exprs_)
+ if (e == expr)
+ return true;
+ return false;
}
void Scope::clear() {
- exprs_.clear();
+ exprs_ = std::vector<Expr*>();
}
ForLoop::ForLoop(
- Passkey passkey,
- IterDomain* iter_domain,
+ Passkey,
Val* index,
- Val* start,
- Val* stop,
- Val* step,
- bool vectorize,
- Val* vectorize_shift)
- : Expr(passkey),
+ IterDomain* iter_domain,
+ Expr* parent_scope)
+ : Expr(ExprType::ForLoop),
+ index_{index},
iter_domain_{iter_domain},
- index_(index),
- start_(start),
- stop_(stop),
- step_(step),
- vectorize_(vectorize),
- vectorize_shift_(vectorize_shift),
- body_(this) {
- TORCH_INTERNAL_ASSERT(index->dtype() == DataType::Int);
+ parent_scope_{parent_scope} {
+ TORCH_INTERNAL_ASSERT(index->isAnInt());
+ TORCH_INTERNAL_ASSERT(isLoweredScalar(index));
addInput(index);
addInput(iter_domain);
- if (start_ == nullptr && iter_domain->isThread()) {
- start_ =
- IrBuilder(GpuLower::current()->kernel())
- .create<kir::NamedScalar>(
- stringifyThread(iter_domain->parallelType()), DataType::Int);
- }
- if (step_ == nullptr) {
- if (iter_domain->isThread()) {
- step_ = IrBuilder(GpuLower::current()->kernel())
- .create<kir::NamedScalar>(
- stringifyThreadSize(iter_domain->parallelType()),
- DataType::Int);
- } else {
- step_ = IrBuilder(GpuLower::current()->kernel()).oneVal();
- }
- }
-}
-
-ForLoop::ForLoop(Passkey passkey, IterDomain* iter_domain)
- : ForLoop(
- passkey,
- iter_domain,
- iter_domain->isBroadcast()
- ? IrBuilder(GpuLower::current()->kernel()).zeroVal()
- : IrBuilder(GpuLower::current()->kernel())
- .create<kir::Int>(c10::nullopt),
- nullptr,
- nullptr,
- nullptr,
- isParallelTypeVectorize(iter_domain->parallelType()),
- nullptr) {}
-
-ForLoop::ForLoop(Passkey passkey, const ForLoop* other)
- : ForLoop(
- passkey,
- other->iter_domain(),
- other->index(),
- other->start(),
- other->stop(),
- other->step(),
- other->vectorize(),
- other->vectorize_shift()) {}
-
-Val* ForLoop::start() const {
- if (start_ != nullptr) {
- return start_;
- } else {
- // clang-tidy complains without this
- TORCH_INTERNAL_ASSERT(iter_domain_ != nullptr);
- return iter_domain_->start();
- }
+ name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
}
-Val* ForLoop::stop() const {
- if (stop_ != nullptr) {
- return stop_;
- } else {
- // clang-tidy complains without this
- TORCH_INTERNAL_ASSERT(iter_domain_ != nullptr);
- return iter_domain_->extent();
- }
+void ForLoop::setParentScope(Expr* scope) {
+ TORCH_INTERNAL_ASSERT(
+ !scope_utils::exprInScope(parentScope(), this),
+ "Cannot change parent scope if not already removed from previous parent.");
+ parent_scope_ = scope;
}
-Val* ForLoop::step() const {
- TORCH_INTERNAL_ASSERT(step_ != nullptr);
- return step_;
+IfThenElse::IfThenElse(Passkey, Bool* cond, Expr* parent_scope)
+ : Expr(ExprType::IfThenElse), cond_{cond}, parent_scope_(parent_scope) {
+ addInput(cond);
+ name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
}
-IfThenElse::IfThenElse(Passkey passkey, Predicate* cond)
- : Expr(passkey), then_body_(this), else_body_(this) {
- setPredicate(cond);
- addInput(cond);
+void IfThenElse::setParentScope(Expr* scope) {
+ TORCH_INTERNAL_ASSERT(
+ !scope_utils::exprInScope(parentScope(), this),
+ "Cannot change parent scope if not already removed from previous parent.");
+ parent_scope_ = scope;
}
Val* TensorIndex::index(int i) const {
nDims() > 0, "Tried to get an index of a 0-dim TensorIndex");
if (i < 0)
i += nDims();
- TORCH_INTERNAL_ASSERT(i >= 0 && i < int(nDims()));
+ assert(i >= 0 && i < nDims());
return indices_[i];
}
Allocate::Allocate(
- Passkey passkey,
+ Passkey,
Val* buffer,
MemoryType memory_type,
- std::vector<Val*> shape,
+ Val* size,
bool zero_init)
- : Expr(passkey),
+ : Expr(ExprType::Allocate),
buffer_(buffer),
memory_type_(memory_type),
- shape_(std::move(shape)),
+ size_(size),
zero_init_(zero_init) {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- if (!shape_.empty()) {
+ if (size_ != nullptr) {
TORCH_INTERNAL_ASSERT(
- (shape_.size() == 1 && shape_[0]->isOneInt()) ||
- buffer_->isA<TensorView>());
+ size_->isOneInt() ||
+ buffer_->getValType().value() == ValType::KirTensorView,
+ "Cannot allocate a non-TensorView buffer with a size != 1, received buffer: ",
+ buffer_);
} else {
- TORCH_INTERNAL_ASSERT(buffer_->isA<TensorView>());
+ TORCH_INTERNAL_ASSERT(
+ buffer_->getValType().value() == ValType::KirTensorView);
TORCH_INTERNAL_ASSERT(
buffer_->as<TensorView>()->memoryType() == memory_type_);
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
const auto domain = buffer_->as<TensorView>()->domain();
- for (auto axis : domain->noReductions()) {
- shape_.push_back(axis->extent());
+ size_ = domain->nDims() == 0 ? ir_builder.create<Int>(1)
+ : domain->axis(0)->extent();
+ for (size_t i = 1; i < domain->nDims(); i++) {
+ size_ = ir_builder.mulExpr(size_, domain->axis(i)->extent());
}
}
- for (auto s : shape_) {
- if (size_ == nullptr) {
- size_ = s;
- } else {
- size_ = ir_builder.mulExpr(size_, s);
+ if (memory_type_ == MemoryType::Local) {
+ if (!size_->isConstScalar()) {
+ TORCH_INTERNAL_ASSERT(
+ false,
+ "Allocations must be based on constant integers for the memory type ",
+ memory_type_,
+ " but tried to alloc ",
+ buffer_,
+ " with symbolic size.");
}
}
- if (size_ == nullptr) {
- size_ = ir_builder.oneVal();
- }
-
addInput(size_);
+ name_ = FusionGuard::getCurFusion()->registerLoweredExpr(this);
}
-Allocate::Allocate(
- Passkey passkey,
- Val* buffer,
- MemoryType memory_type,
- Val* size,
- bool zero_init)
- : Allocate(
- passkey,
- buffer,
- memory_type,
- size == nullptr ? std::vector<Val*>{} : std::vector<Val*>{size},
- zero_init) {}
-
-GridReduction::GridReduction(Passkey passkey, ReductionOp* reduction_op)
- : Expr(passkey), reduction_op_(reduction_op) {
+GridReduction::GridReduction(Passkey, ReductionOp* reduction_op)
+ : Expr(ExprType::GridReduction), reduction_op_(reduction_op) {
TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
}
GridReduction::GridReduction(
- Passkey passkey,
+ Passkey,
ReductionOp* reduction_op,
Allocate* reduction_buffer,
- Allocate* sync_buffer)
- : Expr(passkey),
+ Allocate* sync_buffer,
+ Bool* pred)
+ : Expr(ExprType::GridReduction),
reduction_op_(reduction_op),
reduction_buffer_(reduction_buffer),
- sync_buffer_(sync_buffer) {}
+ sync_buffer_(sync_buffer),
+ pred_(pred) {}
std::string GridReduction::getPredicateFlagName(const TensorView* val) {
std::stringstream ss;
return ss.str();
}
-GridWelford::GridWelford(
- Passkey passkey,
- WelfordOp* welford_op,
- Allocate* var_buffer,
- Allocate* avg_buffer,
- Allocate* n_buffer,
- Allocate* sync_buffer)
- : Expr(passkey),
- welford_op_(welford_op),
- var_buffer_(var_buffer),
- avg_buffer_(avg_buffer),
- n_buffer_(n_buffer),
- sync_buffer_(sync_buffer) {}
-
-std::string GridWelford::getPredicateFlagName(const TensorView* val) {
- std::stringstream ss;
- ss << "T" << val->name() << "_pred";
- return ss.str();
-}
-
-// TODO(kir): remove this
-std::string GridWelford::getPredicateFlagName(
- const fuser::cuda::TensorView* val) {
- std::stringstream ss;
- ss << "T" << val->name() << "_pred";
- return ss.str();
-}
-
} // namespace kir
} // namespace cuda
} // namespace fuser
#pragma once
#include <torch/csrc/jit/codegen/cuda/type.h>
-#include <torch/csrc/jit/codegen/cuda/utils.h>
// TODO(kir): remove these once the Kernel IR is separated from Fusion IR
+#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_internal_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h>
#include <c10/util/Optional.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>
namespace kir {
class IrBuilder;
-class Kernel;
-
-// Abstract nodes
-class Node;
-class Val;
-class Expr;
-
-// Values
-class NamedScalar;
-class Predicate;
-class Bool;
-class Double;
-class Int;
-class IterDomain;
-class TensorDomain;
-class TensorView;
-class TensorIndex;
-
-// Expressions
-class UnaryOp;
-class BinaryOp;
-class TernaryOp;
-class ReductionOp;
-class WelfordOp;
-class BroadcastOp;
-
-// Statements
-class Allocate;
-class Sync;
-class InitMagicZero;
-class UpdateMagicZero;
-class ForLoop;
-class IfThenElse;
-class GridReduction;
-class GridWelford;
-
-// Expr container
-class Scope;
-
-using ValueId = int32_t;
-
-//! Token used to restrict the access to Kernel IR creation
-//!
-//! A token is associated with a kernel, which is passed with the key
-//! (Passkey::kernel)
+
+//! Token used to restrict the access to Kernel IR constructors
//!
-//! It is a "granular friendship" token, used to implement the "passkey" idiom:
+//! Granular "friendship" token, used to implement the "passkey" idiom:
//! https://www.spiria.com/en/blog/desktop-software/passkey-idiom-and-better-friendship-c
//! https://arne-mertz.de/2016/10/passkey-idiom
//!
class Passkey {
friend class IrBuilder;
-
- public:
- Kernel* const kernel = nullptr;
-
- private:
- explicit Passkey(Kernel* kernel) : kernel(kernel) {}
-};
-
-//! Kernel IR visitor interface
-class TORCH_CUDA_CU_API IrVisitor : public PolymorphicBase {
- public:
- // TODO(kir): use Node* instead of void*
- virtual void unhandled(const void* node) {}
-
- // Values
- virtual void visit(const NamedScalar* named_scalar) {
- unhandled(named_scalar);
- }
- virtual void visit(const Predicate* value) {
- unhandled(value);
- }
- virtual void visit(const Bool* value) {
- unhandled(value);
- }
- virtual void visit(const Double* value) {
- unhandled(value);
- }
- virtual void visit(const Int* value) {
- unhandled(value);
- }
- virtual void visit(const IterDomain* iter_domain) {
- unhandled(iter_domain);
- }
- virtual void visit(const TensorDomain* tensor_domain) {
- unhandled(tensor_domain);
- }
- virtual void visit(const TensorView* tensor_view) {
- unhandled(tensor_view);
- }
- virtual void visit(const TensorIndex* tensor_index) {
- unhandled(tensor_index);
- }
-
- // Expressions
- virtual void visit(const UnaryOp* node) {
- unhandled(node);
- }
- virtual void visit(const BinaryOp* node) {
- unhandled(node);
- }
- virtual void visit(const TernaryOp* node) {
- unhandled(node);
- }
- virtual void visit(const ReductionOp* node) {
- unhandled(node);
- }
- virtual void visit(const WelfordOp* node) {
- unhandled(node);
- }
- virtual void visit(const BroadcastOp* node) {
- unhandled(node);
- }
-
- // Statements
- virtual void visit(const Allocate* node) {
- unhandled(node);
- }
- virtual void visit(const Sync* node) {
- unhandled(node);
- }
- virtual void visit(const InitMagicZero* node) {
- unhandled(node);
- }
- virtual void visit(const UpdateMagicZero* node) {
- unhandled(node);
- }
- virtual void visit(const ForLoop* node) {
- unhandled(node);
- }
- virtual void visit(const IfThenElse* node) {
- unhandled(node);
- }
- virtual void visit(const GridReduction* node) {
- unhandled(node);
- }
- virtual void visit(const GridWelford* node) {
- unhandled(node);
- }
-};
-
-//! Kernel IR visitor interface
-class TORCH_CUDA_CU_API MutableIrVisitor : public PolymorphicBase {
- public:
- // TODO(kir): use Node* instead of void*
- virtual void unhandled(const void*) {}
-
- // Values
- virtual void visit(NamedScalar* named_scalar) {
- unhandled(named_scalar);
- }
- virtual void visit(Predicate* value) {
- unhandled(value);
- }
- virtual void visit(Bool* value) {
- unhandled(value);
- }
- virtual void visit(Double* value) {
- unhandled(value);
- }
- virtual void visit(Int* value) {
- unhandled(value);
- }
- virtual void visit(IterDomain* iter_domain) {
- unhandled(iter_domain);
- }
- virtual void visit(TensorDomain* tensor_domain) {
- unhandled(tensor_domain);
- }
- virtual void visit(TensorView* tensor_view) {
- unhandled(tensor_view);
- }
- virtual void visit(TensorIndex* tensor_index) {
- unhandled(tensor_index);
- }
-
- // Expressions
- virtual void visit(UnaryOp* node) {
- unhandled(node);
- }
- virtual void visit(BinaryOp* node) {
- unhandled(node);
- }
- virtual void visit(TernaryOp* node) {
- unhandled(node);
- }
- virtual void visit(ReductionOp* node) {
- unhandled(node);
- }
- virtual void visit(BroadcastOp* node) {
- unhandled(node);
- }
-
- virtual void visit(WelfordOp* node) {
- unhandled(node);
- }
-
- // Statements
- virtual void visit(Allocate* node) {
- unhandled(node);
- }
- virtual void visit(Sync* node) {
- unhandled(node);
- }
- virtual void visit(InitMagicZero* node) {
- unhandled(node);
- }
- virtual void visit(UpdateMagicZero* node) {
- unhandled(node);
- }
- virtual void visit(ForLoop* node) {
- unhandled(node);
- }
- virtual void visit(IfThenElse* node) {
- unhandled(node);
- }
- virtual void visit(GridReduction* node) {
- unhandled(node);
- }
-
- virtual void visit(GridWelford* node) {
- unhandled(node);
- }
-};
-
-//! Base class for Kernel IR nodes
-class TORCH_CUDA_CU_API Node : public NonCopyable, public PolymorphicBase {
- public:
- explicit Node(Passkey) {}
-
- //! IR Visitor double-dispatch interface
- //! (https://en.wikipedia.org/wiki/Visitor_pattern)
- virtual void accept(IrVisitor* visitor) const = 0;
-
- //! Non constant IR Visitor
- virtual void accept(MutableIrVisitor* visitor) = 0;
-
- //! Debug helper, prints the textual representation of an IR node
- void print() const;
-};
-
-//! Generic value (scalar or tensor)
-class TORCH_CUDA_CU_API Val : public Node {
- public:
- Val(Passkey passkey, DataType dtype);
-
- // TODO(kir): consider renaming
- StmtNameType name() const {
- return name_;
- }
-
- void setName(StmtNameType name) {
- name_ = name;
- }
-
- ValueId id() const {
- return id_;
- }
-
- DataType dtype() const {
- return dtype_;
- }
-
- Expr* definition() const {
- return definition_;
- }
-
- void setDefinition(Expr* expr) {
- // TODO(kir): extra checks on changing existing definitions?
- definition_ = expr;
- }
-
- virtual bool isScalar() const {
- return false;
- }
-
- bool isConstScalar() const;
-
- virtual bool isConst() const {
- return false;
- }
-
- // TODO(kir): revisit and find a better interface
- virtual bool isZeroInt() const {
- return false;
- }
-
- virtual bool isOneInt() const {
- return false;
- }
-
- private:
- const DataType dtype_;
-
- // The expression which defines this value, or nullptr
- Expr* definition_ = nullptr;
-
- // This is a value name preserved from the Fusion IR (optional)
- StmtNameType name_ = kInvalidStmName;
-
- // All Kernel IR values have IDs (unique within the same Kernel)
- ValueId id_ = -1;
-};
-
-//! Base class for expressions and statements
-//!
-//! Expressions consume inputs and produce outputs (depending on the context
-//! this may imply assignments). Currently some of the expressions
-//! don't actually produce any outputs (ForLoop, IfThenElse) and they
-//! model statements to be executed.
-//!
-//! TODO(kir): split the expressions, assignments and statements?
-//!
-class TORCH_CUDA_CU_API Expr : public Node {
- public:
- explicit Expr(Passkey passkey) : Node(passkey) {}
-
- const auto& inputs() const {
- return inputs_;
- }
-
- const auto& outputs() const {
- return outputs_;
- }
-
- Scope* scope() const {
- return scope_;
- }
-
- //! Set the current scope
- void setScope(Scope* scope) {
- scope_ = scope;
- }
-
- Expr* parentScope() const;
-
- Predicate* predicate() const {
- return predicate_;
- }
-
- void setPredicate(Predicate* predicate) {
- predicate_ = predicate;
- }
-
- Predicate* writePredicate() const {
- return write_predicate_;
- }
-
- void setWritePredicate(Predicate* write_predicate) {
- write_predicate_ = write_predicate;
- }
-
- protected:
- // TODO(kir): try to avoid this protected interface
- void addInput(Val* input) {
- inputs_.push_back(input);
- }
-
- void addOutput(Val* output) {
- output->setDefinition(this);
- outputs_.push_back(output);
- }
-
- private:
- // TODO(kir): can we avoid this?
- std::vector<Val*> inputs_;
- std::vector<Val*> outputs_;
-
- // TODO(kir): revisit scope/nesting data structures
- Scope* scope_ = nullptr;
-
- Predicate* predicate_ = nullptr;
- // Only used for reduction-related expressions
- Predicate* write_predicate_ = nullptr;
+ Passkey() = default;
};
-class TORCH_CUDA_CU_API NamedScalar final : public Val {
+class TORCH_CUDA_CU_API NamedScalar : public Val {
public:
- // NOLINTNEXTLINE(modernize-pass-by-value)
- NamedScalar(Passkey passkey, std::string name, DataType dtype)
- : Val(passkey, dtype), name_(name) {}
+ NamedScalar(Passkey, std::string name, DataType dtype)
+ : Val(ValType::KirNamedScalar, dtype, true, true),
+ name_(std::move(name)) {}
- explicit NamedScalar(Passkey passkey, const fuser::cuda::NamedScalar* node)
- : Val(passkey, node->getDataType().value()) {
- name_ = node->name();
- }
-
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
+ explicit NamedScalar(Passkey, const fuser::cuda::NamedScalar* node)
+ : Val(node), name_(node->name()) {}
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
-
- bool isScalar() const override {
- return true;
- }
-
- // TODO(kir): this is hiding and redefining Val::name()
const std::string& name() const {
return name_;
}
std::string name_;
};
-class TORCH_CUDA_CU_API Predicate final : public Val {
+class TORCH_CUDA_CU_API Bool : public Val {
public:
- explicit Predicate(
- Passkey passkey,
- PredicateType ptype,
- const Expr* expr = nullptr,
- Bool* thread_pred = nullptr)
- : Val(passkey, DataType::Bool),
- ptype_(ptype),
- expr_(expr),
- thread_pred_(thread_pred) {
- TORCH_INTERNAL_ASSERT(
- ptype != PredicateType::Unswitch && ptype != PredicateType::Manual);
- }
-
- explicit Predicate(Passkey passkey, ForLoop* unrolled_loop)
- : Val(passkey, DataType::Bool),
- ptype_(PredicateType::Unswitch),
- unrolled_loop_(unrolled_loop) {
- TORCH_INTERNAL_ASSERT(unrolled_loop != nullptr);
- }
-
- explicit Predicate(Passkey passkey, Bool* value)
- : Val(passkey, DataType::Bool),
- ptype_(PredicateType::Manual),
- value_(value) {
- TORCH_INTERNAL_ASSERT(value != nullptr);
- }
+ explicit Bool(Passkey, const c10::optional<bool>& value)
+ : Val(ValType::KirScalar, DataType::Bool, true, true),
+ maybe_value_(value) {}
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
+ explicit Bool(Passkey, const fuser::cuda::Bool* node)
+ : Val(node), maybe_value_(node->value()) {}
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
+ bool isSymbolic() const {
+ return !(maybe_value_.has_value());
}
-
- PredicateType predicate_type() const {
- return ptype_;
- }
-
- const Expr* expr() const {
- TORCH_INTERNAL_ASSERT(
- ptype_ != PredicateType::Unswitch &&
- ptype_ != PredicateType::Vectorize && ptype_ != PredicateType::Manual);
- return expr_;
- }
-
- Bool* thread_pred() {
- TORCH_INTERNAL_ASSERT(
- ptype_ == PredicateType::Inline ||
- ptype_ == PredicateType::Misaligned || ptype_ == PredicateType::Shift ||
- ptype_ == PredicateType::Padding ||
- ptype_ == PredicateType::ReductionWrite);
- return thread_pred_;
- }
-
- ForLoop* unrolled_loop() const {
- TORCH_INTERNAL_ASSERT(ptype_ == PredicateType::Unswitch);
- return unrolled_loop_;
- }
-
- bool hasValue() const {
- return value_ != nullptr;
- }
-
- Bool* value() const {
- TORCH_INTERNAL_ASSERT(
- value_ != nullptr,
- "The conditional expression for this Predicate is invalid.");
- return value_;
+ bool isConst() const {
+ return maybe_value_.has_value();
}
-
- void setValue(Bool* value) {
- TORCH_INTERNAL_ASSERT(value != nullptr, "The Bool expression is invalid.");
- value_ = value;
+ c10::optional<bool> value() const {
+ return maybe_value_;
}
private:
- PredicateType ptype_ = PredicateType::Manual;
-
- // For PredicateCompute::getInlinePredicate,
- // ShiftPredicateInserter::getShiftPredicate and getPaddingPredicate
- const Expr* expr_ = nullptr;
-
- // For PredicateCompute::getInlinePredicate
- Bool* thread_pred_ = nullptr;
-
- // For ParallelType::Unswitch - UnswitchPredicate::get
- ForLoop* unrolled_loop_ = nullptr;
-
- // The Bool conditional value
- // The value is nullptr until lower_predicate pass
- Bool* value_ = nullptr;
+ const c10::optional<bool> maybe_value_;
};
-class TORCH_CUDA_CU_API Bool final : public Val {
+class TORCH_CUDA_CU_API Float : public Val {
public:
- explicit Bool(Passkey passkey, const c10::optional<bool>& value)
- : Val(passkey, DataType::Bool), maybe_value_(value) {}
-
- explicit Bool(Passkey passkey, const fuser::cuda::Bool* node)
- : Val(passkey, DataType::Bool), maybe_value_(node->value()) {
- setName(node->name());
- }
+ using ScalarType = double;
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
+ explicit Float(Passkey, const c10::optional<ScalarType>& value)
+ : Val(ValType::KirScalar, DataType::Float, true, true),
+ maybe_value_(value) {}
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
+ explicit Float(Passkey, const fuser::cuda::Float* node)
+ : Val(node), maybe_value_(node->value()) {}
- bool isScalar() const override {
- return true;
+ bool isSymbolic() const {
+ return !(maybe_value_.has_value());
}
-
- bool isConst() const override {
+ bool isConst() const {
return maybe_value_.has_value();
}
-
- c10::optional<bool> value() const {
+ c10::optional<ScalarType> value() const {
return maybe_value_;
}
private:
- const c10::optional<bool> maybe_value_;
+ const c10::optional<ScalarType> maybe_value_;
};
-class TORCH_CUDA_CU_API Double final : public Val {
+class TORCH_CUDA_CU_API Half : public Val {
public:
- using ScalarType = double;
+ explicit Half(Passkey, const c10::optional<float>& value)
+ : Val(ValType::KirScalar, DataType::Half, true, true),
+ maybe_value_(value) {}
- explicit Double(Passkey passkey, const c10::optional<ScalarType>& value)
- : Val(passkey, DataType::Double), maybe_value_(value) {}
+ explicit Half(Passkey, const fuser::cuda::Half* node)
+ : Val(node), maybe_value_(node->value()) {}
- explicit Double(Passkey passkey, const fuser::cuda::Double* node)
- : Val(passkey, DataType::Double), maybe_value_(node->value()) {
- setName(node->name());
+ bool isSymbolic() const {
+ return !(maybe_value_.has_value());
}
-
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
-
- bool isScalar() const override {
- return true;
- }
-
- bool isConst() const override {
+ bool isConst() const {
return maybe_value_.has_value();
}
-
- c10::optional<ScalarType> value() const {
+ c10::optional<float> value() const {
return maybe_value_;
}
private:
- const c10::optional<ScalarType> maybe_value_;
+ const c10::optional<float> maybe_value_;
};
-class TORCH_CUDA_CU_API Int final : public Val {
+class TORCH_CUDA_CU_API Int : public Val {
public:
using ScalarType = int64_t;
- explicit Int(Passkey passkey, const c10::optional<ScalarType>& value)
- : Val(passkey, DataType::Int), maybe_value_(value) {}
+ explicit Int(Passkey, const c10::optional<ScalarType>& value)
+ : Val(ValType::KirScalar, DataType::Int, true, true),
+ maybe_value_(value) {}
- // SFINAE constructor to avoid 0 constant pointer ambiguity
- template <
- typename T,
- typename = typename std::enable_if<
- std::is_pointer<T>::value &&
- std::is_convertible<T, const fuser::cuda::Int*>::value>::type>
- explicit Int(Passkey passkey, T node)
- : Val(passkey, DataType::Int), maybe_value_(node->value()) {
- setName(node->name());
- }
-
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
+ explicit Int(
+ Passkey,
+ const fuser::cuda::Int* node,
+ bool /*avoid_zero_ambiguity*/)
+ : Val(node), maybe_value_(node->value()) {}
- bool isScalar() const override {
- return true;
+ bool isSymbolic() const {
+ return !(maybe_value_.has_value());
}
-
- bool isConst() const override {
+ bool isConst() const {
return maybe_value_.has_value();
}
-
- bool isZeroInt() const override {
- return maybe_value_.has_value() && *maybe_value_ == 0;
- }
-
- bool isOneInt() const override {
- return maybe_value_.has_value() && *maybe_value_ == 1;
- }
-
c10::optional<ScalarType> value() const {
return maybe_value_;
}
const c10::optional<ScalarType> maybe_value_;
};
-class TORCH_CUDA_CU_API IterDomain final : public Val {
+class TORCH_CUDA_CU_API IterDomain : public Val {
public:
- IterDomain(Passkey passkey, Val* start, Val* extent);
+ IterDomain(Passkey, Val* start, Val* extent);
explicit IterDomain(Passkey, const fuser::cuda::IterDomain* iter_domain);
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
-
bool isReduction() const {
- return iterType() == IterType::Reduction;
+ return getIterType() == IterType::Reduction;
}
bool isRFactorProduct() const {
}
bool isBroadcast() const {
- return iterType() == IterType::BroadcastWithStride ||
- iterType() == IterType::BroadcastWithoutStride;
- }
-
- bool isGather() const {
- return iterType() == IterType::Gather;
+ return getIterType() == IterType::BroadcastWithStride ||
+ getIterType() == IterType::BroadcastWithoutStride;
}
bool isParallelized() const {
- return parallelType() != ParallelType::Serial;
+ return getParallelType() != ParallelType::Serial;
}
// Return if this iter domain is mapped to a grid dimension
bool isBlockDim() const {
- return parallelType() == ParallelType::BIDz ||
- parallelType() == ParallelType::BIDy ||
- parallelType() == ParallelType::BIDx;
+ return (
+ getParallelType() == ParallelType::BIDz ||
+ getParallelType() == ParallelType::BIDy ||
+ getParallelType() == ParallelType::BIDx);
}
// Return if this iter domain is mapped to a block dimension
bool isThreadDim() const {
- return parallelType() == ParallelType::TIDz ||
- parallelType() == ParallelType::TIDy ||
- parallelType() == ParallelType::TIDx;
+ return (
+ getParallelType() == ParallelType::TIDz ||
+ getParallelType() == ParallelType::TIDy ||
+ getParallelType() == ParallelType::TIDx);
}
// Return if this iter domain is either mapped to a block or grid dimension
bool isThread() const {
- return isBlockDim() || isThreadDim();
+ return (isBlockDim() || isThreadDim());
}
- ParallelType parallelType() const {
+ ParallelType getParallelType() const {
return parallel_type_;
}
- IterType iterType() const {
+ IterType getIterType() const {
return iter_type_;
}
Val* extent() const;
- bool isSimple() const {
- return is_simple_;
+ Val* rawExtent() const {
+ return extent_;
}
private:
ParallelType parallel_type_ = ParallelType::Serial;
IterType iter_type_ = IterType::Iteration;
bool is_rfactor_domain_ = false;
-
- // An IterDomain is "simple" if the original Fusion IterDomain
- // doesn't have a definition ("definition" expression)
- //
- // TODO(kir): this feels like a hack, revisit
- //
- bool is_simple_ = true;
};
-// TODO(kir): is this really a value?
-class TORCH_CUDA_CU_API TensorDomain final : public Val {
+class TORCH_CUDA_CU_API TensorDomain : public Val {
public:
explicit TensorDomain(Passkey, std::vector<IterDomain*> domain);
explicit TensorDomain(
- Passkey passkey,
+ Passkey,
const fuser::cuda::TensorDomain* tensor_domain);
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
-
std::vector<IterDomain*>::size_type nDims() const {
return domain_.size();
}
- // TODO(kir): rename this
const std::vector<IterDomain*>& domain() const {
return domain_;
}
bool hasBlockBroadcast() const;
bool hasBroadcast() const;
bool hasRFactor() const;
- bool hasVectorize() const;
const std::vector<IterDomain*>& noReductions() const {
return no_reduction_domain_;
const std::vector<bool> contiguity_;
};
-class TORCH_CUDA_CU_API TensorView final : public Val {
+class TORCH_CUDA_CU_API TensorView : public Val {
public:
explicit TensorView(Passkey, const fuser::cuda::TensorView* tv);
- TensorView(
- Passkey,
- DataType dtype,
- TensorDomain* domain,
- MemoryType memory_type);
-
TensorDomain* domain() const {
return domain_;
}
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
-
MemoryType memoryType() const {
return memory_type_;
}
- fuser::cuda::TensorView* fuserTv() const {
+ const fuser::cuda::TensorView* fuserTv() const {
TORCH_INTERNAL_ASSERT(fuser_tv_ != nullptr);
- // TODO(kir): remove the need for const_cast
- return const_cast<fuser::cuda::TensorView*>(fuser_tv_); // NOLINT
+ return fuser_tv_;
}
private:
const fuser::cuda::TensorView* fuser_tv_ = nullptr;
};
-class TORCH_CUDA_CU_API UnaryOp final : public Expr {
+class TORCH_CUDA_CU_API UnaryOp : public Expr {
public:
- UnaryOp(Passkey passkey, UnaryOpType operation, Val* out, Val* in);
-
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
+ UnaryOp(Passkey, UnaryOpType type, Val* out, Val* in);
Val* out() const {
return out_;
return in_;
}
- UnaryOpType operation() const {
- return operation_;
+ UnaryOpType getUnaryOpType() const {
+ return unary_op_type_;
}
private:
- const UnaryOpType operation_;
+ const UnaryOpType unary_op_type_;
Val* const out_ = nullptr;
Val* const in_ = nullptr;
};
-class TORCH_CUDA_CU_API BinaryOp final : public Expr {
+class TORCH_CUDA_CU_API BinaryOp : public Expr {
public:
- BinaryOp(
- Passkey passkey,
- BinaryOpType operation,
- Val* out,
- Val* lhs,
- Val* rhs);
-
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
+ BinaryOp(Passkey, BinaryOpType type, Val* out, Val* lhs, Val* rhs);
Val* out() const {
return out_;
return rhs_;
}
- BinaryOpType operation() const {
- return operation_;
+ BinaryOpType getBinaryOpType() const {
+ return binary_op_type_;
}
private:
- const BinaryOpType operation_;
+ const BinaryOpType binary_op_type_;
Val* const out_ = nullptr;
Val* const lhs_ = nullptr;
Val* const rhs_ = nullptr;
};
-class TORCH_CUDA_CU_API TernaryOp final : public Expr {
+class TORCH_CUDA_CU_API TernaryOp : public Expr {
public:
TernaryOp(
- Passkey passkey,
- TernaryOpType operation,
+ Passkey,
+ TernaryOpType type,
Val* out,
Val* in1,
Val* in2,
Val* in3);
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
-
Val* out() const {
return out_;
}
return in3_;
}
- TernaryOpType operation() const {
- return operation_;
+ TernaryOpType getTernaryOpType() const {
+ return ternary_op_type_;
}
private:
- const TernaryOpType operation_;
+ const TernaryOpType ternary_op_type_;
Val* const out_ = nullptr;
Val* const in1_ = nullptr;
Val* const in2_ = nullptr;
Val* const in3_ = nullptr;
};
-class TORCH_CUDA_CU_API ReductionOp final : public Expr {
+class TORCH_CUDA_CU_API ReductionOp : public Expr {
public:
ReductionOp(
- Passkey passkey,
- BinaryOpType operation,
+ Passkey,
+ BinaryOpType reduction_op_type,
Val* init,
Val* out,
- Val* in);
-
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
+ Val* in,
+ Bool* pred = nullptr);
Val* out() const {
return out_;
return init_;
}
- BinaryOpType operation() const {
- return operation_;
+ Bool* pred() const {
+ return pred_;
+ }
+
+ BinaryOpType getReductionOpType() const {
+ return reduction_op_type_;
}
std::unordered_map<ParallelType, IterDomain*, TypeHash>
std::vector<IterDomain*> getReductionDomains() const;
private:
- const BinaryOpType operation_;
+ const BinaryOpType reduction_op_type_;
Val* const init_ = nullptr;
Val* const out_ = nullptr;
Val* const in_ = nullptr;
+ Bool* const pred_ = nullptr;
};
-class TORCH_CUDA_CU_API WelfordOp final : public Expr {
- public:
- WelfordOp(
- Passkey passkey,
- Val* out_var,
- Val* out_avg,
- Val* out_N,
- Val* init_var,
- Val* init_avg,
- Val* init_N,
- Val* in_var,
- Val* in_avg,
- Val* in_N);
-
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
-
- Val* out() const {
- return out_avg_;
- }
-
- Val* in() const {
- return in_avg_;
- }
-
- // Welford Specific accessors
- // Almost wanted to add a new struct for {var, avg, N}
- Val* outVar() const {
- return out_var_;
- }
-
- Val* outAvg() const {
- return out_avg_;
- }
-
- Val* outN() const {
- return out_N_;
- }
-
- Val* initVar() const {
- return init_var_;
- }
-
- Val* initAvg() const {
- return init_avg_;
- }
-
- Val* initN() const {
- return init_N_;
- }
-
- Val* inVar() const {
- return in_var_;
- }
-
- Val* inAvg() const {
- return in_avg_;
- }
-
- Val* inN() const {
- return in_N_;
- }
-
- std::unordered_map<ParallelType, IterDomain*, TypeHash>
- getParallelReductionDomains() const;
-
- private:
- std::vector<IterDomain*> getReductionDomains() const;
-
- private:
- Val* const out_var_;
- Val* const out_avg_;
- Val* const out_N_;
- Val* const init_var_;
- Val* const init_avg_;
- Val* const init_N_;
- Val* const in_var_;
- Val* const in_avg_;
- Val* const in_N_;
-};
-
-class TORCH_CUDA_CU_API TensorIndex final : public Val {
+class TORCH_CUDA_CU_API TensorIndex : public Val {
public:
TensorIndex(
Passkey,
const fuser::cuda::TensorView* view,
std::vector<Val*> indices);
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
-
std::vector<Val*>::size_type nDims() const {
return indices_.size();
}
return indices_;
}
- TensorView* view() const {
- TORCH_INTERNAL_ASSERT(view_ != nullptr);
- // TODO(kir): remove the need for const_cast
- return const_cast<fuser::cuda::kir::TensorView*>(view_); // NOLINT
+ const TensorView* view() const {
+ return view_;
}
private:
std::vector<Val*> indices_;
};
-class TORCH_CUDA_CU_API BroadcastOp final : public Expr {
+class TORCH_CUDA_CU_API BroadcastOp : public Expr {
public:
- BroadcastOp(Passkey passkey, Val* out, Val* in);
-
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
+ BroadcastOp(Passkey, Val* out, Val* in);
Val* out() const {
return out_;
Val* const in_ = nullptr;
};
-//! Allocate is a lower level Node that describes a buffer of memory that
-//! is required as an intermediate within a kernel. The extent is the expression
-//! of the size of the buffer that is generated from the TensorView that
-//! describes the output of an operation.
-//!
-//! TODO(kir): The components of Allocate like Type and Name could be separated
-//! from the the assocated TensorView. Perhaps that is more appropriate?
-//!
-class TORCH_CUDA_CU_API Allocate final : public Expr {
+// Allocate is a lower level Node that describes a buffer of memory that
+// is required as an intermediate within a kernel. The extent is the expression
+// of the size of the buffer that is generated from the TensorView that
+// describes the output of an operation.
+//
+// TODO: The components of Allocate like Type and Name could be separated from
+// the the assocated TensorView. Perhaps that is more appropriate?
+class TORCH_CUDA_CU_API Allocate : public Expr {
public:
- //! Allocation of a multi-dimensional buffer
- //!
- //! param shape Size of each dimension
explicit Allocate(
- Passkey passkey,
- Val* buffer,
- MemoryType memory_type,
- std::vector<Val*> shape = {},
- bool zero_init = false);
-
- //! Allocation of a non-dimensional buffer
- //!
- //! param size Size of allocation
- explicit Allocate(
- Passkey passkey,
+ Passkey,
Val* buffer,
- MemoryType memory_type,
- Val* size,
+ MemoryType memory_type = MemoryType::Local,
+ Val* size = nullptr,
bool zero_init = false);
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
-
Val* buffer() const {
return buffer_;
}
- MemoryType memoryType() const {
+ MemoryType getMemoryType() const {
return memory_type_;
}
return size_;
}
- const std::vector<Val*>& shape() const {
- return shape_;
- }
-
bool zeroInit() const {
return zero_init_;
}
- const Allocate* alias() const {
+ DataType buffer_type() const {
+ return buffer_->getDataType().value();
+ }
+
+ Allocate* alias() const {
return alias_;
}
- void setAlias(const Allocate* alias) {
- TORCH_INTERNAL_ASSERT(alias != this);
- TORCH_INTERNAL_ASSERT(alias->memoryType() == memory_type_);
+ void setAlias(Allocate* alias) {
+ TORCH_INTERNAL_ASSERT(alias->getMemoryType() == memory_type_);
alias_ = alias;
}
private:
Val* buffer_ = nullptr;
MemoryType memory_type_ = MemoryType::Local;
- //! Size of each dimension
- std::vector<Val*> shape_;
- bool zero_init_ = false;
- //! Total size
Val* size_ = nullptr;
+ bool zero_init_ = false;
// This alias tracks the next Allocate node in a linked chain of aliases
// If the alias is nullptr, then the Allocate node uses memory in the kernel
- const Allocate* alias_ = nullptr;
+ Allocate* alias_ = nullptr;
};
// Sync represents __syncthreads barrier for block level coordination.
-//
-// TODO(kir): change name to SyncThreads as we could have other barriers.
-//
-class TORCH_CUDA_CU_API Sync final : public Expr {
+class TORCH_CUDA_CU_API Sync : public Expr {
public:
- explicit Sync(Passkey passkey, bool war_sync = false);
-
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
+ explicit Sync(Passkey, bool war_sync = false);
bool isWarHazardSync() const {
return war_sync_;
bool war_sync_ = false;
};
-// Simply prints "DEFINE_MAGIC_ZERO" in the code in accordance with magic_zero
-// in helpers.cu
-class TORCH_CUDA_CU_API InitMagicZero final : public Expr {
+// TODO(kir): promote to IR node
+// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+class TORCH_CUDA_CU_API Scope {
public:
- explicit InitMagicZero(Passkey passkey);
+ Scope() = default;
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
+ const std::vector<Expr*>& exprs() const {
+ return exprs_;
}
-};
-
-// Simply prints "UPDATE_MAGIC_ZERO" in the code in accordance with magic_zero
-// in helpers.cu
-class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr {
- public:
- explicit UpdateMagicZero(Passkey passkey);
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
+ void push_back(Expr* e) {
+ exprs_.push_back(e);
}
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
+ void insert(size_t pos, Expr* expr) {
+ exprs_.insert(exprs_.begin() + pos, expr);
}
-};
-// TODO(kir): promote to IR node
-class TORCH_CUDA_CU_API Scope {
- public:
- explicit Scope(Expr* owner) : owner_(owner) {}
-
- const std::vector<Expr*>& exprs() const {
- return exprs_;
+ void erase(size_t pos) {
+ exprs_.erase(exprs_.begin() + pos);
}
bool empty() const {
return exprs_[i];
}
- // Insert expr before expression at pos
- void insert(size_t pos, Expr* expr);
-
// Insert expr before ref
void insert_before(Expr* ref, Expr* expr);
// Insert expr after ref
void insert_after(Expr* ref, Expr* expr);
- void push_back(Expr* e) {
- exprs_.push_back(e);
- e->setScope(this);
- }
-
- // Erase expr at pos
- void erase(size_t pos);
+ bool contains(Expr* expr) const;
- // Erase expr ref
void erase(Expr* ref);
- bool contains(Expr* expr) const;
-
void clear();
- Expr* owner() const {
- return owner_;
- }
-
- private:
- // Insert expr before pos
- void insert(std::vector<Expr*>::const_iterator pos, Expr* expr);
-
- // Erase expr at pos
- void erase(std::vector<Expr*>::const_iterator pos);
-
private:
std::vector<Expr*> exprs_;
-
- //! Owner exprssion of this scope, e.g., IfThenElse
- Expr* owner_ = nullptr;
};
-//! ForLoop provides scoping around an int iterator from 0 to range. Exprs
-//! placed in its body are considered inside the scope of the for loop. In the
-//! future the implementation should look quite different so that we can do
-//! proper dependency annalysis like in Fusion.
-//!
-//! TODO(kir): this is not a real expression
-//!
-//! ForLoop may represent a part of an iteration domain representend
-//! by iter_domain_. In that case, the loop extent field, extent_, may
-//! be smaller than the extent of iter_domain_.
-class TORCH_CUDA_CU_API ForLoop final : public Expr {
+// ForLoop provides scoping around an int iterator from 0 to range. Exprs placed
+// in its body are considered inside the scope of the for loop. In the future
+// the implementation should look quite different so that we can do proper
+// dependency annalysis like in Fusion.
+//
+// TODO(kir): this is not a real expression
+//
+class TORCH_CUDA_CU_API ForLoop : public Expr {
public:
- //! By default, start and stop are the same as those of iter_domain.
- //! Step is one by default.
- //!
- //! TODO: cleaner way to set options?
- ForLoop(
- Passkey passkey,
- IterDomain* iter_domain,
- Val* index,
- Val* start,
- Val* stop,
- Val* step,
- bool vectorize,
- Val* vectorize_shift);
-
- ForLoop(Passkey passkey, IterDomain* iter_domain);
-
- ForLoop(Passkey passkey, const ForLoop* other);
-
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
+ ForLoop(Passkey, Val* index, IterDomain* iter_domain, Expr* parent_scope);
Val* index() const {
return index_;
}
- Val* start() const;
-
- Val* stop() const;
-
- Val* step() const;
-
- Val* vectorize_shift() const {
- return vectorize_shift_;
- }
-
IterDomain* iter_domain() const {
return iter_domain_;
}
return body_;
}
- bool vectorize() const {
- return vectorize_;
+ Expr* parentScope() const {
+ return parent_scope_;
}
- // Returns if a loop could be unrolled. Start and stop must be constant, it
- // must not be a broadcast dimension, cannot be bound to a parallel dimension,
- // and returns false if start is 0 and stop is 1.
- bool isUnrollable() const {
- return start()->isConstScalar() && stop()->isConstScalar() &&
- !iter_domain()->isThread() && !iter_domain()->isBroadcast() &&
- !(start()->isZeroInt() && stop()->isOneInt()) &&
- iter_domain()->parallelType() != ParallelType::Vectorize;
- }
+ void setParentScope(Expr* scope);
private:
- IterDomain* const iter_domain_ = nullptr;
-
- Val* index_ = nullptr;
- Val* start_ = nullptr;
- Val* stop_ = nullptr;
- Val* step_ = nullptr;
-
- // vectorize is true when the for-loop contains a vectorize set
- // the flag is used to omit the for-loop from the kernel
- bool vectorize_ = false;
- // [pre | vectorize | post] <= inner-most, merged root domain
- // shift_ is applied to vectorize and post sections.
- Val* vectorize_shift_ = nullptr;
-
+ Val* const index_ = nullptr;
+ IterDomain* const iter_domain_;
Scope body_;
+ Expr* parent_scope_ = nullptr;
};
-//! IfThenElse provides scoping for an boolean operator. Exprs placed in its
-//! body are considered inside the scope of the if statement. In the future the
-//! implementation should look quite different so that we can do proper
-//! dependency annalysis like in Fusion.
-//!
-//! TODO(kir): this is not a real expression
-//!
-class TORCH_CUDA_CU_API IfThenElse final : public Expr {
+// IfThenElse provides scoping for an boolean operator. Exprs placed in its body
+// are considered inside the scope of the if statement. In the future the
+// implementation should look quite different so that we can do proper
+// dependency annalysis like in Fusion.
+//
+// TODO(kir): this is not a real expression
+//
+class TORCH_CUDA_CU_API IfThenElse : public Expr {
public:
- explicit IfThenElse(Passkey passkey, Predicate* cond);
+ explicit IfThenElse(Passkey, Bool* cond, Expr* parent_scope);
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
+ Bool* cond() const {
+ return cond_;
}
Scope& thenBody() {
return !else_body_.empty();
}
+ Expr* parentScope() const {
+ return parent_scope_;
+ }
+
+ void setParentScope(Expr* scope);
+
private:
+ Bool* const cond_ = nullptr;
Scope then_body_;
Scope else_body_;
+ Expr* parent_scope_ = nullptr;
};
-//! Grid reduction operation
-//!
-//! This node is used only after lowering a fusion to explicitly mark a grid
-//! reduction and the buffer allocation needed to do it.
-//!
-//! This node provides FusionExecutor the information it needs to allocate the
-//! reduction and sync buffers.
-class TORCH_CUDA_CU_API GridReduction final : public Expr {
+// Grid reduction operation, this node is used only after lowering a fusion to
+// explicitly mark a grid reduction and the buffer allocation needed to do it.
+// This node provides FusionExecutor the information it needs to allocate the
+// reduction and sync buffers.
+class TORCH_CUDA_CU_API GridReduction : public Expr {
public:
- explicit GridReduction(Passkey passkey, ReductionOp* reduction_op);
-
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
+ explicit GridReduction(Passkey, ReductionOp* reduction_op);
GridReduction(
- Passkey passkey,
+ Passkey,
ReductionOp* reduction_op,
Allocate* reduction_buffer,
- Allocate* sync_buffer);
+ Allocate* sync_buffer,
+ Bool* pred = nullptr);
ReductionOp* reduction_op() const {
return reduction_op_;
return sync_buffer_;
}
- const ParallelTypeBitmap& threadPredicate() const {
- return thread_predicate_;
- }
-
- void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) {
- thread_predicate_ = thread_predicate;
+ Bool* pred() const {
+ return pred_;
}
static std::string getPredicateFlagName(const TensorView* val);
ReductionOp* reduction_op_ = nullptr;
Allocate* reduction_buffer_ = nullptr;
Allocate* sync_buffer_ = nullptr;
- // gridReduce has template flags for thread predicates. In order to
- // use them, the thread predicate is held here separately from
- // Expr::predicate_.
- ParallelTypeBitmap thread_predicate_;
-};
-
-//! Grid welford operation
-//!
-//! This node is used only after lowering a fusion to explicitly mark a grid
-//! reduction and the buffer allocation needed to do it.
-//!
-//! This node provides FusionExecutor the information it needs to allocate the
-//! reduction and sync buffers.
-class TORCH_CUDA_CU_API GridWelford final : public Expr {
- public:
- void accept(IrVisitor* visitor) const override {
- visitor->visit(this);
- }
-
- void accept(MutableIrVisitor* visitor) override {
- visitor->visit(this);
- }
-
- GridWelford(
- Passkey passkey,
- WelfordOp* welford_op,
- Allocate* var_buffer,
- Allocate* avg_buffer,
- Allocate* n_buffer,
- Allocate* sync_buffer);
-
- WelfordOp* welford_op() const {
- return welford_op_;
- }
-
- Allocate* var_buffer() const {
- return var_buffer_;
- }
-
- Allocate* avg_buffer() const {
- return avg_buffer_;
- }
-
- Allocate* N_buffer() const {
- return n_buffer_;
- }
-
- Allocate* sync_buffer() const {
- return sync_buffer_;
- }
-
- const ParallelTypeBitmap& threadPredicate() const {
- return thread_predicate_;
- }
-
- void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) {
- thread_predicate_ = thread_predicate;
- }
-
- static std::string getPredicateFlagName(const TensorView* val);
- static std::string getPredicateFlagName(const fuser::cuda::TensorView* val);
-
- private:
- WelfordOp* welford_op_ = nullptr;
- Allocate* var_buffer_ = nullptr;
- Allocate* avg_buffer_ = nullptr;
- Allocate* n_buffer_ = nullptr;
- Allocate* sync_buffer_ = nullptr;
- // gridReduce has template flags for thread predicates. In order to
- // use them, the thread predicate is held here separately from
- // Expr::predicate_.
- ParallelTypeBitmap thread_predicate_;
+ Bool* pred_ = nullptr;
};
} // namespace kir
namespace cuda {
namespace kir {
-Val* IrBuilder::newResult(DataType dtype) {
- switch (dtype) {
+bool isLoweredScalar(const Val* val) {
+ switch (val->getValType().value()) {
+ case ValType::KirNamedScalar:
+ case ValType::KirScalar:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool isLoweredVal(const Val* val) {
+ switch (val->getValType().value()) {
+ case ValType::TensorIndex:
+ case ValType::KirNamedScalar:
+ case ValType::KirScalar:
+ case ValType::KirTensorDomain:
+ case ValType::KirIterDomain:
+ case ValType::KirTensorView:
+ return true;
+ default:
+ return false;
+ }
+}
+
+Val* IrBuilder::newResult(const Val* lhs, const Val* rhs) {
+ TORCH_CHECK(isLoweredScalar(lhs));
+ TORCH_CHECK(isLoweredScalar(rhs));
+ TORCH_CHECK(lhs->getDataType() == rhs->getDataType());
+
+ // Allocate a compatible result value
+ switch (lhs->getDataType().value()) {
case DataType::Bool:
return create<Bool>(c10::nullopt);
- case DataType::Double:
- return create<Double>(c10::nullopt);
+ case DataType::Float:
+ return create<Float>(c10::nullopt);
+ case DataType::Half:
+ return create<Half>(c10::nullopt);
case DataType::Int:
return create<Int>(c10::nullopt);
default:
}
Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
- TORCH_CHECK(lhs->dtype() == rhs->dtype(), "Incompatible operand types");
- auto result = newResult(lhs->dtype());
+ auto result = newResult(lhs, rhs);
create<BinaryOp>(op_type, result, lhs, rhs);
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
return result;
return result;
}
-Val* IrBuilder::whereExpr(Val* pred, Val* lhs, Val* rhs) {
- TORCH_CHECK(lhs->dtype() == rhs->dtype(), "Incompatible operand types");
- auto result = newResult(lhs->dtype());
- create<TernaryOp>(TernaryOpType::Where, result, pred, lhs, rhs);
- return result;
-}
-
-Val* IrBuilder::negExpr(Val* val) {
- auto result = newResult(val->dtype());
- create<UnaryOp>(UnaryOpType::Neg, result, val);
- return result;
-}
-
-Val* IrBuilder::setExprNamedScalar(const std::string& name, Val* val) {
- auto result = create<NamedScalar>(name, val->dtype());
- create<UnaryOp>(UnaryOpType::Set, result, val);
- return result;
-}
-
-Val* IrBuilder::addressExprNamedScalar(const std::string& name, Val* val) {
- auto result = create<NamedScalar>(name, DataType::Int);
- create<UnaryOp>(UnaryOpType::Address, result, val);
- return result;
-}
-
Val* IrBuilder::andExpr(Val* lhs, Val* rhs) {
return newLogicExpr(BinaryOpType::And, lhs, rhs);
}
return newLogicExpr(BinaryOpType::Eq, lhs, rhs);
}
-Val* IrBuilder::gtExpr(Val* lhs, Val* rhs) {
- return newLogicExpr(BinaryOpType::GT, lhs, rhs);
-}
-
Val* IrBuilder::ltExpr(Val* lhs, Val* rhs) {
return newLogicExpr(BinaryOpType::LT, lhs, rhs);
}
-Val* IrBuilder::leExpr(Val* lhs, Val* rhs) {
- return newLogicExpr(BinaryOpType::LE, lhs, rhs);
-}
-
-Val* IrBuilder::geExpr(Val* lhs, Val* rhs) {
- return newLogicExpr(BinaryOpType::GE, lhs, rhs);
-}
-
Val* IrBuilder::addExpr(Val* lhs, Val* rhs) {
return newArithmeticExpr(BinaryOpType::Add, lhs, rhs);
}
return newArithmeticExpr(BinaryOpType::Mod, lhs, rhs);
}
-Val* IrBuilder::maxExpr(Val* lhs, Val* rhs) {
- return newArithmeticExpr(BinaryOpType::Max, lhs, rhs);
-}
-
-Val* IrBuilder::minExpr(Val* lhs, Val* rhs) {
- return newArithmeticExpr(BinaryOpType::Min, lhs, rhs);
-}
-
-Int* IrBuilder::zeroVal() {
- if (zero_ == nullptr) {
- zero_ = create<kir::Int>(0);
- }
- return zero_;
-}
-
-Int* IrBuilder::oneVal() {
- if (one_ == nullptr) {
- one_ = create<kir::Int>(1);
- }
- return one_;
-}
-
-Bool* IrBuilder::falseVal() {
- if (false_ == nullptr) {
- false_ = create<kir::Bool>(false);
- }
- return false_;
-}
-
-Bool* IrBuilder::trueVal() {
- if (true_ == nullptr) {
- true_ = create<kir::Bool>(true);
- }
- return true_;
-}
-
-NamedScalar* IrBuilder::magicZeroVal() {
- if (magic_zero_ == nullptr) {
- magic_zero_ = create<kir::NamedScalar>("nvfuser_zero", DataType::Int);
- }
- return magic_zero_;
-}
-
} // namespace kir
} // namespace cuda
} // namespace fuser
namespace cuda {
namespace kir {
+// Simple classification helpers
+bool isLoweredScalar(const Val* val);
+bool isLoweredVal(const Val* val);
+
//! Kernel IR builder interface
//!
//! The only way to create new Kernel IR nodes is through the
//! auto new_node = ir_builder.create<kir::Int>(1));
//! auto result = ir_builder.mulExpr(lhs, rhs);
//!
-class TORCH_CUDA_CU_API IrBuilder {
+class IrBuilder {
public:
explicit IrBuilder(Kernel* kernel) : kernel_(kernel) {}
//! to the appropriate constructor
template <class T, class... Args>
T* create(Args&&... args) {
- const kir::Passkey passkey(kernel_);
- const auto node = new T(passkey, std::forward<Args>(args)...);
- kernel_->registerIrNode(passkey, std::unique_ptr<T>(node));
- return node;
+ // TODO(kir): switch this to Kernel registration
+ return new T(kir::Passkey(), std::forward<Args>(args)...);
}
- // Unary operations
- Val* negExpr(Val* val);
- Val* setExprNamedScalar(const std::string& name, Val* val);
- Val* addressExprNamedScalar(const std::string& name, Val* val);
-
- // Binary operations
+ // Binary expressions
Val* andExpr(Val* lhs, Val* rhs);
Val* eqExpr(Val* lhs, Val* rhs);
- Val* gtExpr(Val* lhs, Val* rhs);
Val* ltExpr(Val* lhs, Val* rhs);
- Val* leExpr(Val* lhs, Val* rhs);
- Val* geExpr(Val* lhs, Val* rhs);
Val* addExpr(Val* lhs, Val* rhs);
Val* subExpr(Val* lhs, Val* rhs);
Val* mulExpr(Val* lhs, Val* rhs);
Val* divExpr(Val* lhs, Val* rhs);
Val* ceilDivExpr(Val* lhs, Val* rhs);
Val* modExpr(Val* lhs, Val* rhs);
- Val* maxExpr(Val* lhs, Val* rhs);
- Val* minExpr(Val* lhs, Val* rhs);
-
- // Ternary operations
- Val* whereExpr(Val* pred, Val* lhs, Val* rhs);
-
- // Shortcuts for frequently used vals
- Int* zeroVal();
- Int* oneVal();
- Bool* falseVal();
- Bool* trueVal();
-
- NamedScalar* magicZeroVal();
private:
- Val* newResult(DataType dtype);
+ Val* newResult(const Val* lhs, const Val* rhs);
Val* newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs);
Val* newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs);
private:
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wunused-private-field"
// Non-owning pointer to the kernel to be modified
Kernel* kernel_ = nullptr;
- // Frequently used constant vals
- Int* zero_ = nullptr;
- Int* one_ = nullptr;
- Bool* false_ = nullptr;
- Bool* true_ = nullptr;
-
- // Magic zero corresponds to runtime/helpers.cu magic_zero
- NamedScalar* magic_zero_ = nullptr;
+#pragma clang diagnostic pop
};
} // namespace kir
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/type.h>
-#include <utility>
+#include <sstream>
namespace torch {
namespace jit {
namespace cuda {
namespace kir {
-namespace {
-
-const char* boolLiteral(bool value) {
+static std::string boolLiteral(bool value) {
return value ? "true" : "false";
}
-std::string varName(const kir::Val* val, const char* prefix) {
- std::stringstream value_name;
- if (val == nullptr) {
- value_name << "$nullptr";
- } else if (val->name() != kInvalidStmName) {
- value_name << prefix << val->name();
- } else {
- value_name << "k" << prefix << val->id();
- }
- return value_name.str();
-}
-
-} // namespace
-
-void IrPrinter::printNode(const kir::Node* node) {
- os_ << gen(node, true);
+void IrPrinter::printNode(const Statement* stmt) {
+ handle(stmt);
}
void IrPrinter::printKernel(const Kernel* kernel) {
// kernel body
startBlock();
for (auto expr : kernel->topLevelExprs()) {
- os_ << gen(expr, true);
+ handle(expr);
}
endBlock();
os_ << "END.\n\n";
std::ostream& IrPrinter::indent() {
for (const auto i : c10::irange(indent_level_)) {
(void)i; // Suppress unused variable warning
- ir_str_ << kTab;
- }
- ir_str_ << margin_;
- return ir_str_;
-}
-
-std::string IrPrinter::gen(const kir::Node* node, bool top_level) {
- if (node == nullptr) {
- return "$nullptr";
- }
-
- // If we're generatign a top level statement we expect to start
- // with an empty set of uses
- TORCH_INTERNAL_ASSERT(!implicit_definition_ || uses_.empty() || !top_level);
-
- // Mark the node as generated
- visited_.insert(node);
-
- // Generate the node itself
- std::stringstream node_str;
- std::swap(node_str, ir_str_);
- node->accept(this);
- std::swap(node_str, ir_str_);
-
- if (!implicit_definition_) {
- return node_str.str();
- }
-
- if (top_level) {
- // Implicitly mark top level nodes as used, so we
- // get their definitions printed (useful for debugging)
- if (auto val = dynamic_cast<const kir::Val*>(node)) {
- uses_.insert(val);
- }
-
- // Make a copy of the node uses (and reset global state)
- const auto node_uses = uses_;
- uses_.clear();
-
- std::stringstream top_level_str;
-
- // Hoist implicit definitions
- for (auto use : node_uses) {
- const auto def = use->definition();
- if (def && visited_.find(def) == visited_.end()) {
- margin_ = "~ ";
- top_level_str << gen(def, true);
- margin_ = "";
- }
- }
-
- top_level_str << node_str.str();
- return top_level_str.str();
- } else {
- return node_str.str();
+ os_ << kTab;
}
+ return os_;
}
-std::string IrPrinter::use(const kir::Val* val) {
- if (val != nullptr) {
- uses_.insert(val);
- }
- return gen(val);
+std::string IrPrinter::gen(const Statement* stmt) {
+ std::stringstream ss;
+ IrPrinter ir_printer(ss);
+ ir_printer.handle(stmt);
+ return ss.str();
}
void IrPrinter::startBlock() {
}
void IrPrinter::handleBlock(const kir::Scope& scope) {
- // Save the uses of the parent scope
- decltype(uses_) outer_uses;
- std::swap(uses_, outer_uses);
-
startBlock();
for (auto expr : scope.exprs()) {
- ir_str_ << gen(expr, true);
+ handle(expr);
}
endBlock();
+}
+
+void IrPrinter::handle(const Statement* s) {
+ OptInConstDispatch::handle(s);
+}
- // Restore parent's uses
- std::swap(uses_, outer_uses);
+void IrPrinter::handle(const Val* v) {
+ OptInConstDispatch::handle(v);
}
-void IrPrinter::visit(const kir::Bool* node) {
- if (node->isConst()) {
- ir_str_ << boolLiteral(*node->value());
+void IrPrinter::handle(const Expr* e) {
+ OptInConstDispatch::handle(e);
+}
+
+void IrPrinter::handle(const kir::Bool* node) {
+ if (node->isSymbolic()) {
+ os_ << "b" << node->name();
} else {
- ir_str_ << varName(node, "b");
+ os_ << boolLiteral(*node->value());
}
}
-void IrPrinter::visit(const kir::Double* node) {
- if (node->isConst()) {
- const int digits = std::numeric_limits<Double::ScalarType>::max_digits10;
- ir_str_ << "double(" << std::setprecision(digits) << *node->value() << ")";
+void IrPrinter::handle(const kir::Float* node) {
+ if (node->isSymbolic()) {
+ os_ << "f" << node->name();
} else {
- ir_str_ << varName(node, "d");
+ const int digits = std::numeric_limits<Float::ScalarType>::max_digits10;
+ os_ << "float(" << std::setprecision(digits) << *node->value() << ")";
}
}
-void IrPrinter::visit(const kir::Int* node) {
- if (node->isConst()) {
- ir_str_ << *node->value();
+void IrPrinter::handle(const kir::Half* node) {
+ if (node->isSymbolic()) {
+ os_ << "h" << node->name();
} else {
- ir_str_ << varName(node, "i");
+ os_ << "half(" << *node->value() << ")";
}
}
-void IrPrinter::visit(const kir::NamedScalar* node) {
- ir_str_ << node->name();
+void IrPrinter::handle(const kir::Int* node) {
+ if (node->isSymbolic()) {
+ os_ << "i" << node->name();
+ } else {
+ os_ << *node->value();
+ }
}
-void IrPrinter::visit(const kir::Predicate* node) {
- switch (node->predicate_type()) {
- case PredicateType::Inline: {
- ir_str_ << "Inline";
- break;
- }
- case PredicateType::Manual: {
- ir_str_ << node->value();
- break;
- }
- case PredicateType::Misaligned: {
- ir_str_ << "Misaligned";
- break;
- }
- case PredicateType::Padding: {
- ir_str_ << "Padding";
- break;
- }
- case PredicateType::Shift: {
- ir_str_ << "Shift";
- break;
- }
- case PredicateType::Unswitch: {
- ir_str_ << "Unswitch";
- break;
- }
- case PredicateType::Vectorize: {
- ir_str_ << "Vectorize";
- break;
- }
- default:
- break;
- }
+void IrPrinter::handle(const kir::NamedScalar* node) {
+ os_ << node->name();
}
-void IrPrinter::visit(const kir::TensorIndex* node) {
- ir_str_ << gen(node->view()) << "[";
+void IrPrinter::handle(const kir::TensorIndex* node) {
+ os_ << gen(node->view()) << "[";
for (auto index : node->indices()) {
- ir_str_ << use(index);
+ os_ << gen(index);
if (index != node->indices().back()) {
- ir_str_ << ", ";
+ os_ << ", ";
}
}
- ir_str_ << "]";
+ os_ << "]";
}
-void IrPrinter::visit(const kir::IterDomain* node) {
- ir_str_ << varName(node, "id") << "[";
+void IrPrinter::handle(const kir::IterDomain* node) {
if (node->isRFactorProduct()) {
- ir_str_ << "rfactor.";
+ os_ << "rfactor.";
}
- ir_str_ << node->parallelType() << "." << node->iterType() << "("
- << use(node->start()) << " .. " << use(node->extent()) << ")]";
+ os_ << node->getParallelType() << "." << node->getIterType() << "("
+ << gen(node->start()) << " .. " << gen(node->rawExtent()) << ")";
}
-void IrPrinter::visit(const kir::TensorDomain*) {
+void IrPrinter::handle(const kir::TensorDomain*) {
// TODO(kir): print Tensor shapes?
- ir_str_ << "kir::TensorDomain";
+ os_ << "kir::TensorDomain";
}
-void IrPrinter::visit(const kir::TensorView* node) {
- // TODO(kir): print memory type too?
- ir_str_ << varName(node, "T");
+void IrPrinter::handle(const kir::TensorView* node) {
+ // TODO(KIR): print memory type too?
+ os_ << "T" << node->name();
}
-void IrPrinter::visit(const kir::UnaryOp* node) {
+void IrPrinter::handle(const kir::UnaryOp* node) {
indent() << gen(node->out()) << " = ";
- auto op_type = node->operation();
-
- if (auto op = inline_op_str(op_type)) {
- if (alsoBooleanOperator(op_type) &&
- node->out()->dtype() == DataType::Bool) {
- ir_str_ << stringifyBooleanOp(op_type) << gen(node->in());
- } else {
- ir_str_ << *op << gen(node->in());
- }
+ if (auto op = inline_op_str(node->getUnaryOpType())) {
+ os_ << *op << gen(node->in());
} else {
- if (op_type == UnaryOpType::Cast) {
- const auto cast_str =
- cast_func_str({node->in()->dtype(), node->out()->dtype()});
- ir_str_ << cast_str.value();
+ if (node->getUnaryOpType() == UnaryOpType::Cast) {
+ const auto cast_str = cast_func_str(
+ {node->in()->getDataType().value(),
+ node->out()->getDataType().value()});
+ os_ << cast_str.value();
} else {
- ir_str_ << op_type;
- if (needFloatSuffix(op_type) && node->out()->dtype() == DataType::Float) {
- ir_str_ << "f";
- }
+ os_ << node->getUnaryOpType();
}
- if (op_type == UnaryOpType::RandLike) {
- ir_str_ << "(RND";
+ os_ << "(";
+ if (node->getUnaryOpType() == UnaryOpType::RandLike) {
+ os_ << "RND";
} else {
- ir_str_ << "(";
- ir_str_ << use(node->in());
+ os_ << gen(node->in());
}
- ir_str_ << ")";
+ os_ << ")";
}
- ir_str_ << "\n";
+ os_ << "\n";
}
-void IrPrinter::visit(const kir::BinaryOp* node) {
+void IrPrinter::handle(const kir::BinaryOp* node) {
indent() << gen(node->out()) << " = ";
- const auto op_type = node->operation();
- const auto lhs = use(node->lhs());
- const auto rhs = use(node->rhs());
+ const auto op_type = node->getBinaryOpType();
+ const auto lhs = gen(node->lhs());
+ const auto rhs = gen(node->rhs());
if (auto op = inline_op_str(op_type)) {
- ir_str_ << lhs << " ";
- if (alsoBooleanOperator(op_type) &&
- node->out()->dtype() == DataType::Bool) {
- ir_str_ << stringifyBooleanOp(op_type);
- } else {
- ir_str_ << *op;
- }
- ir_str_ << " " << rhs;
+ os_ << lhs << " " << *op << " " << rhs;
} else {
- ir_str_ << op_type;
- if (needFloatSuffix(op_type) && node->out()->dtype() == DataType::Float) {
- ir_str_ << "f";
- }
- ir_str_ << "(" << lhs << ", " << rhs << ")";
+ os_ << op_type << "(" << lhs << ", " << rhs << ")";
}
- ir_str_ << "\n";
+ os_ << "\n";
}
-void IrPrinter::visit(const kir::TernaryOp* node) {
- indent() << gen(node->out()) << " = " << node->operation() << "("
- << use(node->in1()) << ", " << use(node->in2()) << ", "
- << use(node->in3()) << ")\n";
+void IrPrinter::handle(const kir::TernaryOp* node) {
+ indent() << gen(node->out()) << " = " << node->getTernaryOpType() << "("
+ << gen(node->in1()) << ", " << gen(node->in2()) << ", "
+ << gen(node->in3()) << ")\n";
}
-void IrPrinter::visit(const kir::ReductionOp* node) {
+void IrPrinter::handle(const kir::ReductionOp* node) {
indent() << gen(node->out()) << " = "
- << "REDUCTION(op='" << node->operation() << "'"
- << ", in=" << use(node->in()) << ", init=" << use(node->init())
- << ", pred=" << use(node->predicate()) << ")\n";
-}
-
-void IrPrinter::visit(const kir::WelfordOp* node) {
- indent() << gen(node->outVar()) << "," << gen(node->outAvg()) << ","
- << gen(node->outN()) << " = "
- << "Welford( inAvg=" << use(node->inAvg());
- if (!node->inN()->isOneInt()) {
- indent() << " inVar=" << use(node->inVar());
- }
- indent() << " inN=" << use(node->inN());
- if (!node->initN()->isZeroInt()) {
- indent() << ", initVar=" << use(node->initVar())
- << " initAvg=" << use(node->initAvg())
- << " initN=" << use(node->initN());
- }
- indent() << ", pred=" << use(node->predicate()) << ")\n";
+ << "REDUCTION(op='" << node->getReductionOpType() << "'"
+ << ", in=" << gen(node->in()) << ", init=" << gen(node->init())
+ << ", pred=" << gen(node->pred()) << ")\n";
}
-void IrPrinter::visit(const kir::GridReduction* node) {
+void IrPrinter::handle(const kir::GridReduction* node) {
const auto* reduction_op = node->reduction_op();
indent() << gen(reduction_op->out()) << " = "
- << "GRID_REDUCTION(op='" << reduction_op->operation() << "'"
- << ", in=" << use(reduction_op->in())
- << ", init=" << use(reduction_op->init())
- << ", pred=" << use(reduction_op->predicate()) << ")\n";
- indent() << kTab << kTab
- << ".reduction_buffer=" << use(node->reduction_buffer()->buffer())
+ << "GRID_REDUCTION(op='" << reduction_op->getReductionOpType() << "'"
+ << ", in=" << gen(reduction_op->in())
+ << ", init=" << gen(reduction_op->init())
+ << ", pred=" << gen(reduction_op->pred()) << ")\n";
+ indent() << kTab << ".reduction_buffer=" << gen(node->reduction_buffer())
<< "\n";
- indent() << kTab << kTab
- << ".sync_buffer=" << use(node->sync_buffer()->buffer()) << "\n";
- indent() << kTab << kTab << ".grid_pred=" << use(node->predicate()) << "\n";
+ indent() << kTab << ".sync_buffer=" << gen(node->sync_buffer()) << "\n";
+ indent() << kTab << ".grid_pred=" << gen(node->pred()) << "\n";
}
-void IrPrinter::visit(const kir::GridWelford* node) {
- const auto* welford_op = node->welford_op();
- indent() << gen(welford_op->outVar()) << "," << gen(welford_op->outAvg())
- << "," << gen(welford_op->outN()) << " = "
- << "GRID_WELFORD("
- << "inAvg=" << use(welford_op->inAvg());
- if (!welford_op->inN()->isOneInt()) {
- indent() << ", inVar=" << use(welford_op->inVar());
- }
- indent() << ", inN=" << use(welford_op->inN());
- if (!welford_op->initN()->isZeroInt()) {
- indent() << ", initVar=" << use(welford_op->initVar())
- << " initAvg=" << use(welford_op->initAvg())
- << " initN=" << use(welford_op->initN());
- }
- indent() << ", pred=" << use(welford_op->predicate()) << ")\n";
- indent() << kTab << kTab
- << ".var_buffer=" << use(node->var_buffer()->buffer())
- << ".avg_buffer=" << use(node->avg_buffer()->buffer())
- << ".n_buffer=" << use(node->N_buffer()->buffer()) << "\n";
- indent() << kTab << kTab
- << ".sync_buffer=" << use(node->sync_buffer()->buffer()) << "\n";
- indent() << kTab << kTab << ".grid_pred=" << use(node->predicate()) << "\n";
+void IrPrinter::handle(const kir::BroadcastOp* node) {
+ indent() << gen(node->out()) << " = BROADCAST(" << gen(node->in()) << ")\n";
}
-void IrPrinter::visit(const kir::BroadcastOp* node) {
- indent() << gen(node->out()) << " = BROADCAST(" << use(node->in()) << ")\n";
-}
-
-void IrPrinter::visit(const kir::ForLoop* node) {
+void IrPrinter::handle(const kir::ForLoop* node) {
indent() << "FOR " << gen(node->index()) << " in " << gen(node->iter_domain())
<< ":\n";
handleBlock(node->body());
}
-void IrPrinter::visit(const kir::IfThenElse* node) {
- indent() << "IF " << use(node->predicate()) << ":\n";
+void IrPrinter::handle(const kir::IfThenElse* node) {
+ indent() << "IF " << gen(node->cond()) << ":\n";
handleBlock(node->thenBody());
if (node->hasElse()) {
indent() << "ELSE:\n";
}
}
-void IrPrinter::visit(const kir::Allocate* node) {
+void IrPrinter::handle(const kir::Allocate* node) {
indent() << gen(node->buffer()) << " = ALLOCATE("
- << "mem_type=" << node->memoryType() << ", "
- << "size=" << use(node->size()) << ", "
+ << "mem_type=" << node->getMemoryType() << ", "
+ << "size=" << gen(node->size()) << ", "
<< "zero_init=" << boolLiteral(node->zeroInit()) << ")\n";
- if (node->alias() != nullptr) {
- indent() << kTab << kTab << ".alias=" << gen(node->alias()->buffer())
- << "\n";
- }
}
-void IrPrinter::visit(const kir::Sync* node) {
+void IrPrinter::handle(const kir::Sync* node) {
indent() << "SYNC(war_hazard=" << boolLiteral(node->isWarHazardSync())
<< ")\n";
}
-void IrPrinter::visit(const kir::InitMagicZero* node) {
- indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n";
-}
-
-void IrPrinter::visit(const kir::UpdateMagicZero* node) {
- indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
-}
-
-std::string toString(const kir::Node* stmt, bool implicit_definitions) {
+std::string toString(const Statement* stmt) {
std::stringstream ss;
- IrPrinter ir_printer(ss, implicit_definitions);
+ IrPrinter ir_printer(ss);
ir_printer.printNode(stmt);
return ss.str();
}
-std::string toString(
- const std::vector<kir::Expr*>& exprs,
- bool implicit_definitions) {
- std::stringstream ss;
- IrPrinter ir_printer(ss, implicit_definitions);
- for (auto expr : exprs) {
- ir_printer.printNode(expr);
- }
- return ss.str();
-}
-
} // namespace kir
} // namespace cuda
} // namespace fuser
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <iostream>
-#include <sstream>
#include <string>
-#include <unordered_set>
namespace torch {
namespace jit {
//! This class is intended for debug printing, so it attempts
//! to handle invalid IR states as much as possible.
//!
-//! implicit_definition_ = true will recurisvely print the definition of all
-//! inputs to an expression if they haven't been printed.
-class TORCH_CUDA_CU_API IrPrinter : private kir::IrVisitor {
- static constexpr char const* kTab = " ";
+class TORCH_CUDA_CU_API IrPrinter : private OptInConstDispatch {
+ static constexpr char* kTab = " ";
public:
//! Constructs a new IrPrinter which outputs to the specified stream
- explicit IrPrinter(std::ostream& os, bool implicit_definition = true)
- : os_(os), implicit_definition_(implicit_definition) {}
+ explicit IrPrinter(std::ostream& os) : os_(os) {}
//! Print a single Kernel IR node
- void printNode(const kir::Node* node);
+ void printNode(const Statement* stmt);
//! Print a complete Kernel definition
void printKernel(const Kernel* kernel);
private:
- // Generates a string representation of an IR node
- //
- // If `top_level` is true, all the value uses are tracked and
- // their definitions are implicitly printed before the node itself
- //
- std::string gen(const kir::Node* node, bool top_level = false);
-
- // Generate a string representation of an used value
- // (this helps automatically tracking the value uses)
- std::string use(const kir::Val* val);
+ static std::string gen(const Statement* stmt);
std::ostream& indent();
void endBlock();
void handleBlock(const kir::Scope& scope);
- void visit(const kir::Bool*) final;
- void visit(const kir::Double*) final;
- void visit(const kir::Int*) final;
- void visit(const kir::NamedScalar*) final;
- void visit(const kir::Predicate*) final;
-
- void visit(const kir::TensorIndex*) final;
- void visit(const kir::IterDomain*) final;
- void visit(const kir::TensorDomain*) final;
- void visit(const kir::TensorView*) final;
-
- void visit(const kir::UnaryOp*) final;
- void visit(const kir::BinaryOp*) final;
- void visit(const kir::TernaryOp*) final;
- void visit(const kir::ReductionOp*) final;
- void visit(const kir::WelfordOp*) final;
- void visit(const kir::BroadcastOp*) final;
-
- void visit(const kir::GridReduction*) final;
- void visit(const kir::GridWelford*) final;
- void visit(const kir::ForLoop*) final;
- void visit(const kir::IfThenElse*) final;
- void visit(const kir::Allocate*) final;
- void visit(const kir::Sync*) final;
- void visit(const kir::InitMagicZero*) final;
- void visit(const kir::UpdateMagicZero*) final;
+ void handle(const Statement*) final;
+ void handle(const Val*) final;
+ void handle(const Expr*) final;
+
+ void handle(const kir::Bool*) final;
+ void handle(const kir::Float*) final;
+ void handle(const kir::Half*) final;
+ void handle(const kir::Int*) final;
+ void handle(const kir::NamedScalar*) final;
+
+ void handle(const kir::TensorIndex*) final;
+ void handle(const kir::IterDomain*) final;
+ void handle(const kir::TensorDomain*) final;
+ void handle(const kir::TensorView*) final;
+
+ void handle(const kir::UnaryOp*) final;
+ void handle(const kir::BinaryOp*) final;
+ void handle(const kir::TernaryOp*) final;
+ void handle(const kir::ReductionOp*) final;
+ void handle(const kir::BroadcastOp*) final;
+
+ void handle(const kir::GridReduction*) final;
+ void handle(const kir::ForLoop*) final;
+ void handle(const kir::IfThenElse*) final;
+ void handle(const kir::Allocate*) final;
+ void handle(const kir::Sync*) final;
private:
std::ostream& os_;
-
- // Current indentation level
int indent_level_ = 0;
-
- // Internal IR generation stream
- std::stringstream ir_str_;
-
- // Tracks the set of nodes which have been printed
- std::unordered_set<const kir::Node*> visited_;
-
- // Optional left margin printed after the indentation
- const char* margin_ = "";
-
- // The set of values used by the current top-level IR node
- std::unordered_set<const kir::Val*> uses_;
-
- // If the definition of all inputs to an expression haven't been printed
- // already implicit_definition_ = true will print them before printing the
- // requested node.
- bool implicit_definition_ = true;
};
-//! Returns the string representation of a Kernel IR node. If the definition of
-//! all inputs to an expression haven't been printed already
-//! implicit_definition_ = true will print them before printing the requested
-//! node.
-TORCH_CUDA_CU_API std::string toString(
- const kir::Node* stmt,
- bool implicit_definitions = true);
-
-//! Returns the string representation of a vector of kir::Expr, convenient
-//! debugm echanism during lowering. If the definition of all inputs to an
-//! expression haven't been printed already implicit_definition_ = true will
-//! print them before printing the requested node.
-TORCH_CUDA_CU_API std::string toString(
- const std::vector<kir::Expr*>& exprs,
- bool implicit_definitions = true);
+//! Returns the string representation of a Kernel IR node
+std::string toString(const Statement* stmt);
} // namespace kir
} // namespace cuda
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/lower_alias_memory.h>
-#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
-#include <torch/csrc/jit/codegen/cuda/lower_expr_sort.h>
#include <torch/csrc/jit/codegen/cuda/lower_index.h>
#include <torch/csrc/jit/codegen/cuda/lower_insert_syncs.h>
#include <torch/csrc/jit/codegen/cuda/lower_loops.h>
-#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
-#include <torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h>
-#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
-#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
-#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
-#include <list>
-#include <unordered_map>
-#include <unordered_set>
-
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
// TODO(kir): revisit this
-thread_local GpuLower* active_gpu_lower = nullptr; // NOLINT
-namespace {
-
-// Going to generate a map of tensor view root domain extents to reduce the
-// number used during lowering. For example if we have:
-//
-// T2[i0, i1] = T1[i0, i1] + T2[i2, i3]
-//
-// We know it would be safe to use:
-//
-// T2[i0, i1] = T1[i0, i1] + T2[i0, i1]
-//
-// And that way we don't generate T2.size[0] and T2.size[1], instead we will
-// reuse T1.size[0] and T1.size[1]
-// This is important when doing CSE as T2 and T1 would otherwise look like
-// they're using different values, even though we know they're the same
-//
-// There's some duplicate logic here that's in computeAt map, but it's not so
-// concice there to pull out. May want to consider making this mapping its own
-// class especially as it may be useful during scheduling.
-std::unordered_map<Val*, Val*> getSimplificationMap(Fusion* fusion) {
- std::list<std::unordered_set<IterDomain*>> disjoint_root_sets;
- std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>*>
- id_to_disjoint_root_set;
-
- auto map_root_ids = [&disjoint_root_sets, &id_to_disjoint_root_set](
- IterDomain* id0, IterDomain* id1) {
- if (id0->isBroadcast() || id1->isBroadcast()) {
- return;
- }
-
- auto disjoint_set_0_it = id_to_disjoint_root_set.find(id0);
- auto disjoint_set_1_it = id_to_disjoint_root_set.find(id1);
- bool set_0_found = disjoint_set_0_it != id_to_disjoint_root_set.end();
- bool set_1_found = disjoint_set_1_it != id_to_disjoint_root_set.end();
-
- if (set_0_found && set_1_found) {
- if (disjoint_set_0_it->second == disjoint_set_1_it->second) {
- return;
- }
- // merge second disjoint set into first
- auto* set_0 = disjoint_set_0_it->second;
- auto* set_1 = disjoint_set_1_it->second;
- for (auto id : *set_1) {
- set_0->emplace(id);
- id_to_disjoint_root_set[id] = set_0;
- }
- // remove second set from disjoint_root_sets
- disjoint_root_sets.erase(std::find(
- disjoint_root_sets.begin(), disjoint_root_sets.end(), *set_1));
- } else if (set_0_found || set_1_found) {
- auto existing_set =
- set_0_found ? disjoint_set_0_it->second : disjoint_set_1_it->second;
- auto to_add_id = set_0_found ? id1 : id0;
- existing_set->emplace(to_add_id);
- id_to_disjoint_root_set[to_add_id] = existing_set;
- // add entry into existing set
- } else {
- // create new set entry
- disjoint_root_sets.push_back(std::unordered_set<IterDomain*>());
- auto* new_set = &disjoint_root_sets.back();
- new_set->emplace(id0);
- new_set->emplace(id1);
- id_to_disjoint_root_set[id0] = new_set;
- id_to_disjoint_root_set[id1] = new_set;
- }
- };
-
- auto fusion_vals = fusion->usedMathVals();
- for (auto producer_tv : ir_utils::filterByType<TensorView>(fusion_vals)) {
- auto consumer_tvs = ir_utils::consumerTvsOf(producer_tv);
- for (auto consumer_tv : consumer_tvs) {
- auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv);
- auto c2p_root_map = pairwise_map.mapConsumerToProducer(
- consumer_tv->domain(), producer_tv->domain());
- for (auto entry : c2p_root_map) {
- auto c_id = entry.first;
- auto p_id = entry.second;
- map_root_ids(p_id, c_id);
- }
- }
- }
-
- // Map each set to an input ID (if it exists) that has the smallest ->name()
- // entry value
- std::unordered_map<std::unordered_set<IterDomain*>*, IterDomain*>
- set_to_input_id;
-
- // Loop over the root domains, of the inputs to the fusion. Pick an input ID
- // to use as the representative ID of the collected sets. Only consider inputs
- // as those are the ones that map to values like "T0.size[1]". They are he
- // ID's that propagated their extents into the problem. We could also check
- // the outputs as we do have C++ examples of using output dimensions for the
- // problem size instead of inputs. However, we don't do anything where we can
- // translate to those kinds of kernels integrated into PyTorch.
- for (auto input_tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
- for (auto id :
- TensorDomain::noReductions(input_tv->getMaybeRFactorDomain())) {
- auto id_set_it = id_to_disjoint_root_set.find(id);
- if (id_set_it == id_to_disjoint_root_set.end()) {
- continue;
- }
- auto* id_set = id_set_it->second;
- if (set_to_input_id.find(id_set) == set_to_input_id.end()) {
- set_to_input_id[id_set] = id;
- } else {
- auto input_id_of_set = set_to_input_id.at(id_set);
- // Swap id's if new name is less than previously set
- bool swap_ids = id->name() < input_id_of_set->name();
- // If new id is a const scalar but previously was'nt use the const
- // scalar
- swap_ids = swap_ids ||
- (id->extent()->isConstScalar() &&
- !input_id_of_set->extent()->isConstScalar());
- // If previous scalar was const and new isn't, don't swap
- swap_ids = swap_ids &&
- !(input_id_of_set->extent()->isConstScalar() &&
- !id->extent()->isConstScalar());
-
- if (swap_ids) {
- set_to_input_id[id_set] = id;
- }
- }
- }
- }
+thread_local GpuLower* active_gpu_lower = nullptr;
- // Finally make map from ID extents to the representitive ID extent.
- std::unordered_map<Val*, Val*> extent_to_min_input_id_extent;
- for (auto entry : set_to_input_id) {
- auto* set = entry.first;
- auto input_id = entry.second;
- for (auto id : *set) {
- extent_to_min_input_id_extent[id->extent()] = input_id->extent();
- }
- }
- return extent_to_min_input_id_extent;
-}
-
-} // namespace
void GpuLower::replaceSymbolicSizes() {
- FUSER_PERF_SCOPE("GpuLower::Lower::replaceSymbolicSizes");
+ FUSER_PERF_SCOPE("replaceSymbolicSizes");
kir::IrBuilder ir_builder(kernel());
// Grab inputs and outputs
+ // TODO: Only run through inputs for the size map, outputs don't actually set
+ // any sizes of the problem.
std::vector<TensorView*> inputs_and_outputs;
for (auto val : fusion_->inputs()) {
if (ir_utils::isTV(val)) {
}
}
- // Generate map for all tensorview root domain values to map them to symbolic
- // values. i.e. T0->getRootDomain()[0] would map to a named scalar
- // "T0.size[0]". This map will be used when lowering fusion ir to kernel ir.
+ // Run through inputs and outputs first. Since we're replacing full
+ // tensorviews their names are going to change. We need the new referenc
+ // name for the inputs/outputs. This way we won't reference the wrong tensor
+ // view. For example T0 may be translated to T9. We don't want our new
+ // variable to be T0->size[...] we need it to be T9->size[...]
for (TensorView* tv : inputs_and_outputs) {
// Replace the domain with one based on Ti.size[j]
+ std::vector<IterDomain*> new_domain_iters;
const std::vector<IterDomain*>& root_td = tv->getRootDomain();
size_t dim = 0;
// Output sizes could have reduction axes, which isn't what gets output.
// NOLINTNEXTLINE(bugprone-branch-clone)
- if (id->isReduction() ||
- (id->getIterType() == IterType::BroadcastWithoutStride)) {
+ if (id->isReduction()) {
+ continue;
+ } else if (id->getIterType() == IterType::BroadcastWithoutStride) {
+ continue;
+ // NOLINTNEXTLINE(bugprone-branch-clone)
+ } else if (id->getIterType() == IterType::BroadcastWithStride) {
+ dim++;
continue;
- } else if (
- // NOLINTNEXTLINE(bugprone-branch-clone)
- (id->getIterType() == IterType::BroadcastWithStride) ||
- orig_size->isConstScalar()) {
+ } else if (orig_size->isConstScalar()) {
dim++;
continue;
}
// TODO(kir): consider a different implementation which doesn't
- // hijack the kir_val_map_
- // Currently turn off this part for inputs of segmented fusion,
- // since FusionKernelRuntime will provide these as integer inputs
- if (kir_val_map_.find(orig_size) == kir_val_map_.end() &&
- !orig_size->isFusionInput() && !orig_size->isConstScalar()) {
+ // hijack the kir_map_
+ if (kir_map_.find(orig_size) == kir_map_.end()) {
std::stringstream ss;
ss << "T" << tv->name() << ".size[" << dim++ << "]";
- kir_val_map_[orig_size] = ir_builder.create<kir::NamedScalar>(
+ kir_map_[orig_size] = ir_builder.create<kir::NamedScalar>(
ss.str(), orig_size->getDataType().value());
- } else {
- dim++;
- }
- }
- }
-
- // Use a minimal number of sizes from provided tensors.
- auto extent_simplification_map = getSimplificationMap(fusion_);
- for (auto extent_entry : extent_simplification_map) {
- auto orig_extent = extent_entry.first;
- auto simplified_extent = extent_entry.second;
- if (kir_val_map_.count(orig_extent)) {
- if (kir_val_map_.count(simplified_extent)) {
- kir_val_map_[orig_extent] = kir_val_map_[simplified_extent];
- } else {
- kir_val_map_[orig_extent] = lowerValue(simplified_extent);
}
}
}
}
void GpuLower::lower() {
- FUSER_PERF_SCOPE("GpuLower::lower");
+ FUSER_PERF_SCOPE("lower");
TORCH_INTERNAL_ASSERT(fusion_ != nullptr);
TORCH_INTERNAL_ASSERT(
FusionGuard fg(fusion_);
// Start with a fresh kernel
- kernel_ = std::make_unique<kir::Kernel>();
+ kernel_ = std::make_unique<Kernel>();
// prepare for lowering
validateIr(fusion_);
replaceSymbolicSizes();
- trivial_reduction_info_.build(fusion_, this);
- // In the future we may directly use this map, but for now it will propagate
- // and validate (to some extent) the parallelization strategy.
- // This is the first time nodes will be lowered to kir nodes. Since for now we
- // propagate the parallel strategy in some instances, we need to do it before
- // lowering.
- ca_parallel_map_ = ComputeAtMap(ComputeAtMap::MappingMode::PARALLEL);
- ca_parallel_map_.build(fusion_, current());
-
- // Want to run this after parallel map is created
- validateVectorize(fusion_);
-
- // Generate mappings to generate indices
- ca_index_map_ = ComputeAtMap(ComputeAtMap::MappingMode::INDEX);
- ca_index_map_.build(fusion_, current());
+ // Compute thread predicates
+ ThreadPredicateMap preds(fusion_);
- // Generate mappings to generate and map to loop nests
- ca_loop_map_ = ComputeAtMap(ComputeAtMap::MappingMode::LOOP);
- ca_loop_map_.build(fusion_, current());
+ // Run our passes keeping the lowered expressions and forwarding them
+ const auto lowered_exprs =
+ LoopNestGenerator::loweredExprs(fusion_, preds, fusion_->exprs(true));
- validateParallelize(fusion_);
+ const auto unrolled_loops =
+ UnrollPass::runPass(fusion_, lowered_exprs, preds);
- parallelDimensionMap().build(fusion_);
- if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) {
- std::cout << parallelDimensionMap().toString();
- }
+ // Reuse memory locations if:
+ // TensorView is dynamic shared memory
+ // TensorViews have the same size
+ // Output TensorView is modified using Input TensorView
+ const auto reuse_mem_exprs = reuseMemoryAllocations(fusion_, unrolled_loops);
- // Scan the whole fusion and build mappings about halo extensions of
- // all IterDomains
- haloInfo().build(fusion_);
+ // Insert SyncThreads at end of for-loop to avoid WAR race condition
+ const auto sync_exprs = insertThreadSynchronization(fusion_, reuse_mem_exprs);
- // Compute thread predicates
- thread_pred_map_.build(fusion_);
+ const auto indexed_loops =
+ IndexLowering::getIndexedExprs(fusion_, sync_exprs);
- // Detects all exprssions that don't need predicates
- predicateElimination().build(fusion_);
+ // We now have the lowered expressions, finalize the kernel IR
+ kernel_->finalize(indexed_loops, preds);
// Set the kernel inputs & outputs
for (auto input : fusion_->inputs()) {
kernel_->addInput(GpuLower::lowerValue(input));
}
-
for (auto output : fusion_->outputs()) {
kernel_->addOutput(GpuLower::lowerValue(output));
}
-
- // Run our passes keeping the lowered expressions and forwarding
- // them
-
- // Reorder expressions for loop-nest generation respecting computeAt
- // relationships
- auto sorted_exprs = reorderExprsForComputeAt();
-
- // Generate loop-nests and place each expression at its
- // corresponding loop
- const auto lowered_exprs = LoopNestGenerator::loweredExprs(sorted_exprs);
-
- // Insert allocations
- const auto alloced_exprs = insertAllocations(lowered_exprs);
-
- // Insert read after write smem syncs
- const auto raw_sync_exprs = insertRawThreadSynchronization(alloced_exprs);
-
- // Inserts predicates after this, need to be careful in later passes when
- // inserting in loop nest structure as insertions could be on if then else
- // instead of directly on a for loop
- const auto unrolled_loops = UnrollPass::runPass(fusion_, raw_sync_exprs);
-
- const auto unrolled_mv_loops =
- processMisalignedVectorization(fusion_, unrolled_loops);
-
- // Reuse memory locations
- // TODO: Reenable once fixed.
- // const auto reuse_mem_exprs = reuseMemoryAllocations(unrolled_mv_loops);
-
- // Insert SyncThreads at end of for-loop to avoid WAR race condition
- // const auto war_sync_exprs =
- // insertWarThreadSynchronization(reuse_mem_exprs);
- const auto war_sync_exprs = insertWarThreadSynchronization(unrolled_mv_loops);
-
- const auto indexed_loops = IndexLowering::getIndexedExprs(war_sync_exprs);
-
- const auto conditional_loops =
- generateConditionalFromPredicate(fusion_, indexed_loops);
-
- // Insert fake zero updates to make sure nvrtc doesn't blow out register use
- // on index and predicate reuse
- const auto register_adjusted = insertMagicZero(conditional_loops);
-
- // We now have the lowered expressions, finalize the kernel IR
- kernel_->finalize(register_adjusted);
}
-kir::Kernel* GpuLower::kernel() const {
+Kernel* GpuLower::kernel() const {
TORCH_CHECK(kernel_);
return kernel_.get();
}
// Maps Fusion IR nodes to the Kernel IR counterparts
-class GpuLower::KernelIrMapper : private OptInConstDispatch {
+//
+// TODO(kir): this is a interim solution for easing the Kernel IR splitting
+//
+class TORCH_CUDA_CU_API GpuLower::KernelIrMapper : private OptInConstDispatch {
public:
explicit KernelIrMapper(GpuLower* gpu_lower)
: gpu_lower_(gpu_lower), ir_builder_(gpu_lower->kernel()) {}
- kir::Val* lowerValue(const Val* value) {
- const auto it = gpu_lower_->kir_val_map_.find(value);
- if (it != gpu_lower_->kir_val_map_.end()) {
+ Val* lower(const Val* value) {
+ const auto it = gpu_lower_->kir_map_.find(value);
+ if (it != gpu_lower_->kir_map_.end()) {
return it->second;
} else {
handle(value);
- const auto kir_value = gpu_lower_->kir_val_map_[value];
- TORCH_CHECK(kir_value != nullptr);
+ const auto lowered_node = gpu_lower_->kir_map_[value];
+ TORCH_CHECK(lowered_node != nullptr);
+ TORCH_CHECK(kir::isLoweredVal(lowered_node));
- // Lower the value definition, if any
+ // Lower the arithmetic expression defining the value, if any
if (value->isScalar()) {
- if (auto def = value->definition()) {
- const auto kir_def = lowerExpr(def);
- TORCH_INTERNAL_ASSERT(kir_value->definition() == kir_def);
+ if (auto def = value->getOrigin()) {
+ lowerDefinition(lowered_node, def);
}
}
- return kir_value;
+ return lowered_node;
}
}
- kir::Expr* lowerExpr(const Expr* expr) {
- const auto it = gpu_lower_->kir_expr_map_.find(expr);
- if (it != gpu_lower_->kir_expr_map_.end()) {
- return it->second;
- } else {
- handle(expr);
- const auto lowered_node = gpu_lower_->kir_expr_map_[expr];
- TORCH_CHECK(lowered_node != nullptr);
- return lowered_node;
+ private:
+ // TODO(kir): rewrite this
+ void lowerDefinition(Val* lowered_value, const Expr* def) {
+ switch (def->type()) {
+ case ExprType::UnaryOp: {
+ const auto op = def->as<UnaryOp>();
+ ir_builder_.create<kir::UnaryOp>(
+ op->getUnaryOpType(), lowered_value, lower(op->in()));
+ break;
+ }
+ case ExprType::BinaryOp: {
+ const auto op = def->as<BinaryOp>();
+ ir_builder_.create<kir::BinaryOp>(
+ op->getBinaryOpType(),
+ lowered_value,
+ lower(op->lhs()),
+ lower(op->rhs()));
+ break;
+ }
+ case ExprType::TernaryOp: {
+ const auto op = def->as<TernaryOp>();
+ ir_builder_.create<kir::TernaryOp>(
+ op->getTernaryOpType(),
+ lowered_value,
+ lower(op->in1()),
+ lower(op->in2()),
+ lower(op->in3()));
+ break;
+ }
+ default:
+ TORCH_CHECK(false, "Unexpected expression type");
}
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
}
- private:
- void handle(const Statement* node) final {
+ void handle(const Statement* node) override {
OptInConstDispatch::handle(node);
}
- void handle(const Val* node) final {
+ void handle(const Val* node) override {
OptInConstDispatch::handle(node);
}
- void handle(const Expr* node) final {
+ void handle(const Expr* node) override {
OptInConstDispatch::handle(node);
}
- void handle(const TensorDomain* node) final {
+ void handle(const TensorDomain* node) override {
const auto lowered_node = ir_builder_.create<kir::TensorDomain>(node);
- TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
+ TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
}
- void handle(const IterDomain* node) final {
+ void handle(const IterDomain* node) override {
const auto lowered_node = ir_builder_.create<kir::IterDomain>(node);
- TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
+ TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
}
- void handle(const TensorView* node) final {
+ void handle(const TensorView* node) override {
const auto lowered_node = ir_builder_.create<kir::TensorView>(node);
- TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
+ TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
}
- void handle(const Bool* node) final {
+ void handle(const Bool* node) override {
const auto lowered_node = ir_builder_.create<kir::Bool>(node);
- TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
+ TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
}
- void handle(const Double* node) final {
- const auto lowered_node = ir_builder_.create<kir::Double>(node);
- TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
+ void handle(const Float* node) override {
+ const auto lowered_node = ir_builder_.create<kir::Float>(node);
+ TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
}
- void handle(const Int* node) final {
- const auto lowered_node = ir_builder_.create<kir::Int>(node);
- TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
+ void handle(const Half* node) override {
+ const auto lowered_node = ir_builder_.create<kir::Half>(node);
+ TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
}
- void handle(const NamedScalar* node) final {
- const auto lowered_node = ir_builder_.create<kir::NamedScalar>(
- node->name(), node->getDataType().value());
- TORCH_CHECK(gpu_lower_->kir_val_map_.insert({node, lowered_node}).second);
- }
-
- void handle(const UnaryOp* node) final {
- const auto lowered_node = ir_builder_.create<kir::UnaryOp>(
- node->getUnaryOpType(),
- lowerValue(node->out()),
- lowerValue(node->in()));
- TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
- }
-
- void handle(const BinaryOp* node) final {
- const auto lowered_node = ir_builder_.create<kir::BinaryOp>(
- node->getBinaryOpType(),
- lowerValue(node->out()),
- lowerValue(node->lhs()),
- lowerValue(node->rhs()));
- TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
- }
-
- void handle(const TernaryOp* node) final {
- const auto lowered_node = ir_builder_.create<kir::TernaryOp>(
- node->getTernaryOpType(),
- lowerValue(node->out()),
- lowerValue(node->in1()),
- lowerValue(node->in2()),
- lowerValue(node->in3()));
- TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
- }
-
- void handle(const ReductionOp* node) final {
- auto out_tv = node->out()->as<TensorView>();
- // If trivial reduction operation lower to set operation.
- if (std::all_of(
- out_tv->domain()->domain().begin(),
- out_tv->domain()->domain().end(),
- [&](IterDomain* id) {
- // If id is a reduction axis, is it a trivial reduction?
- if (id->isReduction()) {
- return gpu_lower_->trivialReductionInfo().isDerived(id);
- } else {
- return true;
- }
- })) {
- const auto lowered_node = ir_builder_.create<kir::UnaryOp>(
- UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in()));
- TORCH_CHECK(
- gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
- return;
- }
-
- const auto lowered_node = ir_builder_.create<kir::ReductionOp>(
- node->getReductionOpType(),
- lowerValue(node->init()),
- lowerValue(node->out()),
- lowerValue(node->in()));
- TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
+ void handle(const Int* node) override {
+ const auto lowered_node = ir_builder_.create<kir::Int>(node, false);
+ TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
}
- void handle(const WelfordOp* node) final {
- auto lowerOptional = [&](Val* v) { return v ? lowerValue(v) : nullptr; };
- const auto lowered_node = ir_builder_.create<kir::WelfordOp>(
- lowerValue(node->outVar()),
- lowerValue(node->outAvg()),
- lowerValue(node->outN()),
- lowerValue(node->initVar()),
- lowerValue(node->initAvg()),
- lowerValue(node->initN()),
- lowerOptional(node->inVar()),
- lowerValue(node->inAvg()),
- lowerValue(node->inN()));
-
- TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
- }
-
- void handle(const BroadcastOp* node) final {
- const auto lowered_node = ir_builder_.create<kir::BroadcastOp>(
- lowerValue(node->out()), lowerValue(node->in()));
- TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
- }
-
- void handle(const TransposeOp* node) final {
- const auto lowered_node = ir_builder_.create<kir::UnaryOp>(
- UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in()));
- TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
- }
-
- void handle(const ShiftOp* node) final {
- const auto lowered_node = ir_builder_.create<kir::UnaryOp>(
- UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in()));
- TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
- }
-
- void handle(const GatherOp* node) final {
- const auto lowered_node = ir_builder_.create<kir::UnaryOp>(
- UnaryOpType::Set, lowerValue(node->out()), lowerValue(node->in()));
- TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second);
+ void handle(const NamedScalar* node) override {
+ const auto lowered_node = ir_builder_.create<kir::NamedScalar>(
+ node->name(), node->getDataType().value());
+ TORCH_CHECK(gpu_lower_->kir_map_.insert({node, lowered_node}).second);
}
private:
kir::IrBuilder ir_builder_;
};
-kir::Val* GpuLower::lowerValue(const Val* val) {
- KernelIrMapper kir_mapper(this);
- return kir_mapper.lowerValue(val);
+Val* GpuLower::lowerValue(const Val* val) {
+ TORCH_INTERNAL_ASSERT(!kir::isLoweredVal(val));
+ TORCH_INTERNAL_ASSERT(active_gpu_lower != nullptr);
+ KernelIrMapper kir_mapper(active_gpu_lower);
+ return kir_mapper.lower(val);
}
-kir::Expr* GpuLower::lowerExpr(const Expr* expr) {
+Val* GpuLower::getLowerValue(const Val* val) {
KernelIrMapper kir_mapper(this);
- return kir_mapper.lowerExpr(expr);
+ return kir_mapper.lower(val);
}
GpuLower* GpuLower::current() {
+ TORCH_INTERNAL_ASSERT(active_gpu_lower != nullptr);
return active_gpu_lower;
}
#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
-#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
-#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
-#include <torch/csrc/jit/codegen/cuda/parallel_dimension_map.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <memory>
#include <ostream>
namespace fuser {
namespace cuda {
-// TODO: we frequently use pairwise root mapping from consumers to producers.
-// This information is implicitly in the computeAtMaps, but there's no isolated
-// container for this information that we can reuse. Would be nice to generate
-// such a structure and propagate it through lowering.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class TORCH_CUDA_CU_API GpuLower {
class KernelIrMapper;
lower();
}
- kir::Kernel* kernel() const;
+ Kernel* kernel() const;
- //! Converts a Fusion IR value into the Kernel IR equivalent
- kir::Val* lowerValue(const Val* val);
+ // Converts a Fusion IR value into the Kernel IR equivalent
+ //
+ // TODO(kir): revisit this interface
+ //
+ static Val* lowerValue(const Val* val);
- //! Converts a Fusion IR expression into the Kernel IR equivalent
- kir::Expr* lowerExpr(const Expr* expr);
+ // TODO(kir): we have two methods which do almost the same thing
+ //
+ Val* getLowerValue(const Val* val);
//! Returns the currently active lowering object
//! (or nullptr if no lowering is in progress)
static GpuLower* current();
- const ThreadPredicateMap& threadPredMap() const {
- return thread_pred_map_;
- }
-
- const ComputeAtMap& caLoopMap() const {
- return ca_loop_map_;
- }
-
- const ComputeAtMap& caIndexMap() const {
- return ca_index_map_;
- }
-
- const ComputeAtMap& caParallelMap() const {
- return ca_parallel_map_;
- }
-
- const auto& trivialReductionInfo() const {
- return trivial_reduction_info_;
- }
-
- const HaloInfo& haloInfo() const {
- return halo_info_;
- }
-
- HaloInfo& haloInfo() {
- return halo_info_;
- }
-
- const ParallelDimensionMap& parallelDimensionMap() const {
- return parallel_dimension_map_;
- }
-
- ParallelDimensionMap& parallelDimensionMap() {
- return parallel_dimension_map_;
- }
-
- PredicateElimination& predicateElimination() {
- return pred_elimination_;
- }
-
- const PredicateElimination& predicateElimination() const {
- return pred_elimination_;
- }
-
private:
void lower();
private:
// Lowered Kernel IR
- std::unique_ptr<kir::Kernel> kernel_;
+ std::unique_ptr<Kernel> kernel_;
// Fusion IR node to Kernel IR node mapping
- std::unordered_map<const Val*, kir::Val*> kir_val_map_;
- std::unordered_map<const Expr*, kir::Expr*> kir_expr_map_;
-
- // Some stateful information during lowering
- ThreadPredicateMap thread_pred_map_;
- PredicateElimination pred_elimination_;
- ComputeAtMap ca_loop_map_;
- ComputeAtMap ca_index_map_;
- ComputeAtMap ca_parallel_map_;
- TrivialReductionInfo trivial_reduction_info_;
- HaloInfo halo_info_;
- ParallelDimensionMap parallel_dimension_map_;
+ std::unordered_map<const Val*, Val*> kir_map_;
Fusion* fusion_ = nullptr;
};
#include <torch/csrc/jit/codegen/cuda/lower_alias_memory.h>
+#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
+#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
+#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <sstream>
-#include <unordered_map>
-#include <unordered_set>
-
namespace torch {
namespace jit {
namespace fuser {
//! Get string representation of Allocate size for symbolic comparison
//!
-class SymbolicSizePrinter : private kir::IrVisitor {
+class SymbolicSizePrinter final : private OptOutConstDispatch {
public:
- static std::string printSize(const kir::Allocate* allocate) {
+ static std::string print_size(const kir::Allocate* alloc) {
SymbolicSizePrinter printer;
- allocate->size()->accept(&printer);
+ printer.handle(alloc->size());
return printer.os_.str();
}
private:
- void visit(const kir::Int* node) final {
- if (auto def = node->definition()) {
- def->accept(this);
- } else if (node->isConst()) {
- os_ << *node->value();
- } else {
- os_ << "ki" << node->id();
- }
+ void handle(const Val* v) final {
+ OptOutConstDispatch::handle(v);
}
- void visit(const kir::NamedScalar* named_scalar) final {
- os_ << "@" << named_scalar->name();
+ void handle(const Expr* e) final {
+ OptOutConstDispatch::handle(e);
+ }
+
+ void handle(const kir::Int* node) final {
+ if (auto def = FusionGuard::getCurFusion()->origin(node)) {
+ os_ << "( ";
+ handle(def);
+ os_ << " )";
+ return;
+ } else if (node->isSymbolic()) {
+ os_ << "i" << node->name();
+ } else {
+ os_ << *node->value();
+ }
}
- void visit(const kir::UnaryOp* unary_op) final {
- os_ << unary_op->operation() << "(";
- unary_op->accept(this);
- os_ << ")";
+ void handle(const kir::NamedScalar* node) final {
+ os_ << node->name();
}
- void visit(const kir::BinaryOp* binary_op) final {
- os_ << binary_op->operation() << "(";
- binary_op->lhs()->accept(this);
- os_ << ",";
- binary_op->rhs()->accept(this);
- os_ << ")";
+ void handle(const kir::BinaryOp* node) final {
+ if (auto inline_bop = inline_op_str(node->getBinaryOpType())) {
+ handle(node->lhs());
+ os_ << " " << inline_bop.value() << " ";
+ handle(node->rhs());
+ } else {
+ os_ << node->getBinaryOpType() << "(";
+ handle(node->lhs());
+ os_ << ", ";
+ handle(node->rhs());
+ os_ << ")";
+ }
}
private:
//! Reuse Allocation nodes via pointer aliasing
//!
-class AllocateReuseModifier {
- // Alias local memory if it exceeds this threshold
- static constexpr size_t kRegisterSizeThreshold = 1;
-
+class AllocateReuseModifier final : private OptOutDispatch {
public:
- void modify(const std::vector<kir::Expr*>& exprs) {
+ explicit AllocateReuseModifier(Fusion* fusion, size_t register_size_threshold)
+ : eval_evaluator_(fusion),
+ register_size_threshold_(register_size_threshold) {}
+
+ void modify(const std::vector<Expr*>& exprs) {
// Find candidate TensorViews and collect analysis information
for (auto expr : exprs) {
handle(expr);
// Iterate over candidates to find match
for (auto tv : candidate_alias_tv_) {
- const auto def = tv->definition();
- TORCH_INTERNAL_ASSERT(def != nullptr);
+ TORCH_INTERNAL_ASSERT(
+ map_tv_to_origin_expr_.find(tv) != map_tv_to_origin_expr_.end());
- const auto alloc_it = map_tv_to_allocations_.find(tv->name());
- TORCH_INTERNAL_ASSERT(alloc_it != map_tv_to_allocations_.end());
- const auto output_alloc = alloc_it->second;
+ const auto& expr = map_tv_to_origin_expr_[tv];
+ const auto output = expr->output(0)->as<TensorView>();
- const auto input_alloc = findCompatibleInputAllocate(
- tv->dtype(), SymbolicSizePrinter::printSize(output_alloc), def);
+ TORCH_INTERNAL_ASSERT(
+ map_tv_to_allocations_.find(output->name()) !=
+ map_tv_to_allocations_.end());
+ auto output_alloc = map_tv_to_allocations_[output->name()];
+
+ auto input_alloc = findCompatibleInputAllocate(
+ SymbolicSizePrinter::print_size(output_alloc), expr);
if (input_alloc != nullptr) {
+ // std::cout << "Alias Match\t" << output->getMemoryType() << std::endl;
output_alloc->setAlias(input_alloc);
}
}
}
private:
- // Do we have a true pointwise op?
- // (ie. a TV op, excluding direct assignments and reductions)
- static bool isPointwiseTvOp(const kir::Expr* expr) {
- if (ir_utils::isTVOp(expr)) {
- if (auto unary_op = dynamic_cast<const kir::UnaryOp*>(expr)) {
- return unary_op->operation() != UnaryOpType::Set;
- } else {
- return expr->isA<kir::BinaryOp>() || expr->isA<kir::TernaryOp>();
- }
- }
+ // Check if we are a Pointwise TensorView op.
+ bool isPwiseTVOp(const Expr* expr) {
+ // Ignore set operations
+ if (expr->outputs().size() == 1 && ir_utils::isTV(expr->output(0)) &&
+ ((expr->getExprType().value() == ExprType::UnaryOp &&
+ expr->as<UnaryOp>()->getUnaryOpType() != UnaryOpType::Set) ||
+ expr->getExprType().value() == ExprType::BinaryOp ||
+ expr->getExprType().value() == ExprType::TernaryOp))
+ return true;
return false;
}
// Find an Input Allocate that is compatible with the Output Allocate
- const kir::Allocate* findCompatibleInputAllocate(
- const DataType output_dtype,
+ kir::Allocate* findCompatibleInputAllocate(
const std::string& output_size_str,
- const kir::Expr* expr) {
+ Expr* expr) {
// Stop searching if current op is not point-wise
- if (!isPointwiseTvOp(expr)) {
+ if (!isPwiseTVOp(expr)) {
return nullptr;
}
- const kir::TensorView* first_tv_input = nullptr;
- for (const auto input : expr->inputs()) {
- if (auto input_tv = dynamic_cast<const kir::TensorView*>(input)) {
- if (first_tv_input == nullptr) {
- first_tv_input = input_tv;
- }
+ const auto& expr_inputs_iter =
+ ir_utils::filterByType<TensorView>(expr->inputs());
- // input_alloc == nullptr implies that input_tv is a kernel input
- const auto input_alloc = map_tv_to_allocations_[input_tv->name()];
- if (input_alloc != nullptr) {
- if (candidate_alias_tv_.find(input_tv) != candidate_alias_tv_.end() &&
- output_size_str == SymbolicSizePrinter::printSize(input_alloc) &&
- output_dtype == input_tv->dtype() &&
- map_tv_to_last_usage_[input_tv] <= map_expr_to_pos_[expr]) {
- return input_alloc;
- }
+ std::vector<TensorView*> expr_inputs(
+ expr_inputs_iter.begin(), expr_inputs_iter.end());
+
+ for (const auto input : expr_inputs) {
+ auto input_alloc = map_tv_to_allocations_[input->name()];
+
+ // input_allocation == nullptr implies that input_tv is a fusion input.
+ if (input_alloc != nullptr) {
+ if (candidate_alias_tv_.find(input) != candidate_alias_tv_.end() &&
+ output_size_str == SymbolicSizePrinter::print_size(input_alloc) &&
+ map_tv_to_last_usage_[input] <= map_expr_to_pos_[expr]) {
+ return input_alloc;
}
}
}
// Assume the first argument contains the primary variable
// Follow path along point-wise operations
- if (first_tv_input != nullptr &&
- map_tv_to_last_usage_[first_tv_input] <= map_expr_to_pos_[expr]) {
- if (const auto def = first_tv_input->definition()) {
- return findCompatibleInputAllocate(output_dtype, output_size_str, def);
+ if (!expr_inputs.empty()) {
+ auto first_input_argument_tv = expr_inputs.front()->getOrigin();
+ if (first_input_argument_tv != nullptr) {
+ return findCompatibleInputAllocate(
+ output_size_str, first_input_argument_tv);
}
}
-
return nullptr;
}
- void handle(kir::Expr* expr) {
- const size_t expr_index = map_expr_to_pos_.size();
+ void handle(Expr* expr) final {
+ size_t expr_index = map_expr_to_pos_.size();
map_expr_to_pos_[expr] = expr_index;
if (ir_utils::isTVOp(expr)) {
- const auto output_tv = expr->outputs()[0]->as<kir::TensorView>();
+ const auto output = expr->output(0)->as<TensorView>();
+ map_tv_to_origin_expr_[output] = expr;
- const auto alloc_it = map_tv_to_allocations_.find(output_tv->name());
- if (alloc_it != map_tv_to_allocations_.end()) {
- const bool smem_valid = (output_tv->memoryType() == MemoryType::Shared);
+ bool has_allocation = map_tv_to_allocations_.find(output->name()) !=
+ map_tv_to_allocations_.end();
+
+ if (has_allocation) {
+ bool smem_valid = output->getMemoryType() == MemoryType::Shared;
bool local_valid = false;
- if (output_tv->memoryType() == MemoryType::Local) {
- const auto allocation = alloc_it->second;
- const auto register_size =
- expr_evaluator_.evaluate(allocation->size());
- if (register_size.has_value()) {
- local_valid = size_t(*register_size) > kRegisterSizeThreshold;
+ if (output->getMemoryType() == MemoryType::Local) {
+ auto allocation = map_tv_to_allocations_[output->name()];
+ auto inferred_register_size =
+ eval_evaluator_.inferValue(allocation->size());
+ if (inferred_register_size.has_value()) {
+ local_valid = inferred_register_size.value() >
+ static_cast<int64_t>(register_size_threshold_);
}
}
// its allocation size must exceed the threshold
// OR be in shared memory
if (smem_valid || local_valid) {
- candidate_alias_tv_.insert(output_tv);
+ candidate_alias_tv_.insert(output);
}
}
- for (auto input_tv :
- ir_utils::filterByType<kir::TensorView>(expr->inputs())) {
- map_tv_to_last_usage_[input_tv] = expr_index;
+ const auto& expr_inputs =
+ ir_utils::filterByType<TensorView>(expr->inputs());
+ for (const auto input : expr_inputs) {
+ map_tv_to_last_usage_[input] = expr_index;
}
- } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
- handle(ite);
- } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
- handle(for_loop);
- } else if (auto allocate = dynamic_cast<kir::Allocate*>(expr)) {
- handle(allocate);
+ } else {
+ OptOutDispatch::handle(expr);
}
}
- void handle(kir::Allocate* allocate) {
- if (auto tv = dynamic_cast<const kir::TensorView*>(allocate->buffer())) {
- map_tv_to_allocations_[tv->name()] = allocate;
+ void handle(kir::Allocate* a) final {
+ if (a->buffer()->getValType().value() == ValType::KirTensorView) {
+ auto tv = a->buffer()->as<kir::TensorView>()->fuserTv();
+ map_tv_to_allocations_[tv->name()] = a;
}
}
- void handle(const kir::ForLoop* for_loop) {
- for (auto expr : for_loop->body().exprs()) {
+ void handle(kir::ForLoop* fl) final {
+ for (auto expr : fl->body().exprs()) {
handle(expr);
}
}
- void handle(const kir::IfThenElse* ite) {
+ void handle(kir::IfThenElse* ite) final {
for (auto expr : ite->thenBody().exprs()) {
handle(expr);
}
private:
// Expression Evaluator to infer size of register allocation
- kir::ExpressionEvaluator expr_evaluator_;
+ StatefulExpressionEvaluator eval_evaluator_;
+
+ // Alias local memory if it exceeds this threshold
+ const size_t register_size_threshold_;
// Map expression to unique position
- // TODO: elaborate - position relative to what?
- std::unordered_map<const kir::Expr*, size_t> map_expr_to_pos_;
+ std::unordered_map<Expr*, size_t> map_expr_to_pos_;
+
+ // Map TensorView to origin expression
+ std::unordered_map<const TensorView*, Expr*> map_tv_to_origin_expr_;
// Map TensorView to last usage expression position
- std::unordered_map<const kir::TensorView*, size_t> map_tv_to_last_usage_;
+ std::unordered_map<const TensorView*, size_t> map_tv_to_last_usage_;
// Map TensorView name to Allocate node
- std::unordered_map<StmtNameType, kir::Allocate*> map_tv_to_allocations_;
+ std::unordered_map<unsigned int, kir::Allocate*> map_tv_to_allocations_;
// Track candidate TensorViews whose Allocate nodes
// could potentially alias another Allocate node
- std::unordered_set<const kir::TensorView*> candidate_alias_tv_;
+ std::unordered_set<const TensorView*> candidate_alias_tv_;
};
} // namespace
-std::vector<kir::Expr*> reuseMemoryAllocations(
- const std::vector<kir::Expr*>& exprs) {
- FUSER_PERF_SCOPE("GpuLower::Lower::reuseMemoryAllocations");
- AllocateReuseModifier arm;
+std::vector<Expr*> reuseMemoryAllocations(
+ Fusion* fusion,
+ const std::vector<Expr*>& exprs) {
+ FUSER_PERF_SCOPE("reuseMemoryAllocations");
+ FusionGuard fg(fusion);
+
+ // Alias local memory if it exceeds this threshold
+ const size_t register_size_threshold = 1;
+ AllocateReuseModifier arm(fusion, register_size_threshold);
arm.modify(exprs);
return exprs;
}
//! is not used after this op:
//! then alias output Allocate to input Allocate.
//!
-std::vector<kir::Expr*> reuseMemoryAllocations(
- const std::vector<kir::Expr*>& exprs);
+std::vector<Expr*> reuseMemoryAllocations(
+ Fusion* fusion,
+ const std::vector<Expr*>& exprs);
} // namespace cuda
} // namespace fuser
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_allocation.h>
-
-#include <unordered_set>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-class AllocationInserter : public kir::MutableIrVisitor {
- private:
- struct AllocationInformation {
- // The for loop that the allocation must be placed in, nullptr if not within
- // a loop
- kir::ForLoop* for_loop = nullptr;
-
- // The expression that this allocation must be placed before
- kir::Expr* place_before = nullptr;
-
- // The allocation position relative to buffer
- size_t alloc_pos = 0;
-
- // The buffer this allocation is for
- kir::TensorView* buffer = nullptr;
-
- // The allocation expression
- kir::Allocate* alloc_expr = nullptr;
-
- // Initialization
- kir::Expr* init_expr = nullptr;
- };
-
- // Find allocation point
- void findAllocationPosition(AllocationInformation& info, kir::Expr* expr) {
- size_t alloc_pos = 0;
- kir::ForLoop* for_loop = nullptr;
- auto fuser_tv = info.buffer->fuserTv();
- size_t fl_idx_next = 0;
-
- for (auto fl : for_loops) {
- if (alloc_pos == fuser_tv->getComputeAtPosition()) {
- break;
- }
-
- if (fuser_tv->axis(alloc_pos)->isReduction()) {
- const auto outputs =
- FusionGuard::getCurFusion()->getTerminatingOutputs();
- TORCH_INTERNAL_ASSERT(
- std::find(outputs.begin(), outputs.end(), fuser_tv) !=
- outputs.end(),
- "Invalid computeAt of T",
- fuser_tv->name(),
- ". A reducation axis is detected within computeAt axes even though it is not an output tensor.");
- break;
- }
-
- auto fl_id = fl->iter_domain();
-
- if (fl_id->parallelType() == ParallelType::Unroll) {
- break;
- }
-
- auto local_id = gpu_lower->lowerValue(fuser_tv->axis(alloc_pos))
- ->as<kir::IterDomain>();
-
- if (gpu_lower->caLoopMap().areMapped(local_id, fl_id)) {
- alloc_pos++;
- }
-
- for_loop = fl;
- ++fl_idx_next;
- }
-
- info.alloc_pos = alloc_pos;
- info.for_loop = for_loop;
-
- if (info.for_loop == nullptr) {
- info.place_before = for_loops.size() > 0 ? for_loops[0] : expr;
- } else {
- if (info.for_loop == for_loops.back()) {
- // Inline allocation, place before expr
- info.place_before = expr;
- } else {
- // Place allocation after the last computeAt axis
- // TODO: may be more efficient to place before the first non-computeAt
- // axis
- info.place_before = for_loops.at(fl_idx_next);
- }
- }
- }
-
- // Create initialization expression if init_val is non-null.
- void createInitExpr(AllocationInformation& info, kir::Val* init_val) {
- if (init_val == nullptr) {
- info.init_expr = nullptr;
- return;
- }
-
- auto fuser_tv = info.buffer->fuserTv();
-
- std::vector<kir::IterDomain*> init_dims;
- for (size_t axis_i = info.alloc_pos; axis_i < fuser_tv->nDims(); axis_i++) {
- if (info.buffer->fuserTv()->axis(axis_i)->isReduction() ||
- info.buffer->fuserTv()->axis(axis_i)->isBroadcast()) {
- continue;
- }
- auto concrete_id =
- gpu_lower
- ->lowerValue(gpu_lower->caParallelMap().getConcreteMappedID(
- fuser_tv->axis(axis_i)))
- ->as<kir::IterDomain>();
- init_dims.push_back(concrete_id);
- }
- kir::Expr* init_expr = ir_builder.create<kir::UnaryOp>(
- UnaryOpType::Set, info.buffer, init_val);
- for (auto init_loop_it = init_dims.rbegin();
- init_loop_it != init_dims.rend();
- ++init_loop_it) {
- auto id = *init_loop_it;
- kir::ForLoop* new_loop = nullptr;
- auto extent_with_halo = gpu_lower->haloInfo().getExtent(id);
- if (extent_with_halo) {
- new_loop = ir_builder.create<kir::ForLoop>(
- id,
- ir_builder.create<kir::Int>(c10::nullopt),
- nullptr,
- extent_with_halo,
- nullptr,
- false,
- nullptr);
- } else {
- new_loop = ir_builder.create<kir::ForLoop>(id);
- }
- new_loop->body().push_back(init_expr);
- init_expr = new_loop;
- }
- info.init_expr = init_expr;
- }
-
- std::vector<kir::Val*> getGlobalAllocationSizes(AllocationInformation& info) {
- const auto& domain = info.buffer->domain();
- const auto& maybe_rfactor_domain =
- domain->hasRFactor() ? domain->rfactorDomain() : domain->rootDomain();
-
- std::vector<kir::Val*> alloc_dims;
-
- for (const auto id : maybe_rfactor_domain) {
- if (id->isReduction() ||
- id->iterType() == IterType::BroadcastWithoutStride) {
- continue;
- }
- auto extent = id->extent();
- // Use halo-extended extent if found
- auto halo_extent = gpu_lower->haloInfo().getRootAxisInfo(id);
- if (halo_extent.hasHalo()) {
- extent = ir_builder.addExpr(extent, halo_extent.width());
- }
- alloc_dims.push_back(extent);
- }
-
- return alloc_dims;
- }
-
- // Get allocation extents of root axes with halo
- //
- // Allocation can be done with leaf IDs with halo as well, but
- // allocation size could be larger than necessary.
- //
- // For example, suppose the shift offset of an axis is 1. When it is
- // split by N, the halo size of the inner output is N+1. When the
- // allocation only has the inner split output, the allocation size
- // would be N+1. Suppose that ID is further split by M, the output
- // extents would be N/M and M+1. The allocation size based on the
- // leaves would be N/M*(M+1) or N+N/M, which is larger than N+1.
- //
- // This function tries to propagate back halo informatin to root
- // axes to avoid inflating allocations. It fails when merged domains
- // are split and only one of the split outputs is used for
- // allocations since in such a case we can't un-merge and properly
- // determine the extents of the merge inputs. Currently, that
- // results in an exception, but it may be more reasonable to simply
- // fall back to the leaf-based allocation.
- //
- // See the FusionShiftDoubleSplit test for an example case.
- std::vector<kir::Val*> getNonGlobalAllocExprWithHalo(
- TensorView* tv,
- const std::vector<IterDomain*>& alloc_domains) {
- std::vector<Val*> start_vals;
- std::transform(
- alloc_domains.begin(),
- alloc_domains.end(),
- std::back_inserter(start_vals),
- [](IterDomain* dom) { return dom->as<Val>(); });
-
- // Get all exprs involved in generating the allocation IDs
- auto exprs = ExprSort::getExprs(tv->fusion(), start_vals);
-
- // Get the halo extent if found
- auto getExtent = [this](IterDomain* id) {
- auto extent = gpu_lower->haloInfo().getExtent(id);
- if (extent == nullptr) {
- extent = gpu_lower->lowerValue(id->extent());
- }
- return extent;
- };
-
- std::unordered_map<IterDomain*, kir::Val*> known_extents;
-
- // IterDomains that are allocated fully. For example, if an ID is
- // split and only one of them is used for allocation, that's not
- // considered full. Only full domains can be unmerged, which is
- // needed to propagate back the halo information to root domains.
- std::unordered_set<IterDomain*> full_domains;
-
- for (auto alloc_domain : alloc_domains) {
- known_extents.insert({alloc_domain, getExtent(alloc_domain)});
- full_domains.insert(alloc_domain);
- }
-
- for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) {
- auto expr = *it;
- if (auto merge = dynamic_cast<Merge*>(expr)) {
- auto out_it = known_extents.find(merge->out());
- // If nothing is know about the out id, no propagation can be
- // done. Note that's not necessarily an error.
- if (out_it == known_extents.end()) {
- continue;
- }
- // Similarly, if the extent of the out id is not full extent,
- // we can't un-merge it.
- if (full_domains.find(merge->out()) == full_domains.end()) {
- continue;
- }
- // Since the extent of the out id is full, the extent of each
- // of the input axes is also full
- known_extents.insert({merge->inner(), getExtent(merge->inner())});
- full_domains.insert(merge->inner());
- known_extents.insert({merge->outer(), getExtent(merge->outer())});
- full_domains.insert(merge->outer());
- known_extents.erase(out_it);
- } else if (auto split = dynamic_cast<Split*>(expr)) {
- auto inner = split->inner();
- const auto inner_it = known_extents.find(inner);
- auto outer = split->outer();
- const auto outer_it = known_extents.find(outer);
- if (inner_it != known_extents.end() &&
- outer_it != known_extents.end()) {
- if (full_domains.find(inner) != full_domains.end() &&
- full_domains.find(outer) != full_domains.end()) {
- known_extents.insert({split->in(), getExtent(split->in())});
- full_domains.insert(split->in());
- } else {
- known_extents.insert(
- {split->in(),
- ir_builder.mulExpr(outer_it->second, inner_it->second)});
- }
- known_extents.erase(inner_it);
- known_extents.erase(outer_it);
- } else if (inner_it != known_extents.end()) {
- known_extents.insert({split->in(), inner_it->second});
- known_extents.erase(inner_it);
- } else if (outer_it != known_extents.end()) {
- known_extents.insert({split->in(), outer_it->second});
- known_extents.erase(outer_it);
- }
- } else {
- TORCH_INTERNAL_ASSERT(false, "Unexpected expr: ", expr);
- }
- }
-
- std::vector<kir::Val*> alloc_dims;
-
- for (auto root_axis : tv->getRootDomain()) {
- auto it = known_extents.find(root_axis);
- if (it == known_extents.end()) {
- continue;
- }
- alloc_dims.push_back(it->second);
- known_extents.erase(it);
- }
-
- // known_extents should have only mappings for root axes, so
- // if anything remains in the map, it's an error
- if (!known_extents.empty()) {
- std::stringstream ss;
- for (auto kv : known_extents) {
- ss << kv.first << " ";
- }
- TORCH_INTERNAL_ASSERT(
- false, "Non-root axes found for TV", tv->name(), ": ", ss.str());
- }
-
- return alloc_dims;
- }
-
- std::vector<kir::Val*> getNonGlobalAllocExpr(AllocationInformation& info) {
- auto fuser_tv = info.buffer->fuserTv();
- const auto memory_type = info.buffer->memoryType();
- TORCH_INTERNAL_ASSERT(
- memory_type != MemoryType::Global,
- "Invalid memory type: ",
- memory_type);
-
- std::vector<kir::Val*> alloc_dims;
-
- bool has_halo = false;
- std::vector<IterDomain*> alloc_domains;
-
- for (size_t axis_i = 0; axis_i < fuser_tv->nDims(); axis_i++) {
- const auto local_id =
- gpu_lower->lowerValue(fuser_tv->axis(axis_i))->as<kir::IterDomain>();
-
- if (
- // If we're reducing this dimension, don't use it in the allocation
- // computation
- local_id->isReduction() ||
- // If this is a broadcast dimension, don't use it in the allocation
- // computation
- local_id->isBroadcast()) {
- continue;
- }
-
- auto concrete_id =
- gpu_lower
- ->lowerValue(gpu_lower->caParallelMap().getConcreteMappedID(
- fuser_tv->axis(axis_i)))
- ->as<kir::IterDomain>();
- const bool is_block_dim =
- isParallelTypeBlockDim(concrete_id->parallelType());
- const bool is_thread_dim =
- isParallelTypeThreadDim(concrete_id->parallelType());
- const bool is_thread = isParallelTypeThread(concrete_id->parallelType());
-
- if (axis_i < info.alloc_pos) {
- // Even when the axis is outside the allocation position, if the
- // tensor is shared with respect to the axis, the buffer size
- // needs to be expanded for the axis. Sharing occurs in two
- // cases: 1) the tensor is on shared memory with the axis
- // parallelized by TIDs, and 2) the tensor is on global memory
- // with the axis parallelized by TIDs or BIDs.
- if (!((memory_type == MemoryType::Shared && is_thread_dim) ||
- (memory_type == MemoryType::Global && is_thread))) {
- continue;
- }
- alloc_domains.push_back(fuser_tv->axis(axis_i));
- } else {
- if (
- // If shared memory, don't use any IDs bound to a grid dimension
- (memory_type == MemoryType::Shared && is_block_dim) ||
- // If local memory, don't use any IDs bound to a grid or block
- // dimension
- (memory_type == MemoryType::Local && is_thread)) {
- continue;
- }
- alloc_domains.push_back(fuser_tv->axis(axis_i));
- }
-
- auto extent = concrete_id->extent();
-
- if (gpu_lower->haloInfo().getExtent(fuser_tv->axis(axis_i)) != nullptr) {
- has_halo = true;
- }
-
- alloc_dims.push_back(extent);
- }
-
- // When an axis with halo extension is detected, propagate back
- // the halo extents from leaf IDs to root IDs
- if (has_halo) {
- return getNonGlobalAllocExprWithHalo(fuser_tv, alloc_domains);
- }
-
- return alloc_dims;
- }
-
- void createAllocExpr(AllocationInformation& info, bool is_output) {
- if (is_output) {
- info.alloc_expr = nullptr;
- return;
- }
-
- std::vector<kir::Val*> alloc_dims;
- const MemoryType memory_type = info.buffer->memoryType();
-
- if (memory_type == MemoryType::Global) {
- alloc_dims = getGlobalAllocationSizes(info);
- } else {
- alloc_dims = getNonGlobalAllocExpr(info);
- }
-
- if (alloc_dims.size() == 0 &&
- info.buffer->domain()->noReductions().size() != 0) {
- alloc_dims.push_back(ir_builder.create<kir::Int>(1));
- }
-
- // Create the allocation node
- info.alloc_expr = ir_builder.create<kir::Allocate>(
- info.buffer, info.buffer->memoryType(), alloc_dims);
- }
-
- void handle(kir::Expr* expr) {
- if (!ir_utils::isTVOp(expr) || expr->isA<kir::Allocate>()) {
- expr->accept(this);
- return;
- }
-
- // // Found where the allocation needs to be inserted
-
- for (auto out : expr->outputs()) {
- if (!out->isA<kir::TensorView>()) {
- continue;
- }
-
- auto out_tv = out->as<kir::TensorView>();
- auto default_val =
- gpu_lower->predicateElimination().getInitValue(out_tv->fuserTv());
-
- kir::Val* init = nullptr;
- if (expr->isA<kir::ReductionOp>() && out_tv->fuserTv()->hasReduction()) {
- TORCH_INTERNAL_ASSERT(
- default_val == nullptr,
- "Reduction should not have a default initialization value for predicate elimination.");
- init = expr->as<kir::ReductionOp>()->init();
- } else if (expr->isA<kir::WelfordOp>()) {
- TORCH_INTERNAL_ASSERT(
- default_val == nullptr,
- "Welford should not have a default initialization value for predicate elimination.");
- const auto welford = expr->as<kir::WelfordOp>();
- if (out->id() == welford->outVar()->id()) {
- init = welford->initVar() == nullptr
- ? ir_builder.create<kir::Double>(0)
- : welford->initVar();
- } else if (out->id() == welford->outAvg()->id()) {
- init = welford->initAvg() == nullptr
- ? ir_builder.create<kir::Double>(0)
- : welford->initAvg();
- } else {
- TORCH_INTERNAL_ASSERT(
- out->id() == welford->outN()->id(), "Unreachable");
- init = welford->initN();
- }
- } else if (default_val != nullptr) {
- init = default_val;
- }
-
- const bool is_output = gpu_lower->kernel()->isOutput(out);
-
- // Don't need to alloc outputs, and if we don't need to initialize we're
- // done.
- if (is_output && init == nullptr) {
- continue;
- }
-
- AllocationInformation allocation;
- allocation.buffer = out_tv;
- findAllocationPosition(allocation, expr);
- createAllocExpr(allocation, is_output);
- createInitExpr(allocation, init);
-
- allocs.push_back(allocation);
- }
- }
-
- void visit(kir::ForLoop* fl) final {
- for_loops.push_back(fl);
- // Modifying in place, make a copy of the vector
- const std::vector<kir::Expr*> exprs = fl->body().exprs();
- for (auto expr : exprs) {
- handle(expr);
- }
- for_loops.pop_back();
- }
-
- void visit(kir::IfThenElse*) final {
- TORCH_INTERNAL_ASSERT(
- false,
- "Pass does not support conditional statements, ",
- "this pass should be run before any conditionals are placed in code.");
- }
-
- AllocationInserter(std::vector<kir::Expr*> _loop_nests)
- : loop_nests_(std::move(_loop_nests)),
- gpu_lower(GpuLower::current()),
- ir_builder(gpu_lower->kernel()) {
- // Compute all allocations
- const std::vector<kir::Expr*> exprs = loop_nests_;
- for (auto expr : exprs) {
- handle(expr);
- }
-
- // First, place allocations of dynamic smem tensors at the very
- // beginning of the expr list. Traverse backward as they should be
- // placed in topological order.
- for (auto it = allocs.rbegin(); it != allocs.rend(); ++it) {
- const auto& alloc = *it;
- if (alloc.alloc_expr == nullptr) {
- continue;
- }
- // Dynamic smem exprs need to be at the begining of the kernel outside for
- // loops
- if (alloc.buffer->memoryType() == MemoryType::Shared &&
- !kir::ExpressionEvaluator::isConst(alloc.alloc_expr->size())) {
- loop_nests_.insert(loop_nests_.begin(), alloc.alloc_expr);
- }
- }
-
- // Place the remaining allocations.
- for (const auto& alloc : allocs) {
- if (alloc.alloc_expr == nullptr) {
- continue;
- }
- if (alloc.buffer->memoryType() == MemoryType::Shared &&
- !kir::ExpressionEvaluator::isConst(alloc.alloc_expr->size())) {
- continue;
- }
- if (alloc.for_loop == nullptr) {
- auto place_before_it = std::find(
- loop_nests_.begin(), loop_nests_.end(), alloc.place_before);
- TORCH_INTERNAL_ASSERT(
- place_before_it != loop_nests_.end(),
- "Could not figure out where to place allocation. ",
- "Use of the buffer, ",
- toString(alloc.buffer),
- ", could not be found.",
- toString(alloc.place_before));
- loop_nests_.insert(place_before_it, alloc.alloc_expr);
- } else {
- alloc.for_loop->body().insert_before(
- alloc.place_before, alloc.alloc_expr);
- }
- }
-
- // Now that allocations are in place, place the initializations
- for (const auto& alloc : allocs) {
- if (alloc.init_expr == nullptr) {
- continue;
- }
- if (alloc.for_loop == nullptr) {
- auto place_before_it = std::find(
- loop_nests_.begin(), loop_nests_.end(), alloc.place_before);
- // Don't need a check here as if the allocation placement succeeded
- // this will too
- loop_nests_.insert(place_before_it, alloc.init_expr);
- } else {
- alloc.for_loop->body().insert_before(
- alloc.place_before, alloc.init_expr);
- }
- }
- }
-
- private:
- std::deque<AllocationInformation> allocs;
-
- std::vector<kir::ForLoop*> for_loops;
-
- std::vector<kir::Expr*> loop_nests_;
-
- GpuLower* gpu_lower;
-
- kir::IrBuilder ir_builder;
-
- public:
- static std::vector<kir::Expr*> insert(
- const std::vector<kir::Expr*>& loop_nests) {
- AllocationInserter inserter(loop_nests);
- return inserter.loop_nests_;
- }
-};
-
-} // namespace
-
-std::vector<kir::Expr*> insertAllocations(
- const std::vector<kir::Expr*>& exprs) {
- FUSER_PERF_SCOPE("GpuLower::Lower::insertAllocations");
- return AllocationInserter::insert(exprs);
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Insert buffer allocations
-std::vector<kir::Expr*> insertAllocations(const std::vector<kir::Expr*>& exprs);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_expr_sort.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-
-#include <deque>
-#include <list>
-#include <sstream>
-#include <unordered_map>
-#include <unordered_set>
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-// TODO: Review const model, and objects
-// ExprSegmentationSorter
-// Responsible for going through DAG and proposing things we could try to
-// merge together, calls "supportedMerge" on these proposed groups to see
-// if they should be merged together, then merges them if so.
-// ExprGroup
-// A group of exprs that are grouped together based on their loop nest
-// structures.
-// ExprGroupConnections
-// Holds vals and what they connect. In other words it's a val that is an
-// output of a ExprSegmentationSorter "from" and an input of
-// ExprSegmentationSorter "to". There's nothing preventing from a val being
-// between groups twice.
-// TODO: make sure there's nothing wrong with grouping of nodes that
-// have the same value input twice. i.e. (B = A*A)
-
-// Selecting segments to propose is based on the theorem 4.2 in the paper which
-// makes sure when segment the segmented graph will be a DAG (assumes Fusion is
-// already a DAG). The segmentation code relies on assumptions of DAG-ness
-// during segmentation, meaning proposed merging of groups must maintain the DAG
-// property of the graph.
-//
-// Julien Herrmann, Yusuf Özkaya, Bora Uçar, Kamer Kaya, Umit Catalyurek.
-// Multilevel Algorithms for Acyclic Partitioning of Directed Acyclic Graphs.
-// SIAM Journal on Scientific Computing, Society for Industrial and Applied
-// Mathematics, 2019, 41 (4), pp.A2117-A2145. ff10.1137/18M1176865ff.
-// ffhal02306566f
-
-class ExprGroup;
-struct ExprGroupConnections;
-class ExprSegmentationSorter;
-
-// Debug printing disabled due to clang tidy, see below for definitions
-// std::ostream& operator<<(std::ostream& os, const ExprGroup* group);
-
-// Wrapper for values, these are edges between expr groups. Multiple edges can
-// exist between expr groups, and the same Val can show up more than once in
-// multiple edges.
-struct ExprGroupConnections {
- ExprGroupConnections(
- ExprGroup* group_from,
- ExprGroup* group_to,
- Val* producer_val,
- Val* consumer_val)
- : from(group_from),
- to(group_to),
- producer_val_(producer_val),
- consumer_val_(consumer_val) {}
- // Producer group from which the edge starts
- ExprGroup* from;
-
- // Consumer group from which the edge ends
- ExprGroup* to;
-
- // The value from the producer group connecting the groups
- // This value helps us resolve the compute at position of expr groups
-
- Val* producer_val_;
-
- // The value that the producer val gets used to create on this edge
- // This value helps us resolve the produce at position of expr groups
- Val* consumer_val_;
-};
-
-struct ExprSortPayload : public PolymorphicBase {
- // Need to track compute at domains as well as produce at domains. Produce at
- // domains will be matched to producers compute at domains. Track the active
- // domains that will be matched from inner most dim to outer most.
- std::vector<IterDomain*> ca_domains_;
- std::vector<IterDomain*> pa_domains_;
-
- // Maximum path distance from an input expr group required for
- // Theorem 4.2
- int level = -1;
-
- // Traversal marker, marks if this group has been visited by current pass
- bool visited = false;
-
- // Marks if this group is already selected to merge with another group, marks
- // which group to merge with
- ExprGroup* merge_with = nullptr;
-
- // Marks if this group is already selected to merge with another group
- bool merged = false;
-};
-
-// Groups together expressions which create a expr group
-class ExprGroup {
- public:
- ExprGroup() : payload_(std::make_unique<ExprSortPayload>()) {}
-
- ExprGroup(Expr* expr) : payload_(std::make_unique<ExprSortPayload>()) {
- exprs_.push_back(expr);
- }
-
- ExprGroup(const ExprGroup& other)
- : payload_(new ExprSortPayload(*(other.payload_))) {}
-
- ExprGroup& operator=(const ExprGroup& other) {
- *payload_ = *other.payload_;
- exprs_ = other.exprs_;
- return *this;
- }
-
- // Clears the traversal information in the payload
- void clearTraversalInfo();
-
- // Returns all neighbors, producers and consumers
- std::vector<ExprGroup*> getNeighbors();
-
- // Return neighbors of this proven to be safe nodes to merge with in regards
- // to maining an acyclic graph. This looks at, neighbors if merged, neighbors
- // level, and merged neighbors of neighbors level. If fallback_mode_enabled
- // will return the inverse set of ExprGroups that are proven to be safe
- // merges.
- std::vector<ExprGroup*> getMergeCandidates(
- bool fallback_mode_enabled = false);
-
- std::unique_ptr<ExprSortPayload>& payload() {
- return payload_;
- }
-
- const auto& producerEdges() const {
- return producer_edges_;
- }
-
- void addProducerEdge(ExprGroupConnections* edge) {
- addEdge(producer_edges_, edge);
- }
-
- void removeProducerEdge(ExprGroupConnections* edge) {
- removeEdge(producer_edges_, edge);
- }
-
- void clearProducerEdges() {
- producer_edges_.clear();
- }
-
- const auto& consumerEdges() const {
- return consumer_edges_;
- }
-
- void addConsumerEdge(ExprGroupConnections* edge) {
- addEdge(consumer_edges_, edge);
- }
-
- void removeConsumerEdge(ExprGroupConnections* edge) {
- removeEdge(consumer_edges_, edge);
- }
-
- void clearConsumerEdges() {
- consumer_edges_.clear();
- }
-
- auto& exprs() {
- return exprs_;
- }
-
- const auto& exprs() const {
- return exprs_;
- }
-
- private:
- static void addEdge(
- std::vector<ExprGroupConnections*>& edges,
- ExprGroupConnections* edge_to_add) {
- edges.push_back(edge_to_add);
- }
-
- static void removeEdge(
- std::vector<ExprGroupConnections*>& edges,
- ExprGroupConnections* edge_to_remove) {
- auto it = std::find(edges.begin(), edges.end(), edge_to_remove);
- TORCH_INTERNAL_ASSERT(it != edges.end(), "Could not find edge to remove.");
- edges.erase(it);
- }
-
- private:
- // "Ancestor nodes", towards inputs of segmentedDAG
- std::vector<ExprGroupConnections*> producer_edges_;
-
- // "Descendent nodes", towards outputs of segmentedDAG
- std::vector<ExprGroupConnections*> consumer_edges_;
-
- // Exprs that make up the group
- std::vector<Expr*> exprs_;
-
- // Stateful traversal information
- std::unique_ptr<ExprSortPayload> payload_;
-};
-
-// This class sorts expressions guarantees two things, 1) Tensors are produced
-// before they're consumed 2) If the production of two tensors are supposed to
-// share a for loop, they're in an order where they can. (1) is pretty standard
-// of ordering a DAG. (2) is where things get a bit complicated and why we do
-// this sorting through segmentation. Consider a section of a DAG: T4 = T3 + T2.
-// Where T2 and T3 are not inputs to the fusion, all tensors are 3D, and we want
-// the production of T3 to share the inner most loop of T4 and we want the
-// production of T2 to share the middle loop with T4. i.e. we're looking for
-// For(i:I){
-// For(j: J){
-// For(k: K){
-// T2[i, j, k] = ...
-// }
-// For(k: K){
-// T3[i, j, k] = ...
-// T4[i, j, k] = T2[i, j, k] + T3[i, j, k]
-// }
-// }
-// }
-// The only valid ordering of expressions is producing T2, then T3, then T4. If
-// we swapped T3 and T2, then T3 and T4 couldn't share their inner most loop,
-// because T2 has its own inner most loop. If we swapped either tensor with T4,
-// then we'd try to be using T2 or T3 without producing them (back to gaurantee
-// 1).
-class ExprSegmentationSorter {
- public:
- ExprSegmentationSorter(Fusion* fusion) : complete_fusion_(fusion) {}
-
- void sort();
-
- std::string toString(int verbosity = 0) const;
-
- //! Returns a flattened list of sorted exprs
- std::vector<Expr*> getExprs() const;
-
- private:
- // Allocate an empty expr group and return it
- ExprGroup* makeEmptyGroup();
-
- // Allocate an expr group with the provided expr and return it
- ExprGroup* makeEmptyGroup(Expr*);
-
- // Returns if sg1 and sg2 should be merged together, is called if they can
- // based on the current status of the DAG.
- bool supportedMerge(ExprGroup* sg1, ExprGroup* sg2);
-
- // Returns true if the graph will remain an acyclic graph after merging sg1
- // and sg2
- bool testStillDag(ExprGroup* sg1, ExprGroup* sg2);
-
- // Merges two ExprGroups and returns the new ExprGroup
- ExprGroup* makeMergedNode(ExprGroup* sg1, ExprGroup* sg2);
-
- // This is called once no more groups can be merged together. This will lower
- // the compute at position of a segment group if the last dimension of the
- // segment group doesn't map to any of the dimensions of its neighbors.
- bool interIterUpdate();
-
- // Reset the ExprSortPayload of the groups so we can traverse and identify
- // merge candidates.
- void resetTraversal();
-
- // Reset the set levels of each group. This is what's used to identify which
- // nodes can be merged together.
- void resetLevels();
-
- // Go through groups that are marked as to merge and merge them.
- void mergeNodes();
-
- // Disconnect the edges connecting group to the rest of the graph, and return
- // all the edges that were disconnected
- std::unordered_set<ExprGroupConnections*> disconnectGroup(ExprGroup* group);
-
- private:
- // Track how many groups we have from iteration to iteration so we can track
- // when we've stopped merging nodes.
- size_t n_groups_ = 0;
-
- // Lifetime of the graph view of the fusion and segmentation. Use list to not
- // invalidate any entries on insertion/deletion.
- std::list<std::unique_ptr<ExprGroupConnections>> edges_;
- std::list<std::unique_ptr<ExprGroup>> groups_;
-
- std::deque<ExprGroup*> to_visit_;
-
- std::unordered_set<ExprGroup*> to_merge_;
-
- // Maintain my own fusion the state of which is not always the same as the
- // original provided fusion.
- Fusion* complete_fusion_;
-
- // We use a theorem out of a paper mentioned in other comments. This theorem
- // is good at identifying multiple expr groups to merge during a single
- // iteration without producing a cyclic graph from an acyclic graph. This
- // theorem is not guaranteed to find all possible nodes that can be merged
- // together. We need to be able to group all disjoint groups of exprs or
- // we fail to generate code. Therefore, if we can't find anything to make
- // forward progress based on the theorem we fallback to manually looking if we
- // can segmenet all combinations we haven't previously looked at.
- bool fallback_mode_enabled_ = false;
-};
-
-// // Debug printing, disabled due to clang-tidy see above for declarations.
-// std::ostream& operator<<(std::ostream& os, ExprGroup* group) {
-// os << "Group Start{\n ca, pa ("
-// << group->payload()->ca_domains_.size() << ", "
-// << group->payload()->pa_domains_.size() << ")";
-// os << " ca_ids {";
-// for (size_t i = 0; i < group->payload()->ca_domains_.size(); i++) {
-// os << group->payload()->ca_domains_[i];
-// if (i + 1 != group->payload()->ca_domains_.size())
-// os << ", ";
-// }
-// os << "} pa_ids {";
-// for (size_t i = 0; i < group->payload()->pa_domains_.size(); i++) {
-// os << group->payload()->pa_domains_[i];
-// if (i + 1 != group->payload()->pa_domains_.size())
-// os << ", ";
-// }
-// os << "}";
-// os << "\nExprs {\n";
-// for(auto expr : group->exprs()){
-// os << expr;
-// }
-// os << "}Group End\n";
-// return os;
-// }
-
-std::vector<ExprGroup*> ExprGroup::getNeighbors() {
- std::vector<ExprGroup*> neighbors;
- for (auto inp : producer_edges_) {
- neighbors.push_back(inp->from);
- }
- for (auto out : consumerEdges()) {
- neighbors.push_back(out->to);
- }
- return neighbors;
-}
-
-std::vector<ExprGroup*> ExprGroup::getMergeCandidates(
- bool fallback_mode_enabled) {
- std::vector<ExprGroup*> neighbors = getNeighbors();
-
- // Don't look for candidates if already merged
- if (payload()->merged) {
- return {};
- }
-
- // Can this node be merged with another? Check if neighbors are merged, if
- // so and merged neighbor is within 1 level or node merged with neighbor is
- // within 1 level, can't merge this node with anything else.
- bool can_merge_this = true;
- bool neighbor_merged = false;
- for (auto neighbor : neighbors) {
- if (!neighbor->payload()->merged) {
- continue;
- }
- neighbor_merged = true;
- if (std::abs(neighbor->payload()->level - payload()->level) <= 1) {
- can_merge_this = false;
- }
- if (std::abs(
- neighbor->payload()->merge_with->payload()->level -
- payload()->level) <= 1) {
- can_merge_this = false;
- }
- }
-
- // If something prevents us from merging this node, and we're not in fallback
- // mode, return empty set.
- if (!can_merge_this && !fallback_mode_enabled) {
- return {};
- }
-
- // If fallback mode already detected a merge somewhere, we shouldn't still be
- // traversing.
- if (fallback_mode_enabled) {
- TORCH_INTERNAL_ASSERT(
- !neighbor_merged,
- "Shouldn't still be traversing in fallback mode if a merge was found.");
- }
-
- std::vector<bool> can_merge(true, neighbors.size());
-
- // Find neighbors with a level that is only 1 differant than this groups level
- for (size_t i = 0; i < neighbors.size(); i++) {
- if (std::abs(neighbors[i]->payload()->level - payload()->level) > 1) {
- can_merge[i] = false;
- }
- }
-
- // Check neighbor of neighbors we're considering, if any of them are merged
- // with another node, make sure the resulting edge wouldn't have a level
- // difference of 1
- for (size_t i = 0; i < neighbors.size(); i++) {
- if (!can_merge[i]) {
- continue;
- }
-
- for (auto neighbor_neighbor : neighbors[i]->getNeighbors()) {
- // Don't check self
- if (neighbor_neighbor == neighbors[i]) {
- continue;
- }
- if (neighbor_neighbor->payload()->merged) {
- // check neighbor_neighbor level
- if (std::abs(neighbor_neighbor->payload()->level - payload()->level) <=
- 1) {
- can_merge[i] = false;
- }
- if (std::abs(
- neighbor_neighbor->payload()->level -
- neighbors[i]->payload()->level) <= 1) {
- can_merge[i] = false;
- }
-
- // check neighbor_neighber->merged->level
- if (std::abs(
- neighbor_neighbor->payload()->merge_with->payload()->level -
- payload()->level) <= 1) {
- can_merge[i] = false;
- }
- if (std::abs(
- neighbor_neighbor->payload()->merge_with->payload()->level -
- neighbors[i]->payload()->level) <= 1) {
- can_merge[i] = false;
- }
- }
- }
- }
-
- std::vector<ExprGroup*> merge_candidates;
- for (size_t i = 0; i < neighbors.size(); i++) {
- if ((can_merge[i] && !fallback_mode_enabled) ||
- (!can_merge[i] && fallback_mode_enabled)) {
- merge_candidates.push_back(neighbors[i]);
- }
- }
- return merge_candidates;
-}
-
-void ExprGroup::clearTraversalInfo() {
- payload()->level = -1;
- payload()->visited = false;
- payload()->merge_with = nullptr;
- payload()->merged = false;
-}
-
-void ExprSegmentationSorter::resetTraversal() {
- for (auto& group : groups_) {
- // Start traversal at input groups
- if (group->producerEdges().empty()) {
- to_visit_.push_back(group.get());
- }
- group->clearTraversalInfo();
- }
-}
-
-// Level is maximum distance from inputs. It's the metric used to select what
-// nodes can be merged while maintaining a DAG
-void ExprSegmentationSorter::resetLevels() {
- std::vector<ExprGroup*> next_to_visit;
-
- while (!to_visit_.empty()) {
- auto visit = to_visit_.front();
- to_visit_.pop_front();
-
- // All inputs processed?
- bool ready = true;
- if (!visit->producerEdges().empty()) {
- ready = std::all_of(
- visit->producerEdges().begin(),
- visit->producerEdges().end(),
- [&](ExprGroupConnections* dep) {
- return dep->from->payload()->visited;
- });
- }
-
- if (!ready) {
- // In case traversal doesn't complete because there's an error in the
- // DAG topology.
- next_to_visit.push_back(visit);
- continue;
- }
-
- visit->payload()->visited = true;
-
- to_visit_.insert(
- to_visit_.end(), next_to_visit.begin(), next_to_visit.end());
- next_to_visit.clear();
-
- for (auto out : visit->consumerEdges()) {
- to_visit_.push_back(out->to);
- }
-
- visit->payload()->level = 0;
- for (auto inp : visit->producerEdges()) {
- visit->payload()->level =
- std::max(visit->payload()->level, inp->from->payload()->level + 1);
- }
- }
- TORCH_INTERNAL_ASSERT(next_to_visit.empty(), "Error in graph, is not a DAG.");
-}
-
-ExprGroup* ExprSegmentationSorter::makeEmptyGroup() {
- groups_.push_back(std::make_unique<ExprGroup>());
- return groups_.back().get();
-}
-
-ExprGroup* ExprSegmentationSorter::makeEmptyGroup(Expr* expr) {
- auto group = makeEmptyGroup();
- group->exprs().push_back(expr);
- if (ir_utils::isTVOp(expr)) {
- auto out_tv = expr->outputs()[0]->as<TensorView>();
- // Grab all id's that are shared with other tensors.
- for (size_t tv_i = 0; tv_i < out_tv->getComputeAtPosition(); tv_i++) {
- group->payload()->ca_domains_.push_back(out_tv->axis(tv_i));
- }
- for (size_t tv_i = 0; tv_i < out_tv->getMaxProducerPosition(); tv_i++) {
- group->payload()->pa_domains_.push_back(out_tv->axis(tv_i));
- }
- }
- return group;
-}
-
-// Debug function that prints the current state of the sorter.
-std::string ExprSegmentationSorter::toString(int verbosity) const {
- std::stringstream ss;
- ss << "{\n";
- for (auto& group : groups_) {
- ss << " " << group.get() << "\n";
-
- if (verbosity > 1) {
- if (group->producerEdges().size() > 0) {
- ss << "Produced by groups with edges: { \n";
- for (auto producer_edge : group->producerEdges()) {
- ss << producer_edge->producer_val_ << " -> "
- << producer_edge->consumer_val_ << "\n";
- }
- ss << " }"
- << "\n";
- }
- }
-
- if (verbosity > 1) {
- if (group->consumerEdges().size() > 0) {
- ss << "Consumed by groups with edges: { \n";
- for (auto consumer_edge : group->consumerEdges()) {
- ss << consumer_edge->producer_val_ << " -> "
- << consumer_edge->consumer_val_ << "\n";
- }
- ss << " }"
- << "\n";
- }
- }
- }
- ss << "}\n";
- return ss.str();
-}
-
-namespace {
-
-// Concat's edges of sg1 and sg2, but removes any edges from/to sg1/sg2
-std::vector<ExprGroupConnections*> getMergedEdges(
- const ExprGroup* sg1,
- const std::vector<ExprGroupConnections*>& edges1,
- const ExprGroup* sg2,
- const std::vector<ExprGroupConnections*>& edges2) {
- TORCH_INTERNAL_ASSERT(
- sg1 != nullptr && sg2 != nullptr,
- "This function doesn't handle trivial.");
-
- auto merged_edges = edges1;
- merged_edges.insert(merged_edges.end(), edges2.begin(), edges2.end());
-
- // Remove intra edges
- merged_edges.erase(
- std::remove_if(
- merged_edges.begin(),
- merged_edges.end(),
- [&sg1, &sg2](ExprGroupConnections* se) {
- return (se->to == sg1 && se->from == sg2) ||
- (se->to == sg2 && se->from == sg1);
- }),
- merged_edges.end());
-
- return merged_edges;
-}
-
-// Concat's producer edges of sg1 and sg2, but removes any edges from/to sg1/sg2
-std::vector<ExprGroupConnections*> getMergedProducerEdges(
- const ExprGroup* sg1,
- const ExprGroup* sg2) {
- return getMergedEdges(sg1, sg1->producerEdges(), sg2, sg2->producerEdges());
-}
-
-// Concat's consumer edges of sg1 and sg2, but removes any edges from/to sg1/sg2
-std::vector<ExprGroupConnections*> getMergedConsumerEdges(
- const ExprGroup* sg1,
- const ExprGroup* sg2) {
- return getMergedEdges(sg1, sg1->consumerEdges(), sg2, sg2->consumerEdges());
-}
-
-// Assuming sg1 and sg2 are connected, figure out which is the consumer
-ExprGroup* getProducer(ExprGroup* sg1, ExprGroup* sg2) {
- for (auto producer_edge : sg1->producerEdges()) {
- if (producer_edge->from == sg2) {
- return sg2;
- }
- }
-
- for (auto consumer_edge : sg1->consumerEdges()) {
- if (consumer_edge->to == sg2) {
- return sg1;
- }
- }
-
- return nullptr;
-}
-
-// Go through all expressions and compute a local ordering of loops. Since
-// overloading comparison operators for iter domains doesn't make a lot of
-// sense, we instead fake having a < operator by considering that every
-// expressions output domain must be relatively ordered correctly. So we use all
-// of the expressions in a group to get a "local" ordering of the output IDs in
-// the group. We can't rely on any single expression because it may or may not
-// have all loops in the group. We also can't break ties without all
-// expressions.
-//
-// For example two expressions may have domains: [I0], [I1] Yet we
-// won't know the ordering unless we see a domain with: [I0, I1]. This happened
-// in advancedIndexing9 test when merging T5 with the group containing T10
-// (cache of T5, which is post broadcasted output) and T6(pre broadcasted
-// output).
-// T5 had the domain [0, 1, 2, 3, 4] produce at 3
-// T6 had the domain [0, 3, 4] compute at 3
-// Merging [0, 1, 2] and [0, 3, 4] resulted in the domain [0, 3, 4, 1, 2]
-//
-// If ID's are not in filter, we don't care about their ordering and ignore
-// them. This is because we're really focused on loops we will have to merge
-// across groups.If the domain is not in a produce at position in the producer
-// edges, or a compute at position in the consumer edges, the expressions we
-// look at may not have a unique ordering.
-std::vector<IterDomain*> getLocalDomainOrdering(
- const std::vector<Expr*>& exprs,
- const ComputeAtMap& map,
- const std::unordered_set<IterDomain*> filter) {
- if (exprs.empty()) {
- return std::vector<IterDomain*>();
- }
-
- std::vector<std::vector<IterDomain*>> domains;
-
- for (auto expr : exprs) {
- if (!ir_utils::isTVOp(expr)) {
- continue;
- }
-
- auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs());
- for (auto tv_input : tv_inputs) {
- std::vector<IterDomain*> domain(
- tv_input->domain()->domain().begin(),
- tv_input->domain()->domain().begin() +
- std::max(
- tv_input->getComputeAtPosition(),
- tv_input->getMaxProducerPosition()));
-
- domain.erase(
- std::remove_if(
- domain.begin(),
- domain.end(),
- [&filter, &map](IterDomain* id) {
- return filter.find(map.getConcreteMappedID(id)) == filter.end();
- }),
- domain.end());
-
- domains.emplace_back(domain);
- }
- }
-
- if (domains.size() == 1) {
- return domains[0];
- }
-
- std::vector<IterDomain*> merged_domains;
-
- // For each domain, keep an iterator to the current iter domain we're
- // checking, and an iterator for the end of the domain.
- typedef std::pair<
- std::vector<IterDomain*>::const_iterator,
- std::vector<IterDomain*>::const_iterator>
- iter_pair_t;
-
- std::vector<iter_pair_t> iterators(domains.size());
- for (auto i : c10::irange(domains.size())) {
- iterators[i] = std::make_pair(domains[i].begin(), domains[i].end());
- }
-
- auto empty = [](iter_pair_t& iter_pair) {
- return iter_pair.first == iter_pair.second;
- };
-
- size_t candidate_i = 0;
- size_t iterations_since_merge = 0;
- IterDomain* last_id_checked = nullptr;
-
- while (std::any_of(
- iterators.begin(), iterators.end(), [](iter_pair_t iter_pair) {
- return iter_pair.first != iter_pair.second;
- })) {
- TORCH_INTERNAL_ASSERT(
- iterations_since_merge <= iterators.size(),
- "Infinite loop detected in lower_expr_sort:mergeDomains.");
- iterations_since_merge++;
-
- if (candidate_i == iterators.size()) {
- candidate_i = 0;
- }
- if (empty(iterators[candidate_i])) {
- candidate_i++;
- continue;
- }
-
- auto iter_dom_candidate = *iterators[candidate_i].first;
- if (iter_dom_candidate == last_id_checked) {
- candidate_i++;
- continue;
- }
- last_id_checked = iter_dom_candidate;
-
- bool candidate_is_next = true;
-
- // Make sure this iter domain is in all first positions of all iter
- // lists that contain it, otherwise it shouldn't be the next iter domain.
- for (auto iterator : iterators) {
- if (empty(iterator)) {
- continue;
- }
- if (!map.areMapped(iter_dom_candidate, *iterator.first)) {
- if (std::any_of(
- iterator.first + 1,
- iterator.second,
- [&map, iter_dom_candidate](IterDomain* id) {
- return map.areMapped(iter_dom_candidate, id);
- })) {
- candidate_is_next = false;
- break;
- }
- }
- }
-
- if (!candidate_is_next) {
- candidate_i++;
- continue;
- }
-
- merged_domains.emplace_back(map.getConcreteMappedID(iter_dom_candidate));
-
- for (auto match_i : c10::irange(iterators.size())) {
- if (empty(iterators[match_i])) {
- continue;
- }
- if (map.areMapped(iter_dom_candidate, *iterators[match_i].first)) {
- iterators[match_i] = std::make_pair(
- iterators[match_i].first + 1, iterators[match_i].second);
- }
- }
- iterations_since_merge = 0;
- }
-
- return merged_domains;
-}
-} // namespace
-
-// Disconect group from neighbors, and return edges that were disconnected
-std::unordered_set<ExprGroupConnections*> ExprSegmentationSorter::
- disconnectGroup(ExprGroup* group) {
- std::unordered_set<ExprGroupConnections*> removed_edges(
- group->producerEdges().begin(), group->producerEdges().end());
-
- for (auto edge : group->producerEdges()) {
- edge->from->removeConsumerEdge(edge);
- }
-
- for (auto edge : group->consumerEdges()) {
- edge->to->removeProducerEdge(edge);
- }
-
- group->clearProducerEdges();
- group->clearConsumerEdges();
-
- return removed_edges;
-}
-
-// TODO: This function may be sub optimial. If we find that an iteration domain
-// matches later in the other domain, we will hold all other iteration domains
-// until that one matches. There may be cases where duplicating that iteration
-// domain, and moving on could be more efficient.
-ExprGroup* ExprSegmentationSorter::makeMergedNode(
- ExprGroup* sg1,
- ExprGroup* sg2) {
- // Keep Expr's sorted in topological order.
- const auto producer = getProducer(sg1, sg2);
- const auto consumer = sg1 == producer ? sg2 : sg1;
-
- // Make the new joined node
- auto joined_groups = makeEmptyGroup();
-
- TORCH_INTERNAL_ASSERT(
- producer != nullptr,
- "Tried to merge expr's together that aren't neighbors.");
-
- joined_groups->exprs() = producer->exprs();
- joined_groups->exprs().insert(
- joined_groups->exprs().end(),
- consumer->exprs().begin(),
- consumer->exprs().end());
-
- auto producer_edges = getMergedProducerEdges(sg1, sg2);
- // Connect joined group to resulting neighbors
- for (auto& edge : producer_edges) {
- auto from = edge->from;
- auto producer_val = edge->producer_val_;
- auto consumer_val = edge->consumer_val_;
-
- edges_.push_back(std::make_unique<ExprGroupConnections>(
- from, joined_groups, producer_val, consumer_val));
-
- joined_groups->addProducerEdge(edges_.back().get());
- from->addConsumerEdge(edges_.back().get());
- }
-
- auto consumer_edges = getMergedConsumerEdges(sg1, sg2);
-
- for (auto& edge : consumer_edges) {
- auto to = edge->to;
- auto producer_val = edge->producer_val_;
- auto consumer_val = edge->consumer_val_;
-
- edges_.push_back(std::make_unique<ExprGroupConnections>(
- joined_groups, to, producer_val, consumer_val));
- joined_groups->addConsumerEdge(edges_.back().get());
- edge->to->addProducerEdge(edges_.back().get());
- }
-
- // Merge the compute at domain of all edges going out from the newly joined
- // group. The val's we're looking for are from our consumer edges, but we want
- // to grab the producer val as that's the one we generate.
- std::unordered_set<IterDomain*> ca_ids;
- for (auto consumer_group_edge : joined_groups->consumerEdges()) {
- auto producer_of_consumer_edge = consumer_group_edge->producer_val_;
- if (producer_of_consumer_edge->isA<TensorView>()) {
- auto tv = producer_of_consumer_edge->as<TensorView>();
- for (size_t tv_i = 0; tv_i < tv->getComputeAtPosition(); tv_i++) {
- ca_ids.emplace(GpuLower::current()->caLoopMap().getConcreteMappedID(
- tv->axis(tv_i)));
- }
- }
- }
-
- // Merge the produce at domain of all edges coming into the newly joined
- // group. The val's we're looking for are from our producer edges, but we want
- // to grab the consumer val as that's the one we generate.
- std::unordered_set<IterDomain*> pa_ids;
- for (auto producer_group_edge : joined_groups->producerEdges()) {
- auto consumer_of_producer_edge = producer_group_edge->consumer_val_;
- if (consumer_of_producer_edge->isA<TensorView>()) {
- auto tv = consumer_of_producer_edge->as<TensorView>();
- for (size_t tv_i = 0; tv_i < tv->getMaxProducerPosition(); tv_i++) {
- pa_ids.emplace(GpuLower::current()->caLoopMap().getConcreteMappedID(
- tv->axis(tv_i)));
- }
- }
- }
-
- auto all_ca_pa_ids = ca_ids;
- all_ca_pa_ids.insert(pa_ids.begin(), pa_ids.end());
-
- auto ordered_ids = getLocalDomainOrdering(
- joined_groups->exprs(), GpuLower::current()->caLoopMap(), all_ca_pa_ids);
-
- for (auto id : ordered_ids) {
- if (ca_ids.count(id)) {
- joined_groups->payload()->ca_domains_.emplace_back(id);
- }
- if (pa_ids.count(id)) {
- joined_groups->payload()->pa_domains_.emplace_back(id);
- }
- }
-
- return joined_groups;
-}
-
-bool canReducePA(ExprGroup* group) {
- if (group->payload()->pa_domains_.empty()) {
- return false;
- }
-
- IterDomain* group_pa_last_id = group->payload()->pa_domains_.back();
-
- // Look through producer edges to see if we can reduce our produce at domain
- for (auto producer_edge : group->producerEdges()) {
- auto producer_val = producer_edge->producer_val_;
- auto consumer_val = producer_edge->consumer_val_;
-
- // If producer isn't a tensor view it can't be mapped into a producer dim of
- // this group
- if (!(consumer_val->isA<TensorView>() && producer_val->isA<TensorView>())) {
- continue;
- }
-
- // If the compute at domains of the producer group is empty, it can't map to
- // the produce at domains of this group
- auto producer_group = producer_edge->from;
- if (producer_group->payload()->ca_domains_.empty()) {
- continue;
- }
-
- auto producer_tv = producer_val->as<TensorView>();
- auto consumer_tv = consumer_val->as<TensorView>();
-
- // If this consumer_tv doesn't map to the last producer domain of this group
- // it can't decide if it can be reduced
- bool has_matching_pa = false;
- for (size_t i = 0; i < consumer_tv->getMaxProducerPosition(); i++) {
- if (GpuLower::current()->caLoopMap().areMapped(
- consumer_tv->axis(i), group_pa_last_id)) {
- has_matching_pa = true;
- break;
- }
- }
-
- if (!has_matching_pa) {
- continue;
- }
-
- // If any compute at positions of producers directly map to the last produce
- // at position it can't be lowered.
- for (int producer_pos_i = producer_tv->getComputeAtPosition();
- producer_pos_i > 0;
- producer_pos_i--) {
- if (GpuLower::current()->caLoopMap().areMapped(
- producer_tv->axis(producer_pos_i - 1), group_pa_last_id)) {
- return false;
- }
- }
- }
-
- return true;
-}
-
-// Update in between attempts to segment. This is called once no more groups
-// can be merged together. Typically we will want to remove compute at groups
-// that have finished being grouped together. However if no groups have been
-// merged after we've done this, we may need to stop as we could have multiple
-// disjoint groups that won't be merged.
-bool ExprSegmentationSorter::interIterUpdate() {
- // Go through groups and lower either pa or ca domain return if anything was
- // lowered
- bool lowered_a_domain = false;
- for (auto& group : groups_) {
- if (canReducePA(group.get())) {
- group->payload()->pa_domains_.pop_back();
- lowered_a_domain = true;
- }
- }
-
- // If we couldn't lower compute at domain any further, and we haven't merged
- // any new groups after fallback_mode_enabled_ has been turned on, make sure
- // we've finished successfully
- if (!lowered_a_domain && n_groups_ == groups_.size()) {
- // Make sure none of the groups are still connected, as that would mean we
- // should have been able to merge them.
- bool successfully_finished = std::all_of(
- groups_.begin(), groups_.end(), [](std::unique_ptr<ExprGroup>& sg) {
- return sg->producerEdges().empty() && sg->consumerEdges().empty();
- });
- if (successfully_finished) {
- return false;
- }
- // If we didn't finish and we tried the fallback, throw.
- TORCH_INTERNAL_ASSERT(
- !fallback_mode_enabled_,
- "Couldn't succcessfully sort out the fusion expressions. ",
- "There are remaining connections of the heirarchical segmentation which should have been ",
- "flattened to a single ordered group, or disjoint ordered groups.");
- // We didn't finish, but we haven't tried the fallback, try again with that.
- fallback_mode_enabled_ = true;
- }
-
- n_groups_ = groups_.size();
- // Not done, continue.
- return true;
-}
-
-void ExprSegmentationSorter::mergeNodes() {
- std::unordered_set<ExprGroup*> clean_up_groups;
- std::unordered_set<ExprGroupConnections*> clean_up_edges;
-
- while (!to_merge_.empty()) {
- auto group1 = *to_merge_.begin();
- auto group2 = group1->payload()->merge_with;
- to_merge_.erase(group1);
- to_merge_.erase(group2);
- clean_up_groups.emplace(group1);
- clean_up_groups.emplace(group2);
- makeMergedNode(group1, group2);
- }
-
- for (auto group : clean_up_groups) {
- auto disconnected_edges = disconnectGroup(group);
- clean_up_edges.insert(disconnected_edges.begin(), disconnected_edges.end());
- }
-
- edges_.remove_if([&](std::unique_ptr<ExprGroupConnections>& edge) {
- return clean_up_edges.find(edge.get()) != clean_up_edges.end();
- });
-
- groups_.remove_if([&](std::unique_ptr<ExprGroup>& group) {
- return clean_up_groups.find(group.get()) != clean_up_groups.end();
- });
-}
-
-// Two expression groups can be merged together if there's a value produced by
-// producer group, consumed by consumer group, where the compute at position
-// maps to the inner most compute at domain of the producer group and maps to
-// the inner most produce at domain of the consumer. If this value doesn't exist
-// we can't be certain these domains share the "next" inner most loop.
-//
-// We're looking for this because we're starting at the inner most loops of all
-// expressions, and looking for neighboring expressions that share inner loops.
-// Once we've found all the inner most loops that expressions share, we merge
-// them together, then look at the next inner most loop of the group and figure
-// out which other groups share this next inner most loop.
-bool ExprSegmentationSorter::supportedMerge(ExprGroup* sg1, ExprGroup* sg2) {
- auto producer_group = getProducer(sg1, sg2);
- auto consumer_group = sg1 == producer_group ? sg2 : sg1;
-
- if (producer_group->payload()->ca_domains_.size() <
- producer_group->payload()->pa_domains_.size()) {
- return false;
- }
-
- if (consumer_group->payload()->pa_domains_.size() <
- consumer_group->payload()->ca_domains_.size()) {
- return false;
- }
-
- const auto& producer_ca_domain = producer_group->payload()->ca_domains_;
- const auto& consumer_pa_domain = consumer_group->payload()->pa_domains_;
-
- if (producer_ca_domain.empty() && consumer_pa_domain.empty()) {
- return true;
- }
-
- if (producer_ca_domain.empty() || consumer_pa_domain.empty()) {
- return false;
- }
-
- const auto& loop_map = GpuLower::current()->caLoopMap();
-
- for (auto edge : producer_group->consumerEdges()) {
- if (edge->to != consumer_group) {
- continue;
- }
- auto producer_val = edge->producer_val_;
- auto consumer_val = edge->consumer_val_;
-
- if (!producer_val->isA<TensorView>()) {
- continue;
- }
-
- TORCH_INTERNAL_ASSERT(
- consumer_val->isA<TensorView>(),
- "Mismatched tensorview to non-tensorview in expression sorting. ",
- producer_val,
- " is consumed by ",
- consumer_val);
-
- auto producer_tv = producer_val->as<TensorView>();
-
- auto compute_at_pos = producer_tv->getComputeAtPosition();
- auto compute_at_dim = compute_at_pos > 0
- ? producer_tv->axis((int)producer_tv->getComputeAtPosition() - 1)
- : nullptr;
-
- if (compute_at_dim == nullptr) {
- continue;
- }
-
- if (!loop_map.areMapped(compute_at_dim, producer_ca_domain.back())) {
- continue;
- }
-
- if (loop_map.areMapped(compute_at_dim, consumer_pa_domain.back())) {
- return true;
- }
- }
- return false;
-}
-
-bool ExprSegmentationSorter::testStillDag(ExprGroup* sg1, ExprGroup* sg2) {
- std::deque<ExprGroup*> to_visit;
- std::unordered_set<ExprGroup*> visited;
- // Add consumers of sg1 if not sg2
- for (auto sg1_consumer_edge : sg1->consumerEdges()) {
- if (sg1_consumer_edge->to != sg2) {
- to_visit.emplace_back(sg1_consumer_edge->to);
- }
- }
-
- // Add consumers of sg2 if not sg1
- for (auto sg2_consumer_edge : sg2->consumerEdges()) {
- if (sg2_consumer_edge->to != sg1) {
- to_visit.emplace_back(sg2_consumer_edge->to);
- }
- }
-
- while (to_visit.size() > 0) {
- auto group = to_visit.front();
- // Arrived back at one of the original groups, merging these two groups
- // would generate a cycle
- if (group == sg1 || group == sg2) {
- return false;
- }
- to_visit.pop_front();
- if (visited.find(group) != visited.end()) {
- continue;
- }
- visited.emplace(group);
- for (auto consumer_edge : group->consumerEdges()) {
- to_visit.emplace_back(consumer_edge->to);
- }
- }
-
- // No cycles found, we're good.
- return true;
-}
-
-void ExprSegmentationSorter::sort() {
- // Need this for initialization of the DAG that is processed
- std::unordered_map<Expr*, ExprGroup*> expr2group;
-
- // Initialize DAG, convert each expr to a segment group
- for (auto expr : complete_fusion_->exprs()) {
- auto group = makeEmptyGroup(expr);
- expr2group.insert(std::make_pair(expr, group));
- }
-
- // Create edges between the Exprs. Mark inputs and outputs of the fusion.
- for (auto expr : complete_fusion_->exprs()) {
- auto expr_group = expr2group.at(expr);
- auto out = expr->outputs()[0];
- for (auto inp : expr->inputs()) {
- if (inp->isFusionInput()) {
- continue;
- }
-
- // Could be something like a constant scalar, definition is nullptr, but
- // isn't an "input" to the fusion. At least not one provided by an
- // external source.
- if (inp->definition() == nullptr) {
- continue;
- }
-
- auto inp_def_group = expr2group.at(inp->definition());
- edges_.push_back(std::make_unique<ExprGroupConnections>(
- inp_def_group, expr_group, inp, out));
- expr_group->addProducerEdge(edges_.back().get());
- inp_def_group->addConsumerEdge(edges_.back().get());
- }
- }
- bool inter_iter_update = true;
- while (inter_iter_update) {
- // If we didn't do any update, stop traversal, we're done.
- if (!fallback_mode_enabled_) {
- // Merge expressions in sorted order
- bool merged_nodes = true;
- while (merged_nodes) {
- // Reset stateful traversal details in ExprGroups
- resetTraversal();
- resetLevels();
-
- for (auto& group : groups_) {
- if (group->payload()->merged) {
- continue;
- }
- auto candidates = group->getMergeCandidates(fallback_mode_enabled_);
- if (candidates.empty()) {
- continue;
- }
-
- auto candidate_it = candidates.begin();
- while (candidate_it != candidates.end() &&
- !supportedMerge(group.get(), *candidate_it)) {
- candidate_it++;
- }
- if (candidate_it == candidates.end()) {
- continue;
- }
-
- to_merge_.emplace(group.get());
- to_merge_.emplace(*candidate_it);
-
- group->payload()->merged = true;
- group->payload()->merge_with = *candidate_it;
-
- (*candidate_it)->payload()->merged = true;
- (*candidate_it)->payload()->merge_with = group.get();
- }
-
- if (to_merge_.empty()) {
- merged_nodes = false;
- }
-
- mergeNodes();
-
- // Move compute at axes left
- inter_iter_update = interIterUpdate();
- }
- } else {
- // fallback_mode_enabled = true
- // Reset stateful traversal details in ExprGroups as we'll exclude merge
- // options that were already ruled out and therefore need traversal and
- // levels reset.
- resetTraversal();
- resetLevels();
-
- for (auto& group : groups_) {
- if (group->payload()->merged) {
- continue;
- }
- // Get merge candidates that weren't proven safe to merge with default
- // algorithm.
- auto candidates = group->getMergeCandidates(fallback_mode_enabled_);
- if (candidates.empty()) {
- continue;
- }
-
- auto candidate_it = candidates.begin();
-
- while (candidate_it != candidates.end()) {
- while (candidate_it != candidates.end() &&
- !supportedMerge(group.get(), *candidate_it)) {
- candidate_it++;
- }
-
- if (candidate_it == candidates.end()) {
- break;
- }
-
- if (testStillDag(group.get(), *candidate_it)) {
- // Mark in same style as default algorithm for convenience even
- // though we will only merge once with the fallback
- to_merge_.emplace(group.get());
- to_merge_.emplace(*candidate_it);
-
- group->payload()->merged = true;
- group->payload()->merge_with = *candidate_it;
-
- (*candidate_it)->payload()->merged = true;
- (*candidate_it)->payload()->merge_with = group.get();
- break;
- }
-
- candidate_it++;
- }
-
- if (to_merge_.size() > 0) {
- break;
- }
- }
-
- // If we can merge something, merge it, disable fallback, and bail
- if (to_merge_.size() > 0) {
- mergeNodes();
- }
-
- // Move compute at axes left
- // If fallback didn't work, interIterUpdate will catch that we failed.
- inter_iter_update = interIterUpdate();
- fallback_mode_enabled_ = false;
- }
- }
-}
-
-std::vector<Expr*> ExprSegmentationSorter::getExprs() const {
- std::vector<Expr*> exprs;
- for (auto& group : groups_) {
- exprs.insert(exprs.end(), group->exprs().begin(), group->exprs().end());
- }
- return exprs;
-}
-
-} // namespace
-
-std::vector<Expr*> reorderExprsForComputeAt() {
- auto fusion = FusionGuard::getCurFusion();
- TORCH_INTERNAL_ASSERT(fusion != nullptr);
- ExprSegmentationSorter sorter(fusion);
- sorter.sort();
- auto sorted_exprs = sorter.getExprs();
- TORCH_INTERNAL_ASSERT(
- sorted_exprs.size() > 0,
- "Error during expression sorting, no expressions produced.");
- return sorted_exprs;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-std::vector<Expr*> reorderExprsForComputeAt();
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
IndexLowering::IndexLowering() : ir_builder_(GpuLower::current()->kernel()) {}
-kir::Val* IndexLowering::lowerSrcIndex(kir::Val* val, kir::Val* dst) const {
- if (auto tv = dynamic_cast<kir::TensorView*>(val)) {
- TORCH_INTERNAL_ASSERT(dst->isA<kir::TensorView>());
+Val* IndexLowering::lowerOperand(Val* op, Val* out) const {
+ if (ir_utils::isTV(op)) {
return Index::getProducerIndex(
- tv->fuserTv(),
- dst->as<kir::TensorView>()->fuserTv(),
- scope_utils::getLoops(active_scope_expr_));
+ ir_utils::asTV(op),
+ ir_utils::asTV(out),
+ scope_utils::getLoops(active_scope_expr));
} else {
- return val;
+ return GpuLower::lowerValue(op);
}
}
-kir::Val* IndexLowering::lowerDstIndex(kir::Val* dst) const {
- if (auto tv = dynamic_cast<kir::TensorView*>(dst)) {
+Val* IndexLowering::lowerOutput(Expr* expr) const {
+ TORCH_CHECK(expr->outputs().size() == 1);
+ const auto out = expr->output(0);
+ if (ir_utils::isTVOp(expr)) {
return Index::getConsumerIndex(
- tv->fuserTv(), scope_utils::getLoops(active_scope_expr_));
+ ir_utils::asTV(out), scope_utils::getLoops(active_scope_expr));
} else {
- return dst;
+ return GpuLower::lowerValue(out);
}
}
-void IndexLowering::pushBack(kir::Expr* expr) {
- if (active_scope_ == nullptr) {
- lowered_exprs_.push_back(expr);
+void IndexLowering::pushBack(Expr* expr) {
+ if (active_scope == nullptr) {
+ lowered_exprs.push_back(expr);
} else {
- active_scope_->push_back(expr);
+ active_scope->push_back(expr);
}
}
-void IndexLowering::visit(const kir::IfThenElse* ite) {
- const auto prev_scope_expr = active_scope_expr_;
- const auto prev_scope = active_scope_;
+void IndexLowering::handle(kir::IfThenElse* ite) {
+ Expr* prev_scope_expr = active_scope_expr;
+ kir::Scope* prev_scope = active_scope;
- // TODO(kir): try to avoid recreating new nodes and leaving old ones around
- auto new_ite = ir_builder_.create<kir::IfThenElse>(ite->predicate());
+ auto new_ite =
+ ir_builder_.create<kir::IfThenElse>(ite->cond(), prev_scope_expr);
pushBack(new_ite);
-
- active_scope_expr_ = new_ite;
- active_scope_ = &new_ite->thenBody();
+ active_scope_expr = new_ite;
+ active_scope = &new_ite->thenBody();
for (auto expr : ite->thenBody().exprs()) {
- expr->accept(this);
+ OptInDispatch::handle(expr);
}
- active_scope_ = &new_ite->elseBody();
+ active_scope = &new_ite->elseBody();
for (auto expr : ite->elseBody().exprs()) {
- expr->accept(this);
+ OptInDispatch::handle(expr);
}
- active_scope_ = prev_scope;
- active_scope_expr_ = prev_scope_expr;
+ active_scope = prev_scope;
+ active_scope_expr = prev_scope_expr;
}
-void IndexLowering::visit(const kir::ForLoop* for_loop) {
- const auto prev_scope_expr = active_scope_expr_;
- const auto prev_scope = active_scope_;
+void IndexLowering::handle(kir::ForLoop* fl) {
+ Expr* prev_scope_expr = active_scope_expr;
+ kir::Scope* prev_scope = active_scope;
- auto new_for_loop = ir_builder_.create<kir::ForLoop>(for_loop);
- pushBack(new_for_loop);
+ auto newFl = ir_builder_.create<kir::ForLoop>(
+ fl->index(), fl->iter_domain(), prev_scope_expr);
+ pushBack(newFl);
- active_scope_expr_ = new_for_loop;
- active_scope_ = &new_for_loop->body();
+ active_scope_expr = newFl;
+ active_scope = &newFl->body();
- for (auto expr : for_loop->body().exprs()) {
- expr->accept(this);
+ for (auto expr : fl->body().exprs()) {
+ OptInDispatch::handle(expr);
}
- active_scope_ = prev_scope;
- active_scope_expr_ = prev_scope_expr;
+ active_scope = prev_scope;
+ active_scope_expr = prev_scope_expr;
}
-void IndexLowering::visit(const kir::UnaryOp* uop) {
- const auto in = lowerSrcIndex(uop->in(), uop->out());
- const auto out = lowerDstIndex(uop->out());
- pushBack(ir_builder_.create<kir::UnaryOp>(uop->operation(), out, in));
+void IndexLowering::handle(UnaryOp* uop) {
+ if (ir_utils::isTVOp(uop)) {
+ const auto in = lowerOperand(uop->in(), uop->out());
+ const auto out = lowerOutput(uop);
+ pushBack(ir_builder_.create<kir::UnaryOp>(uop->getUnaryOpType(), out, in));
+ } else {
+ // This will automatically lower the expression defining the value
+ pushBack(GpuLower::lowerValue(uop->out())->getOrigin());
+ }
}
-void IndexLowering::visit(const kir::BinaryOp* bop) {
- const auto lhs = lowerSrcIndex(bop->lhs(), bop->out());
- const auto rhs = lowerSrcIndex(bop->rhs(), bop->out());
- const auto out = lowerDstIndex(bop->out());
- pushBack(ir_builder_.create<kir::BinaryOp>(bop->operation(), out, lhs, rhs));
+void IndexLowering::handle(BinaryOp* bop) {
+ if (ir_utils::isTVOp(bop)) {
+ const auto lhs = lowerOperand(bop->lhs(), bop->out());
+ const auto rhs = lowerOperand(bop->rhs(), bop->out());
+ const auto out = lowerOutput(bop);
+ pushBack(ir_builder_.create<kir::BinaryOp>(
+ bop->getBinaryOpType(), out, lhs, rhs));
+ } else {
+ // This will automatically lower the expression defining the value
+ pushBack(GpuLower::lowerValue(bop->out())->getOrigin());
+ }
}
-void IndexLowering::visit(const kir::TernaryOp* top) {
- const auto in1 = lowerSrcIndex(top->in1(), top->out());
- const auto in2 = lowerSrcIndex(top->in2(), top->out());
- const auto in3 = lowerSrcIndex(top->in3(), top->out());
- const auto out = lowerDstIndex(top->out());
- pushBack(
- ir_builder_.create<kir::TernaryOp>(top->operation(), out, in1, in2, in3));
+void IndexLowering::handle(TernaryOp* top) {
+ if (ir_utils::isTVOp(top)) {
+ const auto in1 = lowerOperand(top->in1(), top->out());
+ const auto in2 = lowerOperand(top->in2(), top->out());
+ const auto in3 = lowerOperand(top->in3(), top->out());
+ const auto out = lowerOutput(top);
+ pushBack(ir_builder_.create<kir::TernaryOp>(
+ top->getTernaryOpType(), out, in1, in2, in3));
+ } else {
+ // This will automatically lower the expression defining the value
+ pushBack(GpuLower::lowerValue(top->out())->getOrigin());
+ }
}
namespace {
-void allocateGridReductionFlag(
- kir::TensorView* out_tv,
- kir::Expr* current_scope_expr) {
+void allocateGridReductionFlag(TensorView* out_tv, Expr* current_scope_expr) {
kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-
- const auto flag_name = kir::GridReduction::getPredicateFlagName(out_tv);
- const auto flag_var = ir_builder.create<kir::Allocate>(
+ auto flag_name = kir::GridReduction::getPredicateFlagName(out_tv);
+ auto flag_var = ir_builder.create<kir::Allocate>(
ir_builder.create<kir::NamedScalar>(flag_name, DataType::Bool),
MemoryType::Local,
ir_builder.create<kir::Int>(1));
-
// When enclosed by IfThenElse, place the variable outside of the
// IfThenElse. This IfThenElse is assumed to be the prediate for
// this grid reduction expression.
- if (current_scope_expr->isA<kir::IfThenElse>()) {
+ if (current_scope_expr->getExprType() == ExprType::IfThenElse) {
scope_utils::insertBefore(
- current_scope_expr->parentScope(), current_scope_expr, flag_var);
+ scope_utils::getParent(current_scope_expr),
+ current_scope_expr,
+ flag_var);
} else {
- TORCH_INTERNAL_ASSERT(current_scope_expr->isA<kir::ForLoop>());
- current_scope_expr->as<kir::ForLoop>()->body().push_back(flag_var);
+ scope_utils::pushBack(current_scope_expr, flag_var);
}
}
} // namespace
-void IndexLowering::visit(const kir::ReductionOp* rop) {
- TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(rop));
+void IndexLowering::handle(ReductionOp* rop) {
+ TORCH_INTERNAL_ASSERT(
+ ir_utils::isTVOp(rop),
+ "Cannot have a reduction operation on something other than a tensor view, but received ",
+ rop);
- const auto out_tv = rop->out()->as<kir::TensorView>();
- const auto out_domain = out_tv->domain();
+ auto out_tv = ir_utils::asTV(rop->out());
- const bool is_block_reduce = out_domain->hasBlockReduction();
- const bool is_grid_reduce = out_domain->hasGridReduction();
+ const bool is_block_reduce = out_tv->hasBlockReduction();
+ const bool is_grid_reduce = out_tv->hasGridReduction();
// If we do a grid reduction we can't have a reduction axis that is not bound
// to a grid or block dim ()
if (is_grid_reduce) {
TORCH_INTERNAL_ASSERT(
std::none_of(
- out_domain->domain().begin(),
- out_domain->domain().end(),
- [](kir::IterDomain* id) {
- return !id->isThread() && id->isReduction() &&
- !id->extent()->isOneInt();
+ out_tv->domain()->domain().begin(),
+ out_tv->domain()->domain().end(),
+ [](IterDomain* id) {
+ return !id->isThread() && id->isReduction();
}),
- "Found a reduction stage that has both a non-parallelized ",
- "reduction and a grid reduction. This is not supported, ",
- "please use rfactor to do the serialized reduction first, ",
- "then the grid reduction.");
+ "Found a reduction stage that has both a non-parallelized reduction and a grid reduction.",
+ " This is not supported, please use rfactor to do the serialized reduction first, then the grid reduction.");
}
+ const auto loops = scope_utils::getLoops(active_scope_expr);
- const auto out = lowerDstIndex(rop->out());
- const auto in = lowerSrcIndex(rop->in(), rop->out());
+ kir::TensorIndex* out = Index::getConsumerIndex(out_tv, loops);
+ kir::TensorIndex* in = Index::getProducerIndex(
+ ir_utils::asTV(rop->in()), ir_utils::asTV(rop->out()), loops);
kir::ReductionOp* block_reduction_op = nullptr;
-
if (is_block_reduce) {
+ auto pred =
+ PredicateCompute::getInlinePredicate(rop, loops, nullptr, false);
+
block_reduction_op = ir_builder_.create<kir::ReductionOp>(
- rop->operation(), rop->init(), out, in);
- if (rop->predicate()) {
- block_reduction_op->setPredicate(rop->predicate());
- }
- if (rop->writePredicate()) {
- block_reduction_op->setWritePredicate(rop->writePredicate());
- }
+ rop->getReductionOpType(),
+ GpuLower::lowerValue(rop->init()),
+ out,
+ in,
+ pred);
pushBack(block_reduction_op);
}
if (is_grid_reduce) {
// First, declare a boolean flag variable storing the return value
- // of the gridReduce() helper
- allocateGridReductionFlag(out_tv, active_scope_expr_);
+ // of gridReduce.
+ allocateGridReductionFlag(out_tv, active_scope_expr);
- auto buffer_ids = out_domain->domain();
+ std::vector<IterDomain*> buffer_ids(out_tv->domain()->domain());
buffer_ids.erase(
std::remove_if(
buffer_ids.begin(),
buffer_ids.end(),
- [](kir::IterDomain* id) {
- return id->isReduction() && !id->isBlockDim();
+ [](IterDomain* id) {
+ return id->isReduction() & !id->isBlockDim();
}),
buffer_ids.end());
- kir::Val* buffer_size = buffer_ids.empty() ? ir_builder_.create<kir::Int>(1)
- : buffer_ids[0]->extent();
-
+ Val* buffer_size =
+ buffer_ids.empty() ? new Int(1) : buffer_ids[0]->rawExtent();
for (const auto i : c10::irange(1, buffer_ids.size())) {
- buffer_size = ir_builder_.mulExpr(buffer_size, buffer_ids[i]->extent());
+ buffer_size = mul(buffer_size, buffer_ids[i]->rawExtent());
}
- auto sync_ids = out_domain->domain();
+ std::vector<IterDomain*> sync_ids(out_tv->domain()->domain());
sync_ids.erase(
std::remove_if(
sync_ids.begin(),
sync_ids.end(),
- [](kir::IterDomain* id) {
+ [](IterDomain* id) {
return id->isReduction() || !id->isBlockDim();
}),
sync_ids.end());
- kir::Val* sync_size = sync_ids.empty() ? ir_builder_.create<kir::Int>(1)
- : sync_ids[0]->extent();
-
+ Val* sync_size = sync_ids.empty() ? new Int(1) : sync_ids[0]->rawExtent();
for (const auto i : c10::irange(1, sync_ids.size())) {
- sync_size = ir_builder_.mulExpr(sync_size, sync_ids[i]->extent());
+ sync_size = mul(sync_size, sync_ids[i]->rawExtent());
}
- const auto zero = ir_builder_.create<kir::Int>(0);
+ IterDomain* buffer_id = new IterDomain(new Int(0), buffer_size);
+ TensorView* reduce_buffer_tv = new TensorView(
+ new TensorDomain({buffer_id}),
+ out->getDataType().value(),
+ MemoryType::Global);
- const std::vector<kir::IterDomain*> new_buffer_ids = {
- ir_builder_.create<kir::IterDomain>(zero, buffer_size)};
- const auto buffer_domain =
- ir_builder_.create<kir::TensorDomain>(new_buffer_ids);
- const auto reduce_buffer_tv = ir_builder_.create<kir::TensorView>(
- out->dtype(), buffer_domain, MemoryType::Global);
-
- const std::vector<kir::IterDomain*> new_sync_ids = {
- ir_builder_.create<kir::IterDomain>(zero, sync_size)};
- const auto sync_domain =
- ir_builder_.create<kir::TensorDomain>(new_sync_ids);
- const auto reduce_sync_tv = ir_builder_.create<kir::TensorView>(
- DataType::Int, sync_domain, MemoryType::Global);
+ IterDomain* sync_id = new IterDomain(new Int(0), sync_size);
+ TensorView* reduce_sync_tv = new TensorView(
+ new TensorDomain({sync_id}), DataType::Int, MemoryType::Global);
const auto reduce_buffer = ir_builder_.create<kir::Allocate>(
- reduce_buffer_tv, reduce_buffer_tv->memoryType());
-
+ GpuLower::lowerValue(reduce_buffer_tv),
+ reduce_sync_tv->getMemoryType());
const auto sync_buffer = ir_builder_.create<kir::Allocate>(
- reduce_sync_tv, reduce_sync_tv->memoryType(), nullptr, true);
+ GpuLower::lowerValue(reduce_sync_tv),
+ reduce_sync_tv->getMemoryType(),
+ nullptr,
+ true);
- const auto grid_reduction_op = (block_reduction_op == nullptr)
+ const auto grid_reduction_op = block_reduction_op == nullptr
? ir_builder_.create<kir::ReductionOp>(
- rop->operation(), rop->init(), out, in)
+ rop->getReductionOpType(),
+ GpuLower::lowerValue(rop->init()),
+ out,
+ in)
: block_reduction_op;
-
- // The thread predicate for GridReduction needs to be set
- // separately from the main predicate. Do not combine them like
- // other expressions.
- const auto& thread_pred =
- GpuLower::current()->threadPredMap().at(out_tv->fuserTv()).pred;
- auto grid_reduction = ir_builder_.create<kir::GridReduction>(
- grid_reduction_op, reduce_buffer, sync_buffer);
- grid_reduction->setThreadPredicate(thread_pred);
-
- if (rop->predicate()) {
- // If preceded by a blockReduce, all thread blocks should have
- // valid inputs to gridReduce. In fact, using the original
- // predicate does not work when the write predicate of the
- // blockReduce is different from the read predicate.
- if (is_block_reduce) {
- grid_reduction->setPredicate(
- ir_builder_.create<kir::Predicate>(ir_builder_.trueVal()));
- } else {
- grid_reduction->setPredicate(rop->predicate());
- }
- }
-
- if (rop->writePredicate()) {
- grid_reduction->setWritePredicate(rop->writePredicate());
- }
+ auto pred =
+ PredicateCompute::getInlinePredicate(rop, loops, nullptr, false);
+ const auto grid_reduction = ir_builder_.create<kir::GridReduction>(
+ grid_reduction_op, reduce_buffer, sync_buffer, pred);
pushBack(reduce_buffer);
pushBack(sync_buffer);
}
if (!is_block_reduce && !is_grid_reduce) {
- // TODO(kir): this breaks our "SSA" form
- pushBack(ir_builder_.create<kir::BinaryOp>(rop->operation(), out, out, in));
- }
-}
-
-namespace {
-
-template <typename T>
-kir::Allocate* allocGlobalBuffer(
- kir::IrBuilder& ir_builder,
- const kir::TensorDomain* td,
- T id_filter,
- DataType dtype,
- bool zero_init = false) {
- auto buffer_ids = td->domain();
- buffer_ids.erase(
- std::remove_if(buffer_ids.begin(), buffer_ids.end(), id_filter),
- buffer_ids.end());
-
- kir::Val* buffer_size = buffer_ids.empty() ? ir_builder.create<kir::Int>(1)
- : buffer_ids[0]->extent();
- for (size_t i = 1; i < buffer_ids.size(); i++) {
- buffer_size = ir_builder.mulExpr(buffer_size, buffer_ids[i]->extent());
+ pushBack(ir_builder_.create<kir::BinaryOp>(
+ rop->getReductionOpType(), out, out, in));
}
- const auto zero = ir_builder.create<kir::Int>(0);
- const std::vector<kir::IterDomain*> new_buffer_ids = {
- ir_builder.create<kir::IterDomain>(zero, buffer_size)};
- const auto buffer_domain =
- ir_builder.create<kir::TensorDomain>(new_buffer_ids);
- const auto buffer_tv = ir_builder.create<kir::TensorView>(
- dtype, buffer_domain, MemoryType::Global);
- return ir_builder.create<kir::Allocate>(
- buffer_tv, buffer_tv->memoryType(), nullptr, zero_init);
}
-} // namespace
-
-void IndexLowering::visit(const kir::WelfordOp* wop) {
- TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(wop));
-
- const auto out_tv = wop->outAvg()->as<kir::TensorView>();
- const auto out_domain = out_tv->domain();
+void IndexLowering::handle(BroadcastOp* bop) {
+ TORCH_INTERNAL_ASSERT(
+ ir_utils::isTVOp(bop),
+ "Cannot have a broadcast operation on something other than a tensor view, but received ",
+ bop);
- const bool is_block_reduce = out_domain->hasBlockReduction();
- const bool is_grid_reduce = out_domain->hasGridReduction();
+ auto loops = scope_utils::getLoops(active_scope_expr);
- // If we do a grid reduction we can't have a reduction axis that is not bound
- // to a grid or block dim ()
- if (is_grid_reduce) {
- TORCH_INTERNAL_ASSERT(
- std::none_of(
- out_domain->domain().begin(),
- out_domain->domain().end(),
- [](kir::IterDomain* id) {
- return !id->isThread() && id->isReduction();
- }),
- "Found a reduction stage that has both a non-parallelized ",
- "reduction and a grid reduction. This is not supported, ",
- "please use rfactor to do the serialized reduction first, ",
- "then the grid reduction.");
- }
-
- // lower IO tensors
- const auto in_var =
- wop->inVar() ? lowerSrcIndex(wop->inVar(), wop->outAvg()) : nullptr;
- const auto in_avg = lowerSrcIndex(wop->inAvg(), wop->outAvg());
- auto in_N = wop->inN();
-
- // in Rfactor-ed case, the input N is actually a TV
- if (!in_N->isScalar()) {
- in_N = lowerSrcIndex(in_N, wop->outN());
- }
-
- auto out_avg = lowerDstIndex(wop->outAvg());
- auto out_var = lowerDstIndex(wop->outVar());
- auto out_N = lowerDstIndex(wop->outN());
-
- kir::WelfordOp* welford_op = ir_builder_.create<kir::WelfordOp>(
- out_var,
- out_avg,
- out_N,
- wop->initVar(),
- wop->initAvg(),
- wop->initN(),
- in_var,
- in_avg,
- in_N);
-
- kir::WelfordOp* block_welford_op = nullptr;
-
- if (is_block_reduce) {
- block_welford_op = welford_op;
- if (wop->predicate()) {
- block_welford_op->setPredicate(wop->predicate());
- }
- if (wop->writePredicate()) {
- block_welford_op->setWritePredicate(wop->writePredicate());
- }
- pushBack(block_welford_op);
- }
-
- if (is_grid_reduce) {
- // Allocate T_pred
- allocateGridReductionFlag(out_tv, active_scope_expr_);
-
- // Buffer allocation
- auto buffer_filter = [](const kir::IterDomain* id) {
- return id->isReduction() && !id->isBlockDim();
- };
- const auto out_var_buffer = allocGlobalBuffer(
- ir_builder_, out_domain, buffer_filter, out_var->dtype());
- const auto out_avg_buffer = allocGlobalBuffer(
- ir_builder_, out_domain, buffer_filter, out_avg->dtype());
- const auto out_N_buffer = allocGlobalBuffer(
- ir_builder_, out_domain, buffer_filter, out_N->dtype());
- const auto sync_buffer = allocGlobalBuffer(
- ir_builder_, out_domain, buffer_filter, DataType::Int, true);
-
- // Grid Welford instantiation
- const auto grid_welford_op =
- (block_welford_op == nullptr) ? welford_op : block_welford_op;
-
- // The thread predicate for GridReduction needs to be set
- // separately from the main predicate. Do not combine them like
- // other expressions.
- const auto& thread_pred =
- GpuLower::current()->threadPredMap().at(out_tv->fuserTv()).pred;
-
- auto grid_welford = ir_builder_.create<kir::GridWelford>(
- grid_welford_op,
- out_var_buffer,
- out_avg_buffer,
- out_N_buffer,
- sync_buffer);
-
- grid_welford->setThreadPredicate(thread_pred);
-
- if (wop->predicate()) {
- grid_welford->setPredicate(wop->predicate());
- }
-
- pushBack(out_var_buffer);
- pushBack(out_avg_buffer);
- pushBack(out_N_buffer);
- pushBack(sync_buffer);
- pushBack(grid_welford);
- }
-
- if (!is_block_reduce && !is_grid_reduce) {
- pushBack(welford_op);
- }
-}
-
-void IndexLowering::visit(const kir::BroadcastOp* bop) {
- TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(bop));
-
- const auto out = lowerDstIndex(bop->out());
- const auto in = lowerSrcIndex(bop->in(), bop->out());
- auto indexed_expr = ir_builder_.create<kir::BroadcastOp>(out, in);
-
- if (bop->predicate()) {
- indexed_expr->setPredicate(bop->predicate());
- }
+ kir::TensorIndex* out =
+ Index::getConsumerIndex(ir_utils::asTV(bop->out()), loops);
- pushBack(indexed_expr);
+ Val* in = bop->in();
+ if (ir_utils::isTV(in))
+ in = Index::getProducerIndex(
+ ir_utils::asTV(in), ir_utils::asTV(bop->out()), loops);
+ pushBack(ir_builder_.create<kir::BroadcastOp>(out, in));
}
-void IndexLowering::visit(const kir::Allocate* allocate) {
- // TODO(kir): remove the need for const_cast
- pushBack(const_cast<kir::Allocate*>(allocate)); // NOLINT
+void IndexLowering::handle(kir::Allocate* allocate) {
+ pushBack(allocate);
}
-void IndexLowering::visit(const kir::Sync* sync) {
- // TODO(kir): remove the need for const_cast
- pushBack(const_cast<kir::Sync*>(sync)); // NOLINT
+void IndexLowering::handle(kir::Sync* sync) {
+ pushBack(sync);
}
-void IndexLowering::generate(const std::vector<kir::Expr*>& exprs) {
- for (auto expr : exprs) {
- expr->accept(this);
+void IndexLowering::generate(const std::vector<Expr*>& exprs) {
+ // Run through loop nests and further lower the expressions
+ for (auto* expr : exprs) {
+ OptInDispatch::handle(expr);
}
}
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <vector>
namespace fuser {
namespace cuda {
-class TORCH_CUDA_CU_API IndexLowering : private kir::IrVisitor {
+class TORCH_CUDA_CU_API IndexLowering : public OptInDispatch {
public:
- static std::vector<kir::Expr*> getIndexedExprs(
- std::vector<kir::Expr*> incoming_exprs) {
- FUSER_PERF_SCOPE("GpuLower::Lower::IndexLowering::getIndexedExprs");
+ static std::vector<Expr*> getIndexedExprs(
+ Fusion* fusion,
+ std::vector<Expr*> incoming_exprs) {
+ FUSER_PERF_SCOPE("IndexLowering::getIndexedExprs");
+ FusionGuard fg(fusion);
IndexLowering il;
il.generate(incoming_exprs);
- return il.lowered_exprs_;
+ return il.lowered_exprs;
}
private:
IndexLowering();
- void pushBack(kir::Expr*);
+ // Wrap pushBack, if active_scope is null we want it to go
+ // straight to lower_exprs
+ void pushBack(Expr*);
- void visit(const kir::ForLoop*) final;
- void visit(const kir::IfThenElse*) final;
- void visit(const kir::UnaryOp*) final;
- void visit(const kir::BinaryOp*) final;
- void visit(const kir::TernaryOp*) final;
- void visit(const kir::ReductionOp*) final;
- void visit(const kir::WelfordOp*) final;
- void visit(const kir::BroadcastOp*) final;
- void visit(const kir::Allocate*) final;
- void visit(const kir::Sync*) final;
+ // Open the for loop.
+ void handle(kir::ForLoop*) final;
- void generate(const std::vector<kir::Expr*>& exprs);
+ // Open the for loop.
+ void handle(kir::IfThenElse*) final;
- kir::Val* lowerSrcIndex(kir::Val* val, kir::Val* dst) const;
- kir::Val* lowerDstIndex(kir::Val* dst) const;
+ // Remake operations with TensorIndex
+ void handle(UnaryOp*) final;
+ void handle(BinaryOp*) final;
+ void handle(TernaryOp*) final;
+ void handle(ReductionOp*) final;
+ void handle(BroadcastOp*) final;
+ void handle(kir::Allocate*) final;
+ void handle(kir::Sync*) final;
+
+ void generate(const std::vector<Expr*>& exprs);
+
+ Val* lowerOperand(Val* op, Val* out) const;
+ Val* lowerOutput(Expr* expr) const;
private:
- std::vector<kir::Expr*> lowered_exprs_;
+ std::vector<Expr*> lowered_exprs;
// This is a slight work around as scope has a couple definitions, we have the
// Scope that's in ForLoop/IfThenElse which is really just a wrapper around
// to be able to carry both around because when we push back to a scope it
// could be either the body or else body of the IfThenElse. However, we want
// to understand the nesting of IfThenElse/ForLoop nodes.
- kir::Scope* active_scope_ = nullptr;
- kir::Expr* active_scope_expr_ = nullptr;
+ kir::Scope* active_scope = nullptr;
+ Expr* active_scope_expr = nullptr;
kir::IrBuilder ir_builder_;
};
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+#include <torch/csrc/jit/codegen/cuda/lower_insert_syncs.h>
+
+#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
+#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_insert_syncs.h>
-
-#include <unordered_set>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
namespace torch {
namespace jit {
namespace {
-//! Scan through Kernel IR for-loops to insert Sync nodes to avoid
-//! Write-After-Read (WAR) race condition.
-//!
-//! Example:
-//! for () {
-//! smem_buf[threadIdx.x] = x;
-//! __syncthreads();
-//! buf[threadId.x] = smem_buf[threadIdx.x + 1];
-//! }
+//! Scan through Kernel IR to insert Sync nodes to avoid
+//! Write-After-Read (WAR) race condition
//!
-//! In this case, additional syncthreads is needed at the end of the
-//! loop body to avoid a hazard with smem_buf.
-class LocalSyncInserter {
- using TvSet = std::unordered_set<const kir::TensorView*>;
-
+class LocalSyncInserter final : private OptOutDispatch {
public:
- //! Write-After-Read race conditions are only found within for-loops.
- //! Sync nodes are inserted directly into the for-loops.
- //! The expressions are modified in-place and exprs is const.
- static void insertSyncs(const std::vector<kir::Expr*>& exprs) {
+ // Write-After-Read race conditions are only found within for-loops.
+ // Sync nodes are inserted directly into the for-loops.
+ // The expressions are modified in-place and exprs is const.
+ static void InsertSyncs(const std::vector<Expr*>& exprs) {
+ LocalSyncInserter sync_inserter;
for (auto expr : exprs) {
- if (auto fl = dynamic_cast<kir::ForLoop*>(expr)) {
- LocalSyncInserter sync_inserter(fl);
- }
- }
- }
-
- private:
- //! Insert Sync nodes at the end of a given for-loop when a WAR
- //! hazard may happen.
- LocalSyncInserter(kir::ForLoop* fl) {
- for (auto expr : fl->body().exprs()) {
- handle(expr);
- }
-
- // No need to insert sync when the loop is not actually generated
- if (fl->iter_domain()->isThread() || fl->iter_domain()->isBroadcast()) {
- return;
- }
-
- // Determine if any smem TV is written to at beginning of the for-loop
- // and whether that smem TV is read from at the end of the for-loop
- // Insert new SyncThreads at end of for-loop to prevent WAR race condition
- //
- // TODO: replace __syncthreads with __threadfence for alias ops
- //
- if (detectIntersection(initial_, final_) &&
- !fl->body().exprs().back()->isA<kir::Sync>() && !is_last_op_sync_) {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- fl->body().push_back(ir_builder.create<kir::Sync>(true));
- initial_sync_ = true;
- is_last_op_sync_ = true;
- final_.clear();
+ sync_inserter.handle(expr);
}
}
- const auto& initial() const {
+ const std::unordered_set<const TensorView*>& initial() const {
return initial_;
}
- const auto& final() const {
+ const std::unordered_set<const TensorView*>& final() const {
return final_;
}
- const auto& all_smem_inputs() const {
+ const std::unordered_set<const TensorView*>& all_smem_inputs() const {
return all_smem_inputs_;
}
- const auto& all_smem_outputs() const {
+ const std::unordered_set<const TensorView*>& all_smem_outputs() const {
return all_smem_outputs_;
}
- void handle(kir::Expr* expr) {
- if (ir_utils::isTVOp(expr)) {
- is_last_op_sync_ = false;
+ const std::unordered_set<unsigned int>& all_aliased_allocations() const {
+ return all_alias_allocations_;
+ }
+
+ private:
+ explicit LocalSyncInserter(
+ const std::unordered_set<unsigned int>* parent_alias_allocations =
+ nullptr) {
+ if (parent_alias_allocations != nullptr) {
+ all_alias_allocations_.insert(
+ parent_alias_allocations->begin(), parent_alias_allocations->end());
+ }
+ }
+ void handle(Expr* expr) final {
+ if (ir_utils::isTVOp(expr)) {
// For this SyncInserter
- if (initial_sync_) {
- addInputSmemTvs(expr, final_);
- } else {
- addInputSmemTvs(expr, final_);
- addOutputSmemTvs(expr, initial_);
- }
+ (!initial_sync_) ? hasOutputSmemExpr(expr, initial_)
+ : hasInputSmemExpr(expr, final_);
// For parent SyncInserter
- addOutputSmemTvs(expr, all_smem_outputs_);
- addInputSmemTvs(expr, all_smem_inputs_);
- } else if (auto sync = dynamic_cast<kir::Sync*>(expr)) {
- handle(sync);
- } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
- handle(ite);
- } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
- handle(for_loop);
+ hasOutputSmemExpr(expr, all_smem_outputs_);
+ hasInputSmemExpr(expr, all_smem_inputs_);
+ } else {
+ OptOutDispatch::handle(expr);
}
}
- void handle(kir::Sync* sync) {
- is_last_op_sync_ = true;
- initial_sync_ = true;
- final_.clear();
+ void handle(kir::Allocate* a) final {
+ if (a->buffer()->getValType().value() == ValType::KirTensorView &&
+ a->alias() != nullptr && a->getMemoryType() == MemoryType::Shared) {
+ auto tv = a->buffer()->as<kir::TensorView>()->fuserTv();
+ all_alias_allocations_.insert(tv->name());
+ }
}
- void handle(kir::IfThenElse* ite) {
+ void handle(kir::IfThenElse* ite) final {
for (auto expr : ite->thenBody().exprs()) {
handle(expr);
}
}
}
- void handle(kir::ForLoop* fl) {
- LocalSyncInserter child_sync_inserter(fl);
-
- const auto& child_inputs = child_sync_inserter.all_smem_inputs();
- const auto& child_outputs = child_sync_inserter.all_smem_outputs();
- const bool maybe_skipped = !fl->start()->isZeroInt() &&
- !isParallelTypeThread(fl->iter_domain()->parallelType());
-
- // Default - Track all smem inputs / outputs
- all_smem_inputs_.insert(child_inputs.begin(), child_inputs.end());
- all_smem_outputs_.insert(child_outputs.begin(), child_outputs.end());
-
- // Propagate the last_op_sync flag from the child loop. If the
- // child is deterministically executed at least once, just set the
- // flag with the child flag. Otherwise, conservatively set the
- // flag, i.e., if the current flag is true and the child flag is
- // also true, we can say the last op is still sync.
- if (!maybe_skipped) {
- is_last_op_sync_ = child_sync_inserter.is_last_op_sync_;
- } else {
- is_last_op_sync_ =
- is_last_op_sync_ && child_sync_inserter.is_last_op_sync_;
- }
-
- // When the child is not guaranteed to have sync.
- if (!child_sync_inserter.initial_sync_) {
- // If no sync is yet found, add the child outputs to
- // initial.
- if (!initial_sync_) {
- initial_.insert(child_outputs.begin(), child_outputs.end());
- }
- // Add the child inputs to final even when inital_sync is false,
- // which only means sync may not be found yet.
- final_.insert(child_inputs.begin(), child_inputs.end());
- } else {
- // Similar to the above case, but here, the child is guaranteed
- // to have sync, so we only need to look at initial and final.
- if (!initial_sync_) {
- initial_.insert(
- child_sync_inserter.initial().begin(),
- child_sync_inserter.initial().end());
- }
- if (!maybe_skipped) {
+ void handle(kir::ForLoop* fl) final {
+ // Track if last op in body is sync in nested for-loop
+ bool is_last_op_sync_ = false;
+ for (auto expr : fl->body().exprs()) {
+ is_last_op_sync_ = false;
+ if (expr->getExprType().value() == ExprType::Sync) {
initial_sync_ = true;
final_.clear();
+ } else if (expr->getExprType().value() == ExprType::ForLoop) {
+ // Recursively handle nested for-loop
+ LocalSyncInserter child_sync_inserter(&all_alias_allocations_);
+ child_sync_inserter.handle(expr);
+ const auto& child_inputs = child_sync_inserter.all_smem_inputs();
+ const auto& child_outputs = child_sync_inserter.all_smem_outputs();
+ const auto& child_alias_allocations =
+ child_sync_inserter.all_aliased_allocations();
+
+ // Default - Track all smem inputs / outputs
+ all_smem_inputs_.insert(child_inputs.begin(), child_inputs.end());
+ all_smem_outputs_.insert(child_outputs.begin(), child_outputs.end());
+ all_alias_allocations_.insert(
+ child_alias_allocations.begin(), child_alias_allocations.end());
+
+ if (!initial_sync_) {
+ // Parent - None
+ if (!child_sync_inserter.initial_sync_) {
+ // Child - None
+ // Append All Child Outputs to Parent Initial
+ initial_.insert(child_outputs.begin(), child_outputs.end());
+ } else if (child_sync_inserter.has_war_hazard_sync_) {
+ // Child - WAR race
+ // Parent first sync
+ // Inherit Child Initial / Clear Parent Final
+ initial_sync_ = true;
+ is_last_op_sync_ = true;
+ initial_.insert(
+ child_sync_inserter.initial().begin(),
+ child_sync_inserter.initial().end());
+ final_.clear();
+ } else {
+ // Child - 1+
+ // Parent first sync
+ // Inherit Child Initial + Final
+ initial_sync_ = true;
+ initial_.insert(
+ child_sync_inserter.initial().begin(),
+ child_sync_inserter.initial().end());
+ final_.insert(
+ child_sync_inserter.final().begin(),
+ child_sync_inserter.final().end());
+ }
+ } else {
+ // Parent - 1+
+ if (!child_sync_inserter.initial_sync_) {
+ // Child - None
+ // Append All Child to Parent Last
+ final_.insert(child_inputs.begin(), child_inputs.end());
+ } else if (child_sync_inserter.has_war_hazard_sync_) {
+ // Child - WAR race
+ // Clear Parent Last / Discard Child Initial
+ is_last_op_sync_ = true;
+ final_.clear();
+ } else {
+ // Child - 1+
+ // Inherit Child Final / Discard Child Initial
+ final_.insert(
+ child_sync_inserter.final().begin(),
+ child_sync_inserter.final().end());
+ }
+ }
+ } else {
+ handle(expr);
+ }
+ }
+
+ // This level of the nested for-loop may not exist in the kernel.
+ // However, subsequent levels can exist, so we handle the body of the
+ // for-loop first.
+ if (!fl->iter_domain()->isThread() && !fl->iter_domain()->isBroadcast()) {
+ // Determine if any smem TV is written to at beginning of the for-loop
+ // and whether that smem TV is read from at the end of the for-loop
+ // Insert new SyncThreads at end of for-loop to prevent WAR race condition
+ // TODO: replace __syncthreads with __threadfence for alias ops
+ if (detect_intersection(initial_, final_) &&
+ fl->body().exprs().back()->getExprType().value() != ExprType::Sync &&
+ !is_last_op_sync_) {
+ // std::cout << "WAR race detected; Add Sync" << std::endl;
+ has_war_hazard_sync_ = true;
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+ fl->body().push_back(ir_builder.create<kir::Sync>(true));
}
- final_.insert(
- child_sync_inserter.final().begin(),
- child_sync_inserter.final().end());
}
}
- static bool detectIntersection(const TvSet& left, const TvSet& right) {
+ bool detect_intersection(
+ std::unordered_set<const TensorView*>& left,
+ std::unordered_set<const TensorView*>& right) {
for (auto item : left) {
if (right.find(item) != right.end()) {
return true;
return false;
}
- static void addOutputSmemTvs(const kir::Expr* expr, TvSet& set) {
+ void hasOutputSmemExpr(
+ Expr* expr,
+ std::unordered_set<const TensorView*>& set) {
for (auto out : expr->outputs()) {
- if (auto tv = dynamic_cast<kir::TensorView*>(out)) {
- if (tv->memoryType() == MemoryType::Shared) {
+ if (ir_utils::isTV(out)) {
+ auto tv = out->as<TensorView>();
+ if (tv->getMemoryType() == MemoryType::Shared) {
set.insert(tv);
}
}
}
}
- static void addInputSmemTvs(const kir::Expr* expr, TvSet& set) {
- for (auto in : expr->inputs()) {
- if (auto tv = dynamic_cast<kir::TensorView*>(in)) {
- if (tv->memoryType() == MemoryType::Shared) {
+ void hasInputSmemExpr(
+ Expr* expr,
+ std::unordered_set<const TensorView*>& set) {
+ for (auto inp : expr->inputs()) {
+ if (ir_utils::isTV(inp)) {
+ auto tv = inp->as<TensorView>();
+ if (tv->getMemoryType() == MemoryType::Shared) {
set.insert(tv);
}
}
}
private:
+ // Track TensorViews for Allocate nodes that alias another memory location
+ std::unordered_set<unsigned int> all_alias_allocations_;
+
// Track Shared Memory Inputs (Reads) for parent for-loop
- TvSet all_smem_inputs_;
+ std::unordered_set<const TensorView*> all_smem_inputs_;
// Track Shared Memory Outputs (Writes) for parent for-loop
- TvSet all_smem_outputs_;
+ std::unordered_set<const TensorView*> all_smem_outputs_;
// Shared Memory Writes at beginning of the for-loop
// before first SyncThreads
- TvSet initial_;
+ std::unordered_set<const TensorView*> initial_;
// Shared Memory Reads at end of the for-loop
// Cleared after each SyncThreads
- TvSet final_;
+ std::unordered_set<const TensorView*> final_;
- // Track first sync deterministically found in for-loop. Even when a
- // child loop has a sync, if it may not be executed due to non-zero
- // start value, this flag remains false.
+ // Track first sync found in for-loop
bool initial_sync_ = false;
- // Track if last op is sync
- bool is_last_op_sync_ = false;
-};
-
-class ExprFlattener : private kir::IrVisitor {
- private:
- void handle(kir::Expr* expr) {
- if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
- expr->accept(this);
- } else {
- exprs_.push_back(expr);
- }
- }
-
- void visit(const kir::ForLoop* fl) final {
- for (auto expr : fl->body().exprs()) {
- handle(expr);
- }
- }
-
- void visit(const kir::IfThenElse* ite) final {
- for (auto expr : ite->thenBody().exprs()) {
- handle(expr);
- }
- for (auto expr : ite->elseBody().exprs()) {
- handle(expr);
- }
- }
-
- private:
- std::vector<kir::Expr*> exprs_;
-
- public:
- //! Flattens scopes extracting out a single ordered list of exprs.
- static std::vector<kir::Expr*> flatten(
- const std::vector<kir::Expr*>& loop_nests) {
- ExprFlattener flattener;
- for (auto expr : loop_nests) {
- flattener.handle(expr);
- }
- return flattener.exprs_;
- }
-};
-
-class ValidatePlacementAfterWrites : private kir::IrVisitor {
- public:
- //! Validate no expr in writes found under loop
- static void validate(
- kir::ForLoop* loop,
- const std::unordered_set<kir::Expr*>& writes) {
- ValidatePlacementAfterWrites validator(writes);
- validator.handle(loop);
- }
-
- private:
- ValidatePlacementAfterWrites(const std::unordered_set<kir::Expr*>& writes)
- : writes_(writes) {}
-
- void handle(kir::Expr* expr) {
- if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
- expr->accept(this);
- } else {
- TORCH_INTERNAL_ASSERT(
- writes_.find(expr) == writes_.end(),
- "Block sync must be placed after ",
- kir::toString(expr));
- }
- }
-
- void visit(const kir::ForLoop* fl) final {
- for (auto expr : fl->body().exprs()) {
- handle(expr);
- }
- }
-
- void visit(const kir::IfThenElse* ite) final {
- for (auto expr : ite->thenBody().exprs()) {
- handle(expr);
- }
- for (auto expr : ite->elseBody().exprs()) {
- handle(expr);
- }
- }
-
- private:
- const std::unordered_set<kir::Expr*>& writes_;
-};
-
-class ReadAfterWriteSyncs : public kir::MutableIrVisitor {
- private:
- //! Traverse up the loop stack from loops_it and if a halo loop is
- //! found, place a given sync expr before the outer-most halo loop.
- bool insertBeforeHaloLoop(
- std::vector<kir::ForLoop*>::iterator loops_it,
- kir::Sync* sync_expr,
- const std::unordered_set<kir::Expr*>& writes) {
- std::vector<kir::ForLoop*>::iterator halo_loop_it;
- bool halo_loop_found = false;
-
- while (true) {
- if ((*loops_it)->iter_domain()->isThreadDim() &&
- (*loops_it)->iter_domain()->extent() != (*loops_it)->stop()) {
- halo_loop_found = true;
- halo_loop_it = loops_it;
- }
-
- if (loops_it == for_loops_.begin()) {
- break;
- }
- --loops_it;
- }
-
- // No halo loop found. Do not place the sync expr here. Return
- // false to indicate nothing is done.
- if (!halo_loop_found) {
- return false;
- }
-
- auto halo_loop = *halo_loop_it;
-
- // Make sure there's no write to the smem buffer inside the halo
- // loop. syncthreads is moved before the halo loop, so having
- // writes inside the loop invalidates the consistency.
- ValidatePlacementAfterWrites::validate(halo_loop, writes);
-
- if (halo_loop_it == for_loops_.begin()) {
- // place in global scope
- auto place_before_it =
- std::find(loop_nests_.begin(), loop_nests_.end(), halo_loop);
- TORCH_INTERNAL_ASSERT(place_before_it != loop_nests_.end());
- loop_nests_.insert(place_before_it, sync_expr);
- } else {
- auto place_in = *(halo_loop_it - 1);
- place_in->body().insert_before(halo_loop, sync_expr);
- }
-
- return true;
- }
-
- void handle(kir::Expr* expr) {
- if (!ir_utils::isTVOp(expr) || expr->isA<kir::Allocate>()) {
- expr->accept(this);
- return;
- }
-
- if (sync_after_.size() > 0 && sync_after_.front() == expr) {
- sync_after_.pop_front();
- auto last_writes = last_writes_.front();
- last_writes_.pop_front();
- // Found that a sync is needed
- TORCH_INTERNAL_ASSERT(expr->outputs()[0]->isA<kir::TensorView>());
- auto out_tv = expr->outputs()[0]->as<kir::TensorView>();
-
- // Find where a sync needs to be inserted
- // This is very similar to how allocations are placed, simply place sync
- // after the expression instead of placing like allocation where it goes
- // before.
- // TODO: This may be a common operation, could be worth making a utility
- // out of or saving state for tensor view ID -> for loop
- // TODO: Explicitly test the 3 cases below
-
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- auto sync_expr = ir_builder.create<kir::Sync>();
- if (out_tv->fuserTv()->getComputeAtPosition() == 0) {
- // Sync should be placed at global scope, after its outer most loop if
- // it has one.
- kir::Expr* place_after = for_loops_.size() > 0 ? for_loops_[0] : expr;
- // Find location in loop_nests_
- auto place_after_it =
- std::find(loop_nests_.begin(), loop_nests_.end(), place_after);
- TORCH_INTERNAL_ASSERT(
- place_after_it != loop_nests_.end(),
- "Could not figure out where to place synchronization. ",
- "Tried to place after, ",
- toString(place_after),
- ", but could not find this expression at the global scope.");
- loop_nests_.insert(place_after_it + 1, sync_expr);
- } else {
- // Find the last loop in computeAt of out_tv, this is the loop where we
- // would place an allocation for out_tv
- auto fuser_tv = out_tv->fuserTv();
- auto lowered_local_id =
- GpuLower::current()
- ->lowerValue(fuser_tv->axis(
- (int)out_tv->fuserTv()->getComputeAtPosition() - 1))
- ->as<kir::IterDomain>();
-
- auto loops_it = std::find_if(
- for_loops_.begin(),
- for_loops_.end(),
- [&lowered_local_id](const auto& loop) {
- return GpuLower::current()->caLoopMap().areMapped(
- loop->iter_domain(), lowered_local_id) ||
- loop->iter_domain()->parallelType() == ParallelType::Unroll;
- });
-
- TORCH_INTERNAL_ASSERT(loops_it != for_loops_.end());
-
- // block sync must be placed before halo-extended loops
- if (insertBeforeHaloLoop(loops_it, sync_expr, last_writes)) {
- return;
- }
-
- auto place_in = *loops_it;
- kir::Expr* place_after = nullptr;
-
- if (loops_it + 1 == for_loops_.end()) {
- // Inline allocation, place after expr
- place_after = expr;
- } else {
- // Place allocation after the last computeAt axis
- // TODO: may be more efficient to place after the first non-computeAt
- // axis
- place_after = *(loops_it + 1);
- }
-
- place_in->body().insert_after(place_after, sync_expr);
- }
- }
- }
-
- void visit(kir::ForLoop* fl) final {
- for_loops_.push_back(fl);
- // Modifying in place, make a copy of the vector
- const std::vector<kir::Expr*> exprs = fl->body().exprs();
- for (auto expr : exprs) {
- handle(expr);
- }
- for_loops_.pop_back();
- }
-
- void visit(kir::IfThenElse*) final {
- TORCH_INTERNAL_ASSERT(
- false,
- "Pass does not support conditional statements, ",
- "this pass should be run before any conditionals are placed in code.");
- }
-
- // Clear the modify status for all shared memory buffers
- static void cleanSharedMemory(
- std::unordered_map<kir::Val*, kir::Expr*>& smem) {
- smem.clear();
- }
-
- // Return a set of expressions that modify shared-memory
- // tensors. Expressions are excluded when syncthreads are already
- // placed.
- std::unordered_set<kir::Expr*> isModifiedSharedMemory(
- const std::unordered_map<kir::Val*, kir::Expr*>& smem,
- const std::vector<kir::Val*>& tvs) const {
- std::unordered_set<kir::Expr*> last_writes;
- for (auto tv : tvs) {
- auto it = smem.find(tv);
- if (it != smem.end()) {
- last_writes.insert(it->second);
- }
- }
- return last_writes;
- }
-
- ReadAfterWriteSyncs(std::vector<kir::Expr*> _loop_nests)
- : loop_nests_(std::move(_loop_nests)) {
- // Fusion shared_memory values
- // Tracks if shared memory is modified
- std::unordered_map<kir::Val*, kir::Expr*> smem;
-
- // Flatten all the expressions
- auto flattened_exprs = ExprFlattener::flatten(loop_nests_);
-
- kir::Expr* prev_tv_expr = nullptr;
- for (auto expr : flattened_exprs) {
- if (!ir_utils::isTVOp(expr) || expr->isA<kir::Allocate>()) {
- continue;
- }
-
- auto last_writes = isModifiedSharedMemory(smem, expr->inputs());
- if (!last_writes.empty()) {
- TORCH_INTERNAL_ASSERT(
- prev_tv_expr != nullptr,
- "Can't require sync on inputs, however, detected it's needed.");
- sync_after_.push_back(prev_tv_expr);
- last_writes_.push_back(last_writes);
- cleanSharedMemory(smem);
- }
-
- for (auto out : expr->outputs()) {
- if (out->isA<kir::TensorView>()) {
- if (out->as<kir::TensorView>()->memoryType() == MemoryType::Shared) {
- smem[out] = expr;
- }
- }
- }
-
- prev_tv_expr = expr;
- }
-
- // Insert read after write syncs
- const std::vector<kir::Expr*> exprs = loop_nests_;
- for (auto expr : exprs) {
- handle(expr);
- }
-
- TORCH_INTERNAL_ASSERT(
- sync_after_.empty(), "Didn't place all required syncs.");
- }
-
- private:
- //! Keep track of expressions that must be followed by syncthreads
- std::deque<kir::Expr*> sync_after_;
-
- //! Keep track of write expressions that must be placed before
- //! syncthreads.
- //!
- //! syncthreads is placed after for each expression of
- //! sync_after_. However, if it's inside a loop with halo, it must
- //! be placed before that. last_writes_ keeps track of expressions
- //! modifying the smem buffer each syncthreads is used for so that
- //! it is not placed before those write expressions.
- std::deque<std::unordered_set<kir::Expr*>> last_writes_;
-
- //! Keep track of for loops while inserting syncthreads
- std::vector<kir::ForLoop*> for_loops_;
-
- //! Loop-nests where syncthreads are inserted
- std::vector<kir::Expr*> loop_nests_;
-
- public:
- static std::vector<kir::Expr*> insert(
- const std::vector<kir::Expr*>& loop_nests) {
- ReadAfterWriteSyncs inserter(loop_nests);
- return inserter.loop_nests_;
- }
+ // Track sync was inserted for war hazard
+ bool has_war_hazard_sync_ = false;
};
} // namespace
-std::vector<kir::Expr*> insertRawThreadSynchronization(
- const std::vector<kir::Expr*>& exprs) {
- FUSER_PERF_SCOPE("GpuLower::Lower::insertRawThreadSynchronization");
- return ReadAfterWriteSyncs::insert(exprs);
-}
-
-std::vector<kir::Expr*> insertWarThreadSynchronization(
- const std::vector<kir::Expr*>& exprs) {
- FUSER_PERF_SCOPE("GpuLower::Lower::insertWarThreadSynchronization");
- LocalSyncInserter::insertSyncs(exprs);
+std::vector<Expr*> insertThreadSynchronization(
+ Fusion* fusion,
+ const std::vector<Expr*>& exprs) {
+ FUSER_PERF_SCOPE("insertThreadSynchronization");
+ FusionGuard fg(fusion);
+ LocalSyncInserter::InsertSyncs(exprs);
return exprs;
}
+
} // namespace cuda
} // namespace fuser
} // namespace jit
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <vector>
namespace cuda {
//! Insert sync at end of for-loops to prevent write-after-read race condition.
-//!
//! WAR race condition occurs when the next iteration of the loop overwrites
//! shared memory value before a previous operation has finished reading it.
//!
//! If Child - End and Parent has zero remaining operations, then
//! Parent inherits Child End.
//!
-std::vector<kir::Expr*> insertWarThreadSynchronization(
- const std::vector<kir::Expr*>& exprs);
-
-//! Insert syncs between writing to shared memory and then reading it.
-std::vector<kir::Expr*> insertRawThreadSynchronization(
- const std::vector<kir::Expr*>& exprs);
+std::vector<Expr*> insertThreadSynchronization(
+ Fusion* fusion,
+ const std::vector<Expr*>& exprs);
} // namespace cuda
} // namespace fuser
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
#include <algorithm>
-#include <deque>
#include <numeric>
namespace torch {
namespace fuser {
namespace cuda {
-std::vector<kir::Expr*> LoopNestGenerator::loweredExprs(
- const std::vector<Expr*>& exprs) {
- FUSER_PERF_SCOPE("GpuLower::Lower::LoopNestGenerator::loweredExprs");
- TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr);
- LoopNestGenerator generator(exprs);
- return generator.lowered_exprs_;
-}
-
-LoopNestGenerator::LoopNestGenerator(const std::vector<Expr*>& exprs) {
+LoopNestGenerator::LoopNestGenerator(
+ Fusion* fusion,
+ ThreadPredicateMap& thread_predicates,
+ const std::vector<Expr*>& exprs)
+ : fusion_(fusion),
+ thread_predicates_(thread_predicates),
+ ir_builder_(GpuLower::current()->kernel()) {
generate(exprs);
}
-namespace {
+// Create, place, and return the allocation for tv
+Expr* LoopNestGenerator::pushAlloc(TensorView* tv) {
+ TORCH_INTERNAL_ASSERT(
+ !(FusionGuard::getCurFusion()->hasInput(tv) ||
+ FusionGuard::getCurFusion()->hasOutput(tv)),
+ "Tried to allocate an input or output tensor.");
+
+ const auto alloc_point = loop_utils::getAllocPoint(tv, for_loops);
+ const auto alloc_loop = alloc_point.first;
+ const auto alloc_pos = alloc_point.second;
+
+ // Grab the dimensions the allocation will be based on to compute a size
+ std::vector<Val*> alloc_dims;
+ for (size_t i = alloc_pos; i < tv->nDims(); i++) {
+ IterDomain* compute_at_dim = tv->getComputeAtAxis(i).first;
+ IterDomain* local_dim = tv->axis(i);
+ if (
+ // If shared memory, don't use any IDs bound to a grid dimension
+ (tv->memory_type_ == MemoryType::Shared &&
+ compute_at_dim->isBlockDim()) ||
+ // If local memory, don't use any IDs bound to a grid or block dimension
+ (tv->memory_type_ == MemoryType::Local && compute_at_dim->isThread()) ||
+ // If we're reducing this dimension, don't use it in the allocation
+ // computation
+ local_dim->isReduction() ||
+ // If this is a broadcast dimension, don't use it in the allocation
+ // computation
+ local_dim->isBroadcast()) {
+ continue;
+ }
+ alloc_dims.push_back(compute_at_dim->rawExtent());
+ }
-kir::ForLoop* openForHelper(kir::ForLoop* scope, IterDomain* id) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
- const auto kir_id = gpu_lower->lowerValue(id)->as<kir::IterDomain>();
- auto extent_with_halo = gpu_lower->haloInfo().getExtent(kir_id);
- kir::ForLoop* new_scope = nullptr;
- if (extent_with_halo) {
- // When an axis is extended with halo, unrolling and vectorization
- // are assumed to not be used for now.
- TORCH_INTERNAL_ASSERT(
- id->getParallelType() != ParallelType::Unroll &&
- !isParallelTypeVectorize(id->getParallelType()));
- // Use the extent that's extended by halo
- new_scope = ir_builder.create<kir::ForLoop>(
- kir_id,
- id->isBroadcast() ? ir_builder.zeroVal()
- : ir_builder.create<kir::Int>(c10::nullopt),
- nullptr,
- extent_with_halo,
- nullptr,
- false,
- nullptr);
+ // Multiply all the dimensions we're going to use for the allocation together
+ // to get the total size
+ Val* size = nullptr;
+ if (alloc_dims.size() == 0) {
+ size = ir_builder_.create<kir::Int>(1);
} else {
- new_scope = ir_builder.create<kir::ForLoop>(kir_id);
+ size = GpuLower::lowerValue(alloc_dims[0]);
+ for (const auto i : c10::irange(1, alloc_dims.size())) {
+ size = ir_builder_.mulExpr(size, GpuLower::lowerValue(alloc_dims[i]));
+ }
}
- if (scope != nullptr) {
- scope->body().insert(0, new_scope);
+
+ // Create the allocation node
+ const auto lowered_tv = ir_builder_.create<kir::TensorView>(tv);
+ const auto alloc = ir_builder_.create<kir::Allocate>(
+ lowered_tv, lowered_tv->memoryType(), size);
+
+ // Track Dynamic Shared Memory Allocation Nodes
+ if (tv->getMemoryType() == MemoryType::Shared) {
+ if (!size->isConstScalar()) {
+ dynamic_smem_.push_front(alloc);
+ return nullptr;
+ }
}
- return new_scope;
-}
-} // namespace
+ // Place the allocation
+ if (alloc_loop != nullptr) {
+ alloc_loop->body().insert(for_loop_allocations_[alloc_loop], alloc);
+ ++for_loop_allocations_[alloc_loop];
+ } else {
+ lowered_exprs.insert(lowered_exprs.begin(), alloc);
+ }
-void LoopNestGenerator::openFor(IterDomain* iter_domain) {
- if (for_loops_.size() > 0) {
- const auto new_scope = openForHelper(for_loops_.back(), iter_domain);
- // for_loop_allocations_.insert({new_scope, 0});
- for_loops_.push_back(new_scope);
+ return alloc;
+}
+
+void LoopNestGenerator::openFor(std::pair<IterDomain*, TensorView*> id_pair) {
+ compute_at_scope.push_back(id_pair);
+ IterDomain* id = id_pair.first;
+ if (for_loops.size() > 0) {
+ kir::ForLoop* new_scope = scope_utils::openFor(for_loops.back(), id);
+ for_loop_allocations_.insert({new_scope, 0});
+ for_loops.push_back(new_scope);
} else {
- for_loops_.push_back(openForHelper(nullptr, iter_domain));
- lowered_exprs_.insert(lowered_exprs_.begin(), for_loops_.back());
+ for_loops.push_back(scope_utils::openFor(nullptr, id));
+ lowered_exprs.push_back(for_loops.back());
}
}
-void LoopNestGenerator::closeFor() {
- TORCH_INTERNAL_ASSERT(!for_loops_.empty());
- for_loops_.pop_back();
+void LoopNestGenerator::popFor() {
+ TORCH_INTERNAL_ASSERT(
+ !for_loops.empty() && !compute_at_scope.empty(),
+ "Can't pop for loop, scope is empty.");
+ for_loops.pop_back();
+ compute_at_scope.pop_back();
}
-void LoopNestGenerator::pushFront(kir::Expr* expr) {
- if (for_loops_.size() == 0) {
- lowered_exprs_.insert(lowered_exprs_.begin(), expr);
+void LoopNestGenerator::pushBack(Expr* expr) {
+ if (for_loops.size() == 0) {
+ lowered_exprs.push_back(expr);
} else {
- for_loops_.back()->body().insert(0, expr);
+ scope_utils::pushBack(for_loops.back(), expr);
}
}
-void LoopNestGenerator::handle(Expr* expr) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
+// Update for loop structure based on this TensorView, if there's an allocation
+// stmt, send it in so we can make sure that we insert this initialization after
+// it
+void LoopNestGenerator::initReduction(
+ TensorView* tv,
+ Val* init_val,
+ Expr* alloc_expr) {
+ auto alloc_point = loop_utils::getAllocPoint(tv, for_loops);
+ auto alloc_loop = alloc_point.first;
+ auto alloc_pos = alloc_point.second;
+
+ // Grab the IDs that will be involved in the initialization, ignore reduction
+ // dimensions. Everything else will be iterated over to cover the entire
+ // buffer. Index compute will ignore [block, grid]Dims depending on buffer
+ // memory location
+ std::vector<kir::IterDomain*> ids;
+ for (size_t i = alloc_pos; i < tv->nDims(); i++) {
+ IterDomain* dim = tv->getComputeAtAxis(i).first;
+ if (dim->isReduction())
+ continue;
+ ids.push_back(GpuLower::lowerValue(dim)->as<kir::IterDomain>());
+ }
+
+ // Unsafe clone, as we want an exact replica of tv so we can create a UnaryOp
+ // to set the buffer to the init_val.
+ auto clone = tv->unsafeClone();
+ thread_predicates_.duplicate(clone, tv);
+ // The initilization stmt that will be located inside the loop nest (if there
+ // is one)
+ auto init_stmt = new UnaryOp(UnaryOpType::Set, clone, init_val);
+
+ // Init a pointer that will become the entirety of the initialization
+ Expr* init_loop_nest = nullptr;
+
+ // The for loop that we will place the initialization within (alloc_pos - 1),
+ // if one exists. Once we're done this inner_fl will be the inner most loop
+ // containing the init_stmt
+ kir::ForLoop* inner_fl = nullptr;
+ if (alloc_pos >= 1)
+ inner_fl = for_loops[alloc_pos - 1];
+
+ // Work through the iter domains that we need to initialize on, outside to
+ // inside, to construct the loop nest for the initialization.
+ for (auto id : ids) {
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ kir::ForLoop* new_fl;
+
+ if (id->isThread()) {
+ // If based on a thread, make sure we get the named Int right
+ std::stringstream ss;
+ ss << id->getParallelType();
+ new_fl = ir_builder_.create<kir::ForLoop>(
+ ir_builder_.create<kir::NamedScalar>(ss.str(), DataType::Int),
+ id,
+ inner_fl);
+ } else {
+ // Otherwise it's just a new int-
+ new_fl = ir_builder_.create<kir::ForLoop>(
+ ir_builder_.create<kir::Int>(c10::nullopt), id, inner_fl);
+ }
+ for_loop_allocations_.insert({new_fl, 0});
+
+ if (init_loop_nest == nullptr) {
+ // If this is our first generated loop, then it will be our outer most
+ // loop nest
+ init_loop_nest = new_fl;
+ } else {
+ // Otherwise place it inside the last generated loop
+ inner_fl->body().push_back(new_fl);
+ }
+ // Increment the inner most for loop
+ inner_fl = new_fl;
+ }
+
+ if (init_loop_nest == nullptr) {
+ // If no loops were generated, than our init_stmt is all we need
+ init_loop_nest = init_stmt;
+ } else {
+ // If there were for loops generated, place the init_stmt in the inner most
+ // for loop.
+ inner_fl->body().push_back(init_stmt);
+ }
+
+ // If we don't have an alloc_loop defined it means it needs to go in
+ // lowered_exprs. Make sure to place after the allocation of what we're
+ // initializing if there is one.
+ if (alloc_loop == nullptr) {
+ if (alloc_expr != nullptr) {
+ auto it =
+ std::find(lowered_exprs.begin(), lowered_exprs.end(), alloc_expr);
+ TORCH_INTERNAL_ASSERT(
+ it != lowered_exprs.end(),
+ "Could not figure out where to initialize the buffer for ",
+ tv);
+ lowered_exprs.insert(it + 1, init_loop_nest);
+ } else {
+ lowered_exprs.insert(lowered_exprs.begin(), init_loop_nest);
+ }
+ } else {
+ if (alloc_expr != nullptr) {
+ // If there is an allocation for this TensorView
+ // place this loop nest after it
+ alloc_loop->body().insert_after(alloc_expr, init_loop_nest);
+ ++for_loop_allocations_[alloc_loop];
+ } else {
+ // Otherwise we're allocating a global value
+ alloc_loop->body().insert(0, init_loop_nest);
+ }
+ }
+}
+void LoopNestGenerator::handle(Expr* expr) {
// Check if it's a tensor view expression we need to place in the loop nest
// structure
if (!ir_utils::isTVOp(expr)) {
- // Close all the loops, scalar operations cannot be inside for loops based
- // on expr sorting.
- while (!for_loops_.empty()) {
- closeFor();
- }
- pushFront(gpu_lower->lowerExpr(expr));
-
for (auto out : expr->outputs()) {
TORCH_INTERNAL_ASSERT(
out->getValType().value() == ValType::Scalar,
" cannot lower ",
out->getValType().value());
- pushFront(ir_builder.create<kir::Allocate>(
- gpu_lower->lowerValue(out),
+ pushBack(ir_builder_.create<kir::Allocate>(
+ GpuLower::lowerValue(out),
MemoryType::Local,
- ir_builder.create<kir::Int>(1)));
+ ir_builder_.create<kir::Int>(1)));
}
+ pushBack(expr);
return;
}
- TensorView* out_tv = expr->output(0)->as<TensorView>();
+ // 0) Apply SyncThreads if any shared memory inputs are modified
+ bool shared_memory_sync = false;
+ for (auto in : expr->inputs()) {
+ shared_memory_sync |= isModifiedSharedMemory(in);
+ }
+ if (shared_memory_sync) {
+ TORCH_INTERNAL_ASSERT(!for_loops.empty(), "Attempted to add SyncThreads");
+ // push Sync to the back of the last for loop
+ scope_utils::pushBack(for_loops.back(), ir_builder_.create<kir::Sync>());
+ cleanSharedMemory();
+ }
+
+ TensorView* out = expr->output(0)->as<TensorView>();
// Figure out what the entire loop structure should look like.
- std::deque<IterDomain*> loop_structure;
-
- // Fill the entire loop structure by Looking at each axis
- // individually in out's domain
- for (size_t out_i = 0; out_i < out_tv->nDims(); out_i++) {
- // Note: It is not safe to skip trivial reduction axes as they could be
- // inlined with other tensor views. This happens in
- // NVFuserTest.FusionBNRepro_CUDA as of this commit on norm_hack_2_rebased
- // branch
-
- // Look up the concrete ID in the parallel map, not in the loop
- // map, which also maps non-CA axes.
- auto concrete_id =
- gpu_lower->caParallelMap().getConcreteMappedID(out_tv->axis(out_i));
- loop_structure.push_back(concrete_id);
- }
-
- auto loop_structure_it = loop_structure.begin();
- auto for_loop_it = for_loops_.begin();
- auto last_for_loop_matched = for_loops_.begin();
-
- // Match the loop structure with the current for-loops. Reuse
- // matching loops and close unmatched ones.
- while (loop_structure_it != loop_structure.end() &&
- for_loop_it != for_loops_.end()) {
- auto lowered_out_id =
- gpu_lower->lowerValue(*loop_structure_it)->as<kir::IterDomain>();
- // Similar to the above, the parallel map is used rather than the
- // loop map. Again, non-CA axes should not share loops, so the
- // parallel map should be used.
- if (gpu_lower->caParallelMap().areMapped(
- lowered_out_id, (*for_loop_it)->iter_domain())) {
- loop_structure_it++;
- last_for_loop_matched = ++for_loop_it;
+ std::deque<std::pair<IterDomain*, TensorView*>> loop_structure;
+
+ // As we go through iteration domains track the previous view
+ TensorView* last_ca_view = nullptr;
+ // Check where in the previous view our last axis was in that view
+ int64_t last_ca_view_ind = 0;
+
+ // Look at each axis individually in out's domain
+ for (int64_t out_i = 0; out_i < (int64_t)out->getThisComputeAtAxis();
+ out_i++) {
+ // Grab the axis information
+ auto ca_point = out->getComputeAtAxis(out_i);
+ auto ca_view = ca_point.second;
+ auto ca_id = ca_point.first;
+
+ // Figure out if there are axes in the compute at tensor view that aren't
+ // in out, make sure to also open them. Check where to start looking for
+ // them in the compute at view.
+ size_t start = 0;
+ if (last_ca_view == nullptr) {
+ // Start at the begining, we haven't processed any axes yet.
+ start = 0;
+ } else if (last_ca_view == ca_view) {
+ // This view is the same as the last axis, so start where we left off.
+ start = last_ca_view_ind + 1;
} else {
- ++for_loop_it;
+ // This is a new view, figure out where we are in it, and start from there
+ for (start = 0; start < ca_view->nDims(); start++) {
+ if (loop_structure.back().first ==
+ ca_view->getComputeAtAxis(start).first) {
+ break;
+ }
+ }
+ start++;
+ }
+
+ // Go from start, and open all loops in the computeAt view until we hit the
+ // one associated with out->getComputeAtAxis(out_i)
+ for (size_t ca_i = start; ca_i < ca_view->nDims(); ca_i++) {
+ // Note that ca_view->getComputeAtAxis(ca_i) is equivalent to
+ // std::pair(ca_view->axis(ca_i), ca_view)
+ loop_structure.push_back(ca_view->getComputeAtAxis(ca_i));
+
+ // Update the last view processed
+ last_ca_view_ind = ca_i;
+ last_ca_view = ca_view;
+ if (ca_view->getComputeAtAxis(ca_i).first == ca_id) {
+ break;
+ }
+ }
+
+ // Shouldn't ever hit this, but make sure we hit the break above, meaning we
+ // added all necessary axes from the compute at view.
+ TORCH_INTERNAL_ASSERT(
+ ca_view->getComputeAtAxis(last_ca_view_ind).first == ca_id);
+ }
+
+ // We're up to the compute at point in loop_structure, grab the remaining
+ // axes.
+ for (int64_t out_i = (int64_t)out->getThisComputeAtAxis();
+ out_i < (int64_t)out->nDims();
+ out_i++) {
+ // It's actually local, but getComputeAtAxis returns a std::pair, axis
+ // doesn't
+ loop_structure.push_back(out->getComputeAtAxis(out_i));
+ }
+
+ // At this point loop_structure contains our overal target loop nest structure
+ // Lets get a copy of the loop structure, and figure out which loops we need
+ // to open.
+ decltype(loop_structure) loops_to_open(loop_structure);
+ // Pop out loops already opened
+ for (const auto& existing_loop : for_loops) {
+ if (loops_to_open.empty()) {
+ // Nothing to open
+ break;
+ }
+ if (GpuLower::lowerValue(loops_to_open.front().first)
+ ->as<kir::IterDomain>() == existing_loop->iter_domain()) {
+ loops_to_open.pop_front();
+ }
+ }
+
+ // At this point for_loops + loops_to_open contains our overal target loop
+ // nest structure. Open loops in "loops_to_open".
+ while (!loops_to_open.empty()) {
+ openFor(loops_to_open.front());
+ loops_to_open.pop_front();
+ }
+
+ Expr* alloc_expr = nullptr;
+ // Place the allocation for out
+ if (!FusionGuard::getCurFusion()->hasInput(out) &&
+ !FusionGuard::getCurFusion()->hasOutput(out)) {
+ alloc_expr = pushAlloc(out);
+ }
+
+ // If this is a reduction, initialize the output (open for loops to inner
+ // most, predicate, initialize, place next after allocation if exists, close
+ // to computeAt)
+ if (out->hasReduction()) {
+ initReduction(out, expr->as<ReductionOp>()->init(), alloc_expr);
+ }
+
+ // Place the expression
+ pushBack(expr);
+
+ // If output is a shared memory buffer, set modified status
+ modifySharedMemory(out);
+
+ // Reduce the loop nest structure back to computeAt
+ if (out->getThisComputeAtAxis() == 0) {
+ while (!for_loops.empty()) {
+ popFor();
+ }
+ } else {
+ auto ca_axis = out->getThisComputeAtAxis() - 1;
+ while (for_loops.size() > 0 &&
+ for_loops.back()->iter_domain() !=
+ GpuLower::lowerValue(out->getComputeAtAxis(ca_axis).first)
+ ->as<kir::IterDomain>()) {
+ popFor();
+ }
+ }
+}
+
+namespace {
+
+TensorView* findOutputTensor(Expr* expr) {
+ TORCH_INTERNAL_ASSERT(
+ expr->outputs().size() <= 1, "Unexpected number of outputs");
+ if (expr->outputs().size() != 1) {
+ return nullptr;
+ }
+ auto out = expr->output(0);
+ if (out->getValType() != ValType::TensorView) {
+ return nullptr;
+ }
+ return out->as<TensorView>();
+}
+
+void findTargetTensor(Expr* expr, TensorView*& target, unsigned& score) {
+ TORCH_INTERNAL_ASSERT(expr->outputs().size() <= 1);
+
+ TensorView* out_tv = findOutputTensor(expr);
+ if (out_tv == nullptr) {
+ target = nullptr;
+ score = 0;
+ return;
+ }
+
+ if (!out_tv->hasComputeAt()) {
+ target = out_tv;
+ // No computeAt, so this should come last.
+ score = std::numeric_limits<unsigned>::max();
+ return;
+ }
+
+ auto axis = out_tv->getRelativeComputeAtAxis();
+ target = out_tv->getComputeAtView();
+ while (target->hasComputeAt()) {
+ if (target->getThisComputeAtAxis() < axis) {
+ break;
}
+ axis = target->getComputeAtRelPos(axis);
+ target = target->getComputeAtView();
}
- auto n_loops_to_close =
- std::distance(last_for_loop_matched, for_loops_.end());
+ score = axis;
+}
+// Type definitions for brevity
+using ExprListT = std::vector<Expr*>;
+using TargetGroupMapT = std::unordered_map<TensorView*, ExprListT>;
+using ExprTargetMapT = std::unordered_map<Expr*, TensorView*>;
+using ScoreT = unsigned;
+using ExprScoreMapT = std::unordered_map<const Expr*, ScoreT>;
+
+void sanityCheck(
+ const ExprListT& exprs,
+ const ExprListT& reordered_exprs,
+ const ExprScoreMapT& scores,
+ const ExprTargetMapT& target_map,
+ const TargetGroupMapT& computed_at_exprs) {
+ const auto num_exprs = exprs.size();
+ TORCH_INTERNAL_ASSERT(scores.size() == num_exprs);
TORCH_INTERNAL_ASSERT(
- n_loops_to_close >= 0 &&
- n_loops_to_close <= (std::ptrdiff_t)for_loops_.size(),
- "Tried to close an invalid number of loops: ",
- n_loops_to_close);
-
- if (max_close < n_loops_to_close && max_close > 0) {
- // Figure out where the last for loop matches from out_tv, go until the
- // max_close loop marked from previous tv's producer domain. Make sure
- // none of these domains are actually present in current out_tv. If these
- // loops map to current out_tv, it should be responsible for deciding if
- // they stay or go, this could result from an invalid compute at topology
- // on the DAG or bad expression sorting.
- auto for_loops_it = for_loops_.end() - n_loops_to_close;
- auto for_loops_it_end = for_loops_.end() - max_close;
-
- for (; for_loops_it != for_loops_it_end; for_loops_it++) {
+ reordered_exprs.size() + target_map.size() == num_exprs);
+ int num_computed_exprs = std::accumulate(
+ computed_at_exprs.begin(),
+ computed_at_exprs.end(),
+ 0,
+ [](int acc, const std::pair<TensorView*, ExprListT>& p) {
+ return acc + p.second.size();
+ });
+ TORCH_INTERNAL_ASSERT(num_computed_exprs == (int)target_map.size());
+}
+
+// Arrange exprs into loop-nest groups. Loop-nest groups are
+// disjoint grouping of expressions based on the expression
+// where each expression is computed at.
+void groupExpressions(
+ Expr* expr,
+ ExprListT& reordered_exprs,
+ ExprTargetMapT& target_map,
+ TargetGroupMapT& computed_at_exprs,
+ ExprScoreMapT& scores) {
+ TensorView* target_tensor = nullptr;
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ ScoreT score;
+ findTargetTensor(expr, target_tensor, score);
+ scores.emplace(expr, score);
+ if (target_tensor == nullptr) {
+ reordered_exprs.push_back(expr);
+ } else {
+ target_map.emplace(expr, target_tensor);
+ if (computed_at_exprs.find(target_tensor) == computed_at_exprs.end()) {
+ computed_at_exprs.emplace(target_tensor, TargetGroupMapT::mapped_type());
+ }
+ auto& exprs = computed_at_exprs[target_tensor];
+ exprs.push_back(expr);
+ }
+}
+
+// Sort each loop-nest group based on axis (i.e., score)
+void sortGroup(ExprListT& exprs, ExprScoreMapT& scores) {
+ std::stable_sort(
+ exprs.begin(),
+ exprs.end(),
+ [&scores](const Expr* expr1, const Expr* expr2) {
+ return scores[expr1] < scores[expr2];
+ });
+}
+
+// If an expression is missing from expr_status, search for all ancestors
+// that are necessary for the expression
+void mapMissingInputsToAncestors(
+ const TensorView* tv,
+ const std::unordered_map<const Expr*, bool>& expr_status,
+ std::vector<const TensorView*>& ancestors) {
+ const Expr* expr = tv->getOrigin();
+ const auto& expr_inputs = ir_utils::filterByType<TensorView>(expr->inputs());
+ for (auto input : expr_inputs) {
+ const Expr* input_origin = input->getOrigin();
+ if (input_origin != nullptr) {
+ if (expr_status.find(input_origin) == expr_status.end()) {
+ mapMissingInputsToAncestors(input, expr_status, ancestors);
+ } else {
+ ancestors.push_back(input);
+ }
+ }
+ }
+}
+
+// For each expression, find all TensorView inputs.
+// If an input TensorView is missing from expr_status,
+// find that input's ancestors that are present in expr_status.
+std::unordered_map<const Expr*, std::vector<const TensorView*>> findExprTvInputs(
+ const std::unordered_map<const Expr*, bool>& expr_status) {
+ std::unordered_map<const Expr*, std::vector<const TensorView*>>
+ map_expr_to_tv_inputs;
+
+ // Iterate over all exprs and filter missing expr
+ for (auto item : expr_status) {
+ const auto expr = item.first;
+ const auto& expr_inputs =
+ ir_utils::filterByType<TensorView>(expr->inputs());
+
+ map_expr_to_tv_inputs.insert({expr, std::vector<const TensorView*>()});
+ auto& tv_inputs = map_expr_to_tv_inputs[expr];
+
+ for (auto input : expr_inputs) {
+ const Expr* input_origin = input->getOrigin();
+ bool missing_input = input_origin != nullptr &&
+ expr_status.find(input_origin) == expr_status.end();
+
+ if (missing_input) {
+ // Map missing input to ancestor that is present in exprs_status
+ std::vector<const TensorView*> ancestors;
+ mapMissingInputsToAncestors(input, expr_status, ancestors);
+ tv_inputs.insert(tv_inputs.begin(), ancestors.begin(), ancestors.end());
+ } else {
+ tv_inputs.push_back(input);
+ }
+ }
+ }
+ return map_expr_to_tv_inputs;
+}
+
+// Reorder expressions that are computed at the same position in a
+// breadth-first order.
+void reorderSegmentBreadthFirst(
+ ExprListT::iterator seg_begin,
+ ExprListT::const_iterator seg_end) {
+ // mapping of each expression to a bool flag indicating if it's
+ // already been visited
+ std::unordered_map<const Expr*, bool> expr_status;
+ for (auto it = seg_begin; it != seg_end; ++it) {
+ expr_status.insert({*it, false});
+ }
+
+ // Holds all input TVs necessary for every expression.
+ const auto map_expr_to_tv_inputs = findExprTvInputs(expr_status);
+
+ while (seg_begin != seg_end) {
+ std::vector<const Expr*> visited_exprs;
+ for (auto it = seg_begin; it != seg_end; ++it) {
+ const auto expr = *it;
+ const auto& expr_inputs = map_expr_to_tv_inputs.at(expr);
+
+ // if all input expressions are visited
+ // then expr can be visited
+ const bool ready_to_visit = std::all_of(
+ expr_inputs.begin(),
+ expr_inputs.end(),
+ [&expr_status](const TensorView* input) {
+ const Expr* input_origin = input->getOrigin();
+ return input_origin == nullptr ||
+ (expr_status.find(input_origin) != expr_status.end() &&
+ expr_status.at(input_origin));
+ });
+ if (ready_to_visit) {
+ std::iter_swap(seg_begin, it);
+ TORCH_INTERNAL_ASSERT(*seg_begin == expr);
+ ++seg_begin;
+ visited_exprs.push_back(expr);
+ }
+ }
+ for (const auto& visited_expr : visited_exprs) {
+ expr_status.at(visited_expr) = true;
+ }
+ }
+}
+
+// Reorder expressions in a group in a breadth-first order. Reordering
+// is done within a subset of expressions that have the same score
+// (i.e., computeAt position). For each subset,
+// reorderSegmentBreadthFirst is called.
+void reorderGroupBreadthFirst(ExprListT& exprs, const ExprScoreMapT& scores) {
+ auto seg_begin = exprs.begin();
+ auto seg_end = exprs.begin();
+ ScoreT seg_score = scores.at(*seg_begin);
+ while (seg_end != exprs.end()) {
+ const auto expr = *seg_end;
+ const auto cur_score = scores.at(expr);
+ if (seg_score == cur_score) {
+ // advance further
+ ++seg_end;
+ continue;
+ } else if (seg_score < cur_score) {
+ // segment ended
+ reorderSegmentBreadthFirst(seg_begin, seg_end);
+ seg_begin = seg_end;
+ seg_score = cur_score;
+ } else {
+ // exprs list is assumed to be sorted in the order of scores, so
+ // this should never be reachable
TORCH_INTERNAL_ASSERT(
- std::none_of(
- loop_structure_it,
- loop_structure.end(),
- [&gpu_lower, &for_loops_it](IterDomain* loop_structure_id) {
- // Check loop structure doesn't map for_loops in for loop map
- auto id0 = (*for_loops_it)->iter_domain();
- auto id1 = gpu_lower->lowerValue(loop_structure_id)
- ->as<kir::IterDomain>();
- return gpu_lower->caLoopMap().areMapped(id0, id1);
- }),
- "Invalid loop found to close.");
+ false, "Unexpected expression: ", expr, ", score: ", cur_score);
}
+ }
+ reorderSegmentBreadthFirst(seg_begin, seg_end);
+}
- n_loops_to_close = std::min(n_loops_to_close, max_close);
+void mergeNonRootGroupsIntoRootGroups(
+ TargetGroupMapT& computed_at_exprs,
+ ExprTargetMapT& target_map) {
+ for (auto it = computed_at_exprs.begin(); it != computed_at_exprs.end();) {
+ TensorView* target = it->first;
+ if (target->hasComputeAt()) {
+ Expr* target_expr = target->getOrigin();
+ TensorView* target_of_target = target_map.at(target_expr);
+ auto& target_group = computed_at_exprs.at(target_of_target);
+ auto pos =
+ std::find(target_group.begin(), target_group.end(), target_expr);
+ TORCH_INTERNAL_ASSERT(pos != target_group.end());
+ target_group.insert(pos, it->second.begin(), it->second.end());
+ // Update the target map
+ for (auto& inserted_expr : it->second) {
+ TORCH_INTERNAL_ASSERT(target_map.at(inserted_expr) == target);
+ target_map.at(inserted_expr) = target_of_target;
+ }
+ it = computed_at_exprs.erase(it);
+ } else {
+ ++it;
+ }
}
+}
- for (int64_t i_loop_close = 0; i_loop_close < n_loops_to_close;
- i_loop_close++) {
- closeFor();
+// Merge root loop-nests into reordered_exprs
+void mergeGroupsIntoSortedList(
+ TargetGroupMapT& computed_at_exprs,
+ ExprListT& reordered_exprs) {
+ while (computed_at_exprs.size() > 0) {
+ // Find the root loop-nest that has no dependency with the other
+ // loop-nests
+ TensorView* cur_target = computed_at_exprs.begin()->first;
+ for (auto& group : computed_at_exprs) {
+ auto target = group.first;
+ if (cur_target == target)
+ continue;
+ if (DependencyCheck::isDependencyOf(target, cur_target)) {
+ cur_target = target;
+ }
+ }
+ // cur_target can be visited
+ reordered_exprs.insert(
+ reordered_exprs.end(),
+ computed_at_exprs.at(cur_target).begin(),
+ computed_at_exprs.at(cur_target).end());
+ computed_at_exprs.erase(cur_target);
}
+}
- // Open the remaining needed loops
- for (; loop_structure_it != loop_structure.end(); ++loop_structure_it) {
- openFor(*loop_structure_it);
+// Reorder exprs so that LoopNestGenerator::handle(Expr*) can generate
+// correct loop nests. Vector exprs is assumed to be topologically
+// sorted, but that is not sufficient as tensors computed at
+// outer loops need to be located earlier.
+void reorderExprsForComputeAt(std::vector<Expr*>& exprs) {
+ ExprListT reordered_exprs;
+
+ // expr -> target
+ ExprTargetMapT target_map;
+
+ // target -> [computed at expressions]
+ TargetGroupMapT computed_at_exprs;
+
+ // score of each expression that is calculated based on the
+ // computeAt axis. A lower score of an expression means it should be
+ // placed earlier in the expression list. This is a requirement for
+ // the loop-nest generation of this class to work.
+ ExprScoreMapT scores;
+
+ // 1. Group expressions by target tensors. Non-grouped expressions
+ // are copied into reordered_exprs.
+ for (auto& expr : exprs) {
+ groupExpressions(
+ expr, reordered_exprs, target_map, computed_at_exprs, scores);
}
- if (out_tv->getMaxProducerPosition() == 0) {
- max_close = -1;
- } else {
- auto produce_at_id = loop_structure[out_tv->getMaxProducerPosition() - 1];
- auto max_close_loop = std::find_if(
- for_loops_.begin(),
- for_loops_.end(),
- [&produce_at_id, &gpu_lower](kir::ForLoop* fl) {
- auto produce_at_lowered_it =
- gpu_lower->lowerValue(produce_at_id)->as<kir::IterDomain>();
- return gpu_lower->caParallelMap().areMapped(
- produce_at_lowered_it, fl->iter_domain());
- });
+ sanityCheck(exprs, reordered_exprs, scores, target_map, computed_at_exprs);
- max_close = std::distance(max_close_loop, for_loops_.end());
- max_close = max_close > 0 ? max_close - 1 : max_close;
+ // If no computeAt found, no need to reorder.
+ if (computed_at_exprs.size() == 0) {
+ return;
}
- pushFront(gpu_lower->lowerExpr(expr));
+
+ // 2. Sort each loop-nest group based on axis (i.e., score)
+ for (auto& group : computed_at_exprs) {
+ sortGroup(group.second, scores);
+
+ // Reorder expressions in a breadth-first order
+ reorderGroupBreadthFirst(group.second, scores);
+ }
+
+ // 3. Merge non-root loop-nests into root loop-nests
+ mergeNonRootGroupsIntoRootGroups(computed_at_exprs, target_map);
+
+ // At this point, only root loop-nests (i.e., no computeAt'ed)
+ // should exist.
+ for (auto& group : computed_at_exprs) {
+ // Guarantee only root loop-nests exist.
+ TensorView* target = group.first;
+ TORCH_INTERNAL_ASSERT(!target->hasComputeAt());
+ }
+
+ sanityCheck(exprs, reordered_exprs, scores, target_map, computed_at_exprs);
+
+ mergeGroupsIntoSortedList(computed_at_exprs, reordered_exprs);
+
+ // Reordering completed. Reordered exprs exist in reordered_exprs.
+
+ TORCH_INTERNAL_ASSERT(exprs.size() == reordered_exprs.size());
+ exprs = std::move(reordered_exprs);
}
-// Generate the loop nest structure and place it in lowered_exprs_
+} // namespace
+
+// Generate the loop nest structure and place it in lowered_exprs
void LoopNestGenerator::generate(const std::vector<Expr*>& exprs) {
- TORCH_INTERNAL_ASSERT(lowered_exprs_.empty());
+ FusionGuard fg(fusion_);
+
+ // Identify all shared memory TensorViews
+ // Insert into shared_memory map <tv, modify status>
+ for (auto v : fusion_->vals()) {
+ if (v->getValType().value() == ValType::TensorView) {
+ if (v->as<TensorView>()->getMemoryType() == MemoryType::Shared) {
+ smem_.insert({v, false});
+ }
+ }
+ }
+
+ // Initialize members of the class
+ lowered_exprs = std::vector<Expr*>();
+
+ auto reordered = exprs;
+ reorderExprsForComputeAt(reordered);
+
+ for (auto* expr : reordered) {
+ handle(expr);
+ }
+
+ // Insert Dynamic Shared Memory at beginning of kernel
+ for (auto smem_alloc : dynamic_smem_) {
+ lowered_exprs.insert(lowered_exprs.begin(), smem_alloc);
+ }
+}
+
+void LoopNestGenerator::cleanSharedMemory() {
+ for (auto& item : smem_) {
+ item.second = false;
+ }
+}
+
+void LoopNestGenerator::modifySharedMemory(Val* key) {
+ auto it = smem_.find(key);
+ if (it != smem_.end()) {
+ it->second = true;
+ }
+}
- // Process the carefully ordered expressions
- for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) {
- handle(*it);
+bool LoopNestGenerator::isModifiedSharedMemory(Val* key) const {
+ auto it = smem_.find(key);
+ if (it != smem_.end()) {
+ return it->second;
}
+ return false;
}
} // namespace cuda
-
#pragma once
-
#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
namespace fuser {
namespace cuda {
-//! Loop nest generator pass will get IR that looks something like:
-//! T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...* for( i :
-//! I0o{ceil(I0/4)} ) { and will generate the loop nest structure for these
-//! exprs like:
-//!
-//! for( i : I0o{ceil(I0/4)} ) {
-//! for( j : I1o{ceil(I1/128)} ) {
-//! for( k : I0i{4} )
-//! for( l : I1i{128} )
-//! T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
-//!
-//! It does not generate predicates, but it will generate allocations, and loop
-//! nests to initialize reduction buffers.
-class TORCH_CUDA_CU_API LoopNestGenerator {
+/*
+ * Loop nest generator pass will get IR that looks something like:
+ * T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...* for( i :
+ * I0o{ceil(I0/4)} ) { and will generate the loop nest structure for these exprs
+ * like:
+ *
+ * for( i : I0o{ceil(I0/4)} ) {
+ * for( j : I1o{ceil(I1/128)} ) {
+ * for( k : I0i{4} )
+ * for( l : I1i{128} )
+ * T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
+ *
+ * It does not generate predicates, but it will generate allocations, and loop
+ * nests to initialize reduction buffers.
+ *
+ */
+class TORCH_CUDA_CU_API LoopNestGenerator : public OptOutDispatch {
public:
- static std::vector<kir::Expr*> loweredExprs(const std::vector<Expr*>& exprs);
+ static std::vector<Expr*> loweredExprs(
+ Fusion* fusion,
+ ThreadPredicateMap& thread_predicates,
+ const std::vector<Expr*>& exprs) {
+ FUSER_PERF_SCOPE("LoopNestGenerator::loweredExprs");
+ LoopNestGenerator generator(fusion, thread_predicates, exprs);
+ return generator.lowered_exprs;
+ }
private:
- LoopNestGenerator(const std::vector<Expr*>& exprs);
+ LoopNestGenerator(
+ Fusion* fusion,
+ ThreadPredicateMap& thread_predicates,
+ const std::vector<Expr*>& exprs);
+
+ // Create the allocation for tv, place it inside the loop associated with
+ // alloc_id, return the node
+ Expr* pushAlloc(TensorView*);
+
+ // Fusion shared_memory values
+ // Tracks if shared memory is modified
+ std::unordered_map<Val*, bool> smem_;
+
+ // Track dynamic shared memory buffers
+ // Insert allocation at the beginning of the kernel
+ std::deque<kir::Allocate*> dynamic_smem_;
+
+ // Clear the modify status for all shared memory buffers
+ void cleanSharedMemory();
+
+ // Toggle modify status for this shared memory buffer
+ void modifySharedMemory(Val* key);
+
+ // Return the status of the shared memory buffer
+ // False if TensorView is not shared memory buffer
+ bool isModifiedSharedMemory(Val* key) const;
// Open a new inner most for loop, track which TV it was constructed from
// according to the computeAt chain.
- void openFor(IterDomain*);
+ void openFor(std::pair<IterDomain*, TensorView*>);
// Close the inner most for loop
- void closeFor();
+ void popFor();
- // Appends an expression to the current scope
- void pushFront(kir::Expr* expr);
+ // Wrap pushBack in lower_utils if active_scope is null we want it to go
+ // straight to lower_exprs
+ void pushBack(Expr*);
- void handle(Expr* expr);
+ // Initialize a buffer to init_val. If this buffer is in smem or registers,
+ // pass in its allocation statement so we can make sure that we insert this
+ // initialization after the allocation.
+ void initReduction(TensorView* tv, Val* init_val, Expr* alloc_expr = nullptr);
- // Run the pass and accumulate output in lowered_exprs_
+ // Check if expr is a TV op and handle accordingly.
+ void handle(Expr*) final;
+
+ // Run the pass and accumulate output in lowered_exprs
void generate(const std::vector<Expr*>& exprs);
private:
+ // Track number of allocations in each for loop. It is used to insert
+ // allocations in the correct order, which is necessary for memory aliasing
+ std::unordered_map<kir::ForLoop*, size_t> for_loop_allocations_;
+
// Lowered exprs to return
- std::vector<kir::Expr*> lowered_exprs_;
+ std::vector<Expr*> lowered_exprs;
+
+ // Fusion pointer for convenience
+ Fusion* fusion_;
// Keep all for loops conveniently to make unrolling easier, basically just a
// stack of the active for_loops
- std::vector<kir::ForLoop*> for_loops_;
+ std::vector<kir::ForLoop*> for_loops;
+
+ // Track the active computeAt scope, and what view we're "computeAt-ing" into
+ std::vector<std::pair<IterDomain*, TensorView*>> compute_at_scope;
+
+ // Predicates from ThreadPredicates that we will extend to reduction buffer
+ // initialization
+ ThreadPredicateMap& thread_predicates_;
- // How many loops can the next iteration close
- std::ptrdiff_t max_close = -1;
+ // Kernel IR builder
+ kir::IrBuilder ir_builder_;
};
} // namespace cuda
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
-
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-class MagicZeroInserter : public kir::MutableIrVisitor {
- public:
- static std::vector<kir::Expr*> insert(const std::vector<kir::Expr*>& exprs) {
- MagicZeroInserter inserter(exprs);
- return inserter.loop_nests_;
- }
-
- private:
- struct InsertionInfo {
- kir::Scope* scope = nullptr;
- kir::ForLoop* fl = nullptr;
- };
-
- MagicZeroInserter(const std::vector<kir::Expr*>& exprs)
- : loop_nests_(exprs), ir_builder(GpuLower::current()->kernel()) {
- loop_nests_.insert(
- loop_nests_.begin(), ir_builder.create<kir::InitMagicZero>());
- for (auto expr : exprs) {
- handle(expr);
- }
- insertAll();
- }
-
- void handle(kir::Expr* expr) {
- if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
- handle(ite);
- } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
- handle(for_loop);
- }
- }
-
- void handle(kir::IfThenElse* ite) {
- scope_nest_.push_back(&ite->thenBody());
- for (auto expr : ite->thenBody().exprs()) {
- handle(expr);
- }
- scope_nest_.pop_back();
- scope_nest_.push_back(&ite->elseBody());
- for (auto expr : ite->elseBody().exprs()) {
- handle(expr);
- }
- scope_nest_.pop_back();
- }
-
- void handle(kir::ForLoop* fl) {
- if (fl->isUnrollable()) {
- kir::Scope* scope = nullptr;
- if (!scope_nest_.empty()) {
- scope = scope_nest_.back();
- }
- insertion_list_.push_back({scope, fl});
- } else {
- scope_nest_.push_back(&fl->body());
- for (auto expr : fl->body().exprs()) {
- handle(expr);
- }
- scope_nest_.pop_back();
- }
- }
-
- void insertAll() {
- for (const auto& info : insertion_list_) {
- auto fl = info.fl;
- auto scope = info.scope;
- if (scope == nullptr) {
- // place in global scope
- auto loop_it = std::find(loop_nests_.begin(), loop_nests_.end(), fl);
- TORCH_INTERNAL_ASSERT(loop_it != loop_nests_.end());
- // Place after the loop
- loop_it++;
- loop_nests_.insert(loop_it, ir_builder.create<kir::UpdateMagicZero>());
- } else {
- scope->insert_after(fl, ir_builder.create<kir::UpdateMagicZero>());
- }
- }
- }
-
- //! Keep track for loop structure
- std::vector<kir::Scope*> scope_nest_;
-
- // Keep a copy of the expressions provided
- std::vector<kir::Expr*> loop_nests_;
-
- kir::IrBuilder ir_builder;
-
- std::vector<InsertionInfo> insertion_list_;
-};
-
-} // namespace
-
-std::vector<kir::Expr*> insertMagicZero(const std::vector<kir::Expr*>& exprs) {
- FUSER_PERF_SCOPE("GpuLower::Lower::insertMagicZero");
- // Check if magic zero was even used, if not we don't have to define it or
- // update it.
- bool has_magic_zero = false;
- const auto gpu_lower = GpuLower::current();
- auto kernel = gpu_lower->kernel();
- for (auto& val : kernel->irNodes()) {
- if (val->isA<kir::NamedScalar>()) {
- auto named_scalar = val->as<kir::NamedScalar>();
- if (named_scalar->dtype() == DataType::Int &&
- named_scalar->name() == "nvfuser_zero") {
- has_magic_zero = true;
- break;
- }
- }
- }
-
- if (!has_magic_zero) {
- return exprs;
- }
-
- return MagicZeroInserter::insert(exprs);
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Insert magic zero definition at the begining of the kernel. Insert magic
-//! zero update after every (outer most) loop nest with a compile time extent.
-//!
-//! This will make sure nvrtc does not aggressively save predicate and indices.
-std::vector<kir::Expr*> insertMagicZero(const std::vector<kir::Expr*>& exprs);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h>
-
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-class MisalignedVectorizationModifier {
- public:
- void process(const std::vector<kir::Expr*>& exprs) {
- FUSER_PERF_SCOPE(
- "GpuLower::Lower::MisalignedVectorizationModifier::process");
- // Run through loop nests
- // Find for-loops with misaligned vectorization domains
- for (auto* expr : exprs) {
- handle(expr);
- }
- }
-
- const std::unordered_map<kir::Expr*, kir::Expr*>& replacementMap() const {
- return expr_replacement_map_;
- }
-
- private:
- void handle(kir::Expr* expr) {
- if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
- handle(for_loop);
- } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
- handle(ite);
- }
- }
-
- void handle(kir::ForLoop* fl) {
- for_loops_structure_.push_back(fl);
-
- // Make copy of exprs because we replace them inplace in fl
- const auto exprs_copy = fl->body().exprs();
-
- if (containsAnyDirectChildMisalignedVectorize(fl)) {
- auto new_fl = handleMisalignedVectorize(for_loops_structure_, fl);
- expr_replacement_map_.insert({fl, new_fl});
- } else {
- for (auto expr : exprs_copy) {
- handle(expr);
- }
- }
-
- for_loops_structure_.pop_back();
- }
-
- void handle(kir::IfThenElse* ite) {
- for (auto expr : ite->thenBody().exprs()) {
- handle(expr);
- }
- for (auto expr : ite->elseBody().exprs()) {
- handle(expr);
- }
- }
-
- struct ReferenceTensors {
- // Input TensorView to Vectorize Set operation
- kir::TensorView* in_tv = nullptr;
- // Output TensorView to Vectorize Set operation
- kir::TensorView* out_tv = nullptr;
- // TensorView in global memory
- kir::TensorView* global_tv = nullptr;
- // TensorView with vectorize IterDomain and not in global memory
- kir::TensorView* vec_tv = nullptr;
- };
-
- ReferenceTensors getReferenceTensors(kir::Expr* vectorized_expr) {
- TORCH_INTERNAL_ASSERT(vectorized_expr != nullptr);
- TORCH_INTERNAL_ASSERT(
- vectorized_expr->outputs().front()->isA<kir::TensorView>());
- TORCH_INTERNAL_ASSERT(
- vectorized_expr->inputs().front()->isA<kir::TensorView>());
-
- auto in_tv = vectorized_expr->inputs().front()->as<kir::TensorView>();
- auto out_tv = vectorized_expr->outputs().front()->as<kir::TensorView>();
-
- const bool global_vectorize_write_op =
- (out_tv->memoryType() == MemoryType::Global &&
- in_tv->memoryType() == MemoryType::Local);
- const bool global_vectorize_read_op =
- (out_tv->memoryType() == MemoryType::Local &&
- in_tv->memoryType() == MemoryType::Global);
- TORCH_INTERNAL_ASSERT(
- global_vectorize_write_op || global_vectorize_read_op,
- "Unsupported vectorize memory configuration detected.");
-
- // TensorView on global memory. This is the tensor that may have
- // a non-aligned base address.
- auto global_tv =
- (out_tv->memoryType() == MemoryType::Global) ? out_tv : in_tv;
-
- // TensorView with the misaligned vec iterDomain. It is the consumer
- // of vectorized load or the producer of vectorized store. It is
- // assumed that when the output TV is not on global memory, this
- // expression is a vectorized load, so the output TV is vec_tv.
- auto vec_tv = (out_tv->memoryType() != MemoryType::Global) ? out_tv : in_tv;
-
- return {in_tv, out_tv, global_tv, vec_tv};
- }
-
- struct VectorizeData {
- kir::Val* vector_size = nullptr;
- kir::Val* shift = nullptr;
- kir::Val* extent = nullptr;
- kir::Val* remainder = nullptr;
- kir::Val* extent_minus_remainder = nullptr;
- kir::Val* last_root_domain_index = nullptr;
- kir::Val* last_root_domain_index_shift = nullptr;
- };
-
- // Create constants for handling misaligned addresses
- VectorizeData createVectorizeConstants(
- const std::vector<kir::ForLoop*>& for_loop_structure,
- const ReferenceTensors& tensors,
- kir::IfThenElse* parent_scope_ite) {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-
- // Generate vectorize index
- auto indices = (tensors.out_tv->memoryType() == MemoryType::Global)
- ? Index::getConsumerStridedIndices(
- tensors.out_tv->fuserTv(), for_loop_structure)
- : Index::getProducerStridedIndices(
- tensors.in_tv->fuserTv(),
- tensors.out_tv->fuserTv(),
- for_loop_structure);
-
- // >>>>>>>>>>>>>
- // Number of elements in vectorize access
- auto vector_size =
- tensors.vec_tv->domain()->domain().back()->extent()->as<kir::Int>();
-
- // Size of memory type for the elements
- kir::Int* data_size_in_bytes =
- ir_builder.create<kir::Int>(dataTypeSize(tensors.vec_tv->dtype()));
-
- // The number of bytes in the vectorize access
- auto vector_size_in_bytes =
- ir_builder.mulExpr(vector_size, data_size_in_bytes);
-
- auto index = ir_builder.create<kir::TensorIndex>(
- tensors.global_tv->fuserTv(), indices);
- auto address = createNamedScalarFromValue(
- parent_scope_ite->thenBody(), index, "address", true);
-
- // offset_size = (address % vector_size_bytes) / data_type_size_bytes
- // shift_init = vector_size - offset_size
- auto a = ir_builder.modExpr(address, vector_size_in_bytes);
- auto b = ir_builder.divExpr(a, data_size_in_bytes);
- auto c = ir_builder.subExpr(vector_size, b);
- auto shift_init = createNamedScalarFromValue(
- parent_scope_ite->thenBody(), c, "shift_val");
-
- // shift = (shift_init == vector_size) ? 0 : shift_init
- // The number of elements until the first aligned address
- auto shift_pred = ir_builder.eqExpr(shift_init, vector_size);
- auto shift_val =
- ir_builder.whereExpr(shift_pred, ir_builder.zeroVal(), shift_init);
-
- // >>>>>>>>>>>>>
- auto shift = createNamedScalarFromValue(
- parent_scope_ite->thenBody(), shift_val, "shift");
-
- // >>>>>>>>>>>>>
- // Get full extent for the inner-most, merged root domain
- auto extent = getVectorizeExtent(tensors.in_tv, tensors.out_tv);
-
- // remainder = (extent - shift) % vector_size
- // The number of elements remaining not accessed by vectorized operations
- auto remaining_extent = ir_builder.subExpr(extent, shift);
- auto remainder_val = ir_builder.modExpr(remaining_extent, vector_size);
- auto remainder = createNamedScalarFromValue(
- parent_scope_ite->thenBody(), remainder_val, "remainder");
-
- // (extent - remainder) is the upper-bound for the vectorize section
- auto extent_remainder_val = ir_builder.subExpr(extent, remainder);
-
- // >>>>>>>>>>>>>
- auto extent_minus_remainder = createNamedScalarFromValue(
- parent_scope_ite->thenBody(),
- extent_remainder_val,
- "extent_minus_remainder");
-
- // >>>>>>>>>>>>>
- auto last_root_domain_index = createNamedScalarFromValue(
- parent_scope_ite->thenBody(), indices.back(), "last_root_domain_index");
-
- // >>>>>>>>>>>>>
- auto last_root_domain_index_shift =
- ir_builder.addExpr(last_root_domain_index, shift);
-
- return {
- vector_size,
- shift,
- extent,
- remainder,
- extent_minus_remainder,
- last_root_domain_index,
- last_root_domain_index_shift};
- }
-
- // Vectorized : [shift - (extent-remainder))
- // From the first to the last aligned address
- kir::IfThenElse* createVectorizeSection(
- const std::vector<kir::ForLoop*>& child_loops,
- const VectorizeData& params) {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-
- auto vectorized_child_loops =
- cloneForLoops(child_loops, params.vector_size, true, params.shift);
-
- // Vectorize Range: [shift - (extent-remainder))
- // (last_root_domain_index + shift) < (extent - remainder)
- kir::Val* vectorize_cond = ir_builder.ltExpr(
- params.last_root_domain_index_shift, params.extent_minus_remainder);
-
- kir::Predicate* vectorize_pred =
- ir_builder.create<kir::Predicate>(vectorize_cond->as<kir::Bool>());
- kir::IfThenElse* vectorize_ite =
- ir_builder.create<kir::IfThenElse>(vectorize_pred);
-
- for (auto cloned_loop : vectorized_child_loops) {
- vectorize_ite->thenBody().push_back(cloned_loop);
- }
-
- return vectorize_ite;
- }
-
- // Initial : [0 - shift)
- // From the initial address until the first aligned address
- kir::IfThenElse* createInitialSection(
- const std::vector<kir::ForLoop*>& child_loops,
- const VectorizeData& params) {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-
- auto pre_child_loops =
- cloneForLoops(child_loops, params.shift, false, nullptr);
-
- // Initial Range: [0 - shift)
- // last_root_domain_index == 0
- kir::Val* initial_cond =
- ir_builder.eqExpr(params.last_root_domain_index, ir_builder.zeroVal());
-
- kir::Predicate* initial_pred =
- ir_builder.create<kir::Predicate>(initial_cond->as<kir::Bool>());
- kir::IfThenElse* initial_ite =
- ir_builder.create<kir::IfThenElse>(initial_pred);
-
- for (auto cloned_loop : pre_child_loops) {
- initial_ite->thenBody().push_back(cloned_loop);
- }
-
- return initial_ite;
- }
-
- // Remainder : [(extent-remainder) - extent)
- // From the last aligned address until the end of the extent
- kir::IfThenElse* createRemainderSection(
- const std::vector<kir::ForLoop*>& child_loops,
- const VectorizeData& params) {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-
- auto post_child_loops =
- cloneForLoops(child_loops, params.remainder, false, params.shift);
-
- // Remainder Range: [(extent-remainder) - extent)
- // (extent - remainder) <= last_root_domain_index + shift < extent
- kir::Val* lower_bound = ir_builder.geExpr(
- params.last_root_domain_index_shift, params.extent_minus_remainder);
- kir::Val* upper_bound =
- ir_builder.ltExpr(params.last_root_domain_index_shift, params.extent);
- kir::Val* remainder_cond = ir_builder.andExpr(lower_bound, upper_bound);
-
- kir::Predicate* remainder_pred =
- ir_builder.create<kir::Predicate>(remainder_cond->as<kir::Bool>());
- kir::IfThenElse* remainder_ite =
- ir_builder.create<kir::IfThenElse>(remainder_pred);
-
- for (auto cloned_loop : post_child_loops) {
- remainder_ite->thenBody().push_back(cloned_loop);
- }
-
- return remainder_ite;
- }
-
- kir::ForLoop* handleMisalignedVectorize(
- std::vector<kir::ForLoop*> for_loop_structure,
- const kir::ForLoop* parent_for_loop) {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
-
- auto child_loops = findChildForLoops(parent_for_loop);
-
- // Assumption: All vectorize operations have the same shift
- auto vectorized_expr =
- findFirstVectorizedSetOp(for_loop_structure, child_loops);
- TORCH_INTERNAL_ASSERT(vectorized_expr != nullptr);
-
- auto reference_tensors = getReferenceTensors(vectorized_expr);
-
- // The parent_for_loop contains allocate, read, compute, write operations
- const auto new_parent_for_loop =
- ir_builder.create<kir::ForLoop>(parent_for_loop);
-
- // Transfer all expressions except for-loops to new parent for-loop
- // All expressions are placed at the beginning of the new for-loop
- moveExprsExceptForLoops(parent_for_loop, new_parent_for_loop);
-
- // Get the predicate for all but the last root domain
- auto pred_except_last_root_domain = ir_builder.create<kir::Predicate>(
- PredicateType::Misaligned, vectorized_expr, ir_builder.trueVal());
- kir::IfThenElse* pred_ite =
- ir_builder.create<kir::IfThenElse>(pred_except_last_root_domain);
- new_parent_for_loop->body().push_back(pred_ite);
-
- auto constants = createVectorizeConstants(
- for_loop_structure, reference_tensors, pred_ite);
-
- // The last root domain is divided into three sections.
- // | Initial - N/A Shift | Vectorize - Shift | Remainder - Shift |
-
- // Vectorized set operation with vectorize shift
- auto vectorize_ite = createVectorizeSection(child_loops, constants);
- pred_ite->thenBody().push_back(vectorize_ite);
-
- // Standard set operation without vectorize shift
- auto initial_ite = createInitialSection(child_loops, constants);
- pred_ite->thenBody().push_back(initial_ite);
-
- // Standard set operation with vectorize shift
- auto remainder_ite = createRemainderSection(child_loops, constants);
- pred_ite->thenBody().push_back(remainder_ite);
-
- return new_parent_for_loop;
- }
-
- // Determine that the expression is UnaryOpType::Set AND
- // the output TensorView domain is vectorized
- bool isVectorizeSetOp(kir::ForLoop* fl, kir::Expr* expr) {
- if (fl->iter_domain()->parallelType() !=
- ParallelType::MisalignedVectorize) {
- return false;
- }
-
- if (expr->isA<kir::UnaryOp>()) {
- auto unaryOp = expr->as<kir::UnaryOp>();
- if (unaryOp->out()->isA<kir::TensorView>()) {
- auto out_tv = unaryOp->out()->as<kir::TensorView>();
- return unaryOp->operation() == UnaryOpType::Set &&
- out_tv->domain()->hasVectorize();
- }
- }
- return false;
- }
-
- // Clone each for loop
- // stop value - for (index = start; index < stop; index += step)
- // vectorize flag - Do not generate for loop header
- // shift value - Add shift to global indices generated within for loop
- std::vector<kir::ForLoop*> cloneForLoops(
- const std::vector<kir::ForLoop*>& for_loops,
- kir::Val* stop,
- bool vectorize,
- kir::Val* vectorize_shift) {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- std::vector<kir::ForLoop*> cloned_for_loops;
-
- for (auto fl : for_loops) {
- auto first_expr = fl->body().exprs().front();
- bool has_vectorize_op = isVectorizeSetOp(fl, first_expr);
-
- // If the for loop contains a vectorize Set operation, then
- // it should only contain a single expression
- TORCH_INTERNAL_ASSERT(
- !has_vectorize_op || fl->body().exprs().size() == 1);
-
- const auto new_loop = ir_builder.create<kir::ForLoop>(
- fl->iter_domain(),
- fl->index(),
- ir_builder.zeroVal(),
- stop,
- ir_builder.oneVal(),
- vectorize && has_vectorize_op,
- vectorize_shift);
-
- for (auto expr : fl->body().exprs()) {
- new_loop->body().push_back(expr);
- }
-
- cloned_for_loops.push_back(new_loop);
- }
- return cloned_for_loops;
- }
-
- // Add all expressions except for loops to new parent for loop
- void moveExprsExceptForLoops(
- const kir::ForLoop* for_loop,
- kir::ForLoop* new_loop) {
- std::vector<kir::ForLoop*> loops;
- for (auto expr : for_loop->body().exprs()) {
- if (!expr->isA<kir::ForLoop>()) {
- new_loop->body().push_back(expr);
- }
- }
- }
-
- // Find any child for loops inside parent for loop
- std::vector<kir::ForLoop*> findChildForLoops(const kir::ForLoop* for_loop) {
- std::vector<kir::ForLoop*> loops;
- for (auto expr : for_loop->body().exprs()) {
- if (auto nested_for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
- loops.push_back(nested_for_loop);
- }
- }
- return loops;
- }
-
- // Find the first vectorize set - either read or write
- // Add child For-Loop to for_loop_structure
- // Enable vectorize flag in child For-Loop
- kir::Expr* findFirstVectorizedSetOp(
- std::vector<kir::ForLoop*>& for_loop_structure,
- const std::vector<kir::ForLoop*>& for_loops) {
- for (auto fl : for_loops) {
- auto first_expr = fl->body().exprs().front();
- bool has_vectorize_op = isVectorizeSetOp(fl, first_expr);
- if (has_vectorize_op) {
- for_loop_structure.push_back(fl);
- return first_expr;
- }
- }
- return nullptr;
- }
-
- // Get full extent for the inner-most, merged root domain
- kir::Val* getVectorizeExtent(
- kir::TensorView* producer_tv,
- kir::TensorView* consumer_tv) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- auto consumer_fuser_tv = consumer_tv->fuserTv();
- auto producer_fuser_tv = producer_tv->fuserTv();
-
- auto p2c =
- PairwiseRootDomainMap(producer_fuser_tv, consumer_fuser_tv)
- .mapProducerToConsumer(
- producer_fuser_tv->domain(), consumer_fuser_tv->domain());
-
- auto consumer_root_right_of_ca_domains = IterVisitor::getInputsTo(
- {consumer_fuser_tv->domain()->domain().begin() +
- consumer_fuser_tv->getComputeAtPosition(),
- consumer_fuser_tv->domain()->domain().end()});
- auto producer_root_right_of_ca_domains = IterVisitor::getInputsTo(
- {producer_fuser_tv->domain()->domain().begin() +
- producer_fuser_tv->getComputeAtPosition(),
- producer_fuser_tv->domain()->domain().end()});
-
- const auto& consumer_contig = consumer_fuser_tv->domain()->contiguity();
- const auto& producer_contig = producer_fuser_tv->domain()->contiguity();
-
- // No rfactor should exist in the producer TVs
- TORCH_INTERNAL_ASSERT(
- !producer_tv->domain()->hasRFactor(),
- "Invalid producer tensor: ",
- producer_fuser_tv);
- auto producer_root_domain = producer_fuser_tv->getRootDomain();
-
- // Calculate extent of merged root domains
- kir::Val* extent = nullptr;
- auto consumer_root_idx = int(consumer_fuser_tv->getRootDomain().size()) - 1;
- for (int i = int(producer_root_domain.size()) - 1; i >= 0; --i) {
- auto producer_root_id = producer_root_domain.at(i);
-
- TORCH_INTERNAL_ASSERT(
- !gpu_lower->trivialReductionInfo().isDerived(producer_root_id),
- "No trivial reduciton axis should exist: ",
- producer_root_id);
-
- // If the producer ID is reduction or broadcast, it should be safe
- // to ignore.
- if (producer_root_id->isReduction()) {
- continue;
- } else if (producer_root_id->isBroadcast()) {
- --consumer_root_idx;
- continue;
- }
-
- // There must be a matching consumer root ID as the producer ID is
- // not reduction and the expression between them is UnaryOpType::Set.
- auto it = p2c.find(producer_root_id);
- TORCH_INTERNAL_ASSERT(
- it != p2c.end(), "No matching consumer root ID found");
- auto consumer_root_id = it->second;
-
- // Don't extend the vectorization domain beyond the CA position
- if (std::find(
- consumer_root_right_of_ca_domains.begin(),
- consumer_root_right_of_ca_domains.end(),
- consumer_root_id) == consumer_root_right_of_ca_domains.end() ||
- std::find(
- producer_root_right_of_ca_domains.begin(),
- producer_root_right_of_ca_domains.end(),
- producer_root_id) == producer_root_right_of_ca_domains.end()) {
- break;
- }
-
- // We now know it's safe to extend the vectorization domain to these
- // axes. It shouldn't matter whether producer or consumer is used.
- auto consumer_extent = gpu_lower->lowerValue(consumer_root_id->extent());
- if (extent == nullptr) {
- extent = consumer_extent;
- } else {
- extent = ir_builder.mulExpr(extent, consumer_extent);
- }
-
- // If it's not contiguous, extending the vectorization domain
- // further is not possible
- if (!(producer_contig.at(i) && consumer_contig.at(consumer_root_idx))) {
- break;
- }
-
- --consumer_root_idx;
- }
-
- TORCH_INTERNAL_ASSERT(extent != nullptr);
-
- return extent;
- }
-
- kir::Val* createNamedScalarFromValue(
- kir::Scope& body,
- kir::Val* val,
- const std::string& name,
- bool address = false) {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- auto namedScalar = (address) ? ir_builder.addressExprNamedScalar(name, val)
- : ir_builder.setExprNamedScalar(name, val);
- TORCH_INTERNAL_ASSERT(namedScalar->definition() != nullptr);
-
- auto alloc = ir_builder.create<kir::Allocate>(
- namedScalar, MemoryType::Local, ir_builder.oneVal());
- body.push_back(alloc);
- body.push_back(namedScalar->definition());
- return namedScalar;
- }
-
- private:
- // We will track which loops in the incoming IR will be replaced and by what
- std::unordered_map<kir::Expr*, kir::Expr*> expr_replacement_map_;
-
- // A depth-first ordering of nested for loops
- // It is used for indexing and predicate generation
- std::vector<kir::ForLoop*> for_loops_structure_;
-};
-
-} // namespace
-
-std::vector<kir::Expr*> processMisalignedVectorization(
- Fusion* fusion,
- const std::vector<kir::Expr*>& exprs) {
- FUSER_PERF_SCOPE("GpuLower::Lower::processMisalignedVectorization");
-
- MisalignedVectorizationModifier mvm;
- mvm.process(exprs);
-
- std::vector<kir::Expr*> mutated_exprs;
- mutated_exprs.reserve(exprs.size());
- for (auto expr : exprs) {
- mutated_exprs.push_back(
- ir_utils::applyReplacements(mvm.replacementMap(), expr));
- }
-
- return mutated_exprs;
-}
-
-bool containsAnyDirectChildMisalignedVectorize(const kir::ForLoop* fl) {
- for (auto expr : fl->body().exprs()) {
- if (expr->isA<kir::ForLoop>()) {
- auto child_fl = expr->as<kir::ForLoop>();
- if (child_fl->iter_domain()->parallelType() ==
- ParallelType::MisalignedVectorize) {
- return true;
- }
- }
- }
- return false;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Transform for-loop structure to handle misaligned addresses
-//!
-//! Sections of misaligned addresses are handled sequentially
-//! while aligned addresses use vectorized memory accesses.
-//!
-//! ---------------------------------------------------------------------------
-//! Before Misaligned Vectorization:
-//!
-//! Inputs: T0
-//! Outputs: T3
-//!
-//! for(...) {
-//! T1[vector_size];
-//! for( i : vector_size ) {
-//! T1[i] = T0[...]
-//! }
-//!
-//! T2[vector_size];
-//! for( i : vector_size ) {
-//! T2[i] = unaryOp(T1[i])
-//! }
-//!
-//! for( i : vector_size ) {
-//! T3[...] = T2[i]
-//! }
-//! }
-//!
-//! ---------------------------------------------------------------------------
-//! After Misaligned Vectorization:
-//!
-//! Inputs: T0
-//! Outputs: T3
-//!
-//! for(...) {
-//! T1[vector_size];
-//! T2[vector_size];
-//!
-//! if (inline_predicate_except_last_root_domain) {
-//! index_except_last_root_domain = ...
-//! address = (int64_t) &T1[index_except_last_root_domain]
-//!
-//! offset_size = (address % vector_size_bytes) / data_type_size_bytes
-//! shift_init = vector_size - offset_size
-//! shift = (shift_init == vector_size) ? 0 : shift_init
-//!
-//! // size of the last root domain
-//! extent = ...
-//! remainder = (extent - shift) % vector_size
-//!
-//! last_root_domain_index = ...
-//!
-//! // Vectorize Section
-//! if ( (last_root_domain_index + shift) < (extent - remainder) ) {
-//! T1[0] = vectorize_load( T0[index + shift] );
-//!
-//! for( i : vector_size ) {
-//! T2[i] = unaryOp(T1[i])
-//! }
-//!
-//! T3[index + shift] = vectorize_store( T2[0] );
-//! }
-//!
-//! // Initial Section
-//! if ( last_root_domain_index == 0 ) {
-//! for( i : shift ) {
-//! T1[i] = T0[...]
-//! }
-//!
-//! for( i : shift ) {
-//! T2[i] = unaryOp(T1[i])
-//! }
-//!
-//! for( i : shift ) {
-//! T3[...] = T2[i]
-//! }
-//! }
-//!
-//! // Remainder Section
-//! if ( (last_root_domain_index + shift) >= (extent - remainder) &&
-//! (last_root_domain_index + shift) < extent) {
-//!
-//! for( i : remainder ) {
-//! T1[i] = T0[index + shift]
-//! }
-//!
-//! for( i : remainder ) {
-//! T2[i] = unaryOp(T1[i])
-//! }
-//!
-//! for( i : remainder ) {
-//! T3[index + shift] = T2[i]
-//! }
-//! }
-//! }
-//! }
-//!
-std::vector<kir::Expr*> processMisalignedVectorization(
- Fusion* fusion,
- const std::vector<kir::Expr*>& exprs);
-
-bool containsAnyDirectChildMisalignedVectorize(const kir::ForLoop* fl);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
-
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
-#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-class ConditionalFromPredicateModifier {
- public:
- ConditionalFromPredicateModifier(const std::vector<kir::Expr*>& exprs) {
- FUSER_PERF_SCOPE(
- "GpuLower::Lower::ConditionalFromPredicateModifier::process");
- for (auto* expr : exprs) {
- handle(expr);
- }
- }
-
- const std::unordered_map<kir::Expr*, kir::Expr*>& replacementMap() const {
- return expr_replacement_map_;
- }
-
- private:
- void handle(kir::Expr* expr) {
- if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
- handle(for_loop);
- } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
- handle(ite);
- } else if (expr != nullptr && expr->predicate() != nullptr) {
- // Replace expr predicate with bool conditional
- auto conditional = generateConditional(expr->predicate());
- TORCH_INTERNAL_ASSERT(conditional != nullptr);
- expr->predicate()->setValue(conditional);
- TORCH_INTERNAL_ASSERT(expr->predicate()->value() != nullptr);
- setWritePredicate(expr, conditional);
- }
- }
-
- void setWritePredicate(kir::Expr* expr, kir::Bool* read_cond) {
- if (expr->writePredicate() != nullptr) {
- auto write_cond = generateConditional(expr->writePredicate());
- if (write_cond) {
- expr->writePredicate()->setValue(write_cond);
- } else {
- // If generateConditional returns null, it means no specific
- // predicate needs to be used.
- expr->setWritePredicate(nullptr);
- }
- }
- }
-
- void handle(kir::ForLoop* fl) {
- for_loops_structure_.push_back(fl);
-
- const auto exprs_copy = fl->body().exprs();
- for (auto expr : exprs_copy) {
- handle(expr);
- }
-
- for_loops_structure_.pop_back();
- }
-
- void handle(kir::IfThenElse* ite) {
- TORCH_INTERNAL_ASSERT(ite->predicate() != nullptr);
-
- // If ite already has Bool conditional, handle internal expressions
- // Otherwise, generate conditional and update predicate
- if (ite->predicate()->hasValue()) {
- const auto then_exprs_copy = ite->thenBody().exprs();
- for (auto expr : then_exprs_copy) {
- handle(expr);
- }
-
- const auto else_exprs_copy = ite->elseBody().exprs();
- for (auto expr : else_exprs_copy) {
- handle(expr);
- }
- } else {
- auto conditional = generateConditional(ite->predicate());
- TORCH_INTERNAL_ASSERT(conditional != nullptr);
- TORCH_INTERNAL_ASSERT(conditional->isA<kir::Bool>());
-
- // Update bool conditional in-place
- ite->predicate()->setValue(conditional);
- handle(ite);
- TORCH_INTERNAL_ASSERT(ite->predicate()->value() != nullptr);
- }
- }
-
- // Generate conditional according to PredicateType
- kir::Bool* generateConditional(kir::Predicate* pred) {
- switch (pred->predicate_type()) {
- case PredicateType::Inline:
- case PredicateType::ReductionWrite:
- case PredicateType::Misaligned: {
- return PredicateCompute::getInlinePredicate(
- pred->expr(),
- for_loops_structure_,
- pred->thread_pred(),
- pred->predicate_type());
- }
- case PredicateType::Vectorize: {
- std::vector<kir::ForLoop*> outer_loops;
- kir::ForLoop* vectorized_loop = nullptr;
- for (auto loop : for_loops_structure_) {
- if (loop->iter_domain()->parallelType() == ParallelType::Vectorize) {
- vectorized_loop = loop;
- break;
- } else {
- outer_loops.emplace_back(loop);
- }
- }
- TORCH_INTERNAL_ASSERT(
- vectorized_loop != nullptr, "Should be unreachable.");
- return UnswitchPredicate::get(outer_loops, vectorized_loop);
- }
- case PredicateType::Unswitch: {
- return UnswitchPredicate::get(
- for_loops_structure_, pred->unrolled_loop());
- }
- case PredicateType::Shift: {
- kir::TensorView* out_tv = ir_utils::getTVOutput(pred->expr());
- TORCH_INTERNAL_ASSERT(
- out_tv != nullptr, "Missing kir::TensorView output");
- return ShiftPredicateInserter::getPredicate(
- pred->expr(),
- for_loops_structure_,
- out_tv,
- pred->thread_pred(),
- true);
- }
- case PredicateType::Padding: {
- kir::TensorView* out_tv = ir_utils::getTVOutput(pred->expr());
- TORCH_INTERNAL_ASSERT(
- out_tv != nullptr, "Missing kir::TensorView output");
- return ShiftPredicateInserter::getPredicate(
- pred->expr(),
- for_loops_structure_,
- out_tv,
- pred->thread_pred(),
- false);
- }
- case PredicateType::Manual: {
- return pred->value();
- }
- default:
- break;
- }
- return nullptr;
- }
-
- private:
- // We will track which loops in the incoming IR will be replaced and by what
- std::unordered_map<kir::Expr*, kir::Expr*> expr_replacement_map_;
-
- // A depth-first ordering of nested for loops
- // It is used for indexing and predicate generation
- std::vector<kir::ForLoop*> for_loops_structure_;
-};
-
-} // namespace
-
-std::vector<kir::Expr*> generateConditionalFromPredicate(
- Fusion* fusion,
- const std::vector<kir::Expr*>& exprs) {
- FUSER_PERF_SCOPE("GpuLower::Lower::generateConditionalFromPredicate");
-
- ConditionalFromPredicateModifier p2cm(exprs);
-
- std::vector<kir::Expr*> mutated_exprs;
- mutated_exprs.reserve(exprs.size());
- for (auto expr : exprs) {
- mutated_exprs.push_back(
- ir_utils::applyReplacements(p2cm.replacementMap(), expr));
- }
-
- return mutated_exprs;
-}
-
-namespace {
-
-class PredicateAnalyzer : public OptOutDispatch {
- public:
- //! Checks if a predicate is needed to avoid out-of-bound accesses.
- //!
- //! Due to the way we allocate local-memory tensors, there should
- //! never be out-of-bound accesses with consumer tensors when allocated on
- //! local memory. However, accessing producer tensors still may
- //! result in out-of-bound as they are replayed as consumers.
- static bool needsPredicate(TensorView* producer, TensorView* consumer) {
- // Both tensors must be on local memory. Global tensors must be
- // predicated as allocation is done based on root domains. Smem
- // and local tensors are allocated based on leaf domains, however,
- // smem tensors are parallelized, which is highly likely, the size
- // of the parallelized axis is the actual size of the axis, not
- // the number of threads. Since the number of threads can be
- // larger than the axis size, it's not safe to skip predication
- if (!(producer->getMemoryType() == MemoryType::Local &&
- consumer->getMemoryType() == MemoryType::Local)) {
- return true;
- }
-
- auto pairwise_map = PairwiseRootDomainMap(producer, consumer);
- auto c2p =
- BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map)
- .getReplay();
-
- PredicateAnalyzer analyzer(c2p);
-
- for (auto id : consumer->domain()->domain()) {
- if (analyzer.needsPredicate(id)) {
- return true;
- }
- }
-
- return false;
- }
-
- private:
- PredicateAnalyzer(const std::unordered_map<IterDomain*, IterDomain*>& c2p_map)
- : c2p_map_(c2p_map) {}
-
- // Returns true if no out-of-bound accesses could occur with a
- // producer
- bool needsPredicate(IterDomain* consumer_id) {
- needs_predicate_ = false;
- handle(consumer_id);
- return needs_predicate_;
- }
-
- using OptOutDispatch::handle;
-
- void handle(IterDomain* consumer_id) override {
- // The traversal should have ended if needs_predicate_ was true
- TORCH_INTERNAL_ASSERT(!needs_predicate_);
-
- // If consumer_id is not going to be materialized as a loop (e.g.,
- // broadcast), no need to predicate
- const auto gpu_lower = GpuLower::current();
- if (consumer_id->isBroadcast() ||
- gpu_lower->trivialReductionInfo().isDerived(consumer_id)) {
- return;
- }
-
- // If the producer has a matching domain, it should not cause
- // out-of-bound accesses
- if (c2p_map_.find(consumer_id) != c2p_map_.end()) {
- return;
- }
-
- // If no definition exists, stop traversing
- if (consumer_id->definition() == nullptr) {
- return;
- }
-
- handle(consumer_id->definition());
- }
-
- // If it splits the input axis evenly, proceeds to check the input
- // axis. Otherwise, we can't skip predication as it might cause
- // out-bound accesses with the producer tensor
- void handle(Split* split) override {
- auto factor = split->factor()->getInt();
- if (!factor.has_value()) {
- needs_predicate_ = true;
- return;
- }
-
- ExpressionEvaluator ee(split->fusion());
- const auto in_extent = ee.evaluate(split->in()->extent());
-
- if (!in_extent.has_value() || ((in_extent.value() % factor.value()) != 0)) {
- needs_predicate_ = true;
- return;
- }
-
- handle(split->in());
- }
-
- void handle(Merge* merge) override {
- handle(merge->inner());
- if (needs_predicate_) {
- return;
- }
- handle(merge->outer());
- }
-
- private:
- //! BestEffort map from consumer IDs to producer IDs
- const std::unordered_map<IterDomain*, IterDomain*>& c2p_map_;
- bool needs_predicate_ = false;
-};
-
-} // namespace
-
-bool PredicateElimination::needsPredicate(Expr* expr) const {
- if (!ir_utils::isTVOp(expr)) {
- return false;
- }
-
- std::vector<std::function<bool(Expr*)>> filters;
-
- // Always predicate integer division and related ops as we don't
- // know what values are in the out-of-bound region and they may
- // cause exceptions
- filters.push_back([](Expr* expr) {
- auto dt = expr->outputs()[0]->getDataType().value();
- return (
- (dt == DataType::Int || dt == DataType::Int32) &&
- expr->isA<BinaryOp>() &&
- (expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::Div ||
- expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::Mod ||
- expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::Remainder ||
- expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::CeilDiv));
- });
-
- // Skip if MisalignedVectorize is involved for now. This could be
- // relaxed.
- filters.push_back([](Expr* expr) {
- std::vector<const std::vector<Val*>*> inputs_and_outputs = {
- &(expr->inputs()), &(expr->outputs())};
- for (const auto& inputs_or_outputs : inputs_and_outputs) {
- for (auto tv : ir_utils::filterByType<TensorView>(*inputs_or_outputs)) {
- if (std::any_of(
- tv->domain()->domain().begin(),
- tv->domain()->domain().end(),
- [](IterDomain* axis) {
- return axis->getParallelType() ==
- ParallelType::MisalignedVectorize;
- })) {
- return true;
- }
- }
- }
- return false;
- });
-
- // Shift is not supported yet.
- filters.push_back([](Expr* expr) {
- auto& halo_info = GpuLower::current()->haloInfo();
- auto input_tvs = ir_utils::filterByType<TensorView>(expr->inputs());
- return halo_info.needsShiftPredicate(expr) ||
- std::any_of(input_tvs.begin(), input_tvs.end(), [&](auto input_tv) {
- return input_tv->definition() != nullptr &&
- halo_info.needsShiftPredicate(input_tv->definition());
- });
- });
-
- // Predicates the expression if any producer-consumer pair of the
- // expression needs to be predicated
- filters.push_back([](Expr* expr) {
- for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) {
- for (auto input : ir_utils::filterByType<TensorView>(expr->inputs())) {
- if (PredicateAnalyzer::needsPredicate(input, output)) {
- return true;
- }
- }
- }
- return false;
- });
-
- // Predicates Welford ops
- filters.push_back([](Expr* expr) { return expr->isA<WelfordOp>(); });
-
- // If this is a reduction, and if we omit the predicate for the
- // input, the input may have a garbabe value, which must not be used
- // for this reduction. However, if the input is also an output of
- // another reduction with the same binary op, which is a common
- // pattern with rfactor, the input should be safe to use with no
- // predication.
- filters.push_back([this](Expr* expr) {
- if (expr->isA<ReductionOp>()) {
- auto input = expr->inputs()[0]->as<TensorView>();
- auto input_def = input->definition();
- // When input_def is null, input must be an input to the fusion,
- // so that must be allocated on global memory. Since we don't omit
- // predication for expressions involving global memory, this
- // should never occur.
- TORCH_INTERNAL_ASSERT(
- input_def != nullptr, "Inconsistent input found: ", input);
-
- if (non_predicated_exprs_.find(input_def) !=
- non_predicated_exprs_.end() &&
- !(input_def->isA<ReductionOp>() &&
- (expr->as<ReductionOp>()->getReductionOpType() ==
- input_def->as<ReductionOp>()->getReductionOpType()))) {
- return true;
- }
- }
- return false;
- });
-
- // If any of the filters returns true, predicate must be used.
- return std::any_of(filters.begin(), filters.end(), [expr](auto filter) {
- return filter(expr);
- });
-}
-
-void PredicateElimination::handle(Expr* expr) {
- if (!ir_utils::isTVOp(expr)) {
- return;
- }
-
- if (needsPredicate(expr)) {
- return;
- }
-
- non_predicated_exprs_.insert(expr);
-
- // Ensure all inputs have some values set at the out-of-bound
- // regions
- for (auto input : ir_utils::filterByType<TensorView>(expr->inputs())) {
- auto input_def = input->definition();
- // When input_def is null, input must be an input to the fusion,
- // so that must be allocated on global memory. Since we don't omit
- // predication for expressions involving global memory, this
- // should never occur.
- std::stringstream ss;
- ss << input;
- TORCH_INTERNAL_ASSERT(
- input_def != nullptr, "Inconsistent input found: ", ss.str());
-
- // If input is an output of reduction, it should be fully
- // initialied as it's allocated on local memory.
- if (input_def->isA<ReductionOp>() || input_def->isA<WelfordOp>()) {
- continue;
- }
-
- // If this expr is reduction, always initilize the input with the
- // default value. NOTE: This can be done more
- // intelligently. A garbage value can only cause a problem when
- // it's reduced with non-garbage values, so if the non-reduction
- // axes do not have any garbage, it should be just fine without
- // explicit initialization. However, initialization cost should be
- // cheap, so that further optimization should not make a large
- // difference.
- if (expr->isA<ReductionOp>()) {
- setReductionInitValue(input, expr->as<ReductionOp>()->init());
- continue;
- }
-
- // If an input does not need a predicate either, then it should
- // have some value, so no need to set a default value
- if (non_predicated_exprs_.find(input_def) != non_predicated_exprs_.end()) {
- continue;
- }
-
- // Make sure input is initialized
- setDefaultInitValue(input);
- }
-}
-
-bool PredicateElimination::setDefaultInitValue(TensorView* tv) {
- auto it = init_value_map_.find(tv);
- // If there's already a mapping for tv, it should be mapped to a
- // zero val or a reduction init. Either case, no need to modify
- // the existing mapping.
- if (it == init_value_map_.end()) {
- init_value_map_.insert({tv, nullptr});
- }
- return true;
-}
-
-bool PredicateElimination::setReductionInitValue(
- TensorView* tv,
- Val* reduction_init) {
- auto it = init_value_map_.find(tv);
- if (it == init_value_map_.end()) {
- init_value_map_.insert({tv, reduction_init});
- return true;
- }
-
- auto existing_val = it->second;
- if (existing_val == nullptr) {
- // If the existing mapping returns nullptr, it means that a
- // default init was set before. Overwrite with the reduction
- // init val.
- init_value_map_[tv] = reduction_init;
- return true;
- } else if (existing_val->sameAs(reduction_init)) {
- return true;
- } else {
- TORCH_INTERNAL_ASSERT(
- false,
- "Incosistent setting of initialization value for t",
- tv->name(),
- ". Prev: ",
- existing_val,
- ", New: ",
- reduction_init);
- return false;
- }
-}
-
-bool PredicateElimination::canOmitPredicate(const Expr* expr) const {
- TORCH_INTERNAL_ASSERT(expr != nullptr);
- const auto out_tv = ir_utils::getTVOutput(expr);
- TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Not a tensor expression");
- // No need to predicate local tensors to which a scalar is assigned
- if (out_tv->getMemoryType() == MemoryType::Local) {
- if (auto uop = dynamic_cast<const UnaryOp*>(expr)) {
- if (uop->getUnaryOpType() == UnaryOpType::Set && uop->in()->isScalar()) {
- return true;
- }
- }
- }
- if (non_predicated_exprs_.find(expr) != non_predicated_exprs_.end()) {
- return true;
- }
-
- return false;
-}
-
-bool PredicateElimination::canOmitPredicate(const kir::Expr* kir_expr) const {
- TORCH_INTERNAL_ASSERT(kir_expr != nullptr);
- const auto out_tv = ir_utils::getTVOutput(kir_expr);
- TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Not a tensor expression");
- // No need to predicate local tensors to which a scalar is assigned
- if (out_tv->memoryType() == MemoryType::Local) {
- if (auto uop = dynamic_cast<const kir::UnaryOp*>(kir_expr)) {
- if (uop->operation() == UnaryOpType::Set && uop->in()->isScalar()) {
- return true;
- }
- }
- }
- const auto fuser_tv = out_tv->fuserTv();
- if (fuser_tv == nullptr) {
- return false;
- }
- return canOmitPredicate(fuser_tv->definition());
-}
-
-kir::Val* PredicateElimination::getInitValue(TensorView* tv) const {
- auto it = init_value_map_.find(tv);
- if (it == init_value_map_.end()) {
- return nullptr;
- }
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
- auto init_val = it->second;
- if (init_val == nullptr) {
- // No reduction restriction. Just use zero
- return ir_builder.zeroVal();
- } else {
- return gpu_lower->lowerValue(init_val);
- }
-}
-
-void PredicateElimination::build(Fusion* fusion) {
- traverseFrom(fusion, fusion->outputs());
-}
-
-std::string PredicateElimination::toString() const {
- std::stringstream ss;
- ss << "Tensors that do not need predication:";
- for (auto expr : non_predicated_exprs_) {
- for (auto out : expr->outputs()) {
- TORCH_INTERNAL_ASSERT(out->isA<TensorView>());
- ss << " T" << out->name();
- }
- }
- ss << "\n";
- ss << "Init values:";
- for (auto kv : init_value_map_) {
- ss << " T" << kv.first->name() << "->";
- if (kv.second == nullptr) {
- ss << "<default(0)>";
- } else {
- ss << kv.second;
- }
- }
- ss << "\n";
- return ss.str();
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Update predicates with valid bool conditionals
-//!
-std::vector<kir::Expr*> generateConditionalFromPredicate(
- Fusion* fusion,
- const std::vector<kir::Expr*>& exprs);
-
-class TORCH_CUDA_CU_API PredicateElimination : public IterVisitor {
- public:
- void build(Fusion* fusion);
-
- //! True if expr does not need a predicate
- //!
- //! \param expr Tensor expression
- bool canOmitPredicate(const Expr* expr) const;
-
- //! True if expr does not need a predicate
- //!
- //! \param expr KIR tensor expr
- bool canOmitPredicate(const kir::Expr* expr) const;
-
- //! Value to initialize out-of-bound regions
- kir::Val* getInitValue(TensorView* tv) const;
-
- //! Dump to string for debugging
- std::string toString() const;
-
- private:
- using IterVisitor::handle;
-
- void handle(Expr* expr) override;
-
- //! Set a value to initialize out-of-bound regions
- bool setDefaultInitValue(TensorView* tv);
- //! Set a value to initialize out-of-bound regions of reduction tensors
- bool setReductionInitValue(TensorView* tv, Val* reduction_init);
-
- //! Check if expr needs to be predicated
- bool needsPredicate(Expr* expr) const;
-
- private:
- //! Expressions that are found to be safe without predicates
- std::unordered_set<const Expr*> non_predicated_exprs_;
- //! Tensors and their initialization values
- std::unordered_map<TensorView*, Val*> init_value_map_;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-
-#include <functional>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-// utility function
-kir::Bool* makeAndExpr(kir::Val* lhs, kir::Val* rhs) {
- TORCH_INTERNAL_ASSERT(!(lhs == nullptr && rhs == nullptr));
- if (lhs == nullptr) {
- return rhs->as<kir::Bool>();
- } else if (rhs == nullptr) {
- return lhs->as<kir::Bool>();
- } else {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- return ir_builder.andExpr(lhs, rhs)->as<kir::Bool>();
- }
-}
-
-kir::Int* makeAddExpr(kir::Int* lhs, kir::Int::ScalarType rhs) {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- if (rhs == 0) {
- return lhs;
- } else if (lhs == nullptr) {
- return ir_builder.create<kir::Int>(rhs);
- } else if (lhs->isConst()) {
- return ir_builder.create<kir::Int>(lhs->value().value() + rhs);
- } else if (rhs > 0) {
- return ir_builder.addExpr(lhs, ir_builder.create<kir::Int>(rhs))
- ->as<kir::Int>();
- } else {
- return ir_builder.subExpr(lhs, ir_builder.create<kir::Int>(-rhs))
- ->as<kir::Int>();
- }
-}
-
-kir::Int* makeAddExpr(kir::Int* lhs, kir::Int* rhs) {
- if (rhs == nullptr) {
- return lhs;
- } else if (lhs == nullptr) {
- return rhs;
- } else if (lhs->isConst()) {
- return makeAddExpr(rhs, lhs->value().value());
- } else if (rhs->isConst()) {
- return makeAddExpr(lhs, rhs->value().value());
- } else {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- return ir_builder.addExpr(lhs, rhs)->as<kir::Int>();
- }
-}
-
-kir::Val* makeAddExpr(kir::Val* lhs, kir::Val* rhs) {
- TORCH_INTERNAL_ASSERT(lhs != nullptr || rhs != nullptr);
- if (lhs == nullptr || lhs->isZeroInt()) {
- return rhs;
- } else if (rhs == nullptr || rhs->isZeroInt()) {
- return lhs;
- }
- auto lhs_int = dynamic_cast<kir::Int*>(lhs);
- auto rhs_int = dynamic_cast<kir::Int*>(rhs);
- if (lhs_int != nullptr && rhs_int != nullptr) {
- return makeAddExpr(lhs_int, rhs_int);
- } else {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- return ir_builder.addExpr(lhs, rhs);
- }
-}
-
-} // namespace
-
-void ShiftPredicateInserter::insert(
- kir::Expr* expr,
- const std::vector<kir::ForLoop*>& loops,
- kir::Bool* thread_pred) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- kir::TensorView* out_tv = ir_utils::getTVOutput(expr);
- TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing kir::TensorView output");
-
- TensorView* out_fuser_tv = out_tv->fuserTv();
- const bool needs_shift_predicate =
- gpu_lower->haloInfo().needsShiftPredicate(out_fuser_tv->definition());
- if (!needs_shift_predicate) {
- return;
- }
-
- // The conditional branches to create:
- //
- // if (shift_pred) {
- // consumer = producer;
- // } else {
- // if (padding_pred) {
- // consumer = 0;
- // }
- // }
-
- kir::Predicate* shift_pred = ir_builder.create<kir::Predicate>(
- PredicateType::Shift, expr, thread_pred);
-
- // If the expr involves a thread-block barrier, set the predicate of
- // the expre with shift_pred. Since the expr is not shift, the
- // padding should be safe to omit. In fact, padding is probably not
- // necessary for all non-shift exprs (see #877)
- if (ir_utils::hasBlockSync(expr, gpu_lower->threadPredMap())) {
- expr->setPredicate(shift_pred);
- return;
- }
-
- auto shift_ite = ir_builder.create<kir::IfThenElse>(shift_pred);
-
- auto& scope = loops.back()->body();
-
- // Insert the if statement
- scope.insert_before(expr, shift_ite);
-
- // Remove the expr from the list
- scope.erase(expr);
-
- // Place the expr inside the if statement
- shift_ite->thenBody().push_back(expr);
-
- // Padding by zero
- kir::Predicate* padding_pred = ir_builder.create<kir::Predicate>(
- PredicateType::Padding, expr, thread_pred);
- auto bounds_ite = ir_builder.create<kir::IfThenElse>(padding_pred);
- const int pad_value = 0;
- auto pad_expr = ir_builder.create<kir::UnaryOp>(
- UnaryOpType::Set, out_tv, ir_builder.create<kir::Int>(pad_value));
- bounds_ite->thenBody().push_back(pad_expr);
- // Insert the else block
- shift_ite->elseBody().push_back(bounds_ite);
-}
-
-namespace {
-
-kir::Val* getShiftProducerIndex(
- size_t consumer_root_axis,
- kir::Val* consumer_index,
- ShiftOp* shift_expr) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- const int shift_offset =
- (shift_expr != nullptr) ? shift_expr->offset(consumer_root_axis) : 0;
-
- if (shift_offset == 0) {
- return consumer_index;
- } else if (shift_offset > 0) {
- return ir_builder.subExpr(
- consumer_index, ir_builder.create<kir::Int>(shift_offset));
- } else {
- return ir_builder.addExpr(
- consumer_index, ir_builder.create<kir::Int>(-shift_offset));
- }
-}
-
-// Create a producer index by adjusting the corresponding consumer
-// index.
-kir::Val* getGatherProducerIndex(
- size_t consumer_root_axis,
- kir::Val* consumer_index,
- GatherOp* gather_expr,
- const std::vector<kir::Val*>& indices) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- if (gather_expr == nullptr ||
- consumer_root_axis >= gather_expr->windowShape().size() ||
- gather_expr->windowShape()[consumer_root_axis]->isOneInt()) {
- return consumer_index;
- }
-
- // Relative to the consumer index, the producer index needs to
- // account for:
- // - window access
- // - padding at offset 0
- // This adjustment is basically the same as
- // getProducerIndexWithGather in index_compute.cpp.
- // TODO: Refactor shift/gather indexing and predication
- const auto window_axis = gather_expr->gatherAxis(consumer_root_axis);
- TORCH_INTERNAL_ASSERT(window_axis < (int)indices.size());
- auto window_idx = indices[window_axis];
- auto pad_size = gather_expr->padWidth()[consumer_root_axis][0];
- auto producer_index = ir_builder.subExpr(
- ir_builder.addExpr(consumer_index, window_idx),
- ir_builder.create<kir::Int>(pad_size));
- return producer_index;
-}
-
-} // namespace
-
-kir::Bool* ShiftPredicateInserter::getPredicate(
- const kir::Expr* expr,
- const std::vector<kir::ForLoop*>& loops,
- kir::TensorView* out_tv,
- kir::Bool* thread_pred,
- bool isShiftPredicate) {
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- TensorView* out_fuser_tv = out_tv->fuserTv();
-
- const bool needs_shift_predicate =
- gpu_lower->haloInfo().needsShiftPredicate(out_fuser_tv->definition());
- TORCH_INTERNAL_ASSERT(needs_shift_predicate);
-
- const auto& root_domain = out_fuser_tv->getRootDomain();
-
- auto shift_expr = dynamic_cast<ShiftOp*>(out_fuser_tv->definition());
- auto gather_expr = dynamic_cast<GatherOp*>(out_fuser_tv->definition());
-
- // Creates indices at the root domain.
- // Set contiguity of all axes false as separate indices are needed for each
- // root axis.
- // Note: separate indices should be needed only for axes that
- // require shift predication, so other axes could use the actual
- // contiguity information. See a TODO item of issue #877.
- const auto pred_contiguity = std::vector<bool>(root_domain.size(), false);
- auto pred_indices =
- Index::getConsumerRootPredIndices(out_tv, loops, pred_contiguity);
- const auto& indices = pred_indices.first;
- const bool buffer_init = pred_indices.second;
-
- // No predication is needed when the expr is to initialize reduction
- // buffer on local memory
- if (out_tv->memoryType() == MemoryType::Local && buffer_init) {
- return ir_builder.trueVal();
- }
-
- TORCH_INTERNAL_ASSERT(indices.size() == root_domain.size());
-
- kir::Bool* predicate = nullptr;
-
- for (size_t i = 0; i < root_domain.size(); ++i) {
- auto root_id = root_domain[i];
-
- if (root_id->isBroadcast() || (buffer_init && root_id->isReduction()) ||
- gpu_lower->trivialReductionInfo().isDerived(root_id)) {
- continue;
- }
-
- const auto halo_info = gpu_lower->haloInfo().getRootAxisInfo(root_id);
-
- if (isShiftPredicate) {
- // Below, "left" and "right" halo mean halo at offset zero and
- // axis extent, respectively.
- //
- // The consumer axis looks like this:
- //
- // [0, left halo)[0, extent)[0, right halo)
- // ^ ^
- // left limit right limit
- //
- // Accesses outside of the left and right limits are filled by
- // zero. As illustrated above, left limit = left halo, and right
- // limit = left halo + extent.
-
- kir::Val* left_limit = halo_info.width(0);
- kir::Val* right_limit = makeAddExpr(
- out_tv->domain()->rootDomain()[i]->extent(), halo_info.width(0));
-
- kir::Val* consumer_index = indices[i];
- kir::Val* producer_index = nullptr;
-
- if (shift_expr != nullptr) {
- producer_index = getShiftProducerIndex(i, consumer_index, shift_expr);
- } else if (gather_expr != nullptr) {
- producer_index =
- getGatherProducerIndex(i, consumer_index, gather_expr, indices);
- } else {
- producer_index = indices[i];
- }
-
- // If the defining expr is ShiftOp and its offset is positive,
- // consumer access at 0 to the offset corresponds to
- // out-of-bound producer access unless the producer has halo as
- // well. For now, always add predication assuming no halo on the
- // producer. This should be reivisted for performance
- // optimization (#877).
- if (shift_expr && shift_expr->offset(i) > 0) {
- predicate = makeAndExpr(
- predicate, ir_builder.geExpr(producer_index, left_limit));
- } else if (gather_expr) {
- // Since it's unknown if producer_index < consumer_index, we need
- // to predicate using both of the producer and consumer
- // indices. This would be the case if dynamic shift offset is
- // used, which is not yet supported. This can be a performance
- // problem, but in a common case where the input tensor is
- // cached at SMEM, it should be possible to remove the
- // predicate for this expression entirely.
- predicate = makeAndExpr(
- predicate, ir_builder.geExpr(consumer_index, left_limit));
- if (consumer_index != producer_index) {
- predicate = makeAndExpr(
- predicate, ir_builder.geExpr(producer_index, left_limit));
- }
- } else if (!left_limit->isZeroInt()) {
- predicate = makeAndExpr(
- predicate, ir_builder.geExpr(consumer_index, left_limit));
- }
-
- // If the shift offset is negative, the maximum index is extent -
- // abs(shift_offset). Instead of subtracting shift_offset from
- // extent, which can result in wrap around, add the absolute value
- // of the shift offset to the index
- if (shift_expr && shift_expr->offset(i) < 0) {
- predicate = makeAndExpr(
- predicate, ir_builder.ltExpr(producer_index, right_limit));
- } else if (gather_expr) {
- predicate = makeAndExpr(
- predicate, ir_builder.ltExpr(consumer_index, right_limit));
- if (consumer_index != producer_index) {
- predicate = makeAndExpr(
- predicate, ir_builder.ltExpr(producer_index, right_limit));
- }
- } else {
- predicate = makeAndExpr(
- predicate, ir_builder.ltExpr(consumer_index, right_limit));
- }
- } else {
- auto padding_max_offset = makeAddExpr(
- out_tv->domain()->rootDomain()[i]->extent(), halo_info.width());
-
- predicate = makeAndExpr(
- predicate, ir_builder.ltExpr(indices[i], padding_max_offset));
- }
- }
-
- if (thread_pred->isConst()) {
- if (!thread_pred->value().value()) {
- predicate = ir_builder.create<kir::Bool>(false);
- }
- } else {
- predicate = makeAndExpr(predicate, thread_pred);
- }
-
- return predicate;
-}
-
-AxisHaloInfo::AxisHaloInfo() {
- auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
- setWidth(0, ir_builder.zeroVal());
- setWidth(1, ir_builder.zeroVal());
-}
-
-kir::Int* AxisHaloInfo::width() const {
- auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
- return makeAddExpr(width(0), width(1));
-}
-
-kir::Int* AxisHaloInfo::width(int pos) const {
- TORCH_INTERNAL_ASSERT(pos >= 0 && pos < 2);
- TORCH_INTERNAL_ASSERT(widths_[pos] != nullptr);
- return widths_[pos];
-}
-
-void AxisHaloInfo::setWidth(int pos, kir::Int* width) {
- TORCH_INTERNAL_ASSERT(pos >= 0 && pos < 2);
- widths_[pos] = width;
-}
-
-void AxisHaloInfo::merge(int pos, kir::Int* other) {
- auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
- auto cur = width(pos);
- kir::Int* new_width = nullptr;
- if (cur->isConst() && other->isConst()) {
- new_width = ir_builder.create<kir::Int>(
- std::max(cur->value().value(), other->value().value()));
- } else if (cur->isZeroInt()) {
- new_width = other;
- } else if (other->isZeroInt()) {
- new_width = cur;
- } else {
- new_width = ir_builder.maxExpr(width(pos), other)->as<kir::Int>();
- }
- setWidth(pos, new_width);
-}
-
-void AxisHaloInfo::merge(const AxisHaloInfo& other) {
- for (size_t i = 0; i < widths_.size(); ++i) {
- merge(i, other.width(i));
- }
-}
-
-bool AxisHaloInfo::hasHalo() const {
- return std::any_of(
- widths_.begin(), widths_.end(), [](auto w) { return !w->isZeroInt(); });
-}
-
-std::string AxisHaloInfo::toString() const {
- std::stringstream ss;
- ss << "<" << kir::toString(width(0)) << ", " << kir::toString(width(1))
- << ">";
- return ss.str();
-}
-
-const AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) const {
- TORCH_INTERNAL_ASSERT(
- id->definition() == nullptr || id->isRFactorProduct(),
- "Invalid IterDomain: ",
- id);
- auto it = root_axis_map_.find(id);
- TORCH_INTERNAL_ASSERT(
- it != root_axis_map_.end(), "Halo root axis info not found for ", id);
- return it->second;
-}
-
-AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) {
- return const_cast<AxisHaloInfo&>(
- const_cast<const HaloInfo*>(this)->getRootAxisInfo(id));
-}
-
-const AxisHaloInfo& HaloInfo::getRootAxisInfo(kir::IterDomain* id) const {
- TORCH_INTERNAL_ASSERT(
- id->definition() == nullptr || id->isRFactorProduct(),
- "Invalid IterDomain: ",
- id);
- auto it = kir_root_axis_map_.find(id);
- TORCH_INTERNAL_ASSERT(
- it != kir_root_axis_map_.end(), "Halo root axis info not found for ", id);
- return it->second;
-}
-
-AxisHaloInfo& HaloInfo::getRootAxisInfo(kir::IterDomain* id) {
- return const_cast<AxisHaloInfo&>(
- const_cast<const HaloInfo*>(this)->getRootAxisInfo(id));
-}
-
-void HaloInfo::setRootAxisInfo(
- IterDomain* id,
- const AxisHaloInfo& root_axis_info) {
- TORCH_INTERNAL_ASSERT(
- id->definition() == nullptr || id->isRFactorProduct(),
- "Invalid IterDomain: ",
- id);
- root_axis_map_[id] = root_axis_info;
- kir_root_axis_map_
- [GpuLower::current()->lowerValue(id)->as<kir::IterDomain>()] =
- root_axis_info;
- return;
-}
-
-void HaloInfo::build(Fusion* fusion) {
- const auto vals = fusion->usedMathVals();
- auto tvs = ir_utils::filterByType<TensorView>(vals);
-
- // Initialize all root axis info
- for (auto tv : tvs) {
- for (auto root_axis : tv->getRootDomain()) {
- setRootAxisInfo(root_axis, AxisHaloInfo());
- }
- // Just adds a placeholder to make it not fail. Reduction and
- // rfactor support is not yet in place.
- if (tv->hasRFactor()) {
- for (auto rf_root_axis : tv->getRFactorDomain()) {
- setRootAxisInfo(rf_root_axis, AxisHaloInfo());
- }
- }
- }
-
- // Propagate backward halo information of root axes from fusion
- // outputs to inputs
- auto exprs = fusion->exprs();
- for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) {
- auto expr = *it;
- if (!expr->outputs()[0]->isA<TensorView>()) {
- continue;
- }
-
- propagateRootAxisInfo(expr);
- }
-
- // Propagates halo information from root axes down to leaf axes
- for (auto tv : tvs) {
- build(tv->domain());
- }
-
- // Note that validation requires consumer halo info
- for (auto tv : tvs) {
- validate(tv);
- }
-}
-
-void HaloInfo::propagateRootAxisInfo(Expr* expr) {
- for (auto output : expr->outputs()) {
- auto out_tv = dynamic_cast<TensorView*>(output);
- if (out_tv == nullptr) {
- continue;
- }
- for (auto input : expr->inputs()) {
- auto in_tv = dynamic_cast<TensorView*>(input);
- if (in_tv == nullptr) {
- continue;
- }
- propagateRootAxisInfo(in_tv, out_tv, expr);
- }
- }
-}
-
-void HaloInfo::propagateRootAxisInfo(
- TensorView* producer,
- TensorView* consumer,
- Expr* expr) {
- // Do not add halo to input tensors
- if (producer->isFusionInput()) {
- return;
- }
-
- auto c2p = PairwiseRootDomainMap(producer, consumer)
- .mapConsumerToProducer(consumer->domain(), producer->domain());
-
- const auto& c_root = consumer->getRootDomain();
-
- auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- for (size_t i = 0; i < c_root.size(); ++i) {
- auto c_id = c_root[i];
- auto it = c2p.find(c_id);
- if (it == c2p.end()) {
- // nothing to propagate
- continue;
- }
-
- // propagate root-axis halo info from c_id to p_id
-
- auto p_id = it->second;
-
- auto p_info = getRootAxisInfo(p_id);
- const auto c_info = getRootAxisInfo(c_id);
-
- // If the root axes are broadcast, no halo should be associated
- // with them.
- if (c_id->isBroadcast()) {
- TORCH_INTERNAL_ASSERT(!c_info.hasHalo());
- p_info.merge(c_info);
- setRootAxisInfo(p_id, p_info);
- continue;
- }
-
- // If the defining expression is shift, adjust the producer halo
- // width based on the shift offset. If the shift offset is
- // positive, create halo at offset zero of the producer axis so
- // that the consumer can safely access the producer. If the offset
- // is negative, halo is created at the other end of the axis.
- // If the expr is not shift, just merge the consumer halo info
- // to the producer halo info so that the producer halo can be the
- // maximum of all its consumers.
- if (auto shift_op = dynamic_cast<ShiftOp*>(expr)) {
- const auto offset = shift_op->offset(i);
- if (offset == 0) {
- p_info.merge(c_info);
- } else {
- int pos = (offset > 0) ? 0 : 1;
- p_info.merge(pos, makeAddExpr(c_info.width(pos), std::abs(offset)));
- }
- } else if (auto gather_op = dynamic_cast<GatherOp*>(expr)) {
- const auto window_dim =
- gpu_lower->lowerValue(gather_op->windowShape()[i]);
- if (window_dim->isOneInt()) {
- p_info.merge(c_info);
- continue;
- }
- const auto& pad_dim = gather_op->padWidth()[i];
- const auto pad_dim0 = gpu_lower->lowerValue(pad_dim[0])->as<kir::Int>();
- p_info.merge(0, makeAddExpr(c_info.width(0), pad_dim0));
- // The right-side halo is propagated as:
- // consumer_right_halo + (window_dim - 1 - left_padding)
- p_info.merge(
- 1,
- ir_builder
- .subExpr(
- makeAddExpr(c_info.width(1), window_dim),
- makeAddExpr(pad_dim0, 1))
- ->as<kir::Int>());
- } else {
- p_info.merge(c_info);
- }
- setRootAxisInfo(p_id, p_info);
- }
-}
-
-// Propagate extent information from root axes to descendants
-void HaloInfo::build(TensorDomain* td) {
- auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- for (auto root_axis : td->getRootDomain()) {
- const auto& halo_info = getRootAxisInfo(root_axis);
- auto halo_width = halo_info.width();
-
- // There should be no existing mapping. Note that at one point it
- // wasn't the case as root axes were reused when creating
- // reference tensors.
- // TODO: This is not the case actually. Root domains are reused
- // when creating some TensorDomains, so a single IterDomain can
- // show up multiple times. That itself should be fixed, but for
- // now disable this assertion.
- TORCH_INTERNAL_ASSERT(
- halo_width_map_.find(root_axis) == halo_width_map_.end(),
- "Invalid domain: ",
- root_axis,
- " of ",
- td->getRootDomain());
-
- if (!halo_info.hasHalo()) {
- halo_width_map_.insert({root_axis, ir_builder.zeroVal()});
- continue;
- }
-
- auto expanded_extent = ir_builder.addExpr(
- gpu_lower->lowerValue(root_axis->extent()), halo_width);
- kir_extent_map_.insert(
- {gpu_lower->lowerValue(root_axis)->as<kir::IterDomain>(),
- expanded_extent});
- halo_width_map_.insert({root_axis, halo_width});
- }
-
- auto exprs = ExprSort::getExprs(
- td->fusion(),
- std::vector<Val*>(td->domain().begin(), td->domain().end()));
-
- // Track IDs that are generated by merging halo-extended IDs
- std::unordered_set<IterDomain*> merged_shifted_ids;
-
- // Propagate halo information by traversing IterDomain
- // expressions. We populate extent_map_ and
- // halo_width_map_.
- // - extent_map_ maps to Expr* representing the
- // extent of each axis including its halo. If no mapping exists for
- // a particular axis in extent_map_, it means the axis does not have
- // halo.
- // - halo_width_map_ just maps to the integer size of the halo,
- // which is used for extent comparison (e.g., extentLessEqual).
- //
- // - When expr is split: if the halo width of the input axis is
- // zero, both the split outputs get zero halo in halo_width_map_. No
- // mapping is added for extent_map_. Otherwise, the halo is
- // propagated only to the inner output, so the inner output gets the
- // same halo width and its mapping is created in extent_map_.
- //
- // One major assumption here is that splitting an axis that is
- // an output of merging halo-extended axes is not allowed. This is
- // because it is unclear how to split the halo part of the merged
- // axis. This is unlikely to be a real limitation in practice.
- //
- // - When expr is merge: if either of the inputs has halo, a mapping
- // for the output is created in extent_map_. No mapping is created
- // for halo_width_map_ (see the comment on HaloInfo::halo_width_map_
- // in lower_shift.h). If both of them don't have halo, just adds a
- // new mapping of the output to zero in halo_width_map_. Also adds
- // it to a set (merged_shifted_ids) to track which axes are merge
- // outputs of halo-extended axes.
-
- for (auto expr : exprs) {
- if (auto split = dynamic_cast<Split*>(expr)) {
- // Merge-then-split of halo-extended IDs is not allowed
- TORCH_INTERNAL_ASSERT(
- merged_shifted_ids.find(split->in()) == merged_shifted_ids.end(),
- "Splitting IterDomain that is a merged domain of halo-extended domains is not allowed");
-
- auto in_id = split->in();
-
- // There must be always a mapping for the input axis of a split
- // expr. The only exception is when the input axis is an output
- // of merge, but that's excluded by the assertion above.
- const auto& halo_width_it = halo_width_map_.find(in_id);
- TORCH_INTERNAL_ASSERT(halo_width_it != halo_width_map_.end());
-
- const auto halo_width = halo_width_it->second;
-
- if (halo_width->isZeroInt()) {
- halo_width_map_.insert({split->outer(), halo_width});
- halo_width_map_.insert({split->inner(), halo_width});
- continue;
- }
-
- // propagate to inner domain
- auto out_id = split->inner();
-
- auto expanded_extent = ir_builder.addExpr(
- gpu_lower->lowerValue(out_id->extent()), halo_width);
- kir_extent_map_.insert(
- {gpu_lower->lowerValue(out_id)->as<kir::IterDomain>(),
- expanded_extent});
-
- halo_width_map_.insert({split->outer(), ir_builder.zeroVal()});
- halo_width_map_.insert({split->inner(), halo_width});
- } else if (auto merge = dynamic_cast<Merge*>(expr)) {
- // If either of the two inputs has halo extension, propagate it
- // to the merged output ID
- auto inner_extent = getExtent(merge->inner());
- auto outer_extent = getExtent(merge->outer());
- if (inner_extent != nullptr || outer_extent != nullptr) {
- if (inner_extent == nullptr) {
- inner_extent = gpu_lower->lowerValue(merge->inner()->extent());
- }
- if (outer_extent == nullptr) {
- outer_extent = gpu_lower->lowerValue(merge->outer()->extent());
- }
- auto expanded_extent = ir_builder.mulExpr(outer_extent, inner_extent);
- kir_extent_map_.insert(
- {gpu_lower->lowerValue(merge->out())->as<kir::IterDomain>(),
- expanded_extent});
- // Splitting the output of this merge is not allowed, so
- // remember it
- merged_shifted_ids.insert(merge->out());
- // Note that halo_width_map_ is not updated
- } else {
- halo_width_map_.insert({merge->out(), ir_builder.zeroVal()});
- }
- } else {
- TORCH_INTERNAL_ASSERT(false, "Unsupported expr: ", expr);
- }
- }
-}
-
-//! Restriction 1: When allocation is outside of a shifted
-//! axis, the shifted axis must be guaranteed to have a smaller extent
-//! than the concrete axis. For now, shifted axes always mean expanded
-//! allocations when the axis is located inside the allocation
-//! point. This restriction is validated at the allocation lowering
-//! pass.
-//!
-//! Restriction 2: If an expanded axis is parallelized, its memory
-//! must be accessible by all other threads. More specifically:
-//! - TIDx: It must be on shared memory. May want to consider
-//! utilizing the shuffle instructions as well.
-//! - BIDx: Not supported. If on global memory, Cooperative Launch
-//! may be used to support it, however, it's unclear in what
-//! situations block-level parallelization should be used.
-//!
-//! Other types of parallelization should be supported except for
-//! vectorization. Vectorization should be eventually supported but
-//! needs further work.
-void HaloInfo::validate(TensorView* tv) const {
- const auto& par_map = GpuLower::current()->caParallelMap();
- const auto& loop_map = GpuLower::current()->caLoopMap();
- const auto mem_type = tv->getMemoryType();
-
- for (auto axis : tv->domain()->domain()) {
- auto concrete_id = par_map.getConcreteMappedID(axis);
-
- // The extent is assumed to be the same
- TORCH_INTERNAL_ASSERT(
- extentEqual(axis, concrete_id),
- "Axis does not have the same exact size with its concrete ID due to halo extension.",
- " Tensor: T",
- tv->name(),
- ", Axis: ",
- axis,
- ", concrete ID: ",
- concrete_id);
-
- auto halo_extent = getExtent(axis);
-
- // If no halo extent is associated with this axis, it means the
- // axis is not extended.
- if (halo_extent == nullptr) {
- continue;
- }
-
- // Enforce restrictions on parallelization and memory type
- const auto ptype = concrete_id->getParallelType();
-
- if (ptype == ParallelType::Serial) {
- continue;
- }
-
- // Only threading parallelism is considered for now
- TORCH_CHECK(
- isParallelTypeThread(ptype), "Unsupported parallel type: ", ptype);
-
- bool shared_mem_needed = false;
- for (auto use : tv->uses()) {
- if (!ir_utils::isTVOp(use)) {
- continue;
- }
- if (use->isA<ShiftOp>() || use->isA<GatherOp>()) {
- shared_mem_needed = true;
- break;
- }
- auto consumer = use->outputs()[0]->as<TensorView>();
- // Find the corresponding axis in the consumer
- auto it = std::find_if(
- consumer->domain()->domain().begin(),
- consumer->domain()->domain().end(),
- [&](IterDomain* consumer_axis) {
- return loop_map.areMapped(axis, consumer_axis);
- });
- if (it == consumer->domain()->domain().end()) {
- continue;
- }
- if (!extentEqual(axis, *it)) {
- shared_mem_needed = true;
- break;
- }
- }
-
- if (!shared_mem_needed) {
- continue;
- }
-
- if (isParallelTypeThreadDim(ptype)) {
- // If all the consumers have the same extent and none of the
- // expressions is shift, any memory should be fine. Otherwise, it
- // must be accessible by all threads involved in the
- // parallelization.
- TORCH_CHECK(
- mem_type == MemoryType::Shared,
- "TV",
- tv->name(),
- " must be allocated on shared memory as its halo-extended axis is parallelized by ",
- ptype);
-
- } else if (isParallelTypeBlockDim(ptype)) {
- TORCH_CHECK(
- false,
- "Block-based parallelization of a halo-extended axis is not supported: ",
- axis);
- }
- }
- return;
-}
-
-kir::Val* HaloInfo::getExtent(IterDomain* id) const {
- auto kir_id = GpuLower::current()->lowerValue(id)->as<kir::IterDomain>();
- return getExtent(kir_id);
-}
-
-kir::Val* HaloInfo::getExtent(kir::IterDomain* id) const {
- auto it = kir_extent_map_.find(id);
- if (it != kir_extent_map_.end()) {
- return it->second;
- } else {
- return nullptr;
- }
-}
-
-kir::Int* HaloInfo::getHaloWidth(IterDomain* id) const {
- auto it = halo_width_map_.find(id);
- TORCH_INTERNAL_ASSERT(it != halo_width_map_.end());
- return it->second;
-}
-
-bool HaloInfo::hasHaloWidth(IterDomain* id) const {
- return halo_width_map_.find(id) != halo_width_map_.end();
-}
-
-namespace {
-
-//! Prove if the comparison operator, cmp, is true with the extents of
-//! id1 and id2, including their halo. The comparison is done
-//! conservatively, meaning false negative is possible.
-//!
-//! It is assumed that id1 and id2 are mapped with the CA Loop map, so
-//! what is checked here is only about halo
-//! sizes using HaloInfo::halo_width_map_. Since it does not have
-//! mappings for merged axes, each axis of merge inputs are
-//! individually compared, and only when both of the input axes
-//! return true, the merge output axis returns true.
-template <typename Cmp>
-bool extentCompare(
- const HaloInfo& halo_map,
- IterDomain* id1,
- IterDomain* id2,
- Cmp cmp) {
- auto gpu_lower = GpuLower::current();
- TORCH_INTERNAL_ASSERT(
- gpu_lower->caLoopMap().areMapped(id1, id2), "Invalid axes to compare");
-
- // It's invalid to compare two axes and when only either of them has
- // halo.
-
- if (halo_map.hasHaloWidth(id1)) {
- TORCH_INTERNAL_ASSERT(
- halo_map.hasHaloWidth(id2), "Invalid comparison: ", id1, " and ", id2);
- // Both axes have halo. We assume the axes themselves have equal
- // extents, excluding halo, as they are mapped with the CA
- // map. So, we just need to compare the halo width of each axis.
- return cmp(halo_map.getHaloWidth(id1), halo_map.getHaloWidth(id2));
- } else {
- TORCH_INTERNAL_ASSERT(!halo_map.hasHaloWidth(id2));
- // Both don't have halo. The only case this can happen must be
- // both axes are the output of a merge expression, so each merge
- // input is recursively compared, and returns true only when both
- // inputs return.
- if (auto merge1 = dynamic_cast<Merge*>(id1->definition())) {
- auto merge2 = dynamic_cast<Merge*>(id2->definition());
- TORCH_INTERNAL_ASSERT(
- merge2 != nullptr, "Invalid comparison: ", id1, " and ", id2);
- auto inner_le =
- extentCompare(halo_map, merge1->inner(), merge2->inner(), cmp);
- auto outer_le =
- extentCompare(halo_map, merge1->outer(), merge2->outer(), cmp);
- return inner_le && outer_le;
- } else {
- // This is not considered. Should never reach here.
- TORCH_INTERNAL_ASSERT(false, "Invalid comparison: ", id1, " and ", id2);
- }
- }
-}
-
-} // namespace
-
-bool HaloInfo::extentLessEqual(IterDomain* id1, IterDomain* id2) const {
- auto cmp = [](kir::Int* x, kir::Int* y) {
- if (x == y) {
- return true;
- }
- auto xv = x->value();
- auto yv = y->value();
- return xv.has_value() && yv.has_value() && xv.value() <= yv.value();
- };
- return extentCompare(*this, id1, id2, cmp);
-}
-
-bool HaloInfo::extentEqual(IterDomain* id1, IterDomain* id2) const {
- // Returns true only when x and y are proven to be the same. The
- // analysis is not comprehensive and can prove in rather trivial
- // cases only. Specifically:
- // - x and y are the same pointers
- // - Both have static values and they are the same
- // - Both are defined by the same expression and the inputs are
- // proven to be equal
- std::function<bool(kir::Int*, kir::Int*)> cmp = [&](kir::Int* x,
- kir::Int* y) {
- if (x == y) {
- return true;
- }
-
- auto xv = x->value();
- auto yv = y->value();
- if (xv.has_value() && yv.has_value() && xv.value() == yv.value()) {
- return true;
- }
-
- // Check if both are defined by an expression of the same type. If
- // so, recursively check the input operands.
- auto x_def = x->definition();
- auto y_def = y->definition();
- if (x_def && y_def &&
- ((x_def->isA<kir::UnaryOp>() && y_def->isA<kir::UnaryOp>() &&
- x_def->as<kir::UnaryOp>()->operation() ==
- y_def->as<kir::UnaryOp>()->operation()) ||
- (x_def->isA<kir::BinaryOp>() && y_def->isA<kir::BinaryOp>() &&
- x_def->as<kir::BinaryOp>()->operation() ==
- y_def->as<kir::BinaryOp>()->operation()))) {
- for (size_t i = 0; i < x_def->inputs().size(); ++i) {
- auto x_input = dynamic_cast<kir::Int*>(x_def->inputs()[i]);
- auto y_input = dynamic_cast<kir::Int*>(y_def->inputs()[i]);
- // Both must be kir::Int
- TORCH_INTERNAL_ASSERT(x_input && y_input);
- if (!cmp(x_input, y_input)) {
- return false;
- }
- }
- return true;
- }
-
- return false;
- };
- return extentCompare(*this, id1, id2, cmp);
-}
-
-std::string HaloInfo::toString() const {
- std::stringstream ss;
-
- ss << "HaloInfo:\n";
-
- if (root_axis_map_.empty()) {
- return ss.str();
- }
-
- Fusion* fusion = root_axis_map_.begin()->first->fusion();
-
- auto used_vals = DependencyCheck::getAllValsBetween(
- {fusion->inputs().begin(), fusion->inputs().end()}, fusion->outputs());
-
- for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
- const auto& root = tv->getRootDomain();
- ss << "TV" << tv->name() << " root domain: ";
- for (auto axis : root) {
- ss << axis << " -> " << getRootAxisInfo(axis).toString() << ", ";
- }
- ss << "\n";
- }
-
- return ss.str();
-}
-
-bool HaloInfo::needsShiftPredicate(Expr* expr) const {
- auto consumer_td = ir_utils::getTVOutput(expr)->domain();
- auto shift_expr = dynamic_cast<ShiftOp*>(expr);
- auto gather_expr = dynamic_cast<GatherOp*>(expr);
- for (size_t i = 0; i < consumer_td->getRootDomain().size(); ++i) {
- auto consumer_id = consumer_td->getRootDomain()[i];
- const auto consumer_halo_info = getRootAxisInfo(consumer_id);
- if (consumer_halo_info.hasHalo() ||
- (shift_expr != nullptr && shift_expr->offset(i) != 0 &&
- !consumer_id->isBroadcast()) ||
- (gather_expr != nullptr && !gather_expr->windowShape()[i]->isOneInt() &&
- !consumer_id->isBroadcast())) {
- return true;
- }
- }
- return false;
-}
-
-bool HaloInfo::needsShiftPredicate(kir::Expr* expr) const {
- const auto out_tv = expr->outputs()[0]->as<kir::TensorView>();
- auto fuser_expr = out_tv->fuserTv()->definition();
- TORCH_INTERNAL_ASSERT(fuser_expr != nullptr);
- return needsShiftPredicate(fuser_expr);
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Auxiliary class to represent information about halo of an axis
-class AxisHaloInfo {
- public:
- AxisHaloInfo();
-
- //! Width of halo.
- //!
- //! pos is either 0 or 1. The width of halo at offset zero is set
- //! when pos is 0.
- kir::Int* width(int pos) const;
-
- //! Sum of the widths of both widths
- kir::Int* width() const;
-
- const auto& widths() const {
- return widths_;
- }
-
- //! Set the halo width of either side.
- //! pos is either 0 or 1. The width of halo at offset zero is set
- //! when pos is 0.
- void setWidth(int pos, kir::Int* width);
-
- //! Extend the halo width to account for another axis.
- void merge(int pos, kir::Int* other);
-
- //! Extend the halo width to account for another axis.
- void merge(const AxisHaloInfo& other);
-
- //! True when halo may be attached
- bool hasHalo() const;
-
- std::string toString() const;
-
- private:
- //! Sizes of the halo regions of two sides. Both values are zero for
- //! axes with no halo. When an axis has halo at offset zero,
- //! widths_[0] is non-zero and designates the size of the
- //! halo. Similarly, non-zero widths_[1] means the axis has halo at
- //! the other end of the axis.
- std::array<kir::Int*, 2> widths_ = {nullptr, nullptr};
-};
-
-//! Helper class for lowering tensors with halo. Only valid at the
-//! lowering time.
-class HaloInfo {
- public:
- //! Scan a fusion and collect all information for lowering
- void build(Fusion* fusion);
-
- //! Build mappings of extent information of a TensorDomain
- void build(TensorDomain* td);
-
- //! Set initial AxisHaloInfo of a root axis
- //!
- //! This is only for root or rfactor axes. It is an error to query
- //! with other axes.
- void setRootAxisInfo(IterDomain* id, const AxisHaloInfo& root_axis_info);
-
- //! Returns the registed AxisHaloInfo of a root axis.
- //!
- //! This is only for root axes. It is an error to query with
- //! non-root axes.
- const AxisHaloInfo& getRootAxisInfo(IterDomain* id) const;
- AxisHaloInfo& getRootAxisInfo(IterDomain* id);
- //! KIR version
- const AxisHaloInfo& getRootAxisInfo(kir::IterDomain* id) const;
- AxisHaloInfo& getRootAxisInfo(kir::IterDomain* id);
-
- //! Query if an axis has a halo width.
- //!
- //! See the comment at halo_width_map_.
- bool hasHaloWidth(IterDomain* id) const;
-
- //! Return the halo width of an axis.
- //!
- //! It's an error if queried for an axis with no halo width
- //! information.
- kir::Int* getHaloWidth(IterDomain* id) const;
-
- //! Returns an extent if id is extended for halo. Nullptr is
- //! returned otherwise.
- kir::Val* getExtent(IterDomain* id) const;
- kir::Val* getExtent(kir::IterDomain* id) const;
-
- // True when the extent of id1 is guaranteed to be lesser than or
- // equal to id2. False when it *may* not.
- bool extentLessEqual(IterDomain* id1, IterDomain* id2) const;
- // True when the extent of id1 is guaranteed to be equal to
- // id2. False when it *may* not.
- bool extentEqual(IterDomain* id1, IterDomain* id2) const;
-
- //! Check if expr must be predicated based on boundary conditions
- //! directly or indirectly induced by shift expressions.
- //!
- //! When yes, the expression needs two predications: one for
- //! interior and another for padding. Predicate insertion is done in
- //! the ShiftPredicateInserter class below.
- bool needsShiftPredicate(Expr* expr) const;
- bool needsShiftPredicate(kir::Expr* expr) const;
-
- std::string toString() const;
-
- private:
- //! Propagate root axis information from outputs to inputs of an
- //! expression
- void propagateRootAxisInfo(Expr* expr);
-
- //! Propagate root axis information from consumer to producer
- void propagateRootAxisInfo(
- TensorView* producer,
- TensorView* consumer,
- Expr* expr);
-
- //! Validate shift usage
- void validate(TensorView* td) const;
-
- private:
- //! Halo information of root axes
- std::unordered_map<IterDomain*, AxisHaloInfo> root_axis_map_;
- //! KIR version
- std::unordered_map<kir::IterDomain*, AxisHaloInfo> kir_root_axis_map_;
-
- //! Halo-extended extents. No mapping for axes without halo extension
- std::unordered_map<kir::IterDomain*, kir::Val*> kir_extent_map_;
-
- //! The halo width of an axis.
- //!
- //! The mapped value is a sum of two widths of both sizes of an
- //! axis. For root axes, it is equivalent to AxisHaloInfo.widths_[0]
- //! + AxisHaloInfo.widths_[1] (or AxisHaloInfo.width()). For
- //! example, when a root axis is extended by 1 for both sides, it'd
- //! be mapped to 2. For axes with no halo, they are mapped to zero.
- //!
- //! When an axis is split, its halo is only propagated to the inner
- //! output axis, so the value of this map for the inner output is
- //! the same as the input of split, while the outer output is mapped
- //! to zero.
- //!
- //! When an axis is merged, no mapping is created for its
- //! output at this point primarly because it isn't clear what the
- //! "halo width" for a merged axis should mean. Perhaps, a merged
- //! axis of (N+a)*(M+b), where N and M correspond to the original
- //! extens of two axes, and a and b correspond to their halo widths,
- //! it might make sense to set the halo width of this merged axis as
- //! (N+a)*(M+b)-N*M. Currently, however, this isn't necessary, so no
- //! particular mapping is created for merged axes.
- //!
- //! This is currently used only for conservatively comparing the
- //! overall extents of axes. See HaloInfo::extentLessEqual and
- //! HaloInfo::extentEqual.
- //!
- //! Example: Suppose a root axis has {0, 1} of
- //! AxisHaloInfo.widths_. The root axis is mapped to 1. When it is
- //! split, say, by 4, the output axes, [N / 4] and [4], where N is
- //! the extent of the root axis, the outer axis is mapped to 0,
- //! whereas the inner axis is mapped to 1. Further, suppose the
- //! inner axis is merged with another axis of extent M, we know that
- //! the extent of the resulting output axis is 5*M, but we don't
- //! create its mapping.
- std::unordered_map<IterDomain*, kir::Int*> halo_width_map_;
-};
-
-class ShiftPredicateInserter {
- public:
- //! Works mostly the same way as
- //! PredicateCompute::getInlinePredicate but does the insertion of
- //! the generated predicate. The branch structure is different from
- //! the usual predicated expression, so the insertion is also done
- //! here.
- static void insert(
- kir::Expr* expr,
- const std::vector<kir::ForLoop*>& loops,
- kir::Bool* thread_pred);
-
- //! Returns predicates for the interior and overall domains of a
- //! tensor.
- //!
- //! The isShiftPredicate flag toggles between the predicate for shifted
- //! accesses and padding.
- static kir::Bool* getPredicate(
- const kir::Expr* expr,
- const std::vector<kir::ForLoop*>& loops,
- kir::TensorView* out_tv,
- kir::Bool* thread_pred,
- bool isShiftPredicate);
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
namespace {
-kir::Val* getPredicatePerParallelType(
+Val* getPredicatePerParallelType(
ParallelType pt,
- const ThreadPredicateMap::SourceMap& source_map) {
+ const ThreadPredicateMap::SourceMapType& source_map) {
kir::IrBuilder ir_builder(GpuLower::current()->kernel());
if (pt == ParallelType::BIDx || pt == ParallelType::BIDy ||
pt == ParallelType::BIDz) {
auto source = source_map.at(pt);
TORCH_INTERNAL_ASSERT(!source.empty(), "No predicate source found");
- kir::Val* pred = nullptr;
- for (auto src : source) {
- if (pred == nullptr) {
- auto flag_name = kir::GridReduction::getPredicateFlagName(src);
- pred = ir_builder.create<kir::NamedScalar>(flag_name, DataType::Bool);
- } else {
- auto flag_name = kir::GridReduction::getPredicateFlagName(src);
- pred = ir_builder.andExpr(
- pred,
- ir_builder.create<kir::NamedScalar>(flag_name, DataType::Bool));
- }
- }
- return pred;
+ TORCH_INTERNAL_ASSERT(source.size() == 1, "Multiple sources detected");
+ auto src = *source.begin();
+ auto flag_name = kir::GridReduction::getPredicateFlagName(src);
+ return ir_builder.create<kir::NamedScalar>(flag_name, DataType::Bool);
} else {
return ir_builder.eqExpr(
kir::NamedScalar::getParallelIndex(pt), ir_builder.create<kir::Int>(0));
}
}
-kir::Bool* getPredicateFromParallelTypes(
- const ParallelTypeBitmap& bits,
- const ThreadPredicateMap::SourceMap& source_map) {
+kir::Bool* getPredicate(
+ const ir_utils::ParallelTypeBitmap& bits,
+ const ThreadPredicateMap::SourceMapType& source_map) {
kir::IrBuilder ir_builder(GpuLower::current()->kernel());
if (bits.none()) {
- return ir_builder.trueVal();
+ return ir_builder.create<kir::Bool>(true);
}
- kir::Bool* pred = nullptr;
+ Val* pred = nullptr;
for (const auto& pt_bool : bits.getMap()) {
if (pt_bool.second) {
- const auto tp = getPredicatePerParallelType(pt_bool.first, source_map);
- if (pred == nullptr) {
- pred = ir_builder.create<kir::Bool>(c10::nullopt);
- ir_builder.create<kir::UnaryOp>(UnaryOpType::Set, pred, tp);
- } else {
- pred = ir_builder.andExpr(pred, tp)->as<kir::Bool>();
- }
+ auto tp = getPredicatePerParallelType(pt_bool.first, source_map);
+ pred = (pred == nullptr) ? tp : ir_builder.andExpr(pred, tp);
}
}
+ // Should never be hit.
TORCH_INTERNAL_ASSERT(pred != nullptr);
- return pred;
+ TORCH_INTERNAL_ASSERT(
+ pred->getDataType().value() == DataType::Bool,
+ "Tried to return a predicate that is not a bool val.");
+
+ return pred->as<kir::Bool>();
}
void mergeSourceMap(
- ThreadPredicateMap::SourceMap& dst,
- const ThreadPredicateMap::SourceMap& src) {
+ ThreadPredicateMap::SourceMapType& dst,
+ const ThreadPredicateMap::SourceMapType& src) {
for (const auto& kv : src) {
const auto& src_key = kv.first;
const auto& src_value = kv.second;
- auto& dst_set = dst[src_key];
+ std::unordered_set<const TensorView*>& dst_set = dst[src_key];
for (const auto& src_tensor : src_value) {
dst_set.insert(src_tensor);
}
}
void addToSouceMap(
- ThreadPredicateMap::SourceMap& dst,
+ ThreadPredicateMap::SourceMapType& dst,
const TensorView* tv,
- const ParallelTypeBitmap& reducton_pred) {
+ const ir_utils::ParallelTypeBitmap& reducton_pred) {
for (const auto& kv : reducton_pred.getMap()) {
if (kv.second) {
ParallelType ptype = kv.first;
}
void maskSouceMap(
- ThreadPredicateMap::SourceMap& src_map,
- const ParallelTypeBitmap& mask) {
+ ThreadPredicateMap::SourceMapType& src_map,
+ const ir_utils::ParallelTypeBitmap& mask) {
for (const auto& kv : mask.getMap()) {
if (!kv.second) {
ParallelType ptype = kv.first;
// A bit of a hack for now for GEMM tiling so we don't fetch tiles multiple
// times. It's safe to do, there may simply be a better place to do it.
-ParallelTypeBitmap avoidRedundantWritesToSmem(
- const TensorView* out_tv,
- const ParallelTypeBitmap& pred) {
- const auto& ca_map = GpuLower::current()->caParallelMap();
- auto new_pred = pred;
+void avoidRedundantWritesToSmem(
+ TensorView* out_tv,
+ ir_utils::ParallelTypeBitmap& pred) {
if (out_tv->getMemoryType() == MemoryType::Shared) {
for (const auto i : c10::irange(out_tv->nDims())) {
- auto id = ca_map.getConcreteMappedID(out_tv->axis(i));
+ auto id = out_tv->getComputeAtAxis(i).first;
if (out_tv->axis(i)->isBroadcast() && id->isThreadDim()) {
- new_pred.set(id->getParallelType(), true);
+ pred.set(id->getParallelType(), true);
}
}
}
- return new_pred;
}
} // namespace
// Update the reduction_deps bitset based on provided Expr
-void ThreadPredicateMap::updateBitSet(const Expr* expr) {
- FUSER_PERF_SCOPE("GpuLower::Lower::ThreadPredicateMap::updateBitSet");
+void ThreadPredicateMap::updateBitSet(Expr* expr) {
+ FUSER_PERF_SCOPE("ThreadPredicateMap::updateBitSet");
// Which predicates were set for the inputs
- ParallelTypeBitmap input_preds;
+ ir_utils::ParallelTypeBitmap input_preds;
// Which dims are reductions in inputs
- ParallelTypeBitmap input_reductions;
+ ir_utils::ParallelTypeBitmap input_reductions;
// Which dims are bcast in inputs
- ParallelTypeBitmap input_bcasts;
+ ir_utils::ParallelTypeBitmap input_bcasts;
- SourceMap src_map;
+ SourceMapType src_map;
// Run through inputs and update bitsets
for (const auto* inp : expr->inputs()) {
if (!ir_utils::isTV(inp))
continue;
- auto tv_inp = inp->as<TensorView>();
-
- // Change for welford Op, we want the users of all outputs of welfordOp
- // to use a single predicate name.
- if (auto tv_def = tv_inp->definition()) {
- if (auto wop = dynamic_cast<WelfordOp*>(tv_def)) {
- tv_inp = wop->out()->as<TensorView>();
- }
- }
-
+ auto tv_inp = ir_utils::asConstTV(inp);
TORCH_INTERNAL_ASSERT(
thread_predicates_.find(tv_inp) != thread_predicates_.end(),
"Thread predicate map was not initialized, couldn't find ",
inp);
- const auto& pred_and_src = at(tv_inp);
-
- input_preds |= pred_and_src.pred;
+ input_preds |= at(tv_inp).first;
- mergeSourceMap(src_map, pred_and_src.source_map);
+ mergeSourceMap(src_map, at(tv_inp).second);
- ParallelTypeBitmap id_reductions;
- ParallelTypeBitmap id_bcasts;
- ParallelTypeBitmap id_ptypes;
+ ir_utils::ParallelTypeBitmap id_reductions;
+ ir_utils::ParallelTypeBitmap id_bcasts;
+ ir_utils::ParallelTypeBitmap id_ptypes;
for (auto id : tv_inp->domain()->domain()) {
if (id->isThread()) {
}
// Validate the combination of ptypes, reductions, bcasts
- for (const auto i : c10::irange(ParallelTypeBitmap::num_p_type)) {
+ for (const auto i : c10::irange(ir_utils::ParallelTypeBitmap::num_p_type)) {
if (input_reductions[i]) {
if (id_ptypes[i]) {
TORCH_INTERNAL_ASSERT(
auto output_preds = input_preds | input_reductions;
// Figure out which dims bcast wants to reset
- const auto bcast_reset_mask = ~(output_preds & input_bcasts);
+ auto bcast_reset_map = output_preds & input_bcasts;
- // Get rid of any reductions which are bcasted
- output_preds &= bcast_reset_mask;
+ // Flip it to make a bit mask
+ bcast_reset_map = ~bcast_reset_map;
+ // Get rid of any reductions which are bcasted
+ output_preds &= bcast_reset_map;
// Similarly, drop non-relevant source tensors
- maskSouceMap(src_map, bcast_reset_mask);
+ maskSouceMap(src_map, bcast_reset_map);
// Run through outputs and set bitset predicates
for (auto* out : expr->outputs()) {
- if (auto tv = dynamic_cast<const TensorView*>(out)) {
- TORCH_INTERNAL_ASSERT(find(tv) == end());
- insert(tv, avoidRedundantWritesToSmem(tv, output_preds), src_map);
- }
+ if (!ir_utils::isTV(out))
+ continue;
+ TORCH_INTERNAL_ASSERT(find(ir_utils::asConstTV(out)) == end());
+ auto pred_for_this_out = output_preds;
+ avoidRedundantWritesToSmem(ir_utils::asTV(out), pred_for_this_out);
+ insert(ir_utils::asConstTV(out), pred_for_this_out, src_map);
}
}
-void ThreadPredicateMap::build(Fusion* fusion) {
- FUSER_PERF_SCOPE("GpuLower::Lower::ThreadPredicateMap");
-
+// TODO(kir): revisit this - can we build it from the kernel IR?
+ThreadPredicateMap::ThreadPredicateMap(Fusion* _fusion) : fusion_(_fusion) {
+ FUSER_PERF_SCOPE("ThreadPredicateMap");
// Initialize mapping for input tensors
- for (auto inp : fusion->inputs()) {
- if (auto tv = dynamic_cast<const TensorView*>(inp)) {
- insert(tv, ParallelTypeBitmap(), SourceMap());
+ for (auto inp : fusion_->inputs()) {
+ if (ir_utils::isTV(inp)) {
+ insert(
+ ir_utils::asConstTV(inp),
+ ir_utils::ParallelTypeBitmap(),
+ SourceMapType());
}
}
- for (auto expr : fusion->exprs()) {
+ for (auto expr : fusion_->exprs(true)) {
updateBitSet(expr);
}
}
return thread_predicates_.end();
}
-const ThreadPredicateMap::PredAndSource& ThreadPredicateMap::at(
+const ThreadPredicateMap::MapType::mapped_type& ThreadPredicateMap::at(
const TensorView* tv) const {
return thread_predicates_.at(tv);
}
-ThreadPredicateMap::PredAndSource& ThreadPredicateMap::at(
+ThreadPredicateMap::MapType::mapped_type& ThreadPredicateMap::at(
const TensorView* tv) {
return thread_predicates_.at(tv);
}
-void ThreadPredicateMap::insert(
- const TensorView* tv,
- const ParallelTypeBitmap& pred,
- const SourceMap& src_map) {
- insert(tv, {pred, src_map});
+ThreadPredicateMap::MapType::mapped_type& ThreadPredicateMap::operator[](
+ const TensorView* tv) {
+ return thread_predicates_[tv];
}
void ThreadPredicateMap::insert(
const TensorView* tv,
- const PredAndSource& pred_and_src) {
- thread_predicates_.insert({tv, pred_and_src});
+ const ir_utils::ParallelTypeBitmap& pred,
+ const SourceMapType& src_map) {
+ insert(tv, std::make_pair(pred, src_map));
}
-kir::Bool* ThreadPredicateMap::getPredicate(const TensorView* tv) const {
- // No thread predicate is needed when tv is an output of a
- // parallel broadcast expression.
- if (auto bop = dynamic_cast<BroadcastOp*>(tv->definition())) {
- if (getParallelBroadcastDomains(tv).any()) {
- return kir::IrBuilder(GpuLower::current()->kernel()).trueVal();
- }
- }
- TORCH_INTERNAL_ASSERT(find(tv) != end(), "Couldn't find ", tv);
- const auto& pred_and_src = at(tv);
- return getPredicateFromParallelTypes(
- pred_and_src.pred, pred_and_src.source_map);
+void ThreadPredicateMap::insert(
+ const TensorView* tv,
+ const std::pair<ir_utils::ParallelTypeBitmap, SourceMapType>&
+ pred_and_src) {
+ thread_predicates_.insert(std::make_pair(tv, pred_and_src));
}
-ParallelTypeBitmap ThreadPredicateMap::getParallelBroadcastDomains(
- const TensorView* tv) const {
- // If no pred is found for tv, no predicate is necessary
- if (find(tv) == end()) {
- return ParallelTypeBitmap();
- }
-
- ParallelTypeBitmap parallel_broadcast;
-
- const auto& iter_domains = tv->domain()->domain();
-
- // If the output is on shared memory, assume that all subsequent
- // reads from all threads in its CTA can be done with no parallel
- // broadcast. Only one thread will write to shared memory followed
- // by a proper _syncthreads.
- const bool output_smem = tv->getMemoryType() == MemoryType::Shared;
-
- for (auto id : iter_domains) {
- if (!id->isBroadcast()) {
- continue;
- }
- if (id->isBlockDim() || (!output_smem && id->isThreadDim())) {
- parallel_broadcast.set(id->getParallelType(), true);
- }
+void ThreadPredicateMap::duplicate(
+ const TensorView* copy,
+ const TensorView* origin) {
+ if (find(origin) != end()) {
+ insert(copy, at(origin).first, at(origin).second);
}
-
- return parallel_broadcast & at(tv).pred;
}
-void ThreadPredicateMap::print() const {
- std::cout << "\nThreadPredicateMap\n";
- std::cout << "--------------------------------\n";
- for (const auto& kv : thread_predicates_) {
- std::cout << "T" << kv.first->name() << " {";
- // ParallelTypeBitmap
- for (auto ptkv : kv.second.pred.getMap()) {
- if (ptkv.second) {
- std::cout << " " << ptkv.first;
- }
- }
- std::cout << " }\n";
- // SourceMap
- for (const auto& pkv : kv.second.source_map) {
- std::cout << " " << pkv.first << " : [";
- for (auto tv : pkv.second) {
- std::cout << " T" << tv->name();
- }
- std::cout << " ]\n";
- }
- }
- std::cout << "--------------------------------\n\n";
+kir::Bool* ThreadPredicateMap::getExpr(const TensorView* out_tv) const {
+ TORCH_INTERNAL_ASSERT(find(out_tv) != end(), "Couldn't find ", out_tv);
+ return getPredicate(at(out_tv).first, at(out_tv).second);
}
} // namespace cuda
-
#pragma once
-
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h>
-#include <unordered_map>
-#include <unordered_set>
-#include <utility>
+#include <bitset>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
-//! Maps TensorViews to a { ParallelTypeBitmap, SourceMap } pair
+//! Maps TensorViews to std::pair<ir_utils::ParallelTypeBitmap, SourceMapType>>
//!
-//! Map from TensorView to bit set represnting <BIDx, BIDy, BIDz, TIDx, TIDy,
+//! Map from tensorview to bit set represnting <BIDx, BIDy, BIDz, TIDx, TIDy,
//! TIDz> If any dependency of TV had a parallelized reduction, we will track
//! it here. This will be used for predicate generation to prevent
//! parallelization on that axis. This is important if we have a reduction on
//!
class TORCH_CUDA_CU_API ThreadPredicateMap {
public:
- using SourceMap = std::unordered_map<
+ using SourceMapType = std::unordered_map<
ParallelType,
std::unordered_set<const TensorView*>,
TypeHash>;
-
- struct PredAndSource {
- ParallelTypeBitmap pred;
- SourceMap source_map;
- };
-
- using MapType = std::unordered_map<const TensorView*, PredAndSource>;
-
+ using MapType = std::unordered_map<
+ const TensorView*,
+ std::pair<ir_utils::ParallelTypeBitmap, SourceMapType>>;
using const_iterator = MapType::const_iterator;
- void build(Fusion* fusion);
+ explicit ThreadPredicateMap(Fusion* _fusion);
- // TODO(kir): these methods are only used by getParallelBroadcastDomains() ?
const_iterator find(const TensorView* tv) const;
const_iterator end() const;
- const PredAndSource& at(const TensorView* tv) const;
- PredAndSource& at(const TensorView* tv);
-
- // Returns a Bool predicate for a given TensorView.
- kir::Bool* getPredicate(const TensorView* tv) const;
+ const MapType::mapped_type& at(const TensorView* tv) const;
+ MapType::mapped_type& at(const TensorView* tv);
+ MapType::mapped_type& operator[](const TensorView* tv);
- //! Returns a ParallelTypeBitmap representing which domain needs
- //! blockBroadcast.
- //!
- //! Even when a domain is broadcast and parallelized, it does not need
- //! blockBroadcast unless it is predicated.
- ParallelTypeBitmap getParallelBroadcastDomains(const TensorView* tv) const;
+ void duplicate(const TensorView* copy, const TensorView* origin);
- void print() const;
+ // Returns a Bool predicate expression for a given output TensorView.
+ kir::Bool* getExpr(const TensorView* out_tv) const;
private:
// Update the thread_predicates bitset based on provided Expr
- void updateBitSet(const Expr*);
+ void updateBitSet(Expr*);
void insert(
const TensorView* tv,
- const ParallelTypeBitmap& pred,
- const SourceMap& src_map);
-
- void insert(const TensorView* tv, const PredAndSource& pred_and_src);
+ const ir_utils::ParallelTypeBitmap& pred,
+ const SourceMapType& src_map);
+ void insert(const TensorView* tv, const MapType::mapped_type& pred_and_src);
private:
+ Fusion* fusion_ = nullptr;
MapType thread_predicates_;
};
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-
-#include <unordered_set>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id);
-
-bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {
- TORCH_INTERNAL_ASSERT(
- root_id->definition() == nullptr, "Not root IterDomain: ", root_id);
-
- if (tv->definition() == nullptr) {
- // This is an input tensor, so no rfactor tensor to traverse.
- return false;
- }
-
- const auto& inputs = tv->definition()->inputs();
-
- if (inputs.size() != 1 || !inputs[0]->isA<TensorView>() ||
- (tv->definition()->getExprType() != ExprType::ReductionOp &&
- tv->definition()->getExprType() != ExprType::WelfordOp)) {
- // No rfactor producer found
- return false;
- }
-
- auto producer = inputs[0]->as<TensorView>();
-
- if (!producer->hasRFactor()) {
- return false;
- }
-
- auto c2p = PairwiseRootDomainMap(producer, tv)
- .mapConsumerToProducer(tv->domain(), producer->domain());
-
- auto producer_id_it = c2p.find(root_id);
- if (producer_id_it == c2p.end()) {
- // No matching producer is found. Stop traversing.
- return false;
- }
-
- auto producer_root_id = producer_id_it->second;
-
- return analyzeIfDerivedFromTrivialReduction(producer, producer_root_id);
-}
-
-bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id) {
- auto id_inputs = InputsOf::output(id->fusion(), id);
- for (auto root_id : ir_utils::filterByType<IterDomain>(id_inputs)) {
- if (root_id->isReduction() && root_id->extent()->isOneInt()) {
- continue;
- }
- // If not possible to prove the root ID is trivial, see if the ID
- // is derived from a rfactor tensor and, if so, continue the
- // analysis at the rfactor tensor.
- if (!traverseToRFactorTensor(tv, root_id)) {
- return false;
- }
- }
- return true;
-}
-
-} // namespace
-
-void TrivialReductionInfo::build(Fusion* fusion, GpuLower* gpu_lower) {
- auto used_vals = fusion->usedMathVals();
-
- for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
- for (auto id : tv->domain()->domain()) {
- if (analyzeIfDerivedFromTrivialReduction(tv, id)) {
- // If id is a trivial reduction, all of its ancestor vals are
- // also trivial reductions.
- for (auto dep_id : DependencyCheck::getAllValsBetween(
- std::unordered_set<Val*>(
- tv->getRootDomain().begin(), tv->getRootDomain().end()),
- {id})) {
- domains_.insert(dep_id->as<IterDomain>());
- domains_derived_from_root_.insert(dep_id->as<IterDomain>());
- }
- } else if (id->isReduction() && id->extent()->isOneInt()) {
- // This happens when a leaf domain is trivial but its root
- // axes are not. For example, consider a non-trivial domain
- // split by one. The inner output axis is a trivial domain,
- // whereas the outer output axis is not. Since the root axis
- // is not trivial, a for-loop needs to be generated.
- domains_.insert(id);
- }
- }
- }
-
- buildKir(fusion, gpu_lower);
-}
-
-void TrivialReductionInfo::buildKir(Fusion* fusion, GpuLower* gpu_lower) {
- for (auto id : domains_) {
- auto kir_trivial_id = gpu_lower->lowerValue(id)->as<kir::IterDomain>();
- kir_domains_.insert(kir_trivial_id);
- }
-
- for (auto id : domains_derived_from_root_) {
- auto kir_trivial_id = gpu_lower->lowerValue(id)->as<kir::IterDomain>();
- kir_domains_derived_from_root_.insert(kir_trivial_id);
- }
-}
-
-bool TrivialReductionInfo::isDerived(IterDomain* id) const {
- return domains_.find(id) != domains_.end();
-}
-
-bool TrivialReductionInfo::isDerivedFromRoot(IterDomain* id) const {
- return domains_derived_from_root_.find(id) !=
- domains_derived_from_root_.end();
-}
-
-bool TrivialReductionInfo::isDerived(kir::IterDomain* id) const {
- return kir_domains_.find(id) != kir_domains_.end();
-}
-
-bool TrivialReductionInfo::isDerivedFromRoot(kir::IterDomain* id) const {
- return kir_domains_derived_from_root_.find(id) !=
- kir_domains_derived_from_root_.end();
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <unordered_set>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class GpuLower;
-
-//! Detect almost all IterDomains that are derived from trivial
-//! reductons.
-class TORCH_CUDA_CU_API TrivialReductionInfo {
- public:
- void build(Fusion* fusion, GpuLower* gpu_lower);
-
- bool isDerived(IterDomain* id) const;
- bool isDerivedFromRoot(IterDomain* id) const;
-
- bool isDerived(kir::IterDomain* id) const;
- bool isDerivedFromRoot(kir::IterDomain* id) const;
-
- private:
- //! Convert the sets to KIR sets
- void buildKir(Fusion* fusion, GpuLower* gpu_lower);
-
- private:
- //! IterDomains that are derived only from trivial
- //! reductons. Included domains are not limited to reduction axes as
- //! rfactor can make reductions to normal axes.
- //!
- //! Note that the set should cover almost all cases but there can be
- //! undetected trivial domains. For example, split by one creates a
- //! trivial reduction domain, which is detected. However, if it is
- //! further split, both of the two resulting axes are also trivial,
- //! however, only the inner axis is recognized as trivial. While this
- //! is a limitation, it would have very little practical
- //! implication.
- std::unordered_set<IterDomain*> domains_;
- //! Subset of domains_, whose input root axes are all derived from
- //! trivial reductions. These domains do not need to manifest as
- //! for-loops.
- std::unordered_set<IterDomain*> domains_derived_from_root_;
-
- std::unordered_set<kir::IterDomain*> kir_domains_;
- std::unordered_set<kir::IterDomain*> kir_domains_derived_from_root_;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
#include <torch/csrc/jit/codegen/cuda/index_compute.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-#include <torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
namespace fuser {
namespace cuda {
-namespace {
-
-// Provide a new for loop matching the one provided
-kir::ForLoop* cloneLoopNest(const kir::ForLoop* for_loop) {
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- const auto new_loop = ir_builder.create<kir::ForLoop>(
- for_loop->iter_domain(),
- for_loop->index(),
- for_loop->start(),
- for_loop->stop(),
- for_loop->step(),
- for_loop->vectorize(),
- for_loop->vectorize_shift());
- for (auto expr : for_loop->body().exprs()) {
- if (auto nested_for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
- expr = cloneLoopNest(nested_for_loop);
+kir::Bool* UnrollPass::getThreadPredicate(TensorView* tv) {
+ // No thread predicate is needed predicate when tv is output of a
+ // parallel broadcast expression.
+ const auto origin = tv->getOrigin();
+ if (origin != nullptr && origin->getExprType() == ExprType::BroadcastOp) {
+ const auto out = origin->as<BroadcastOp>()->out();
+ if (ir_utils::getParallelBroadcastDomains(out, thread_predicates_).any()) {
+ return nullptr;
}
- new_loop->body().push_back(expr);
}
- return new_loop;
-}
-// Returns true if expr is an expression that initializes a reduction
-// buffer.
-bool isReductionInitExpr(const kir::Expr* expr) {
- // False if its output isn't a TensorView
- if (!ir_utils::isTVOp(expr)) {
- return false;
- }
- // False if it doesn't have any reduction axis
- const auto out_tv = expr->outputs()[0]->as<kir::TensorView>();
- if (!out_tv->domain()->hasReduction()) {
- return false;
- }
- // False if it has have TensorView inputs as initialization should
- // never use TensorViews
- const auto tv_filter_inp_view =
- ir_utils::filterByType<kir::TensorView>(expr->inputs());
- if (tv_filter_inp_view.begin() != tv_filter_inp_view.end()) {
- return false;
- }
- return true;
+ return thread_predicates_.getExpr(tv);
}
-} // namespace
-
-void UnrollPass::handle(kir::Expr* expr) {
+// Custom dispatch for Expr, want to find out of it's a TV op.
+void UnrollPass::handle(Expr* expr) {
+ // If tv op, predciate it.
if (ir_utils::isTVOp(expr)) {
- // If tv op, predicate it
- const auto out_tv = ir_utils::getTVOutput(expr);
- const bool should_predicate = !for_loops_.empty() ||
- out_tv->memoryType() == MemoryType::Global ||
- out_tv->memoryType() == MemoryType::Shared;
- if (!should_predicate) {
- return;
- }
-
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- const auto thread_pred = isReductionInitExpr(expr)
- ? ir_builder.trueVal()
- : GpuLower::current()->threadPredMap().getPredicate(out_tv->fuserTv());
-
- // When a predicate needs to account for ShiftOp, it is currently
- // taken care by its own function.
- if (GpuLower::current()->haloInfo().needsShiftPredicate(expr)) {
- ShiftPredicateInserter::insert(expr, for_loops_, thread_pred);
- return;
- }
+ TORCH_INTERNAL_ASSERT(for_loops.size() != 0);
- // Reduction may need a separate predicate for writes.
- if (!isReductionInitExpr(expr) && out_tv->domain()->hasReduction()) {
- const auto write_pred = ir_builder.create<kir::Predicate>(
- PredicateType::ReductionWrite, expr, thread_pred);
- expr->setWritePredicate(write_pred);
- }
-
- // For expr calling a device func with block sync, don't create
- // if-then-else but pass the predicate to the device func
- if (ir_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) {
- const auto pred = ir_builder.create<kir::Predicate>(
- PredicateType::Inline, expr, thread_pred);
- expr->setPredicate(pred);
- return;
- }
-
- // Vectorized expressions should never use inline predicates
- kir::Predicate* vectorized_pred = nullptr;
- if (std::any_of(
- for_loops_.begin(), for_loops_.end(), [](const kir::ForLoop* fl) {
- return fl->iter_domain()->parallelType() ==
- ParallelType::Vectorize;
- })) {
- vectorized_pred =
- ir_builder.create<kir::Predicate>(PredicateType::Vectorize);
- }
-
- const auto pred = vectorized_pred == nullptr
- ? ir_builder.create<kir::Predicate>(
- PredicateType::Inline, expr, thread_pred)
- : vectorized_pred;
-
- TORCH_INTERNAL_ASSERT(pred != nullptr);
+ auto pred = PredicateCompute::getInlinePredicate(
+ expr, for_loops, getThreadPredicate(ir_utils::getTVOutput(expr)));
// If we need a predicate, put expr inside an if then else
- non_trivial_pred_found_ = true;
- kir::IfThenElse* inline_ite = ir_builder.create<kir::IfThenElse>(pred);
- if (for_loops_.empty()) {
- // Special handling for top level output expressions that still
- // need predicates. One motivating example is a reduction op that
- // reduces to a scalar (issue #491)
- expr_replacement_map_.insert({expr, inline_ite});
- } else {
- for_loops_.back()->body().insert_before(expr, inline_ite);
- for_loops_.back()->body().erase(expr);
+ if (!(pred->isConst()) || !(pred->isConst() && pred->value().value())) {
+ non_trivial_pred_found = true;
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+ kir::IfThenElse* inline_ite =
+ ir_builder.create<kir::IfThenElse>(pred, for_loops.back());
+ inline_ite->thenBody().push_back(expr);
+ for_loops.back()->body().insert_before(expr, inline_ite);
+ for_loops.back()->body().erase(expr);
}
- inline_ite->thenBody().push_back(expr);
- } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
- handle(for_loop);
+
+ } else {
+ // If not tv op, dispatch it.
+ OptOutDispatch::handle(expr);
}
}
// IR nodes "unroll_pred" or "inline_pred", then generate those later.
void UnrollPass::handle(kir::ForLoop* fl) {
// Setup for loop scoping
- const bool is_unroll =
- fl->iter_domain()->parallelType() == ParallelType::Unroll ||
- fl->iter_domain()->parallelType() == ParallelType::Unswitch ||
- fl->iter_domain()->parallelType() == ParallelType::Vectorize;
-
+ bool is_unroll = ir_utils::isUnrolledFor(fl);
// If we're not looking for an unroll loop, or didn't find one, process as
// normal.
- if (!is_unroll || !look_for_unroll_) {
- for_loops_.push_back(fl);
+ if (!is_unroll || !look_for_unroll) {
+ for_loops.push_back(fl);
+ std::vector<Expr*> exprs_copy = fl->body().exprs();
// Make copy of exprs because we replace them inplace in fl
- const auto exprs_copy = fl->body().exprs();
-
- // Skip Misaligned Vectorization For-Loops here
- if (!containsAnyDirectChildMisalignedVectorize(fl)) {
- for (auto expr : exprs_copy) {
- handle(expr);
- }
+ for (auto expr : exprs_copy) {
+ handle(expr);
}
+ for_loops.pop_back();
- for_loops_.pop_back();
return;
}
- kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- auto unroll_pred = ir_builder.create<kir::Predicate>(fl);
+ auto unroll_pred = UnrollPredicate::get(for_loops, fl, p2c_root_map);
- kir::IfThenElse* unroll_ite = ir_builder.create<kir::IfThenElse>(unroll_pred);
+ kir::ForLoop* parent_scope = for_loops.empty() ? nullptr : for_loops.back();
+
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+ kir::IfThenElse* unroll_ite =
+ ir_builder.create<kir::IfThenElse>(unroll_pred, parent_scope);
// Get the loop nest for the unrolled path
- kir::ForLoop* unrolled_loop_nest = cloneLoopNest(fl);
+ kir::ForLoop* unrolled_loop_nest = scope_utils::cloneLoopNest(fl, unroll_ite);
unroll_ite->thenBody().push_back(unrolled_loop_nest);
- if (fl->iter_domain()->parallelType() == ParallelType::Vectorize) {
- expr_replacement_map_.insert({fl, unroll_ite});
- return;
- }
// Loop nest for inlined path
- kir::ForLoop* inlined_loop = cloneLoopNest(fl);
+ kir::ForLoop* inlined_loop = scope_utils::cloneLoopNest(fl, unroll_ite);
// Add inline predicates for inlined loop nest
- look_for_unroll_ = false;
- non_trivial_pred_found_ = false;
+ look_for_unroll = false;
+ non_trivial_pred_found = false;
handle(inlined_loop);
- look_for_unroll_ = true;
- if (!non_trivial_pred_found_) {
- expr_replacement_map_.insert({fl, inlined_loop});
+ look_for_unroll = true;
+ if (!non_trivial_pred_found) {
+ inlined_loop->setParentScope(parent_scope);
+ loop_replacement_map.insert({fl, inlined_loop});
} else {
- if (!canOmitElseClause(fl)) {
- unroll_ite->elseBody().push_back(inlined_loop);
- }
- expr_replacement_map_.insert({fl, unroll_ite});
- }
-}
-
-bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) const {
- kir::ExpressionEvaluator eval;
- std::vector<kir::ForLoop*> loops({fl});
-
- const auto& pred_map = GpuLower::current()->threadPredMap();
-
- while (loops.size() > 0) {
- auto loop = loops.back();
- loops.pop_back();
-
- // If there's any expression that requires barrier
- // synchronization, the else part can't be omitted
- for (auto expr : loop->body().exprs()) {
- if (expr->isA<kir::BroadcastOp>()) {
- const ParallelTypeBitmap domains = pred_map.getParallelBroadcastDomains(
- expr->outputs()[0]->as<kir::TensorView>()->fuserTv());
- if (domains.any()) {
- return false;
- }
- } else if (expr->isA<kir::ReductionOp>() || expr->isA<kir::WelfordOp>()) {
- auto td = ir_utils::getTVOutput(expr)->domain();
- if (td->hasBlockReduction() || td->hasGridReduction()) {
- return false;
- }
- }
- }
- // If the number of visits of the loop body per thread is one, the
- // unswitch predicate is sufficient.
- // When the loop stop is the same as the extent of its IterDomain,
- // the per-thread visit count is guaranteed to be one at most (see
- // CudaKernelGenerator::visit(kir::ForLoop*) as well. Also, when a
- // loop is vectorized (not misaligned), the count must be one at
- // most. Even if not parallelized nor vectoirzed, it is also
- // sufficient if the loop stop is in fact one.
- bool visit_once = false;
- auto id = loop->iter_domain();
- if ((id->isThread() && (loop->stop() == id->extent())) ||
- id->parallelType() == ParallelType::Vectorize) {
- visit_once = true;
- }
- if (!visit_once) {
- const auto result = eval.evaluate(loop->stop());
- if (result.has_value() && result.value() == 1) {
- visit_once = true;
- }
- }
-
- // The visit count is not guaranteed to be one, so the else part
- // must be created.
- if (!visit_once) {
- return false;
- }
-
- // The unswitch predicate is sufficient for this loop. Proceed to
- // nested loops.
- for (auto nested_loop :
- ir_utils::filterByType<kir::ForLoop>(loop->body().exprs())) {
- loops.push_back(nested_loop);
- }
+ unroll_ite->elseBody().push_back(inlined_loop);
+ loop_replacement_map.insert({fl, unroll_ite});
}
-
- return true;
}
// Generate the loop nest structure and place it in lowered_exprs
-UnrollPass::UnrollPass(const std::vector<kir::Expr*>& exprs) {
- FUSER_PERF_SCOPE("GpuLower::Lower::UnrollPass::computeMap");
+void UnrollPass::computeMap() {
+ FUSER_PERF_SCOPE("UnrollPass::computeMap");
+
+ FusionGuard fg(fusion_);
// Run through loop nests and further lower the expressions
- for (auto* expr : exprs) {
- handle(expr);
+ for (auto* expr : incoming_exprs_) {
+ OptOutDispatch::handle(expr);
}
}
-std::vector<kir::Expr*> UnrollPass::runPass(
+std::vector<Expr*> UnrollPass::runPass(
Fusion* fusion,
- const std::vector<kir::Expr*>& exprs) {
- FUSER_PERF_SCOPE("GpuLower::Lower::UnrollPass::runPass");
-
- UnrollPass unroll_pass(exprs);
-
- std::vector<kir::Expr*> mutated_exprs;
- mutated_exprs.reserve(exprs.size());
- for (auto expr : exprs) {
- mutated_exprs.push_back(
- ir_utils::applyReplacements(unroll_pass.replacementMap(), expr));
+ const std::vector<Expr*>& exprs,
+ const ThreadPredicateMap& thread_predicates) {
+ FUSER_PERF_SCOPE("UnrollPass::runPass");
+ FusionGuard fg(fusion);
+ UnrollPass up(fusion, exprs, thread_predicates);
+ up.computeMap();
+ std::vector<Expr*> mutated_exprs;
+ for (Expr* expr : exprs) {
+ if (up.loop_replacement_map.find(expr) != up.loop_replacement_map.end()) {
+ mutated_exprs.push_back(up.loop_replacement_map[expr]);
+ } else {
+ if (ir_utils::isScope(expr))
+ scope_utils::replaceExprsInScope(expr, up.loop_replacement_map);
+ mutated_exprs.push_back(expr);
+ }
}
-
return mutated_exprs;
}
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <bitset>
-#include <unordered_map>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
-//! Unroll pass
-//!
-//! A bit deceptively: UnrollPass adds all predicates, so it needs to be run
-//! even if we don't unroll any loops.
-//!
-//! Unrolling pass will get IR that looks something like:
-//! for( i : I0o{ceil(I0/4)} ) {
-//! for( j : I1o{ceil(I1/128)} ) {
-//! for( k : I0i{4} )
-//! for( l : I1i{128} )
-//! T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
-//!
-//! And it will return the following:
-//! for( i : I0o{ceil(I0/4)} ) {
-//! for( j : I1o{ceil(I1/128)} ) {
-//!
-//! if( i * 4 + 3 < I && j * 128 + 127 < J ){
-//! for( k : I0i{4} )
-//! for( l : I1i{128} )
-//! T0[ ( i * 4 + k ) * J + j * 128 + l ] = ...
-//! } else {
-//! for( k : I0i{4} )
-//! for( l : I1i{128} )
-//! if( i * 4 + k < I && j * 128 + l < J)
-//! T0[ ( i * 4 + k ) * J + j * 128 + l ] = ...
-//! }
-//!
-//! }
-//! }
-//!
-//! As can be seen it generates two sets of loops for I0i{4} and I1i{128}. The
-//! first set is protected by a predicate that makes sure there's a full
-//! internal tile we can iterate over. This way we remove the predicate nested
-//! in the inner most loop. There's of course a second set of loops, which has a
-//! predicate still in the inner most loop, making sure that we cover edges and
-//! corners.
-//!
-class TORCH_CUDA_CU_API UnrollPass {
- public:
- // Take the incoming exprs and run loop unrolling, returning the new IR
- static std::vector<kir::Expr*> runPass(
- Fusion* fusion,
- const std::vector<kir::Expr*>& exprs);
+/*
+ * A bit deceptively: UnrollPass adds all predicates, so it needs to be run even
+ * if we don't unroll any loops.
+ *
+ * Unrolling pass will get IR that looks something like:
+ * for( i : I0o{ceil(I0/4)} ) {
+ * for( j : I1o{ceil(I1/128)} ) {
+ * for( k : I0i{4} )
+ * for( l : I1i{128} )
+ * T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
+ *
+ * And it will return the following:
+ * for( i : I0o{ceil(I0/4)} ) {
+ * for( j : I1o{ceil(I1/128)} ) {
+ *
+ * if( i * 4 + 3 < I && j * 128 + 127 < J ){
+ * for( k : I0i{4} )
+ * for( l : I1i{128} )
+ * T0[ ( i * 4 + k ) * J + j * 128 + l ] = ...
+ * } else {
+ * for( k : I0i{4} )
+ * for( l : I1i{128} )
+ * if( i * 4 + k < I && j * 128 + l < J)
+ * T0[ ( i * 4 + k ) * J + j * 128 + l ] = ...
+ * }
+ *
+ * }
+ * }
+ *
+ * As can be seen it generates two sets of loops for I0i{4} and I1i{128}. The
+ * first set is protected by a predicate that makes sure there's a full internal
+ * tile we can iterate over. This way we remove the predicate nested in the
+ * inner most loop. There's of course a second set of loops, which has a
+ * predicate still in the inner most loop, making sure that we cover edges and
+ * corners.
+ */
+class TORCH_CUDA_CU_API UnrollPass : public OptOutDispatch {
private:
- // Generate the for Expr replacement map
- UnrollPass(const std::vector<kir::Expr*>& exprs);
+ // Wrapper to access thread_predicates_ based on an output TV
+ kir::Bool* getThreadPredicate(TensorView*);
- const std::unordered_map<kir::Expr*, kir::Expr*>& replacementMap() const {
- return expr_replacement_map_;
- }
+ // We will track which loops in the incomming IR will be replaced and by what
+ std::unordered_map<Expr*, Expr*> loop_replacement_map;
- void handle(kir::ForLoop* fl);
+ // Hold on to a reference to the fusion for convenience
+ Fusion* fusion_;
- void handle(kir::Expr* expr);
+ // Hold on to the incoming exprs, but don't modify them. We don't set the
+ // Expr* to be const as Exprs' are const by virtue of their interface design
+ const std::vector<Expr*>& incoming_exprs_;
- bool canOmitElseClause(kir::ForLoop* fl) const;
+ // Keep all for loops conveniently to make unrolling easier
+ std::vector<kir::ForLoop*> for_loops;
- private:
- // We will track which loops in the incoming IR will be replaced and by what
- std::unordered_map<kir::Expr*, kir::Expr*> expr_replacement_map_;
+ // Map from TensorView
+ const ThreadPredicateMap& thread_predicates_;
- // Keep all for loops conveniently to make unrolling easier
- std::vector<kir::ForLoop*> for_loops_;
+ std::unordered_map<IterDomain*, IterDomain*> p2c_root_map;
// keep track if we're within an unrolled loop
- bool look_for_unroll_ = true;
+ bool look_for_unroll = true;
// As we generate inline predicates check if we actually generated a
// non-trivial one.
- bool non_trivial_pred_found_ = false;
+ bool non_trivial_pred_found = false;
+
+ // Custom dispatch for Expr, want to find out of it's a TV op
+ void handle(Expr*) final;
+
+ // Open the for loop.
+ void handle(kir::ForLoop*) final;
+
+ // Constructor
+ UnrollPass(
+ Fusion* _fusion,
+ const std::vector<Expr*>& _incoming_exprs,
+ const ThreadPredicateMap& _thread_predicates)
+ : fusion_(_fusion),
+ incoming_exprs_(_incoming_exprs),
+ thread_predicates_(_thread_predicates) {
+ p2c_root_map = loop_utils::p2cRootMap(_fusion->exprs(true));
+ }
+
+ // Generate the for Expr replacement map
+ void computeMap();
+
+ public:
+ // Take the incoming fusion and exprs and run loop unrolling, returning the
+ // new IR.
+ static std::vector<Expr*> runPass(
+ Fusion* fusion,
+ const std::vector<Expr*>& exprs,
+ const ThreadPredicateMap& thread_predicates);
};
} // namespace cuda
#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <algorithm>
-// TODO: refactor this file (one per namespace)
-
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
-
namespace scope_utils {
-std::vector<kir::ForLoop*> getLoops(kir::Expr* scope) {
- std::vector<kir::ForLoop*> loops;
- while (scope != nullptr) {
- if (auto loop = dynamic_cast<kir::ForLoop*>(scope)) {
- loops.push_back(loop);
+// START SCOPE HELPER SYSTEMS
+namespace {
+
+class Loops : private OptInDispatch {
+ private:
+ std::deque<kir::ForLoop*> loops;
+ void handle(kir::ForLoop* fl) final {
+ loops.insert(loops.begin(), fl);
+ }
+
+ void handle(kir::IfThenElse* ite) final {}
+
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
+
+ public:
+ static std::vector<kir::ForLoop*> getLoops(Expr* scope) {
+ Loops loops;
+ Expr* it = scope;
+ while (it != nullptr) {
+ loops.handle(it);
+ it = scope_utils::getParent(it);
+ }
+ return std::vector<kir::ForLoop*>(loops.loops.begin(), loops.loops.end());
+ }
+};
+
+class scopePushBack : private OptInDispatch {
+ private:
+ Expr* expr_;
+ void handle(kir::ForLoop* fl) final {
+ fl->body().push_back(expr_);
+ }
+
+ void handle(kir::IfThenElse* ite) final {
+ ite->thenBody().push_back(expr_);
+ }
+
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
+
+ scopePushBack(Expr* expr) : expr_(expr) {}
+
+ public:
+ static void push(Expr* scope, Expr* expr) {
+ scopePushBack pb(expr);
+ TORCH_INTERNAL_ASSERT(
+ expr != nullptr && scope != nullptr,
+ "Cannot push back, scope or expr is a nullptr.");
+ pb.handle(scope);
+ }
+};
+
+class scopeInsertBefore : private OptInDispatch {
+ private:
+ Expr* ref_;
+ Expr* expr_;
+ void handle(kir::ForLoop* fl) final {
+ fl->body().insert_before(ref_, expr_);
+ }
+
+ void handle(kir::IfThenElse* ite) final {
+ ite->thenBody().insert_before(ref_, expr_);
+ }
+
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
+
+ scopeInsertBefore(Expr* ref, Expr* expr) : ref_(ref), expr_(expr) {}
+
+ public:
+ static void insert(Expr* scope, Expr* ref, Expr* expr) {
+ scopeInsertBefore scb(ref, expr);
+ TORCH_INTERNAL_ASSERT(
+ expr != nullptr && scope != nullptr,
+ "Cannot push back, scope or expr is a nullptr.");
+ scb.handle(scope);
+ }
+};
+
+class ExprInScope : private OptInDispatch {
+ private:
+ Expr* expr_;
+ bool contains_ = false;
+
+ void handle(kir::ForLoop* fl) final {
+ if (fl->body().contains(expr_)) {
+ contains_ = true;
+ }
+ }
+
+ void handle(kir::IfThenElse* ite) final {
+ if (ite->thenBody().contains(expr_)) {
+ contains_ = true;
+ }
+ }
+
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
+
+ ExprInScope(Expr* expr) : expr_(expr) {}
+
+ public:
+ static bool find(Expr* scope, Expr* expr) {
+ ExprInScope eis(expr);
+ TORCH_INTERNAL_ASSERT(
+ expr != nullptr && scope != nullptr,
+ "Cannot push back, scope or expr is a nullptr.");
+ eis.handle(scope);
+ return eis.contains_;
+ }
+};
+
+class parentScope : private OptInDispatch {
+ private:
+ Expr* parent_ = nullptr;
+
+ void handle(kir::ForLoop* fl) final {
+ parent_ = fl->parentScope();
+ }
+
+ void handle(kir::IfThenElse* ite) final {
+ parent_ = ite->parentScope();
+ }
+
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
+
+ public:
+ static Expr* get(Expr* scope) {
+ parentScope sp;
+ sp.handle(scope);
+ return sp.parent_;
+ }
+};
+
+void assertScope(Expr* expr) {
+ TORCH_INTERNAL_ASSERT(
+ expr->getExprType() == ExprType::ForLoop ||
+ expr->getExprType() == ExprType::IfThenElse,
+ "Assert Scope failed when calling a scope_util function.");
+}
+
+class CloneLoopNest : public OptOutMutator {
+ private:
+ Expr* parent_scope_ = nullptr;
+ Expr* to_clone_ = nullptr;
+
+ Statement* mutate(kir::ForLoop* fl) final {
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+ const auto parent_scope =
+ fl == to_clone_ ? parent_scope_ : fl->parentScope();
+ auto new_loop = ir_builder.create<kir::ForLoop>(
+ fl->index(), fl->iter_domain(), parent_scope);
+ for (Expr* expr : fl->body().exprs()) {
+ new_loop->body().push_back(ir_utils::asExpr(OptOutMutator::mutate(expr)));
+ }
+ return new_loop;
+ }
+
+ CloneLoopNest(Expr* _to_clone, Expr* _parent_scope)
+ : parent_scope_(_parent_scope), to_clone_(_to_clone) {}
+
+ public:
+ static kir::ForLoop* getClone(kir::ForLoop* _to_clone, Expr* _parent_scope) {
+ TORCH_INTERNAL_ASSERT(
+ _to_clone != nullptr,
+ "Tried to clone a scope, but received a nullptr.");
+ CloneLoopNest cln(_to_clone, _parent_scope);
+ return ir_utils::asForLoop(ir_utils::asExpr(cln.mutate(_to_clone)));
+ }
+};
+
+class ReplaceExprsInScope : public OptOutDispatch {
+ public:
+ static void replace(
+ Expr* scope,
+ std::unordered_map<Expr*, Expr*> replacement_map) {
+ ReplaceExprsInScope reis(std::move(replacement_map));
+ reis.handle(scope);
+ }
+
+ private:
+ explicit ReplaceExprsInScope(std::unordered_map<Expr*, Expr*> replacement_map)
+ : replacement_map_(std::move(replacement_map)) {}
+
+ void handleScope(kir::Scope& scope) {
+ for (const auto i : c10::irange(scope.size())) {
+ const auto it = replacement_map_.find(scope[i]);
+ if (it == replacement_map_.end()) {
+ handle(scope[i]);
+ continue;
+ }
+ scope[i] = it->second;
+ }
+ }
+
+ void handle(Expr* expr) final {
+ OptOutDispatch::handle(expr);
+ }
+
+ void handle(kir::ForLoop* fl) final {
+ handleScope(fl->body());
+ }
+
+ void handle(kir::IfThenElse* ite) final {
+ handleScope(ite->thenBody());
+ handleScope(ite->elseBody());
+ }
+
+ private:
+ std::unordered_map<Expr*, Expr*> replacement_map_;
+};
+
+class FirstInnerMostScope : private OptInDispatch {
+ private:
+ Expr* active_scope = nullptr;
+
+ void handle(kir::ForLoop* fl) final {
+ for (auto expr : fl->body().exprs()) {
+ if (ir_utils::isScope(expr)) {
+ active_scope = expr;
+ return;
+ }
}
- scope = scope->parentScope();
+ active_scope = nullptr;
}
- std::reverse(loops.begin(), loops.end());
- return loops;
+
+ void handle(kir::IfThenElse* ite) final {
+ for (auto expr : ite->thenBody().exprs()) {
+ if (ir_utils::isScope(expr)) {
+ active_scope = expr;
+ return;
+ }
+ }
+ for (auto expr : ite->elseBody().exprs()) {
+ if (ir_utils::isScope(expr)) {
+ active_scope = expr;
+ return;
+ }
+ }
+ active_scope = nullptr;
+ }
+
+ Expr* getInner(Expr* expr) {
+ OptInDispatch::handle(expr);
+ return active_scope;
+ }
+
+ public:
+ static Expr* get(Expr* scope) {
+ TORCH_INTERNAL_ASSERT(
+ scope != nullptr,
+ "Tried to get inner most scope, but was provided nullptr.");
+
+ FirstInnerMostScope fims;
+ Expr* inner = fims.getInner(scope);
+
+ if (inner == nullptr)
+ return scope;
+
+ while (fims.getInner(inner) != nullptr)
+ inner = fims.getInner(inner);
+ return inner;
+ }
+};
+
+// END SCOPE HELPER SYSTEMS
+} // namespace
+
+// Grab the ForLoop starting from scope working out
+std::vector<kir::ForLoop*> getLoops(Expr* scope) {
+ if (scope == nullptr)
+ return std::vector<kir::ForLoop*>();
+ assertScope(scope);
+ return Loops::getLoops(scope);
}
-void insertBefore(kir::Expr* scope, kir::Expr* ref, kir::Expr* expr) {
- if (auto ite = dynamic_cast<kir::IfThenElse*>(scope)) {
- ite->thenBody().insert_before(ref, expr);
- } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(scope)) {
- for_loop->body().insert_before(ref, expr);
+// Push back an expr to scope
+void pushBack(Expr* scope, Expr* expr) {
+ TORCH_INTERNAL_ASSERT(
+ scope != nullptr, "Scope is a nullptr, cannot push an expr to it.");
+ assertScope(scope);
+ scopePushBack::push(scope, expr);
+}
+
+// Insert expr in scope before ref
+void insertBefore(Expr* scope, Expr* ref, Expr* expr) {
+ scopeInsertBefore::insert(scope, ref, expr);
+}
+
+bool exprInScope(Expr* scope, Expr* expr) {
+ return ExprInScope::find(scope, expr);
+}
+
+// Return the parent of the active scope
+Expr* getParent(Expr* scope) {
+ TORCH_INTERNAL_ASSERT(
+ scope != nullptr,
+ "Tried to close the active scope, but there isn't one set.");
+ assertScope(scope);
+ return parentScope::get(scope);
+}
+
+// Open a new inner most for loop
+kir::ForLoop* openFor(Expr* scope, IterDomain* id) {
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+ const auto kir_id = GpuLower::lowerValue(id)->as<kir::IterDomain>();
+ kir::ForLoop* new_scope = nullptr;
+ if (id->isThread()) {
+ std::stringstream ss;
+ ss << id->getParallelType();
+ new_scope = ir_builder.create<kir::ForLoop>(
+ ir_builder.create<kir::NamedScalar>(ss.str(), DataType::Int),
+ kir_id,
+ scope);
} else {
- TORCH_INTERNAL_ASSERT(false, "Unexpected scope expression");
+ new_scope = ir_builder.create<kir::ForLoop>(
+ ir_builder.create<kir::Int>(c10::nullopt), kir_id, scope);
}
+ if (scope != nullptr)
+ pushBack(scope, new_scope);
+ return new_scope;
+}
+
+kir::ForLoop* cloneLoopNest(kir::ForLoop* to_clone, Expr* parent_scope) {
+ return CloneLoopNest::getClone(to_clone, parent_scope);
+}
+
+void replaceExprsInScope(
+ Expr* scope,
+ std::unordered_map<Expr*, Expr*> replacement_map) {
+ TORCH_INTERNAL_ASSERT(
+ replacement_map.find(scope) == replacement_map.end(),
+ "Error trying to replace expressions in a scope, scope wants to be replaced entirely.");
+ ReplaceExprsInScope::replace(scope, std::move(replacement_map));
+}
+
+Expr* firstInnerMostScope(Expr* scope) {
+ return FirstInnerMostScope::get(scope);
}
} // namespace scope_utils
return ordered_inputs;
}
+std::vector<Val*> indices(std::vector<kir::ForLoop*> loops) {
+ std::vector<Val*> inds(loops.size());
+ std::transform(
+ loops.begin(), loops.end(), inds.begin(), [](kir::ForLoop* fl) {
+ return fl->index();
+ });
+ return inds;
+}
+
bool isTV(const Val* val) {
return val->getValType().value() == ValType::TensorView;
}
// Check if we're a TensorView op that we can generate code for.
bool isTVOp(const Expr* expr) {
- if (std::any_of(
- expr->outputs().begin(),
- expr->outputs().end(),
- [](Val* v) { return isTV(v); }) &&
+ if (expr->outputs().size() == 1 && isTV(expr->output(0)) &&
(expr->getExprType().value() == ExprType::BinaryOp ||
expr->getExprType().value() == ExprType::UnaryOp ||
expr->getExprType().value() == ExprType::TernaryOp ||
expr->getExprType().value() == ExprType::ReductionOp ||
- expr->getExprType().value() == ExprType::WelfordOp ||
- expr->getExprType().value() == ExprType::BroadcastOp ||
- expr->getExprType().value() == ExprType::TransposeOp ||
- expr->getExprType().value() == ExprType::ShiftOp ||
- expr->getExprType().value() == ExprType::GatherOp)) {
+ expr->getExprType().value() == ExprType::BroadcastOp))
return true;
- }
return false;
}
-bool isTVOp(const kir::Expr* expr) {
- const auto& outputs = expr->outputs();
- return outputs.size() >= 1 && outputs[0]->isA<kir::TensorView>();
-}
-
-// TODO: why do we assume there's a single TV output?
TensorView* getTVOutput(const Expr* expr) {
for (auto out : expr->outputs()) {
if (out->getValType().value() == ValType::TensorView) {
return nullptr;
}
-kir::TensorView* getTVOutput(const kir::Expr* expr) {
- for (auto out : expr->outputs()) {
- if (auto tv = dynamic_cast<kir::TensorView*>(out)) {
- return tv;
- } else if (auto ti = dynamic_cast<kir::TensorIndex*>(out)) {
- return ti->view();
- }
- }
- return nullptr;
-}
-
bool isScalarOp(const Expr* expr) {
for (auto out : expr->outputs())
if (!out->isScalar())
return true;
}
+void ASSERT_EXPR(Statement* stmt) {
+ TORCH_INTERNAL_ASSERT(
+ stmt->isExpr(),
+ "Tried to generate a kernel but hit a non expression during lowering: ",
+ stmt);
+}
+
Expr* asExpr(Statement* stmt) {
- TORCH_INTERNAL_ASSERT(stmt->isExpr());
+ ASSERT_EXPR(stmt);
return stmt->as<Expr>();
}
return val->as<TensorView>();
}
-bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map) {
- if (!isTVOp(expr)) {
+bool isScope(const Expr* expr) {
+ return expr->getExprType() == ExprType::ForLoop ||
+ expr->getExprType() == ExprType::IfThenElse;
+}
+
+kir::ForLoop* asForLoop(Statement* stmt) {
+ Expr* expr = asExpr(stmt);
+ TORCH_INTERNAL_ASSERT(expr->getExprType() == ExprType::ForLoop);
+ return expr->as<kir::ForLoop>();
+}
+
+const TensorView* asConstTV(const Val* val) {
+ TORCH_INTERNAL_ASSERT(isTV(val));
+ return val->as<TensorView>();
+}
+
+bool isUnrolledFor(const Expr* expr) {
+ if (expr->getExprType() != ExprType::ForLoop) {
return false;
}
+ return expr->as<kir::ForLoop>()->iter_domain()->getParallelType() ==
+ ParallelType::Unroll;
+}
- auto tv = getTVOutput(expr);
+const std::unordered_map<ParallelType, int, TypeHash>
+ ParallelTypeBitmap::pt_to_offset_{
+ {ParallelType::BIDx, 0},
+ {ParallelType::BIDy, 1},
+ {ParallelType::BIDz, 2},
+ {ParallelType::TIDx, 3},
+ {ParallelType::TIDy, 4},
+ {ParallelType::TIDz, 5}};
+
+const std::unordered_map<int, ParallelType> ParallelTypeBitmap::offset_to_pt_ =
+ {{0, ParallelType::BIDx},
+ {1, ParallelType::BIDy},
+ {2, ParallelType::BIDz},
+ {3, ParallelType::TIDx},
+ {4, ParallelType::TIDy},
+ {5, ParallelType::TIDz}};
+
+bool ParallelTypeBitmap::get(ParallelType pt) const {
+ if (pt_to_offset_.find(pt) == pt_to_offset_.end()) {
+ TORCH_INTERNAL_ASSERT(false, "Could not recognize parallel type.");
+ }
+ return bitset_[pt_to_offset_.at(pt)];
+}
- if ((expr->isA<ReductionOp>() || expr->isA<WelfordOp>()) &&
- (tv->hasBlockReduction() || tv->hasGridReduction())) {
- return true;
- } else if (expr->isA<BroadcastOp>()) {
- const ParallelTypeBitmap pt_map =
- GpuLower::current()->threadPredMap().getParallelBroadcastDomains(tv);
- return pt_map.hasTID();
+bool ParallelTypeBitmap::set(ParallelType pt, bool new_val) {
+ if (pt_to_offset_.find(pt) == pt_to_offset_.end()) {
+ TORCH_INTERNAL_ASSERT(false, "Could not recognize parallel type.");
}
+ bool old_val = bitset_[pt_to_offset_.at(pt)];
+ bitset_[pt_to_offset_.at(pt)] = new_val;
+ return old_val;
+}
- return false;
+ParallelTypeBitmap ParallelTypeBitmap::operator&=(
+ const ParallelTypeBitmap& other) {
+ bitset_ &= other.bitset_;
+ return *this;
+}
+
+ParallelTypeBitmap ParallelTypeBitmap::operator|=(
+ const ParallelTypeBitmap& other) {
+ bitset_ |= other.bitset_;
+ return *this;
+}
+
+ParallelTypeBitmap ParallelTypeBitmap::operator^=(
+ const ParallelTypeBitmap& other) {
+ bitset_ ^= other.bitset_;
+ return *this;
}
-bool hasBlockSync(const kir::Expr* expr, const ThreadPredicateMap& pred_map) {
- if (expr->isA<kir::ReductionOp>() || expr->isA<kir::GridReduction>() ||
- expr->isA<kir::BroadcastOp>() || expr->isA<kir::WelfordOp>() ||
- expr->isA<kir::GridWelford>()) {
- auto fuser_tv = getTVOutput(expr)->fuserTv();
- auto fuser_expr = fuser_tv->definition();
- TORCH_INTERNAL_ASSERT(fuser_expr != nullptr);
- return hasBlockSync(fuser_expr, pred_map);
+ParallelTypeBitmap ParallelTypeBitmap::operator~() const {
+ return ParallelTypeBitmap(~bitset_);
+}
+
+bool ParallelTypeBitmap::none() const {
+ return bitset_.none();
+}
+
+bool ParallelTypeBitmap::any() const {
+ return bitset_.any();
+}
+
+bool ParallelTypeBitmap::all() const {
+ return bitset_.all();
+}
+
+bool ParallelTypeBitmap::operator[](size_t pos) const {
+ TORCH_INTERNAL_ASSERT(
+ pos < num_p_type, "Invalid index to ParallelTypeBitset: ", pos);
+ return bitset_[pos];
+}
+
+std::map<ParallelType, bool> ParallelTypeBitmap::getMap() const {
+ std::map<ParallelType, bool> map;
+ for (const auto& pt_offset : pt_to_offset_) {
+ map.emplace(pt_offset.first, bitset_[pt_offset.second]);
}
+ return map;
+}
- return false;
+ParallelTypeBitmap operator&(
+ const ParallelTypeBitmap& lhs,
+ const ParallelTypeBitmap& rhs) {
+ auto x = lhs;
+ x &= rhs;
+ return x;
}
-kir::Expr* applyReplacements(
- const std::unordered_map<kir::Expr*, kir::Expr*>& expr_replacement_map,
- kir::Expr* expr) {
- auto handle_scope = [&](kir::Scope& scope) {
- for (size_t i = 0; i < scope.size(); ++i) {
- scope[i] = applyReplacements(expr_replacement_map, scope[i]);
- }
- };
+ParallelTypeBitmap operator|(
+ const ParallelTypeBitmap& lhs,
+ const ParallelTypeBitmap& rhs) {
+ auto x = lhs;
+ x |= rhs;
+ return x;
+}
- const auto it = expr_replacement_map.find(expr);
- if (it != expr_replacement_map.end()) {
- return it->second;
- } else {
- if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
- handle_scope(for_loop->body());
- } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
- handle_scope(ite->thenBody());
- handle_scope(ite->elseBody());
+ParallelTypeBitmap operator^(
+ const ParallelTypeBitmap& lhs,
+ const ParallelTypeBitmap& rhs) {
+ auto x = lhs;
+ x ^= rhs;
+ return x;
+}
+
+ParallelTypeBitmap getParallelBroadcastDomains(
+ const Val* bop_out,
+ const ThreadPredicateMap& preds) {
+ if (bop_out->getValType().value() == ValType::TensorIndex) {
+ bop_out = bop_out->as<kir::TensorIndex>()->view()->fuserTv();
+ }
+ TORCH_INTERNAL_ASSERT(
+ bop_out->getValType().value() == ValType::TensorView,
+ "Out is not tensor view");
+ auto out_tv = bop_out->as<TensorView>();
+ // If no pred is found for out_tv, no predicate is necessary
+ if (preds.find(out_tv) == preds.end()) {
+ return ParallelTypeBitmap();
+ }
+ const ParallelTypeBitmap& out_pred = preds.at(out_tv).first;
+
+ ParallelTypeBitmap parallel_broadcast;
+ const auto& iter_domains = out_tv->domain()->domain();
+ // If the output is on shared memory, assume that all subsequent
+ // reads from all threads in its CTA can be done with no parallel
+ // broadcast. Only one thread will write to shared memory followed
+ // by a proper _syncthreads.
+ const bool output_smem = out_tv->getMemoryType() == MemoryType::Shared;
+ for (auto id : iter_domains) {
+ if (!id->isBroadcast()) {
+ continue;
+ }
+ if (id->isBlockDim() || (!output_smem && id->isThreadDim())) {
+ parallel_broadcast.set(id->getParallelType(), true);
}
- return expr;
}
+
+ return parallel_broadcast & out_pred;
}
} // namespace ir_utils
namespace loop_utils {
-// TODO: Clean this up, Naoya added a mechanism we should be able to reuse.
std::pair<kir::ForLoop*, int64_t> getAllocPoint(
- const TensorView* tv,
- const std::vector<kir::ForLoop*>& loops,
- const std::unordered_map<IterDomain*, IterDomain*>& id_map,
- bool use_id_map) {
- const auto gpu_lower = GpuLower::current();
-
+ TensorView* tv,
+ const std::vector<kir::ForLoop*>& loops) {
// If in global memory, it can be all the way outside the loops.
if (tv->getMemoryType() == MemoryType::Global) {
return {nullptr, 0};
kir::ForLoop* alloc_loop = nullptr;
auto loops_it = loops.begin();
+
// Look at each axis individually in out's domain
- for (const auto tv_i : c10::irange((int64_t)tv->getComputeAtPosition())) {
+ for (const auto tv_i : c10::irange((int64_t)tv->getThisComputeAtAxis())) {
// Grab the axis ID
- auto local_id = tv->axis(tv_i);
- if (use_id_map) {
- auto id_it = id_map.find(local_id);
- if (id_it != id_map.end()) {
- local_id = id_it->second;
- }
- }
+ auto ca_id = tv->getComputeAtAxis(tv_i).first;
+ auto kir_ca_id = GpuLower::lowerValue(ca_id)->as<kir::IterDomain>();
- if (gpu_lower->trivialReductionInfo().isDerivedFromRoot(local_id)) {
- continue;
- }
-
- auto lowered_local_id =
- gpu_lower->lowerValue(local_id)->as<kir::IterDomain>();
- loops_it = std::find_if(
- loops_it, loops.end(), [&lowered_local_id](const auto& loop) {
- return GpuLower::current()->caLoopMap().areMapped(
- lowered_local_id, loop->iter_domain()) ||
- loop->iter_domain()->parallelType() == ParallelType::Unroll;
+ loops_it =
+ std::find_if(loops_it, loops.end(), [&kir_ca_id](const auto& loop) {
+ return kir_ca_id == loop->iter_domain() ||
+ loop->iter_domain()->getParallelType() == ParallelType::Unroll;
});
+ if (loops_it == loops.end()) {
+ for (auto loop : loops) {
+ std::cout << kir::toString(loop->iter_domain()) << " ";
+ }
+ std::cout << std::endl;
+ }
TORCH_INTERNAL_ASSERT(
loops_it != loops.end(),
"Could not find all required axes for indexing when trying to index into ",
tv);
- if ((*loops_it)->iter_domain()->parallelType() == ParallelType::Unroll) {
+
+ if (kir_ca_id->getParallelType() == ParallelType::Unroll) {
return {alloc_loop, tv_i};
}
++loops_it;
}
- return {alloc_loop, (int64_t)tv->getComputeAtPosition()};
+ return {alloc_loop, (int64_t)tv->getThisComputeAtAxis()};
}
-std::pair<kir::ForLoop*, int64_t> getAllocPoint(
- const TensorView* tv,
- const std::vector<kir::ForLoop*>& loops) {
- return getAllocPoint(tv, loops, {}, false);
+std::unordered_map<IterDomain*, IterDomain*> p2cRootMap(
+ const std::vector<Expr*>& exprs) {
+ std::unordered_map<IterDomain*, IterDomain*> p2c_root_map;
+
+ for (auto expr : exprs) {
+ auto out_tv = ir_utils::getTVOutput(expr);
+ for (auto inp : expr->inputs()) {
+ if (inp->getValType().value() != ValType::TensorView) {
+ continue;
+ }
+
+ auto root_p2c = TensorDomain::mapRootPtoC(
+ inp->as<TensorView>()->domain(), out_tv->domain());
+ for (auto entry : root_p2c) {
+ auto p_id = entry.first;
+ auto c_id = entry.second;
+ // Careful we don't allow circular references
+ if (p_id != c_id) {
+ p2c_root_map[p_id] = c_id;
+ }
+ }
+ }
+ }
+
+ return p2c_root_map;
+}
+
+IterDomain* getTermIDInMap(
+ IterDomain* root_id,
+ std::unordered_map<IterDomain*, IterDomain*> p2c_root_map) {
+ auto entry = root_id;
+ while (p2c_root_map.find(entry) != p2c_root_map.end()) {
+ entry = p2c_root_map.at(entry);
+ }
+ return entry;
}
} // namespace loop_utils
-
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h>
#include <bitset>
#include <map>
class ThreadPredicateMap;
-using IterDomainMap = std::unordered_map<kir::IterDomain*, kir::IterDomain*>;
-
namespace scope_utils {
-//! Returns the list of nesting loops starting at `scope`
-// Primarily used in indexing, maybe could be moved there
-std::vector<kir::ForLoop*> getLoops(kir::Expr* scope);
+// Grab the ForLoop starting from scope working out
+std::vector<kir::ForLoop*> getLoops(Expr* scope);
+
+// Track how far our for loop scope is
+unsigned int computeForDepth(Expr* scope);
+
+// Push back an expr to scope
+void pushBack(Expr* scope, Expr* expr);
+
+// Insert expr in scope before ref
+void insertBefore(Expr* scope, Expr* ref, Expr* expr);
+
+// Returns if expr is in scope, does not check nested scopes
+bool exprInScope(Expr* scope, Expr* expr);
+
+// Return the parent of the active scope
+Expr* getParent(Expr* scope);
+
+// Open a new inner most for loop
+kir::ForLoop* openFor(Expr* scope, IterDomain*);
-//! Insert expr in scope before ref
-//!
-//! \warning for kir::IfThenElse we implicitly insert in the "then" branch!
-//!
-void insertBefore(kir::Expr* scope, kir::Expr* ref, kir::Expr* expr);
+// Provide a new for loop matching the one provided, sets parent_scope as
+// parent_scope, but does not insert into parent scope.
+kir::ForLoop* cloneLoopNest(kir::ForLoop* to_clone, Expr* parent_scope);
+
+// Run through a scope and replace expressions inside with replacement_map
+void replaceExprsInScope(
+ Expr* scope,
+ std::unordered_map<Expr*, Expr*> replacement_map);
+
+Expr* firstInnerMostScope(Expr* scope);
} // namespace scope_utils
const std::vector<IterDomain*>& of,
const std::vector<IterDomain*>& order);
-bool isTV(const Val* const);
+std::vector<Val*> indices(std::vector<kir::ForLoop*>);
-TORCH_CUDA_CU_API bool isTVOp(const Expr*);
+bool isTV(const Val* const);
-bool isTVOp(const kir::Expr* expr);
+bool isTVOp(const Expr*);
TensorView* getTVOutput(const Expr*);
-kir::TensorView* getTVOutput(const kir::Expr*);
bool isScalarOp(const Expr*);
-// TODO(kir): remove
+void ASSERT_EXPR(Statement*);
+
+bool isScope(const Expr*);
+
Expr* asExpr(Statement*);
-// TODO(kir): Remove in favor of ->as<TensorView>()
+// TODO: Remove in favor of ->as<TensorView>()
TensorView* asTV(Val*);
-bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map);
-bool hasBlockSync(const kir::Expr* expr, const ThreadPredicateMap& pred_map);
-
-// expr_replacement_map maps an expression to its replacement.
-//
-// The applyReplacement function serves two purposes.
-//
-// 1. If expr is found in expr_replacement_map, return the value for expr key.
-// Otherwise, return the original expression.
-//
-// 2. If a replacement is not found and the expression is a ForLoop or an
-// IfThenElse, it modifies the expressions in its scope by running the
-// handle_scope function
-//
-// The handle_scope function iterates over the expressions in the scope.
-// For each expression, it updates the expression the value returned by
-// applyReplacement.
-kir::Expr* applyReplacements(
- const std::unordered_map<kir::Expr*, kir::Expr*>& expr_replacement_map,
- kir::Expr* expr);
+// TODO: Remove in favor of ->as<ForLoop>()
+kir::ForLoop* asForLoop(Statement*);
+
+// TODO: Remove in favor of ->as<TensorView>()
+const TensorView* asConstTV(const Val*);
+
+bool isUnrolledFor(const Expr*);
+
+// Represents mapping to bool from BIDx, BIDy, BIDz, TIDx, TIDy and TIDz.
+class ParallelTypeBitmap {
+ public:
+ static constexpr int num_p_type = 6;
+ ParallelTypeBitmap() = default;
+ bool get(ParallelType pt) const;
+ bool set(ParallelType pt, bool);
+ ParallelTypeBitmap operator&=(const ParallelTypeBitmap& other);
+ ParallelTypeBitmap operator|=(const ParallelTypeBitmap& other);
+ ParallelTypeBitmap operator^=(const ParallelTypeBitmap& other);
+ ParallelTypeBitmap operator~() const;
+ bool none() const;
+ bool any() const;
+ bool all() const;
+ bool operator[](size_t pos) const;
+ std::map<ParallelType, bool> getMap() const;
+
+ private:
+ ParallelTypeBitmap(const std::bitset<num_p_type>& bs) : bitset_(bs) {}
+ std::bitset<num_p_type> bitset_;
+ const static std::unordered_map<ParallelType, int, TypeHash> pt_to_offset_;
+ const static std::unordered_map<int, ParallelType> offset_to_pt_;
+};
+
+ParallelTypeBitmap operator&(
+ const ParallelTypeBitmap& lhs,
+ const ParallelTypeBitmap& rhs);
+
+ParallelTypeBitmap operator|(
+ const ParallelTypeBitmap& lhs,
+ const ParallelTypeBitmap& rhs);
+
+ParallelTypeBitmap operator^(
+ const ParallelTypeBitmap& lhs,
+ const ParallelTypeBitmap& rhs);
+
+// Returns a ParallelTypeBitmap representing which domain needs
+// blockBroadcast.
+// Even when a domain is broadcast and parallelized, it does not need
+// blockBroadcast unless it is predicated.
+ParallelTypeBitmap getParallelBroadcastDomains(
+ const Val* bop_out,
+ const ThreadPredicateMap& preds);
} // namespace ir_utils
// outside the first loop in loops. Also find out which index in tv the
// first dimension that needs to be allocated is. Meaning we need to allocate
// that local axis and above.
-// TODO: Only remaining use of this is in index compute, remove use from there,
-// or refactor and use in lower_allocation
-std::pair<kir::ForLoop*, int64_t> getAllocPoint(
- const TensorView* tv,
- const std::vector<kir::ForLoop*>& loops,
- const std::unordered_map<IterDomain*, IterDomain*>& id_map,
- bool use_id_map);
-
std::pair<kir::ForLoop*, int64_t> getAllocPoint(
- const TensorView* tv,
+ TensorView* tv,
const std::vector<kir::ForLoop*>& loops);
+
+// Go through exprs mapping root domains from producer to consumer. Provides a
+// ground truth for how root domains map through our expressions. Needed for
+// unrolling.
+std::unordered_map<IterDomain*, IterDomain*> p2cRootMap(
+ const std::vector<Expr*>& exprs);
+
+// Given a root IterationDomain and a p2c_root_map find the root IterationDomain
+// furthest down in the sorted expr list it maps to. Needed for unrolling.
+IterDomain* getTermIDInMap(
+ IterDomain* root_id,
+ std::unordered_map<IterDomain*, IterDomain*> p2c_root_map);
+
} // namespace loop_utils
} // namespace cuda
} // namespace fuser
+#include <c10/util/irange.h>
+
#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
#include <torch/csrc/jit/codegen/cuda/type.h>
namespace fuser {
namespace cuda {
-namespace {
-
-//! A parallel type validation pass to make sure all the outputs of
-//! welford ops are parallelized the same way. Will infer and modify serial
-//! parallel types if other output/s are parallelized, so that
-//! user wouldn't have to specify the same parallelization
-//! 3 times. Will throw if conflicts are detected, i.e.
-//! TIDx vs BIDx etc.
-class ValidateParallelType : public IterVisitor {
- public:
- static void validate(Fusion* fusion) {
- ValidateParallelType VPT;
- VPT.traverse(fusion);
- }
-
- private:
- using IterVisitor::handle;
- // Parallelize id1 and id0 consistently if one is serial and the other isn't
- void convertIterDomain(IterDomain* id0, IterDomain* id1) {
- const auto ptype0 = id0->getParallelType();
- const auto ptype1 = id1->getParallelType();
-
- if (ptype0 == ParallelType::Vectorize ||
- ptype1 == ParallelType::Vectorize) {
- auto other_type = ptype0 == ParallelType::Vectorize ? ptype1 : ptype0;
- TORCH_INTERNAL_ASSERT(
- other_type == ParallelType::Vectorize ||
- (!isParallelTypeThreadDim(other_type) &&
- !isParallelTypeBlockDim(other_type)),
- "Vectorize type was parallelized inconsistently in. ",
- "Detected during promoting parallel types.");
- return;
- }
-
- if (ptype0 != ptype1) {
- TORCH_CHECK(
- ptype0 == ParallelType::Serial || ptype1 == ParallelType::Serial,
- "Error promoting parallel types");
- if (ptype0 == ParallelType::Serial) {
- id0->parallelize(ptype1);
- }
- if (ptype1 == ParallelType::Serial) {
- id1->parallelize(ptype0);
- }
- }
- }
-
- void handle(WelfordOp* wop) override {
- auto out_avg = wop->outAvg()->as<TensorView>();
- auto out_var = wop->outVar()->as<TensorView>();
- auto out_n = wop->outN()->as<TensorView>();
- TORCH_INTERNAL_ASSERT(out_avg->nDims() == out_var->nDims());
- TORCH_INTERNAL_ASSERT(out_avg->nDims() == out_n->nDims());
- for (size_t i = 0; i < out_avg->nDims(); i++) {
- // TODO: can be cleaner.
- convertIterDomain(out_avg->axis(i), out_var->axis(i));
- convertIterDomain(out_avg->axis(i), out_n->axis(i));
- convertIterDomain(out_n->axis(i), out_var->axis(i));
- }
- }
-};
-
-// Make sure all IterDomains are only used for a unique
-// TensorView. Several mappings from IterDomains are
-// created during lowering, which relies on the unique usage of
-// IterDomains.
-void validateIterDomainUsage(Fusion* fusion) {
- FUSER_PERF_SCOPE("GpuLower::Lower::validateIterDomainUse");
- FusionGuard fg(fusion);
-
- auto used_vals = fusion->usedMathVals();
- std::unordered_map<IterDomain*, TensorView*> domain_use_map;
-
- for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
- std::unordered_set<Val*> root_domains;
- std::copy(
- tv->getRootDomain().begin(),
- tv->getRootDomain().end(),
- std::inserter(root_domains, root_domains.begin()));
-
- std::vector<Val*> leaf_domains;
- std::copy(
- tv->domain()->domain().begin(),
- tv->domain()->domain().end(),
- std::back_inserter(leaf_domains));
-
- auto all_domain_vals =
- DependencyCheck::getAllValsBetween(root_domains, leaf_domains);
-
- for (auto id : ir_utils::filterByType<IterDomain>(all_domain_vals)) {
- auto it = domain_use_map.find(id);
- TORCH_INTERNAL_ASSERT(
- it == domain_use_map.end(),
- "Multiple use of ",
- id,
- " detected.",
- " Used in both TV",
- tv->name(),
- " and TV",
- it->second->name());
- domain_use_map.insert({id, tv});
- }
- }
-}
-
-} // namespace
-
void validateIr(Fusion* fusion) {
- FUSER_PERF_SCOPE("GpuLower::Lower::validateIr");
-
- FusionGuard fg(fusion);
-
- fusion->validateInputs();
-
- // Convert all input broadcast iterdomains to strided
- for (auto tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
- for (auto id : tv->getMaybeRFactorDomain()) {
- if (id->isBroadcast()) {
- id->toStridedBroadcast();
- }
- }
- }
-
- // Convert all output broadcast iterdomains to strided
- for (auto tv : ir_utils::filterByType<TensorView>(fusion->outputs())) {
- for (auto id : tv->getMaybeRFactorDomain()) {
- if (id->isBroadcast()) {
- id->toStridedBroadcast();
- }
- }
- }
-
- // Validate Parallelization
- ValidateParallelType::validate(fusion);
-
- validateIterDomainUsage(fusion);
-}
-
-namespace {
-
-// Check contiguity for all root domains associated with Misaligned Vectorize
-// ParallelType
-void checkContiguity(
- const std::unordered_set<IterDomain*>& domains,
- TensorView* tv) {
- TORCH_INTERNAL_ASSERT(tv->getMemoryType() == MemoryType::Global);
-
- for (size_t idx = 0; idx < tv->getRootDomain().size(); ++idx) {
- auto root = tv->getRootDomain()[idx];
- if (domains.find(root) != domains.end()) {
- TORCH_INTERNAL_ASSERT(
- !root->isBroadcast(),
- "Misaligned vectorization prohibits merging broadcast domains.",
- "Issue found in, ",
- tv);
- TORCH_INTERNAL_ASSERT(
- tv->domain()->contiguity()[idx],
- "Cannot merge non-contiguous root domains with misaligned vectorization.",
- "Issue found in, ",
- tv);
- }
- }
-}
-
-// Check all root iter domains in consumer that are present in domain, making
-// sure they're contiguous. Map these domains to producer and make sure they are
-// also contiguous in producer. Producer-consumer relationship is assumed to be
-// through a set operation.
-void checkContiguity(
- const std::unordered_set<IterDomain*>& domains,
- TensorView* consumer,
- TensorView* producer) {
- // This seems not quite right, shouldn't we be able to reverse this?
- TORCH_INTERNAL_ASSERT(consumer->getMemoryType() == MemoryType::Local);
- TORCH_INTERNAL_ASSERT(producer->getMemoryType() == MemoryType::Global);
-
- auto root_c2p =
- PairwiseRootDomainMap(producer, consumer)
- .mapConsumerToProducer(consumer->domain(), producer->domain());
-
- std::unordered_map<IterDomain*, bool> producer_domain_contiguity;
- for (size_t idx = 0; idx < producer->getRootDomain().size(); ++idx) {
- auto root = producer->getRootDomain()[idx];
- auto contiguity = producer->domain()->contiguity()[idx];
- producer_domain_contiguity.insert({root, contiguity});
- }
-
- for (auto consumer_root : consumer->getRootDomain()) {
- if (domains.find(consumer_root) != domains.end()) {
- auto producer_root = root_c2p[consumer_root];
- TORCH_INTERNAL_ASSERT(
- producer_domain_contiguity.find(producer_root) !=
- producer_domain_contiguity.end());
-
- TORCH_INTERNAL_ASSERT(
- !consumer_root->isBroadcast() || !producer_root->isBroadcast(),
- "Misaligned vectorization prohibits merging broadcast domains.",
- "Issue found in, ",
- consumer);
-
- TORCH_INTERNAL_ASSERT(root_c2p.find(consumer_root) != root_c2p.end());
-
- TORCH_INTERNAL_ASSERT(
- producer_domain_contiguity[producer_root],
- "Cannot merge non-contiguous root domains with misaligned vectorization.",
- "Issue found in, ",
- consumer);
- }
- }
-}
-
-class VectorizeValidator : public OptInDispatch {
- private:
- // Initially, vectorized_id is the IterDomain with Vectorize ParallelType
- // After processing all merge and split operations,
- // vectorized_id is the corresponding root domain
- VectorizeValidator(IterDomain* vectorized_id)
- : vectorized_id_(vectorized_id) {}
-
- using OptInDispatch::handle;
-
- void handle(Split* s) final {
- if (s->outer() == vectorized_id_) {
- is_valid = false;
- } else if (s->inner() == vectorized_id_) {
- vectorized_id_ = s->in();
- }
- domains_.insert(s->outer());
- domains_.insert(s->inner());
- }
+ FUSER_PERF_SCOPE("validateIr");
- void handle(Merge* m) final {
- if (m->out() == vectorized_id_) {
- if (m->inner()->isBroadcast() && !m->outer()->isBroadcast()) {
- vectorized_id_ = m->outer();
- } else {
- vectorized_id_ = m->inner();
- }
- }
- domains_.insert(m->outer());
- domains_.insert(m->inner());
- }
-
- private:
- std::unordered_set<IterDomain*> domains_;
- IterDomain* vectorized_id_ = nullptr;
- bool is_valid = true;
-
- public:
- static void validate(TensorView* tv) {
- // Make sure there's only one vectorized ID
- IterDomain* v_id = nullptr;
- bool misaligned_vectorize = false;
- for (auto id : tv->domain()->domain()) {
- if (id->getParallelType() == ParallelType::Vectorize ||
- id->getParallelType() == ParallelType::MisalignedVectorize) {
- TORCH_INTERNAL_ASSERT(
- v_id == nullptr,
- "Found two vectorized domains in ",
- tv,
- " only one is allowed.");
- v_id = id;
- misaligned_vectorize =
- id->getParallelType() == ParallelType::MisalignedVectorize;
- }
- }
-
- // If no vectorized id's found simply return;
- if (v_id == nullptr) {
- return;
- }
-
- auto fusion = FusionGuard::getCurFusion();
-
- TORCH_CHECK(
- v_id->extent()->isConstScalar(),
- "Vectorizing a domain requires a constant size.");
-
- ExpressionEvaluator const_expr_eval(fusion);
-
- auto vector_size_optional = const_expr_eval.evaluate(v_id->extent());
-
- TORCH_CHECK(
- vector_size_optional.has_value(),
- "Could not evaluate constant value bound to vectorized dim.");
-
- auto vector_size = ((int64_t)dataTypeSize(tv->getDataType().value())) *
- vector_size_optional.value();
-
- // Allow half2, float2, float4 and same sized vtypes.
- std::array<int64_t, 4> allowed_vector_sizes = {2, 4, 8, 16}; // NOLINT
-
- TORCH_CHECK(
- std::find(
- allowed_vector_sizes.begin(),
- allowed_vector_sizes.end(),
- vector_size) != allowed_vector_sizes.end(),
- "Tried to vectorize a dim resulting in a word size of ",
- vector_size,
- " however, vector sizes only upto and including 16 bytes are supported.");
-
- auto replay_exprs = ExprSort::getExprs(fusion, {v_id});
-
- VectorizeValidator validator(v_id);
-
- for (auto expr_it = replay_exprs.rbegin(); expr_it != replay_exprs.rend();
- ++expr_it) {
- auto expr = *expr_it;
- validator.handle(expr);
- }
-
- TORCH_CHECK(
- validator.is_valid,
- "Invalid vectorized pattern found, vectorization iter domains must be descendants of inner-most dimension.",
- "Issue found in, ",
- tv,
- "\n");
-
- if (misaligned_vectorize) {
- if (tv->getMemoryType() == MemoryType::Global) {
- checkContiguity(validator.domains_, tv);
- } else if (
- tv->definition()->getExprType() == ExprType::UnaryOp &&
- tv->definition()->as<UnaryOp>()->getUnaryOpType() ==
- UnaryOpType::Set) {
- auto input = tv->definition()->input(0);
- TORCH_INTERNAL_ASSERT(input->isA<TensorView>());
- auto input_tv = input->as<TensorView>();
- checkContiguity(validator.domains_, tv, input_tv);
- }
- }
-
- TORCH_INTERNAL_ASSERT(validator.vectorized_id_ != nullptr);
-
- // TODO: Contiguity is based on root domain not rfactor. Seems this
- // generally doesn't cause problems, though contiguity should be on rfactor
- // domain as that's the domain we index on.
- IterDomain* last_root_dim = nullptr;
- int last_root_dim_pos = -1;
- for (size_t i = tv->getRootDomain().size(); i > 0; i--) {
- auto r_id = tv->getRootDomain()[i - 1];
- if (r_id->isReduction() || r_id->isBroadcast()) {
- continue;
- }
- last_root_dim = r_id;
- last_root_dim_pos = (int)i - 1;
- break;
- }
-
- if (last_root_dim == nullptr) {
- // Should never get here, but that would mean there are no concrete dims,
- // so we should be fine.
- return;
- }
-
- TORCH_CHECK(
- last_root_dim == validator.vectorized_id_ &&
- tv->domain()->contiguity()[last_root_dim_pos],
- "Vectorized dim has to be from a contiguous inner most position: ",
- tv,
- "\n");
- }
-};
-
-} // namespace
-
-void validateVectorize(Fusion* fusion) {
- FUSER_PERF_SCOPE("GpuLower::Lower::validateVectorize");
FusionGuard fg(fusion);
- auto used_vals = fusion->usedMathVals();
+ auto used_vals = DependencyCheck::getAllValsBetween(
+ {fusion->outputs().begin(), fusion->outputs().end()}, fusion->inputs());
std::unordered_set<TensorView*> used_tvs;
- for (auto val : used_vals) {
+ for (const auto& val : used_vals) {
if (ir_utils::isTV(val)) {
used_tvs.emplace(val->as<TensorView>());
}
}
- for (auto tv : used_tvs) {
- bool has_vectorize_dim = false;
- bool has_misaligned_vectorize_dim = false;
-
- for (size_t i = 0; i < tv->nDims(); i++) {
- IterDomain* id = tv->axis(i);
- IterDomain* concrete_id =
- GpuLower::current()->caParallelMap().getConcreteMappedID(id);
+ fusion->validateInputs();
- auto ptype = concrete_id->getParallelType();
+ for (const auto& tv : used_tvs) {
+ for (const auto i : c10::irange(tv->nDims())) {
+ IterDomain* id = tv->getComputeAtAxis(i).first;
- if (ptype == ParallelType::Vectorize) {
- // If we want to do this check up front we would have to do 2 things:
- // (1) Check that the tensor view with vectorize being set on it is
- // getting set outside the local compute at position
- // (2) Check any producers of the tensor view with vectorize being set
- // on it to make sure their compute at position isn't to the right of
- // the vectorize dim.
- TORCH_INTERNAL_ASSERT(
- i >= tv->getComputeAtPosition(),
- "IterDomains to the left of the compute at point cannot be vectorized: ",
+ if (id->isBlockDim()) {
+ TORCH_CHECK(
+ !id->isBroadcast(),
+ "Parallelization across blocks on broadcast axes is not supported, but found on, ",
tv,
- "\n");
- has_vectorize_dim = true;
- }
-
- if (concrete_id->getParallelType() == ParallelType::MisalignedVectorize) {
- TORCH_INTERNAL_ASSERT(
- !tv->hasComputeAt() ||
- tv->getComputeAtPosition() == tv->nDims() - 1,
- "Only allow misaligned vectorization in the -2 computeAt position.");
- TORCH_INTERNAL_ASSERT(
- tv->getMemoryType() == MemoryType::Local ||
- tv->getMemoryType() == MemoryType::Global,
- "Only allow misaligned vectorization between global and local memory.");
- has_misaligned_vectorize_dim = true;
+ ".");
}
- }
- if (has_vectorize_dim) {
- TORCH_INTERNAL_ASSERT(
- tv->definition() == nullptr ||
- (tv->definition()->isA<UnaryOp>() &&
- tv->definition()->as<UnaryOp>()->getUnaryOpType() ==
- UnaryOpType::Set),
- "Vectorized accesses cannot be inline with computation, they are only supported with a Set operation.",
- "TensorView: ",
- tv);
- }
- if (has_vectorize_dim || has_misaligned_vectorize_dim) {
- VectorizeValidator::validate(tv);
- }
- }
-}
-
-namespace {
-
-//! Return true if axis is derived from a root axis that is an input
-//! to a CA leaf axis.
-bool derivedFromRootCAAxes(TensorView* tv, IterDomain* axis) {
- std::vector<IterDomain*> ca_axes(
- tv->domain()->domain().begin(),
- tv->domain()->domain().begin() + tv->getComputeAtPosition());
-
- auto ca_root_vals = IterVisitor::getInputsTo(
- std::vector<Val*>(ca_axes.begin(), ca_axes.end()));
-
- auto root_vals = IterVisitor::getInputsTo({axis});
-
- return std::any_of(
- root_vals.begin(), root_vals.end(), [&ca_root_vals](auto root) {
- return std::find(ca_root_vals.begin(), ca_root_vals.end(), root) !=
- ca_root_vals.end();
- });
-}
-
-} // namespace
-
-void validateParallelize(Fusion* fusion) {
- FUSER_PERF_SCOPE("GpuLower::Lower::validateParallelize");
- FusionGuard fg(fusion);
-
- const auto& par_map = GpuLower::current()->caParallelMap();
- const auto& loop_map = GpuLower::current()->caLoopMap();
- const auto& index_map = GpuLower::current()->caIndexMap();
- const auto& pred_map = GpuLower::current()->threadPredMap();
-
- auto exprs = ExprSort::getExprs(fusion);
-
- for (auto expr : exprs) {
- if (!ir_utils::isTVOp(expr)) {
- continue;
- }
- for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) {
- // Parallelization on input tensors have no effect.
- if (producer->isFusionInput()) {
- continue;
- }
- const auto parallel_bcast_doms =
- pred_map.getParallelBroadcastDomains(producer);
- ParallelTypeBitmap pt_map;
- for (size_t i = 0; i < producer->nDims(); ++i) {
- // If a producer axis is threaded, either with threadIdx or
- // blockIdx, there must be a mapped consumer axis with the
- // same ParallelType. An exception is when the producer is
- // allocated on shared memory and its parallelized with
- // threadIdx. In that case, there is no parallelization
- // constraint on the consumer as syncthreads will be inserted
- // when necessary.
- auto producer_axis = producer->axis(i);
- auto producer_ptype =
- par_map.getConcreteMappedID(producer_axis)->getParallelType();
- if (!isParallelTypeThread(producer_ptype)) {
- continue;
- }
- // Each ParallelType can be used only once.
- TORCH_INTERNAL_ASSERT(
- !pt_map.get(producer_ptype),
- "Multiple use of ",
- producer_ptype,
- " in tensor t",
- producer->name(),
- ": ",
- producer);
- pt_map.set(producer_ptype, true);
- // When the producer axis is a broadcast, it is not really
- // parallelized unless thread-predicated
- if (producer_axis->isBroadcast() && parallel_bcast_doms.none()) {
- continue;
- }
- // No constraint on the consumer tensor when the producer
- // axis is parallelized with threadIdx and allocates on
- // shared memory
- if (isParallelTypeThreadDim(producer_ptype) &&
- producer->getMemoryType() == MemoryType::Shared) {
- continue;
- }
- // There should be also nothing to validate when the producer
- // axis is reduction.
- if (producer_axis->isReduction()) {
- continue;
- }
- // There must be a consumer axis that uses the same indexing
- // with the same parallel type as the producer axis. The index
- // map is used to to find such an axis. In addition, even when
- // no mapped axis is found in the index map, but when an
- // mapped axis exists in the loop map, the producer and
- // consumer axes may still use the same indexing. That only
- // happens when the producer is derived from a root axis that
- // is an input to any leaf CA axes. In such a case, the axis
- // in the reference tensor that maps to
- // the producer axis is created based on the consumer, so both
- // the producer and consumer axes should have the same
- // indexing. See issue #995 as well as the
- // FusionValidateParallelize6 test for a concrete example.
- for (auto consumer :
- ir_utils::filterByType<TensorView>(expr->outputs())) {
- auto it = std::find_if(
- consumer->domain()->domain().begin(),
- consumer->domain()->domain().end(),
- [&](IterDomain* consumer_axis) {
- return index_map.areMapped(producer_axis, consumer_axis) ||
- (loop_map.areMapped(producer_axis, consumer_axis) &&
- derivedFromRootCAAxes(producer, producer_axis));
- });
- TORCH_INTERNAL_ASSERT(
- it != consumer->domain()->domain().end(),
- "Inconsistent parallelization found between TV",
- producer->name(),
- " (",
- producer,
- ") and TV",
- consumer->name(),
- "(",
- consumer,
- "). ",
- "TV",
- consumer->name(),
- " does not have a matching axis for parallelized producer axis, ",
- producer_axis,
- ". CA Map: ",
- loop_map.toString());
- auto consumer_axis = *it;
- auto consumer_ptype =
- par_map.getConcreteMappedID(consumer_axis)->getParallelType();
- TORCH_INTERNAL_ASSERT(
- producer_ptype == consumer_ptype,
- "Inconsistent parallelization found between TV",
- producer->name(),
- " (",
- producer,
- ") and TV",
- consumer->name(),
- "(",
- consumer,
- "). "
- "Producer axis, ",
- producer_axis,
- " is parallelized with ",
- stringifyThread(producer_ptype),
- ", but the parallel type of its matching consumer axis, ",
- consumer_axis,
- " is ",
- stringifyThread(consumer_ptype),
- ".");
+ if (tv->hasBroadcast() && tv->getMemoryType() != MemoryType::Global) {
+ auto td = tv->domain()->domain();
+ auto ca_inputs = ir_utils::iterDomainInputsOf(
+ {td.begin(), td.begin() + tv->getThisComputeAtAxis()});
+ auto non_ca_inputs = ir_utils::iterDomainInputsOf(
+ {td.begin() + tv->getThisComputeAtAxis(), td.end()});
+
+ std::unordered_set<IterDomain*> ca_inputs_set(
+ ca_inputs.begin(), ca_inputs.end());
+ std::unordered_set<IterDomain*> non_ca_inputs_set(
+ non_ca_inputs.begin(), non_ca_inputs.end());
+
+ for (const auto& id : tv->getRootDomain()) {
+ if (id->isBroadcast()) {
+ // If a broadcast dimension is an input to both an axis within the
+ // computeAt point and outside the compute at point we would have to
+ // look at consumers to figure out what that axis will be
+ // broadcasted to, because we would have to generate everything the
+ // consumer could need on that axis. This could be supported but is
+ // not at this point.
+ TORCH_INTERNAL_ASSERT(
+ !(ca_inputs_set.find(id) != ca_inputs_set.end() &&
+ non_ca_inputs_set.find(id) != non_ca_inputs_set.end()),
+ "Cannot generate a kernel where a root broadcast dimension is input to both IterDomains outside and within the computeAt point.");
+ }
}
}
}
void validateIr(Fusion* fusion);
-void validateVectorize(Fusion* fusion);
-
-//! Validates all tensors are consistently parallelized. Basically,
-//! when a producer axis is threaded, either with threadIdx or
-//! blockIdx, there must be a mapped consumer axis with the
-//! same ParallelType with some exceptions.
-//!
-//! This function assumes Loop and Parallel ComputeAtMaps are already
-//! built as they are used to validate consistency.
-void validateParallelize(Fusion* fusion);
-
} // namespace cuda
} // namespace fuser
} // namespace jit
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
#include <torch/csrc/jit/codegen/cuda/manager.h>
#include <torch/csrc/jit/codegen/cuda/parser.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
+#include <torch/csrc/jit/codegen/cuda/scheduler.h>
#include <torch/csrc/jit/codegen/cuda/shape_inference.h>
-#include <torch/csrc/jit/codegen/cuda/utils.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
// We should not call `EraseShapeInformation(graph);`, graph representation
// does not incorporate static sizes, but just rank of input tensors, which
// is exactly what we wanted.
- auto canonical_graph = Canonicalize(graph, false);
- auto repr = canonical_graph->toString(false);
+ Canonicalize(graph, false);
+ auto repr = graph->toString(false);
// create new graph_cache_ids_ entry if none existed yet;
if (graph_cache_ids_.count(repr) == 0) {
return graph_cache_ids_[repr];
};
- void unregisterCacheId(std::shared_ptr<Graph>& graph) {
- auto canonical_graph = Canonicalize(graph, false);
- auto repr = canonical_graph->toString(false);
-
- // create new graph_cache_ids_ entry if none existed yet;
- if (graph_cache_ids_.count(repr) > 0) {
- int32_t kernel_id = graph_cache_ids_[repr];
- graph_cache_.erase(kernel_id);
- graph_cache_ids_.erase(repr);
- }
- }
-
std::vector<at::Tensor> runFusionNode(
int32_t kernel_id,
const at::ArrayRef<IValue> inputs) {
std::lock_guard<std::mutex> guard(mutex_);
- TORCH_INTERNAL_ASSERT(
- graph_cache_.count(kernel_id) > 0, "graph cache miss at run time");
return graph_cache_[kernel_id]->runGraphWithInputs(inputs);
}
} // namespace
void compileCudaFusionGroup(Node* fusion_node) {
- FUSER_PERF_SCOPE("nvFuser::Manager::compileCudaFusionGroup");
+ FUSER_PERF_SCOPE("compileCudaFusionGroup");
TORCH_CHECK(
fusion_node->kind() == prim::CudaFusionGroup,
// This is not a critical code path, it's OK to do graph copy here;
auto graph = fusion_node->g(attr::Subgraph)->copy();
- auto compile_fusion = [&]() {
- // type propagation is needed, as the protocol only requires scalar type on
- // input tensors.
- // Note that even for Profiling Executor, scalar type could still be
- // missing, especially for output tensor from a given node (as profiling
- // node only insert meta information after itself).
- TypePropagate(graph);
-
- int32_t fusion_cache_id =
- CudaFusionManager::getManager().registerOrGetCacheId(graph);
- fusion_node->i_(attr::cache_id, fusion_cache_id);
- };
-
- if (useFallback()) {
- try {
- compile_fusion();
- } catch (...) {
- TORCH_WARN(
- "FALLBACK path has been taken. This is an indication that codegen"
- "Failed for some reason. To debug try disable codegen fallback path"
- "via setting the env variable"
- "`export PYTORCH_NVFUSER_DISABLE_FALLBACK=1`");
- CudaFusionManager::getManager().unregisterCacheId(graph);
- }
- } else {
- compile_fusion();
- }
+ // type propagation is needed, as the protocol only requires scalar type on
+ // input tensors.
+ // Note that even for Profiling Executor, scalar type could still be missing,
+ // especially for output tensor from a given node (as profiling node only
+ // insert meta information after itself).
+ TypePropagate(graph);
+
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ int32_t fusion_cache_id =
+ CudaFusionManager::getManager().registerOrGetCacheId(graph);
+ fusion_node->i_(attr::cache_id, fusion_cache_id);
}
void runCudaFusionGroup(const Node* fusion_node, Stack& stack) {
- FUSER_PERF_SCOPE("nvFuser::Manager::runCudaFusionGroup");
-
- // Fallback to use if anything goes wrong
- auto take_fallback = [&]() {
- // copying graph here since we are eliminating shape information;
- auto copied_graph = fusion_node->g(attr::Subgraph)->copy();
- EraseShapeInformation(copied_graph);
- InterpreterState{Code(copied_graph, "fallback_cuda_fuser")}.run(stack);
- };
+ FUSER_PERF_SCOPE("runCudaFusionGroup");
- auto run_fusion = [&]() {
- TORCH_CHECK(
- fusion_node->kind() == prim::CudaFusionGroup,
- "prim::CudaFusionGroup expected");
- // TODO: should we support runtime compilation with updated dynamic shape;
- // shape inference would be needed so we can allocate output;
- TORCH_CHECK(
- fusion_node->hasAttribute(attr::cache_id),
- "node prim::CudaFusionGroup has not been compiled yet");
+ TORCH_CHECK(
+ fusion_node->kind() == prim::CudaFusionGroup,
+ "prim::CudaFusionGroup expected");
+ // TODO: should we support runtime compilation with updated dynamic shape;
+ // shape inference would be needed so we can allocate output;
+ TORCH_CHECK(
+ fusion_node->hasAttribute(attr::cache_id),
+ "node prim::CudaFusionGroup has not been compiled yet");
+ int32_t kernel_id = fusion_node->i(attr::cache_id);
- int32_t kernel_id = fusion_node->i(attr::cache_id);
- // Currently we just construct I/O tensors for static graph;
+ // Currently we just construct I/O tensors for static graph;
- const auto nInputs = fusion_node->g(attr::Subgraph)->inputs().size();
+ const auto nInputs = fusion_node->g(attr::Subgraph)->inputs().size();
+ auto execute_lambda = [&]() {
at::ArrayRef<IValue> inputs = last(stack, nInputs);
auto outputs =
std::make_move_iterator(outputs.end()));
};
- if (useFallback()) {
+ const char* disable_fb_env = getenv("PYTORCH_CUDA_FUSER_DISABLE_FALLBACK");
+ int disable_fb_flag = disable_fb_env ? atoi(disable_fb_env) : 0;
+ if (disable_fb_flag) {
+ execute_lambda();
+ } else {
try {
- run_fusion();
+ execute_lambda();
} catch (...) {
TORCH_WARN(
- "FALLBACK path has been taken. This is an indication that codegen"
+ "FALLBACK path is taken. This is an indication that codegen"
"Failed for some reason. To debug try disable codegen fallback path"
"via setting the env variable"
- "`export PYTORCH_NVFUSER_DISABLE_FALLBACK=1`");
- take_fallback();
+ "`export PYTORCH_CUDA_FUSER_DISABLE_FALLBACK=1`");
+ // copying graph here since we are eliminating shape information;
+ auto copied_graph = fusion_node->g(attr::Subgraph)->copy();
+ EraseShapeInformation(copied_graph);
+ InterpreterState{Code(copied_graph, "fallback_cuda_fuser")}.run(stack);
}
- } else {
- run_fusion();
}
}
namespace fuser {
namespace cuda {
+void OptOutMutator::mutate(Fusion* fusion) {
+ std::vector<Expr*> orig_exprs = fusion->exprs();
+
+ /*
+ * We go through all the exprs, in topologically sorted order. We call mutate
+ * on them which could insert nodes, removes nodes, or both. These operations
+ * modify the dag and the Fusion will keep track of what has/hasn't been
+ * changed by the origin dependency tracking that it does. If an operation is
+ * added, and its output node is a val which previously was the output of
+ * another expresion, that older expresion will be removed as we can only
+ * assign a Val once due to our SSA restriction. Therefore we don't need to
+ * manually track what expressions stayed constant or were changed.
+ */
+
+ for (Statement* stmt : orig_exprs)
+ mutate(stmt);
+}
+
// MUTATE FUNCTIONS FOR VALS
Statement* OptOutMutator::mutate(IterDomain* id) {
Statement* OptOutMutator::mutate(TensorView* tv) {
TensorDomain* td = mutateAsVal(tv->domain())->as<TensorDomain>();
- if (!tv->domain()->sameAs(td)) {
+ TensorView* computeAtView = nullptr;
+ if (tv->hasComputeAt())
+ computeAtView = mutateAsVal(tv->getComputeAtView())->as<TensorView>();
+
+ if (!tv->domain()->sameAs(td) ||
+ (tv->hasComputeAt() && !tv->getComputeAtView()->sameAs(computeAtView))) {
TensorView* mutated_tv = new TensorView(td, tv->getDataType().value());
+ if (tv->hasComputeAt()) {
+ mutated_tv->setComputeAt(
+ computeAtView, (int)(tv->getRelativeComputeAtAxis()));
+ }
registerMutation(tv, mutated_tv);
return mutated_tv;
}
return tv;
}
+Statement* OptOutMutator::mutate(kir::TensorIndex* ti) {
+ return ti;
+}
+
Statement* OptOutMutator::mutate(Bool* b) {
return b;
}
-Statement* OptOutMutator::mutate(Double* d) {
- return d;
+Statement* OptOutMutator::mutate(Float* f) {
+ return f;
+}
+
+Statement* OptOutMutator::mutate(Half* h) {
+ return h;
}
Statement* OptOutMutator::mutate(Int* i) {
// MUTATE FUNCTIONS FOR EXPRESSIONS.
+Statement* OptOutMutator::mutate(kir::Allocate* a) {
+ return a;
+}
+
+Statement* OptOutMutator::mutate(kir::Sync* a) {
+ return a;
+}
+
Statement* OptOutMutator::mutate(Split* s) {
IterDomain* ot = mutateAsVal(s->outer())->as<IterDomain>();
IterDomain* inr = mutateAsVal(s->inner())->as<IterDomain>();
return s;
}
FusionGuard::getCurFusion()->removeExpr(s);
- return new Split(ot, inr, in, fact, s->innerSplit());
+ return new Split(ot, inr, in, fact);
}
Statement* OptOutMutator::mutate(Merge* m) {
return new ReductionOp(rop->getReductionOpType(), init, out, in);
}
-namespace {
-__inline__ bool compareOptional(Val* a, Val* b) {
- if (!a || !b) {
- return (!a && !b);
- }
- return a->sameAs(b);
-}
-
-} // namespace
-
-Statement* OptOutMutator::mutate(WelfordOp* wop) {
- Val* out_avg = mutateAsVal(wop->outAvg())->asVal();
- Val* out_var = mutateAsVal(wop->outVar())->asVal();
- Val* out_N = mutateAsVal(wop->outN())->asVal();
-
- Val* in_avg = mutateAsVal(wop->inAvg())->asVal();
- Val* in_var = wop->inVar() ? mutateAsVal(wop->inVar())->asVal() : nullptr;
- Val* in_N = mutateAsVal(wop->inN())->asVal();
-
- Val* init_avg =
- wop->initAvg() ? mutateAsVal(wop->initAvg())->asVal() : nullptr;
- Val* init_var =
- wop->initVar() ? mutateAsVal(wop->initVar())->asVal() : nullptr;
- Val* init_N = mutateAsVal(wop->initN())->asVal();
-
- const bool out_compare = out_avg->sameAs(wop->outAvg()) &&
- out_var->sameAs(wop->outVar()) && out_N->sameAs(wop->outN());
- const bool in_compare = in_avg->sameAs(wop->inAvg()) &&
- compareOptional(in_var, wop->inVar()) && in_N->sameAs(wop->inN());
- const bool init_compare = compareOptional(init_avg, wop->initAvg()) &&
- compareOptional(init_var, wop->initVar()) && init_N->sameAs(wop->initN());
-
- if (out_compare && init_compare && in_compare) {
- return wop;
- } else {
- return new WelfordOp(
- out_avg,
- out_var,
- out_N,
- init_avg,
- init_var,
- init_N,
- in_avg,
- in_var,
- in_N);
- }
+Statement* OptOutMutator::mutate(kir::GridReduction* gr) {
+ return gr;
}
Statement* OptOutMutator::mutate(BroadcastOp* bop) {
return bop;
}
-Statement* OptOutMutator::mutate(TransposeOp* top) {
- return top;
+Statement* OptOutMutator::mutate(kir::ForLoop* fl) {
+ return fl;
}
-Statement* OptOutMutator::mutate(ShiftOp* sop) {
- Val* out = mutateAsVal(sop->out())->asVal();
- Val* in = mutateAsVal(sop->in())->asVal();
-
- if (out->sameAs(sop->out()) && in->sameAs(sop->in()))
- return sop;
- auto offsets = sop->offsets();
- FusionGuard::getCurFusion()->removeExpr(sop);
- return new ShiftOp(out, in, offsets);
-}
-
-Statement* OptOutMutator::mutate(GatherOp* op) {
- Val* out = mutateAsVal(op->out())->asVal();
- Val* in = mutateAsVal(op->in())->asVal();
-
- if (out->sameAs(op->out()) && in->sameAs(op->in()))
- return op;
- auto window_shape = op->windowShape();
- auto pad_width = op->padWidth();
- FusionGuard::getCurFusion()->removeExpr(op);
- return new GatherOp(out, in, window_shape, pad_width);
+Statement* OptOutMutator::mutate(kir::IfThenElse* ite) {
+ return ite;
}
} // namespace cuda
+++ /dev/null
-#pragma once
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/ops/composite.h>
-#include <torch/csrc/jit/codegen/cuda/ops/normalization.h>
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/ops/composite.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-ForwardDropoutResult dropout(TensorView* x, Val* prob) {
- auto p1m = sub(new Double(1.), prob);
- auto zero_check = add(eq(p1m, new Double(0.)), p1m);
- auto scale = div(new Double(1.), zero_check);
- return dropout(x, p1m, scale);
-}
-
-ForwardDropoutResult dropout(TensorView* x, Val* prob, Val* scale) {
- TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
- TORCH_INTERNAL_ASSERT(
- prob != nullptr && prob->getDataType().has_value() &&
- prob->getDataType().value() == DataType::Double,
- "Probability is not a valid Double.");
- TORCH_INTERNAL_ASSERT(
- scale != nullptr && scale->getDataType().has_value() &&
- scale->getDataType().value() == DataType::Double,
- "Scale is not a valid Double.");
-
- auto rand_vals = unaryOp(UnaryOpType::RandLike, x);
- auto mask = lt(rand_vals, prob);
- auto apply_mask = mul(x, mask);
- auto y = mul(apply_mask, scale);
-
- return {y, mask};
-}
-
-TensorView* dropout_backward(TensorView* dy, TensorView* mask, Val* scale) {
- TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
- TORCH_INTERNAL_ASSERT(mask != nullptr, "Mask is invalid");
- TORCH_INTERNAL_ASSERT(
- scale != nullptr && scale->getDataType().has_value() &&
- scale->getDataType().value() == DataType::Double,
- "Scale is not a valid Double.");
-
- auto grad_mask = mul(dy, mask);
- auto dx = mul(grad_mask, scale);
-
- return dx;
-}
-
-Val* softplus(Val* x, Val* beta, Val* threshold) {
- TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
- TORCH_INTERNAL_ASSERT(beta != nullptr, "Beta is invalid.");
- TORCH_INTERNAL_ASSERT(
- threshold != nullptr, "Threshold is not a valid Double.");
-
- auto op_beta = mul(x, beta);
- auto maybe_result = div(
- unaryOp(UnaryOpType::Log1p, unaryOp(UnaryOpType::Exp, op_beta)), beta);
- auto y = where(gt(op_beta, threshold), x, maybe_result);
- return y;
-}
-
-LstmResult lstm(
- TensorView* prev_cell,
- TensorView* in_x,
- TensorView* forget_x,
- TensorView* cell_x,
- TensorView* out_x) {
- TORCH_INTERNAL_ASSERT(
- prev_cell != nullptr, "Previous cell state is invalid.");
- TORCH_INTERNAL_ASSERT(in_x != nullptr, "In-gate input is invalid");
- TORCH_INTERNAL_ASSERT(forget_x != nullptr, "Forget-gate input is invalid");
- TORCH_INTERNAL_ASSERT(cell_x != nullptr, "Cell-gate input is invalid");
- TORCH_INTERNAL_ASSERT(out_x != nullptr, "Out-gate input is invalid");
-
- const auto in_gate = unaryOp(UnaryOpType::Sigmoid, in_x);
- const auto forget_gate = unaryOp(UnaryOpType::Sigmoid, forget_x);
- const auto cell_gate = unaryOp(UnaryOpType::Tanh, cell_x);
- const auto out_gate = unaryOp(UnaryOpType::Sigmoid, out_x);
-
- const auto cell = add(mul(forget_gate, prev_cell), mul(in_gate, cell_gate));
- const auto hidden = mul(out_gate, unaryOp(UnaryOpType::Tanh, cell));
-
- return {cell, hidden};
-}
-
-Val* fast_gelu(Val* x) {
- TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid");
-
- constexpr double kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
- constexpr double kKappa = 0.044715;
-
- auto x_cube = mul(x, mul(x, x));
-
- auto inner_1 = mul(new Double(kKappa), x_cube);
- auto inner_2 = add(x, inner_1);
- auto inner_3 = mul(new Double(kBeta), inner_2);
- auto tanh_inner = unaryOp(UnaryOpType::Tanh, inner_3);
-
- auto out = mul(x, add(new Double(1.), tanh_inner));
- auto y = mul(new Double(0.5), out);
- return y;
-}
-
-Val* fast_gelu_backward(Val* dy, Val* x) {
- TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
- TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid");
-
- constexpr double kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
- constexpr double kKappa = 0.044715;
-
- auto x_sq = mul(x, x);
- auto x_cube = mul(x, x_sq);
-
- auto inner_1 = mul(new Double(kKappa), x_cube);
- auto inner_2 = add(x, inner_1);
- auto inner_3 = mul(new Double(kBeta), inner_2);
- auto tanh_inner = unaryOp(UnaryOpType::Tanh, inner_3);
-
- auto left = mul(new Double(0.5), x);
- auto right = add(new Double(1.), tanh_inner);
-
- auto left_derivative = mul(new Double(0.5), right);
-
- auto tanh_inner_sq = mul(tanh_inner, tanh_inner);
- auto tanh_derivative = sub(new Double(1), tanh_inner_sq);
-
- auto constant_mul_x_sq = mul(new Double(kBeta * 3 * kKappa), x_sq);
- auto inner_derivative = add(new Double(kBeta), constant_mul_x_sq);
- auto right_derivative = mul(left, mul(tanh_derivative, inner_derivative));
-
- auto dx = mul(dy, add(left_derivative, right_derivative));
- return dx;
-}
-
-Val* gelu_backward(Val* dy, Val* x) {
- TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
- TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid");
-
- constexpr double kAlpha = M_2_SQRTPI * M_SQRT1_2 * 0.5;
- const double kHalf = 0.5;
-
- auto cdf_1 = mul(x, new Double(M_SQRT1_2));
- auto cdf_2 = unaryOp(UnaryOpType::Erf, cdf_1);
- auto cdf_3 = add(cdf_2, new Double(1.));
- auto cdf_4 = mul(cdf_3, new Double(kHalf));
-
- auto pdf_1 = mul(x, x);
- auto pdf_2 = mul(pdf_1, new Double(-kHalf));
- auto pdf_3 = unaryOp(UnaryOpType::Exp, pdf_2);
-
- auto out = addcmul(cdf_4, x, pdf_3, new Double(kAlpha));
- auto dx = mul(out, dy);
- return dx;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/type.h>
-
-//
-// The operations defined in this header is intended as user facing functions.
-// The user will provide the necessary input TensorViews and the function will
-// create the correct intermediate nodes and return the output TensorViews.
-//
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-struct ForwardDropoutResult {
- TensorView* output = nullptr;
- TensorView* mask = nullptr;
-};
-
-TORCH_CUDA_CU_API ForwardDropoutResult dropout(TensorView* x, Val* prob);
-
-TORCH_CUDA_CU_API ForwardDropoutResult
-dropout(TensorView* x, Val* prob, Val* scale);
-
-TORCH_CUDA_CU_API TensorView* dropout_backward(
- TensorView* dy,
- TensorView* mask,
- Val* scale);
-
-TORCH_CUDA_CU_API Val* softplus(Val* x, Val* beta, Val* threshold);
-
-struct LstmResult {
- TensorView* cell = nullptr;
- TensorView* hidden = nullptr;
-};
-
-TORCH_CUDA_CU_API LstmResult lstm(
- TensorView* prev_cell,
- TensorView* in_x,
- TensorView* forget_x,
- TensorView* cell_x,
- TensorView* out_x);
-
-TORCH_CUDA_CU_API Val* fast_gelu(Val* x);
-TORCH_CUDA_CU_API Val* fast_gelu_backward(Val* dy, Val* x);
-TORCH_CUDA_CU_API Val* gelu_backward(Val* dy, Val* x);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/ops/normalization.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-TensorView* softmax(TensorView* x, int dim) {
- TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
-
- const int kNumberOfDims =
- TensorDomain::noReductions(x->getRootDomain()).size();
- const int kReductionAxis = (dim < 0) ? dim + kNumberOfDims : dim;
- TORCH_INTERNAL_ASSERT(kReductionAxis >= 0 && kReductionAxis < kNumberOfDims);
-
- std::vector<bool> broadcast_mask(kNumberOfDims, false);
- broadcast_mask[kReductionAxis] = true;
-
- auto max_val = max(x, {kReductionAxis});
- auto bcast_max = broadcast(max_val, broadcast_mask);
- auto x_max_sub = sub(x, bcast_max);
- auto exp = unaryOp(UnaryOpType::Exp, x_max_sub);
- auto sum_exp = sum(exp, {kReductionAxis});
- auto bcast_sum = broadcast(sum_exp, broadcast_mask);
- auto y = div(exp, bcast_sum);
-
- return y;
-}
-
-TensorView* softmax_backward(
- TensorView* dy,
- TensorView* y,
- int dim,
- TensorView* x) {
- TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
- TORCH_INTERNAL_ASSERT(y != nullptr, "Output is invalid.");
- TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
-
- const int kNumberOfDims =
- TensorDomain::noReductions(x->getRootDomain()).size();
- const int kReductionAxis = (dim < 0) ? dim + kNumberOfDims : dim;
- TORCH_INTERNAL_ASSERT(kReductionAxis >= 0 && kReductionAxis < kNumberOfDims);
-
- std::vector<bool> broadcast_mask(kNumberOfDims, false);
- broadcast_mask[kReductionAxis] = true;
-
- auto new_grad = mul(dy, y);
- auto sum_new_grad = sum(new_grad, {kReductionAxis});
- auto bcast_sum = broadcast(sum_new_grad, broadcast_mask);
- auto output_sum_mul = mul(y, bcast_sum);
- auto dx = sub(new_grad, output_sum_mul);
-
- return dx;
-}
-
-ForwardNormResult layer_norm(
- TensorView* x,
- const std::vector<int64_t>& norm_shape,
- TensorView* weight,
- TensorView* bias,
- Val* eps) {
- return layer_norm(x, norm_shape.size(), weight, bias, eps);
-}
-
-ForwardNormResult layer_norm(
- TensorView* x,
- const size_t kNormShapeNumDims,
- TensorView* weight,
- TensorView* bias,
- Val* eps) {
- TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
- TORCH_INTERNAL_ASSERT(
- eps != nullptr && eps->getDataType().has_value() &&
- eps->getDataType().value() == DataType::Double,
- "Epsilon (eps) is not a valid Double.");
-
- // (B, C, H, W, D) tensor
- // norm_shape = [H, W, D]
- // M = outer = product of remaining dimensions = B * C
- // N = reduction = product of norm_shape = H * W * D
- // weight = bias = norm_shape tensor
- const size_t kNumberOfDims =
- TensorDomain::noReductions(x->getRootDomain()).size();
- const size_t kOuterNumDims = kNumberOfDims - kNormShapeNumDims;
-
- std::vector<int> outer_reduction_axes(kOuterNumDims);
- std::vector<bool> outer_broadcast_mask(kNumberOfDims, false);
- for (size_t idx = 0; idx < kOuterNumDims; ++idx) {
- outer_reduction_axes[idx] = idx;
- outer_broadcast_mask[idx] = true;
- }
-
- std::vector<int> inner_reduction_axes(kNormShapeNumDims);
- std::vector<bool> inner_broadcast_mask(kNumberOfDims, false);
- Val* num_features = new Double(1);
- for (size_t idx = 0; idx < kNormShapeNumDims; ++idx) {
- const size_t axis = kNumberOfDims - 1 - idx;
- inner_reduction_axes[idx] = axis;
- inner_broadcast_mask[axis] = true;
- num_features = mul(num_features, x->domain()->domain()[axis]->extent());
- }
-
- // Main algorithm
- auto welford_out = Welford(x, inner_reduction_axes);
- auto mean_bcast = broadcast(welford_out.avg, inner_broadcast_mask);
- auto x_sub_mean = sub(x, mean_bcast);
-
- auto var_sum_bcast = broadcast(welford_out.var_sum, inner_broadcast_mask);
- auto var = div(var_sum_bcast, num_features);
- auto var_eps = add(var, eps);
- auto invstd = unaryOp(UnaryOpType::Rsqrt, var_eps);
-
- auto y = mul(x_sub_mean, invstd);
-
- // Optional: norm * weight
- if (weight != nullptr) {
- auto weight_bcast = broadcast(weight, outer_broadcast_mask);
- y = mul(y, weight_bcast);
- }
-
- // Optional: norm * weight + bias
- if (bias != nullptr) {
- auto bias_bcast = broadcast(bias, outer_broadcast_mask);
- y = add(y, bias_bcast);
- }
-
- return {y, mean_bcast, invstd};
-}
-
-BackwardNormResult layer_norm_backward(
- TensorView* dy,
- TensorView* x,
- const std::vector<int64_t>& norm_shape,
- TensorView* mean,
- TensorView* invstd,
- TensorView* weight,
- TensorView* bias,
- const std::vector<bool>& output_mask) {
- TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
- TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
- TORCH_INTERNAL_ASSERT(mean != nullptr, "Mean is invalid.");
- TORCH_INTERNAL_ASSERT(invstd != nullptr, "Inv std is invalid.");
-
- // (B, C, H, W, D) tensor
- // norm_shape = [H, W, D]
- // M = outer = product of remaining dimensions = B * C
- // N = reduction = product of norm_shape = H * W * D
- // weight = bias = norm_shape tensor
- const size_t kNumberOfDims =
- TensorDomain::noReductions(x->getRootDomain()).size();
- const size_t kNormShapeNumDims = norm_shape.size();
- const size_t kOuterNumDims = kNumberOfDims - kNormShapeNumDims;
-
- std::vector<int> outer_reduction_axes(kOuterNumDims);
- std::vector<bool> outer_broadcast_mask(kNumberOfDims, false);
- for (size_t idx = 0; idx < kOuterNumDims; ++idx) {
- outer_reduction_axes[idx] = idx;
- outer_broadcast_mask[idx] = true;
- }
-
- std::vector<int> inner_reduction_axes(kNormShapeNumDims);
- std::vector<bool> inner_broadcast_mask(kNumberOfDims, false);
- Val* num_features = new Double(1);
- for (size_t idx = 0; idx < kNormShapeNumDims; ++idx) {
- const size_t axis = kNumberOfDims - 1 - idx;
- inner_reduction_axes[idx] = axis;
- inner_broadcast_mask[axis] = true;
- num_features = mul(num_features, x->domain()->domain()[axis]->extent());
- }
-
- auto x_hat = mul(sub(x, mean), invstd);
-
- TensorView* grad_x_hat = nullptr;
- if (weight != nullptr) {
- auto* bcast_weight = broadcast(weight, outer_broadcast_mask);
- grad_x_hat = mul(dy, bcast_weight);
- } else {
- grad_x_hat = dy;
- }
-
- auto a = mul(num_features, grad_x_hat);
-
- auto b = sum(grad_x_hat, inner_reduction_axes);
- auto bcast_b = broadcast(b, inner_broadcast_mask);
-
- auto c1 = mul(grad_x_hat, x_hat);
- auto c2 = sum(c1, inner_reduction_axes);
- auto bcast_c2 = broadcast(c2, inner_broadcast_mask);
- auto c3 = mul(x_hat, bcast_c2);
-
- auto inner = sub(sub(a, bcast_b), c3);
- auto reciprocal_size = unaryOp(UnaryOpType::Reciprocal, num_features);
-
- TensorView* dx = nullptr;
- if (output_mask[0]) {
- dx = mul(mul(reciprocal_size, invstd), inner);
- }
-
- TensorView* dw = nullptr;
- if (output_mask[1] && weight != nullptr) {
- dw = sum(mul(dy, x_hat), outer_reduction_axes);
- }
-
- TensorView* db = nullptr;
- if (output_mask[2] && bias != nullptr) {
- db = sum(dy, outer_reduction_axes);
- }
- return {dx, dw, db};
-}
-
-ForwardNormResult batch_norm(
- TensorView* x,
- TensorView* weight,
- TensorView* bias,
- TensorView* running_mean,
- TensorView* running_var,
- const bool kTraining,
- Val* momentum,
- Val* eps) {
- auto fusion = FusionGuard::getCurFusion();
-
- TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
-
- TORCH_INTERNAL_ASSERT(
- !((running_var == nullptr) ^ (running_mean == nullptr)),
- "running stats should comes in pairs");
-
- TORCH_INTERNAL_ASSERT(
- momentum != nullptr && momentum->getDataType().has_value() &&
- momentum->getDataType().value() == DataType::Double,
- "Momentum is not a valid Double.");
-
- TORCH_INTERNAL_ASSERT(
- eps != nullptr && eps->getDataType().has_value() &&
- eps->getDataType().value() == DataType::Double,
- "Epsilon (eps) is not a valid Double.");
-
- // (B, C, H, W, D) tensor
- // M = outer = channels
- // N = reduction = B * H * W * D
- // weight = bias = (C) tensor
- // const size_t kChannelsDim = 1;
- const size_t kNumberOfDims =
- TensorDomain::noReductions(x->getRootDomain()).size();
-
- std::vector<int> reduction_axes;
- std::vector<bool> broadcast_mask(kNumberOfDims, false);
- Val* num_features = new Double(1);
- for (size_t axis = 0; axis < kNumberOfDims; ++axis) {
- if (axis != 1) {
- reduction_axes.push_back(axis);
- broadcast_mask[axis] = true;
- num_features = mul(num_features, x->domain()->domain()[axis]->extent());
- }
- }
-
- TensorView* y = nullptr;
- TensorView* mean = nullptr;
- TensorView* invstd = nullptr;
- if (kTraining || running_mean == nullptr) {
- // Algorithm
- auto welford_out = Welford(x, reduction_axes);
-
- // updating running mean and running var
- if (running_mean != nullptr && running_var != nullptr) {
- auto rev_momentum = sub(new Double(1.0), momentum);
- auto current_mean_hat = mul(welford_out.avg, momentum);
- auto mean_hat = mul(running_mean, rev_momentum);
- auto new_mean_hat = add(mean_hat, current_mean_hat);
- fusion->addOutput(new_mean_hat);
- fusion->aliasOutputToInput(new_mean_hat, running_mean);
-
- auto num_feature_decrement = sub(num_features, new Int(1));
- auto unbiased_var = div(welford_out.var_sum, num_feature_decrement);
- auto current_var_hat = mul(unbiased_var, momentum);
- auto var_hat = mul(running_var, rev_momentum);
- auto new_var_hat = add(var_hat, current_var_hat);
- fusion->addOutput(new_var_hat);
- fusion->aliasOutputToInput(new_var_hat, running_var);
- }
-
- mean = welford_out.avg;
- auto mean_bcast = broadcast(mean, broadcast_mask);
- auto x_sub_mean = sub(x, mean_bcast);
-
- auto var = div(welford_out.var_sum, num_features);
- auto var_eps = add(var, eps);
- invstd = unaryOp(UnaryOpType::Rsqrt, var_eps);
- auto invstd_bcast = broadcast(invstd, broadcast_mask);
-
- y = mul(x_sub_mean, invstd_bcast);
- } else {
- // This is inference mode with running stats
- auto r_mean_bcasted = broadcast(running_mean, broadcast_mask);
- auto x_sub_mean = sub(x, r_mean_bcasted);
-
- auto var_eps = add(running_var, eps);
- auto unbiased_invstd = unaryOp(UnaryOpType::Rsqrt, var_eps);
- auto invstd_bcast = broadcast(unbiased_invstd, broadcast_mask);
-
- // During inference, mean/invstd output are empty tensors
- mean = TensorViewBuilder().shape({0}).build();
- invstd = TensorViewBuilder().shape({0}).build();
- y = mul(x_sub_mean, invstd_bcast);
- }
-
- // Optional: norm * weight
- if (weight) {
- auto weight_bcast = broadcast(weight, broadcast_mask);
- y = mul(y, weight_bcast);
- }
-
- // Optional: norm * weight + bias
- if (bias) {
- auto bias_bcast = broadcast(bias, broadcast_mask);
- y = add(y, bias_bcast);
- }
- return {y, mean, invstd};
-}
-
-BackwardNormResult batch_norm_backward(
- TensorView* x,
- TensorView* dy,
- TensorView* weight,
- TensorView* running_mean,
- TensorView* running_var,
- TensorView* save_mean,
- TensorView* save_invstd,
- const bool kTraining,
- Val* eps,
- const std::vector<bool>& output_mask) {
- TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
- TORCH_INTERNAL_ASSERT(dy != nullptr, "Grad Output is invalid.");
- TORCH_INTERNAL_ASSERT(
- eps != nullptr && eps->getDataType().has_value() &&
- eps->getDataType().value() == DataType::Double,
- "Epsilon (eps) is not a valid Double.");
-
- // (B, C, H, W, D) tensor
- // M = outer = channels
- // N = reduction = B * H * W * D
- // weight = bias = (C) tensor
- const size_t kChannelsDim = 1;
- const size_t kNumberOfDims =
- TensorDomain::noReductions(x->getRootDomain()).size();
-
- std::vector<int> reduction_axes;
- std::vector<bool> broadcast_mask(kNumberOfDims, false);
- Val* num_features = new Double(1);
- for (size_t axis = 0; axis < kNumberOfDims; ++axis) {
- if (axis != kChannelsDim) {
- reduction_axes.push_back(axis);
- broadcast_mask[axis] = true;
- num_features = mul(num_features, x->domain()->domain()[axis]->extent());
- }
- }
-
- Val* bcast_weight = nullptr;
- if (weight != nullptr) {
- bcast_weight = broadcast(weight, broadcast_mask);
- } else {
- bcast_weight = new Double(1);
- }
-
- TensorView* dx = nullptr;
- TensorView* dw = nullptr;
- TensorView* db = nullptr;
- if (kTraining) {
- TORCH_INTERNAL_ASSERT(
- save_mean != nullptr && save_invstd != nullptr,
- "When training=True, save_mean and save_invstd are required.");
-
- auto bcast_rstd = broadcast(save_invstd, broadcast_mask);
- auto bcast_mean = broadcast(save_mean, broadcast_mask);
- auto x_hat = mul(sub(x, bcast_mean), bcast_rstd);
- auto grad_x_hat = mul(dy, bcast_weight);
-
- auto a = mul(num_features, grad_x_hat);
-
- auto b = sum(grad_x_hat, reduction_axes);
- auto bcast_b = broadcast(b, broadcast_mask);
-
- auto c1 = mul(grad_x_hat, x_hat);
- auto c2 = sum(c1, reduction_axes);
- auto bcast_c2 = broadcast(c2, broadcast_mask);
- auto c3 = mul(x_hat, bcast_c2);
-
- auto inner = sub(sub(a, bcast_b), c3);
-
- auto reciprocal_size = unaryOp(UnaryOpType::Reciprocal, num_features);
-
- if (output_mask[0]) {
- dx = mul(mul(reciprocal_size, bcast_rstd), inner);
- }
-
- if (output_mask[1]) {
- dw = sum(mul(dy, x_hat), reduction_axes);
- }
- } else {
- // TODO: this is not a legit assumption? Can't we run with
- // track_running_stats == false && training == false
- // which should just run through the case above.
- TORCH_INTERNAL_ASSERT(
- running_mean != nullptr && running_var != nullptr,
- "When training=False, running_mean and running_invstd are required.");
-
- auto bcast_var = broadcast(running_var, broadcast_mask);
- auto var_eps = add(bcast_var, eps);
- auto bcast_rstd = unaryOp(UnaryOpType::Rsqrt, var_eps);
- auto bcast_mean = broadcast(running_mean, broadcast_mask);
-
- if (output_mask[0]) {
- dx = mul(mul(dy, bcast_rstd), bcast_weight);
- }
-
- if (output_mask[1]) {
- auto x_hat = mul(sub(x, bcast_mean), bcast_rstd);
- dw = sum(mul(dy, x_hat), reduction_axes);
- }
- }
-
- if (output_mask[2]) {
- db = sum(dy, reduction_axes);
- }
-
- return {dx, dw, db};
-}
-
-ForwardNormResult instance_norm(
- TensorView* x,
- TensorView* weight,
- TensorView* bias,
- TensorView* running_mean,
- TensorView* running_var,
- const bool kUseInputStats,
- Val* momentum,
- Val* eps) {
- auto fusion = FusionGuard::getCurFusion();
-
- TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
-
- TORCH_INTERNAL_ASSERT(
- !((running_var == nullptr) ^ (running_mean == nullptr)),
- "running stats should comes in pairs");
-
- TORCH_INTERNAL_ASSERT(
- momentum != nullptr && momentum->getDataType().has_value() &&
- momentum->getDataType().value() == DataType::Double,
- "Momentum is not a valid Double.");
-
- TORCH_INTERNAL_ASSERT(
- eps != nullptr && eps->getDataType().has_value() &&
- eps->getDataType().value() == DataType::Double,
- "Epsilon (eps) is not a valid Double.");
-
- // (B, C, H, W, D) tensor
- // M = outer = B * C
- // N = reduction = H * W * D
- // weight = bias = C tensor
- const size_t kBatchDim = 0;
- const size_t kChannelsDim = 1;
- const size_t kNumberOfDims =
- TensorDomain::noReductions(x->getRootDomain()).size();
-
- std::vector<int> x_reduction_axes;
- std::vector<bool> x_broadcast_mask(kNumberOfDims, false);
- Val* N = new Double(1);
- for (size_t axis = 0; axis < kNumberOfDims; ++axis) {
- if (axis != kBatchDim && axis != kChannelsDim) {
- x_reduction_axes.push_back(axis);
- x_broadcast_mask[axis] = true;
- N = mul(N, x->domain()->domain()[axis]->extent());
- }
- }
- Val* B = new Double(1);
- B = mul(B, x->domain()->domain()[kBatchDim]->extent());
-
- std::vector<bool> channels_only_broadcast_mask(kNumberOfDims, false);
- for (size_t axis = 0; axis < kNumberOfDims; ++axis) {
- if (axis != kChannelsDim) {
- channels_only_broadcast_mask[axis] = true;
- }
- }
-
- TensorView* y = nullptr;
- TensorView* mean = nullptr;
- TensorView* invstd = nullptr;
- if (kUseInputStats || running_mean == nullptr) {
- // Algorithm
- auto welford_out = Welford(x, x_reduction_axes);
-
- // updating running mean and running var
- if (running_mean != nullptr && running_var != nullptr) {
- auto rev_momentum = sub(new Double(1.0), momentum);
- auto current_mean_hat = mul(welford_out.avg, momentum);
- auto mean_hat = mul(running_mean, rev_momentum);
- auto new_mean_hat = add(mean_hat, current_mean_hat);
-
- auto new_mean_sum = sum(new_mean_hat, {kBatchDim});
- auto new_mean_channels_only = div(new_mean_sum, B);
- fusion->addOutput(new_mean_channels_only);
- fusion->aliasOutputToInput(new_mean_channels_only, running_mean);
-
- auto num_feature_decrement = sub(N, new Int(1));
- auto unbiased_var = div(welford_out.var_sum, num_feature_decrement);
- auto current_var_hat = mul(unbiased_var, momentum);
- auto var_hat = mul(running_var, rev_momentum);
- auto new_var_hat = add(var_hat, current_var_hat);
-
- auto new_var_sum = sum(new_var_hat, {kBatchDim});
- auto new_var_channels_only = div(new_var_sum, B);
- fusion->addOutput(new_var_channels_only);
- fusion->aliasOutputToInput(new_var_channels_only, running_var);
- }
-
- mean = welford_out.avg;
- auto mean_bcast = broadcast(mean, x_broadcast_mask);
- auto x_sub_mean = sub(x, mean_bcast);
-
- auto var = div(welford_out.var_sum, N);
- auto var_eps = add(var, eps);
- invstd = unaryOp(UnaryOpType::Rsqrt, var_eps);
- auto invstd_bcast = broadcast(invstd, x_broadcast_mask);
-
- y = mul(x_sub_mean, invstd_bcast);
- } else {
- // This is inference mode with running stats
- auto r_mean_bcasted = broadcast(running_mean, channels_only_broadcast_mask);
- auto x_sub_mean = sub(x, r_mean_bcasted);
-
- auto var_eps = add(running_var, eps);
- auto unbiased_invstd = unaryOp(UnaryOpType::Rsqrt, var_eps);
- auto invstd_bcast =
- broadcast(unbiased_invstd, channels_only_broadcast_mask);
-
- // During inference, mean/invstd output are empty tensors
- mean = TensorViewBuilder().shape({0}).build();
- invstd = TensorViewBuilder().shape({0}).build();
- y = mul(x_sub_mean, invstd_bcast);
- }
-
- // Optional: norm * weight
- if (weight) {
- auto weight_bcast = broadcast(weight, channels_only_broadcast_mask);
- y = mul(y, weight_bcast);
- }
-
- // Optional: norm * weight + bias
- if (bias) {
- auto bias_bcast = broadcast(bias, channels_only_broadcast_mask);
- y = add(y, bias_bcast);
- }
- return {y, mean, invstd};
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/type.h>
-
-//
-// The operations defined in this header is intended as user facing functions.
-// The user will provide the necessary input TensorViews and the function will
-// create the correct intermediate nodes and return the output TensorViews.
-//
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-struct ForwardNormResult {
- TensorView* output = nullptr;
- TensorView* mean = nullptr;
- TensorView* invstd = nullptr;
-};
-
-struct BackwardNormResult {
- TensorView* grad_input = nullptr;
- TensorView* grad_weight = nullptr;
- TensorView* grad_bias = nullptr;
-};
-
-TORCH_CUDA_CU_API TensorView* softmax(TensorView* x, int dim);
-
-TORCH_CUDA_CU_API TensorView* softmax_backward(
- TensorView* dy,
- TensorView* y,
- const int dim,
- TensorView* x);
-
-TORCH_CUDA_CU_API ForwardNormResult layer_norm(
- TensorView* x,
- const std::vector<int64_t>& norm_shape,
- TensorView* weight,
- TensorView* bias,
- Val* eps);
-
-TORCH_CUDA_CU_API ForwardNormResult layer_norm(
- TensorView* x,
- const size_t kNormShapeNumDims,
- TensorView* weight,
- TensorView* bias,
- Val* eps);
-
-TORCH_CUDA_CU_API BackwardNormResult layer_norm_backward(
- TensorView* dy,
- TensorView* x,
- const std::vector<int64_t>& norm_shape,
- TensorView* mean,
- TensorView* rstd,
- TensorView* weight,
- TensorView* bias,
- const std::vector<bool>& output_mask);
-
-TORCH_CUDA_CU_API ForwardNormResult batch_norm(
- TensorView* x,
- TensorView* weight,
- TensorView* bias,
- TensorView* running_mean,
- TensorView* running_var,
- const bool kTraining,
- Val* momentum,
- Val* eps);
-
-TORCH_CUDA_CU_API BackwardNormResult batch_norm_backward(
- TensorView* x,
- TensorView* dy,
- TensorView* weight,
- TensorView* running_mean,
- TensorView* running_var,
- TensorView* save_mean,
- TensorView* save_invstd,
- const bool kTraining,
- Val* eps,
- const std::vector<bool>& output_mask);
-
-TORCH_CUDA_CU_API ForwardNormResult instance_norm(
- TensorView* x,
- TensorView* weight,
- TensorView* bias,
- TensorView* running_mean,
- TensorView* running_var,
- const bool kUseInputStats,
- Val* momentum,
- Val* eps);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/parallel_dimension_map.h>
-
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/lower2device.h>
-
-#include <sstream>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-void ParallelDimensionMap::build(Fusion* fusion) {
- // Scan all TVs to build ParallelType maps
- auto all_vals = fusion->usedMathVals();
- for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
- for (auto id : tv->domain()->domain()) {
- registerConstantExtent(id);
- if (!isParallelTypeThread(id->getParallelType())) {
- continue;
- }
- handleParallelDomain(id);
- }
- }
-
- // Populate the dimension map for each parallel type
- for (const auto& kv : concrete_dom_map_) {
- auto pt = kv.first;
- const auto& concrete_dom_set = kv.second;
- TORCH_INTERNAL_ASSERT(!concrete_dom_set.empty());
- if (concrete_dom_set.size() == 1) {
- populateDimensionMapWithSingleCASet(pt, concrete_dom_set);
- } else {
- populateDimensionMapWithMultipleCASet(pt, concrete_dom_set);
- }
- }
-}
-
-void ParallelDimensionMap::registerConstantExtent(IterDomain* id) {
- ExpressionEvaluator ee(id->fusion());
- auto extent_int = ee.evaluate(id->extent());
- if (!extent_int.has_value()) {
- // Nothing to do if not constant
- return;
- }
-
- auto const_extent = extent_int.value();
-
- // Ignore if this is derived from a size-1 domain as it is likely a
- // size-1 broadcast domain and that does not represent the actual
- // dimension even if it's constant. Being size-1 may not always mean
- // it's a broadcast domain, but it'd be safe to assume it is mostly
- // the case. If it is not a broadcast, ignoring this domain does not
- // impact the correctness.
- auto extent_inputs = InputsOf::output(id->fusion(), id->extent());
- if (std::any_of(extent_inputs.begin(), extent_inputs.end(), [](Val* input) {
- return input->isOneInt();
- })) {
- return;
- }
-
- auto concrete_id = getCAMappedConcreteDomain(id);
-
- auto existing_it = constant_extent_map_.find(id);
-
- // Adds the constant extent to the set for the concrete domain. If
- // multiple constants are found, this concrete domain has multiple
- // distinctive extents, which can happen with broadcast.
- if (existing_it == constant_extent_map_.end()) {
- constant_extent_map_.insert({concrete_id, {const_extent}});
- } else {
- existing_it->second.insert(const_extent);
- }
-}
-
-// Adds the conrecte domain of id to the mappsed set for its
-// parallel type
-void ParallelDimensionMap::handleParallelDomain(IterDomain* id) {
- auto pt = id->getParallelType();
- TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt));
- auto concrete_id = getCAMappedConcreteDomain(id);
-
- auto it = concrete_dom_map_.find(pt);
- if (it == concrete_dom_map_.end()) {
- concrete_dom_map_.insert({pt, {concrete_id}});
- } else {
- it->second.insert(concrete_id);
- }
-}
-
-void ParallelDimensionMap::populateDimensionMapWithSingleCASet(
- ParallelType pt,
- const std::unordered_set<IterDomain*>& dom_set) {
- TORCH_INTERNAL_ASSERT(dom_set.size() == 1);
-
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- // pt is used by only one concrete domain
- auto id = *dom_set.begin();
- auto it = constant_extent_map_.find(id);
-
- if (it != constant_extent_map_.end()) {
- if (it->second.size() == 1) {
- dim_map_.insert({pt, ir_builder.create<kir::Int>(*(it->second.begin()))});
- exact_types_.insert(pt);
- } else {
- // Multiple constant dimensions found; Use the corresponding
- // symbolic parallel dim
- dim_map_.insert({pt, kir::NamedScalar::getParallelDim(pt)});
- }
- } else {
- // Prefer to use blockDim/gridDim if not constant
- dim_map_.insert({pt, kir::NamedScalar::getParallelDim(pt)});
- exact_types_.insert(pt);
- }
-}
-
-void ParallelDimensionMap::populateDimensionMapWithMultipleCASet(
- ParallelType pt,
- const std::unordered_set<IterDomain*>& dom_set) {
- TORCH_INTERNAL_ASSERT(dom_set.size() > 1);
-
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- bool all_equal = true;
- kir::Val* known_dimension =
- gpu_lower->lowerValue((*dom_set.begin())->extent());
- // Set it -1 to signal it's not initialied yet
- int64_t known_const = -1;
-
- // Check all of concrete domains to see if they match all together.
- for (auto concrete_id : dom_set) {
- // If this concrete domain has a constant extent, check if it
- // matches with the known constant extent.
- auto it = constant_extent_map_.find(concrete_id);
- if (it != constant_extent_map_.end()) {
- const auto& const_extent_set = it->second;
- // If multiple constants are detected, it's not exact.
- if (const_extent_set.size() > 1) {
- all_equal = false;
- break;
- }
- auto this_const = *(const_extent_set.begin());
- // known_const is initialized to -1
- if (known_const == -1) {
- known_const = this_const;
- } else if (known_const == this_const) {
- // Matched with previously known const. The extent of this
- // domain must be equal to that's previously known.
- continue;
- } else {
- // Unmatched. This dom_set extents may not be unique.
- all_equal = false;
- break;
- }
- }
-
- // At this point, it still remains undetermined whether this id
- // matches with those previously looked at. Constant check failed,
- // but symbolic matching may succeed.
- if (!equalDim(
- known_dimension, gpu_lower->lowerValue(concrete_id->extent()))) {
- all_equal = false;
- break;
- }
- }
-
- // If all_equal is still true, the dimension of this paralel type
- // must be exact.
- if (all_equal) {
- exact_types_.insert(pt);
- }
- // Use the const value, if found, as its dimension
- if (all_equal && known_const != -1) {
- dim_map_.insert({pt, ir_builder.create<kir::Int>(known_const)});
- } else {
- dim_map_.insert({pt, kir::NamedScalar::getParallelDim(pt)});
- }
-}
-
-kir::Val* ParallelDimensionMap::get(ParallelType pt) const {
- TORCH_INTERNAL_ASSERT(isParallelTypeThread(pt), "Invalid ParallelType: ", pt);
- auto it = dim_map_.find(pt);
- if (it == dim_map_.end()) {
- return nullptr;
- } else {
- return it->second;
- }
-}
-
-bool ParallelDimensionMap::isExact(ParallelType pt) const {
- return exact_types_.find(pt) != exact_types_.end();
-}
-
-IterDomain* ParallelDimensionMap::getCAMappedConcreteDomain(IterDomain* id) {
- const auto gpu_lower = GpuLower::current();
- const auto& ca_map = gpu_lower->caIndexMap();
- return ca_map.getConcreteMappedID(id);
-}
-
-// Symbolically compares equality of two KIR vals. Comparison is done
-// conservatively, so returning false does not guarantee non-equality.
-bool ParallelDimensionMap::equalDim(kir::Val* dim1, kir::Val* dim2) {
- TORCH_INTERNAL_ASSERT(dim1 != nullptr && dim2 != nullptr);
-
- if (dim1 == dim2) {
- return true;
- }
-
- // When Both are Int, they are same if both have the same constant
- auto dim1_int = dynamic_cast<kir::Int*>(dim1);
- auto dim2_int = dynamic_cast<kir::Int*>(dim2);
- if (dim1_int && dim2_int) {
- if (dim1_int->isConst() && dim2_int->isConst()) {
- return dim1_int->value() == dim2_int->value();
- }
- }
-
- // When both are NamedScalar, they are same if Both have the same
- // name
- auto dim1_ns = dynamic_cast<kir::NamedScalar*>(dim1);
- auto dim2_ns = dynamic_cast<kir::NamedScalar*>(dim2);
- if (dim1_ns && dim2_ns) {
- return dim1_ns->name() == dim2_ns->name();
- }
-
- // Check recursively their definitions
-
- auto dim1_def = dim1->definition();
- auto dim2_def = dim2->definition();
-
- if (dim1_def == nullptr || dim2_def == nullptr) {
- return false;
- }
-
- // If both are BinaryOp or UnaryOp, check their inputs. Since these
- // Vals are IterDomain extents, UnaryOp should not occur, but
- // checking shouldn't be harmful.
- if ((dim1_def->isA<kir::BinaryOp>() && dim2_def->isA<kir::BinaryOp>() &&
- (dim1_def->as<kir::BinaryOp>()->operation() ==
- dim2_def->as<kir::BinaryOp>()->operation())) ||
- (dim1_def->isA<kir::UnaryOp>() && dim2_def->isA<kir::UnaryOp>() &&
- (dim1_def->as<kir::UnaryOp>()->operation() ==
- dim2_def->as<kir::UnaryOp>()->operation()))) {
- for (size_t i = 0; i < dim1_def->inputs().size(); ++i) {
- if (!equalDim(dim1_def->inputs()[0], dim2_def->inputs()[0])) {
- return false;
- }
- }
- return true;
- }
-
- return false;
-}
-
-std::string ParallelDimensionMap::toString() const {
- std::stringstream ss;
-
- const std::array<ParallelType, 6> ptypes{
- ParallelType::BIDx,
- ParallelType::BIDy,
- ParallelType::BIDz,
- ParallelType::TIDx,
- ParallelType::TIDy,
- ParallelType::TIDz};
-
- for (auto pt : ptypes) {
- ss << pt << ": ";
- auto dim = get(pt);
- if (dim != nullptr) {
- ss << kir::toString(dim);
- if (isExact(pt)) {
- ss << ", exact";
- } else {
- ss << ", non-exact";
- }
- } else {
- ss << "unused";
- }
- ss << "\n";
- }
-
- return ss.str();
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-
-#include <deque>
-#include <unordered_map>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Maps TID/BID to its dimension. It is by default blockDim/gridDim,
-//! but if use of a ParallelType is mapped to a unique constant
-//! extent, the constant value is used instead since presumably it's
-//! more efficient.
-class TORCH_CUDA_CU_API ParallelDimensionMap {
- public:
- void build(Fusion* fusion);
-
- //! Returns the dimension of a ParallelType. nullptr is returned if
- //! a ParallelType is unused.
- kir::Val* get(ParallelType pt) const;
-
- //! True if the dimension of a ParallelType is known to be exact
- bool isExact(ParallelType pt) const;
-
- std::string toString() const;
-
- //! Symbolically analyze if two extent vals are equal
- static bool equalDim(kir::Val* dim1, kir::Val* dim2);
-
- private:
- //! Register the extent of an IterDomain if its constant
- void registerConstantExtent(IterDomain* id);
-
- void handleParallelDomain(IterDomain* id);
-
- void populateDimensionMapWithSingleCASet(
- ParallelType pt,
- const std::unordered_set<IterDomain*>& dom_set);
-
- void populateDimensionMapWithMultipleCASet(
- ParallelType pt,
- const std::unordered_set<IterDomain*>& dom_set);
-
- static IterDomain* getCAMappedConcreteDomain(IterDomain* id);
-
- private:
- //! Maps from parallel types to dimensions, which are constant if
- //! a unique value is found.
- std::unordered_map<ParallelType, kir::Val*, TypeHash> dim_map_;
- //! Set of parallel types whose dimensions are identified to be
- //! exactly the same as extents of mapped domains.
- std::unordered_set<ParallelType, TypeHash> exact_types_;
-
- // Below are temporary maps to build the ParallelType-to-dimension
- // map. Only used during build().
-
- //! Map from a parallel type to a set of concrete domains where the
- //! parallel type is used.
- std::unordered_map<ParallelType, std::unordered_set<IterDomain*>, TypeHash>
- concrete_dom_map_;
- //! Keep track of constant extents found for a CA domain set
- //! represented by the concrete domain.
- std::unordered_map<IterDomain*, std::unordered_set<int64_t>>
- constant_extent_map_;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-const std::unordered_map<ParallelType, int, TypeHash>
- ParallelTypeBitmap::pt_to_offset_{
- {ParallelType::BIDx, 0},
- {ParallelType::BIDy, 1},
- {ParallelType::BIDz, 2},
- {ParallelType::TIDx, 3},
- {ParallelType::TIDy, 4},
- {ParallelType::TIDz, 5}};
-
-const std::unordered_map<int, ParallelType> ParallelTypeBitmap::offset_to_pt_ =
- {{0, ParallelType::BIDx},
- {1, ParallelType::BIDy},
- {2, ParallelType::BIDz},
- {3, ParallelType::TIDx},
- {4, ParallelType::TIDy},
- {5, ParallelType::TIDz}};
-
-bool ParallelTypeBitmap::get(ParallelType pt) const {
- if (pt_to_offset_.find(pt) == pt_to_offset_.end()) {
- TORCH_INTERNAL_ASSERT(false, "Could not recognize parallel type.");
- }
- return bitset_[pt_to_offset_.at(pt)];
-}
-
-bool ParallelTypeBitmap::set(ParallelType pt, bool new_val) {
- if (pt_to_offset_.find(pt) == pt_to_offset_.end()) {
- TORCH_INTERNAL_ASSERT(false, "Could not recognize parallel type.");
- }
- bool old_val = bitset_[pt_to_offset_.at(pt)];
- bitset_[pt_to_offset_.at(pt)] = new_val;
- return old_val;
-}
-
-ParallelTypeBitmap ParallelTypeBitmap::operator&=(
- const ParallelTypeBitmap& other) {
- bitset_ &= other.bitset_;
- return *this;
-}
-
-ParallelTypeBitmap ParallelTypeBitmap::operator|=(
- const ParallelTypeBitmap& other) {
- bitset_ |= other.bitset_;
- return *this;
-}
-
-ParallelTypeBitmap ParallelTypeBitmap::operator^=(
- const ParallelTypeBitmap& other) {
- bitset_ ^= other.bitset_;
- return *this;
-}
-
-ParallelTypeBitmap ParallelTypeBitmap::operator~() const {
- return ParallelTypeBitmap(~bitset_);
-}
-
-bool ParallelTypeBitmap::none() const {
- return bitset_.none();
-}
-
-bool ParallelTypeBitmap::any() const {
- return bitset_.any();
-}
-
-bool ParallelTypeBitmap::all() const {
- return bitset_.all();
-}
-
-bool ParallelTypeBitmap::operator[](size_t pos) const {
- TORCH_INTERNAL_ASSERT(
- pos < num_p_type, "Invalid index to ParallelTypeBitset: ", pos);
- return bitset_[pos];
-}
-
-bool ParallelTypeBitmap::hasTID() const {
- for (auto pt : {ParallelType::TIDx, ParallelType::TIDy, ParallelType::TIDz}) {
- if (get(pt)) {
- return true;
- }
- }
- return false;
-}
-
-bool ParallelTypeBitmap::hasBID() const {
- for (auto pt : {ParallelType::BIDx, ParallelType::BIDy, ParallelType::BIDz}) {
- if (get(pt)) {
- return true;
- }
- }
- return false;
-}
-
-std::map<ParallelType, bool> ParallelTypeBitmap::getMap() const {
- std::map<ParallelType, bool> map;
- for (const auto& pt_offset : pt_to_offset_) {
- map.emplace(pt_offset.first, bitset_[pt_offset.second]);
- }
- return map;
-}
-
-ParallelTypeBitmap operator&(
- const ParallelTypeBitmap& lhs,
- const ParallelTypeBitmap& rhs) {
- auto x = lhs;
- x &= rhs;
- return x;
-}
-
-ParallelTypeBitmap operator|(
- const ParallelTypeBitmap& lhs,
- const ParallelTypeBitmap& rhs) {
- auto x = lhs;
- x |= rhs;
- return x;
-}
-
-ParallelTypeBitmap operator^(
- const ParallelTypeBitmap& lhs,
- const ParallelTypeBitmap& rhs) {
- auto x = lhs;
- x ^= rhs;
- return x;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/codegen/cuda/type.h>
-
-#include <bitset>
-#include <map>
-#include <unordered_map>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Represents mapping to bool from BIDx, BIDy, BIDz, TIDx, TIDy and TIDz.
-class ParallelTypeBitmap {
- public:
- static constexpr int num_p_type = 6;
-
- ParallelTypeBitmap() = default;
-
- //! Return true if pt is included
- bool get(ParallelType pt) const;
- //! Set the mapping of pt
- bool set(ParallelType pt, bool);
- //! Assign logical AND with other
- ParallelTypeBitmap operator&=(const ParallelTypeBitmap& other);
- //! Assign logical OR with other
- ParallelTypeBitmap operator|=(const ParallelTypeBitmap& other);
- //! Assign logical NOR with other
- ParallelTypeBitmap operator^=(const ParallelTypeBitmap& other);
- //! Return logical compliment
- ParallelTypeBitmap operator~() const;
- //! Return true if none of the mapppings is true
- bool none() const;
- //! Return true if any of the mapppings is true
- bool any() const;
- //! Return true if all of the mapppings is true
- bool all() const;
- //! Return true if the parallel type corresponding to a position
- //! defined in offset_to_pt_ is true
- bool operator[](size_t pos) const;
- //! Return an equivalent std::map
- std::map<ParallelType, bool> getMap() const;
- //! Return true if TIDx/y/z is included
- bool hasTID() const;
- //! Return true if BIDx/y/z is included
- bool hasBID() const;
-
- private:
- ParallelTypeBitmap(const std::bitset<num_p_type>& bs) : bitset_(bs) {}
-
- private:
- std::bitset<num_p_type> bitset_;
- //! Map of ParallelType to bit positions
- const static std::unordered_map<ParallelType, int, TypeHash> pt_to_offset_;
- //! Map of bit positions to ParallelType
- const static std::unordered_map<int, ParallelType> offset_to_pt_;
-};
-
-ParallelTypeBitmap operator&(
- const ParallelTypeBitmap& lhs,
- const ParallelTypeBitmap& rhs);
-
-ParallelTypeBitmap operator|(
- const ParallelTypeBitmap& lhs,
- const ParallelTypeBitmap& rhs);
-
-ParallelTypeBitmap operator^(
- const ParallelTypeBitmap& lhs,
- const ParallelTypeBitmap& rhs);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <torch/csrc/jit/ir/constants.h>
namespace fuser {
namespace cuda {
-constexpr auto kNumUnaryOps = 32;
-constexpr auto kNumBinaryOps = 29;
+constexpr auto kNumUnaryOps = 31;
+constexpr auto kNumBinaryOps = 24;
constexpr auto kNumBinaryOpsWithAlpha = 4;
constexpr auto kNumLerpOps = 2;
-constexpr auto kNumLayernormFwd = 2;
-constexpr auto kNumBatchnormFwd = 3;
-constexpr auto kNumInstancenormFwd = 1;
-constexpr auto kNumSumToSize = 2;
-// constexpr auto kNumAutocastOps = 2;
namespace {
-#define REGISTER_PARSE_RULE(op, func_body, ...) \
- registerParseRule( \
- op, \
- [](const Node* node, \
- std::unordered_map<size_t, CgValue>& value_map) -> void func_body, \
- __VA_ARGS__)
-
-const auto& sizeAttr = Symbol::attr("profiled_size");
-const auto& intListAttr = Symbol::attr("profiled_int_list");
-const auto& intAttr = Symbol::attr("profiled_int");
-const auto& boolListAttr = Symbol::attr("profiled_bool_list");
-const auto& boolAttr = Symbol::attr("profiled_bool");
-
typedef Val* CgValue;
typedef Expr* CgOp;
// TODO: add a mutex to make it thread safe.
class IrParser {
- enum class OperatorType {
- ElementWise,
- Reduction,
- ReductionToSize,
- Normalization
- };
- typedef OperatorType (*OperatorTypeFuncPtr)(const Node*);
-
class RegistrationEntry {
public:
- RegistrationEntry(
- ParseFuncPtr parse_f,
- MergeQueryFuncPtr merge_f = nullptr,
- OperatorTypeFuncPtr type_f = nullptr)
- : parse_f_(parse_f), merge_f_(merge_f), type_f_(type_f) {}
+ RegistrationEntry(ParseFuncPtr parse_f, MergeQueryFuncPtr merge_f = nullptr)
+ : parse_f_(parse_f), merge_f_(merge_f) {}
- void parse(const Node* node, std::unordered_map<size_t, CgValue>& values)
- const {
+ void parse(const Node* node, std::unordered_map<size_t, CgValue>& values) {
parse_f_(node, values);
}
- bool isCompatible(const Node* node) const {
+ bool is_compatible(const Node* node) {
if (merge_f_ == nullptr) {
return true;
}
return merge_f_(node);
}
- bool isType(const Node* node, OperatorType type) const {
- auto n_type =
- type_f_ == nullptr ? OperatorType::ElementWise : type_f_(node);
- return n_type == type;
- }
-
private:
ParseFuncPtr parse_f_;
MergeQueryFuncPtr merge_f_;
- OperatorTypeFuncPtr type_f_;
};
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
IrParser(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
- initRegistry();
+ if (init_registry_) {
+ registerJitOperator();
+ init_registry_ = false;
+ }
}
+ // Fuses pointwise ops with loop unrolling (factor = 4).
std::unique_ptr<Fusion> parse() {
auto fusion = std::make_unique<Fusion>();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
for (auto val : block->inputs()) {
TORCH_INTERNAL_ASSERT(
registerValue(val),
- "Failure when register value: ",
- *(val->node()),
- " with type: ",
- val->type());
+ "Error trying to register value with code generation.");
fusion->addInput(value_map_[val->unique()]);
auto opt_dtype = value_map_[val->unique()]->getDataType();
}
}
+ // TODO: disable unroll to ensure rand_like generates identical output as
+ // with eager mode
+ bool disable_unroll = false;
+ bool has_reduction = false;
// compose nodes in topo order;
for (const JitOp* node : block->nodes()) {
processJitNode(node);
+ if (node->kind() == aten::rand_like) {
+ // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
+ disable_unroll = true;
+ }
+ if (node->kind() == aten::sum) {
+ // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
+ has_reduction = true;
+ }
}
- auto alias_indices = fusion->getInputAliasIndices();
// mark output;
for (auto jit_output : block->outputs()) {
}
fusion->addOutput(out);
}
-
return fusion;
}
- // return nullptr if entry does not exist
- static const RegistrationEntry* lookupInRegistry(const Node* node) {
- // we need to use maybeSchema for nodes like prim::Constant, which doesn't
- // have a schema
- auto schema_ptr = node->maybeSchema();
- if (schema_ptr != nullptr) {
- // search cached entry first
- auto cache_it = cached_registry_lookup_.find(schema_ptr);
- if (cache_it != cached_registry_lookup_.end()) {
- return cache_it->second;
- } else {
- // match signature
- auto schema_str = canonicalSchemaString(*schema_ptr);
-
- auto iter = jit_operator_registry_.find(schema_str);
- if (iter != jit_operator_registry_.end()) {
- // update cache entry
- cached_registry_lookup_.insert(cache_it, {schema_ptr, &iter->second});
- return &iter->second;
- }
- }
- }
- return nullptr;
- }
-
- static void initRegistry() {
+ static bool canParseNode(const Node* node) {
if (init_registry_) {
// TODO: mutex this guy;
registerJitOperator();
init_registry_ = false;
}
- }
-
- static bool canParseNode(const Node* node) {
- initRegistry();
// match signature.
- auto schema_ptr = node->maybeSchema();
- if (schema_ptr == nullptr) {
+ auto iter = jit_operator_registry_.find(node->kind());
+ if (iter == jit_operator_registry_.end()) {
return false;
}
- auto reg_entry = lookupInRegistry(node);
- return reg_entry != nullptr && reg_entry->isCompatible(node);
- }
-
- static bool isReductionToSizeNode(const Node* node) {
- initRegistry();
-
- auto reg_entry = lookupInRegistry(node);
- return reg_entry != nullptr &&
- reg_entry->isType(node, OperatorType::ReductionToSize);
+ for (auto& pair_op_func : iter->second) {
+ if (node->matches(pair_op_func.first->schema())) {
+ return pair_op_func.second.is_compatible(node);
+ }
+ }
+ return false;
}
static bool isReductionNode(const Node* node) {
- initRegistry();
-
- auto reg_entry = lookupInRegistry(node);
- return reg_entry != nullptr &&
- (reg_entry->isType(node, OperatorType::Reduction) ||
- reg_entry->isType(node, OperatorType::ReductionToSize));
- }
-
- static bool isNormalizationNode(const Node* node) {
- initRegistry();
-
- auto reg_entry = lookupInRegistry(node);
- return reg_entry != nullptr &&
- reg_entry->isType(node, OperatorType::Normalization);
- }
-
- static bool isElementWiseNode(const Node* node) {
- initRegistry();
+ if (init_registry_) {
+ // TODO: mutex this guy;
+ registerJitOperator();
+ init_registry_ = false;
+ }
- auto reg_entry = lookupInRegistry(node);
- return reg_entry != nullptr &&
- reg_entry->isType(node, OperatorType::ElementWise);
+ return jit_reduction_op_registry_.count(node->kind());
}
// TODO: is_reduction is too hacky here. we should categorize operation types
std::shared_ptr<Operator>& op,
ParseFuncPtr parse_fn,
MergeQueryFuncPtr merge_query_fn = nullptr,
- OperatorTypeFuncPtr type_fn = nullptr) {
- jit_operator_registry_.emplace(
- std::piecewise_construct,
- std::forward_as_tuple(canonicalSchemaString(op->schema())),
- std::forward_as_tuple(parse_fn, merge_query_fn, type_fn));
+ bool is_reduction = false) {
+ jit_operator_registry_[Symbol::fromQualString(op->schema().name())]
+ .emplace_back(
+ std::piecewise_construct,
+ std::forward_as_tuple(op),
+ std::forward_as_tuple(parse_fn, merge_query_fn));
+ if (is_reduction) {
+ jit_reduction_op_registry_.emplace(
+ Symbol::fromQualString(op->schema().name()));
+ }
}
private:
"aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor"};
for (auto signature : BinaryOpWithAlpha) {
auto ptr_op = getOperatorForLiteral(signature);
- REGISTER_PARSE_RULE(
+ registerParseRule(
ptr_op,
- {
+ [](const Node* node,
+ std::unordered_map<size_t, CgValue>& value_map) -> void {
using BinaryOpWithAlphaType = Val* (*)(Val*, Val*, Val*);
static std::unordered_map<
Symbol,
auto out = op_mapping[node->kind()].second(lhs, rhs, alpha);
value_map.emplace(node->output()->unique(), out);
}
- },
- nullptr,
- nullptr);
+ });
}
std::array<const char*, kNumBinaryOps> BinaryOp = {
"aten::pow(Scalar self, Tensor exponent) -> Tensor",
"aten::remainder(Tensor self, Tensor other) -> Tensor",
"aten::fmod(Tensor self, Tensor other) -> Tensor",
- "aten::__and__(Tensor self, Tensor other) -> Tensor",
- "aten::__or__(Tensor self, Tensor other) -> Tensor",
- "aten::__xor__(Tensor self, Tensor other) -> Tensor",
- "aten::__lshift__(Tensor self, Tensor other) -> Tensor",
- "aten::__rshift__(Tensor self, Tensor other) -> Tensor",
"aten::eq(Tensor self, Tensor other) -> Tensor",
"aten::eq(Tensor self, Scalar other) -> Tensor",
"aten::ne(Tensor self, Tensor other) -> Tensor",
"aten::lt(Tensor self, Scalar other) -> Tensor"};
for (auto signature : BinaryOp) {
auto ptr_op = getOperatorForLiteral(signature);
- REGISTER_PARSE_RULE(
+ registerParseRule(
ptr_op,
- {
+ [](const Node* node,
+ std::unordered_map<size_t, CgValue>& value_map) -> void {
static std::unordered_map<Symbol, BinaryOpType> op_mapping(
{{aten::div, BinaryOpType::Div},
{aten::mul, BinaryOpType::Mul},
{aten::gt, BinaryOpType::GT},
{aten::ge, BinaryOpType::GE},
{aten::ne, BinaryOpType::NE},
- {aten::eq, BinaryOpType::Eq},
- {aten::__and__, BinaryOpType::And},
- {aten::__or__, BinaryOpType::Or},
- {aten::__xor__, BinaryOpType::Xor},
- {aten::__lshift__, BinaryOpType::Lshift},
- {aten::__rshift__, BinaryOpType::Rshift}});
+ {aten::eq, BinaryOpType::Eq}});
auto lhs = value_map[node->inputs()[0]->unique()];
auto rhs = value_map[node->inputs()[1]->unique()];
auto out = binaryOp(op_mapping[node->kind()], lhs, rhs);
value_map.emplace(node->output()->unique(), out);
- },
- nullptr,
- nullptr);
+ });
}
// TODO: cast operations should be merged in.
"aten::floor(Tensor self) -> Tensor",
"aten::round(Tensor self) -> Tensor",
"aten::trunc(Tensor self) -> Tensor",
- "aten::bitwise_not(Tensor self) -> Tensor",
"aten::frac(Tensor self) -> Tensor",
"aten::reciprocal(Tensor self) -> Tensor",
"aten::relu(Tensor self) -> Tensor",
"aten::sigmoid(Tensor self) -> Tensor",
- "aten::silu(Tensor self) -> Tensor",
+ "aten::gelu(Tensor self) -> Tensor",
};
for (auto signature : UnaryOp) {
auto ptr_op = getOperatorForLiteral(signature);
- REGISTER_PARSE_RULE(
+ registerParseRule(
ptr_op,
- {
+ [](const Node* node,
+ std::unordered_map<size_t, CgValue>& value_map) -> void {
static std::unordered_map<Symbol, UnaryOpType> op_mapping({
{aten::neg, UnaryOpType::Neg},
{aten::abs, UnaryOpType::Abs},
{aten::floor, UnaryOpType::Floor},
{aten::round, UnaryOpType::Round},
{aten::trunc, UnaryOpType::Trunc},
- {aten::bitwise_not, UnaryOpType::Not},
{aten::frac, UnaryOpType::Frac},
{aten::reciprocal, UnaryOpType::Reciprocal},
{aten::relu, UnaryOpType::Relu},
{aten::sigmoid, UnaryOpType::Sigmoid},
- {aten::silu, UnaryOpType::Silu},
+ {aten::gelu, UnaryOpType::Gelu},
});
auto operand = value_map[node->input()->unique()];
auto out = unaryOp(op_mapping[node->kind()], operand);
value_map.emplace(node->output()->unique(), out);
- },
- nullptr,
- nullptr);
+ });
}
{
auto ptr_op = getOperatorForLiteral(
"aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor");
- REGISTER_PARSE_RULE(
+ registerParseRule(
ptr_op,
- {
+ [](const Node* node,
+ std::unordered_map<size_t, CgValue>& value_map) -> void {
auto operand = value_map[node->inputs()[0]->unique()];
auto out = unaryOp(UnaryOpType::RandLike, operand);
value_map.emplace(node->output()->unique(), out);
- },
- nullptr,
- nullptr);
- }
-
- {
- auto ptr_op = getOperatorForLiteral(
- "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor");
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- auto operand = value_map[node->inputs()[0]->unique()];
- auto beta = value_map[node->inputs()[1]->unique()];
- auto threshold = value_map[node->inputs()[2]->unique()];
- auto out = softplus(operand, beta, threshold);
- value_map.emplace(node->output()->unique(), out);
- },
- nullptr,
- nullptr);
+ });
}
{
auto ptr_op = getOperatorForLiteral(
"aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor");
- REGISTER_PARSE_RULE(
+ registerParseRule(
ptr_op,
- {
+ [](const Node* node,
+ std::unordered_map<size_t, CgValue>& value_map) -> void {
auto operand = value_map[node->inputs()[0]->unique()];
auto th = value_map[node->inputs()[1]->unique()];
auto value = value_map[node->inputs()[2]->unique()];
auto out = threshold(operand, th, value);
value_map.emplace(node->output()->unique(), out);
- },
- nullptr,
- nullptr);
+ });
}
{
auto ptr_op = getOperatorForLiteral(
"aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor");
- REGISTER_PARSE_RULE(
+ registerParseRule(
ptr_op,
- {
+ [](const Node* node,
+ std::unordered_map<size_t, CgValue>& value_map) -> void {
auto operand = value_map[node->inputs()[0]->unique()];
// TODO: we need to get a proper lower bound per dtype in operand.
auto low = value_map.count(node->inputs()[1]->unique()) != 0
? value_map[node->inputs()[1]->unique()]
- : new Double(std::numeric_limits<float>::min());
+ : new Float(std::numeric_limits<float>::min());
auto high = value_map.count(node->inputs()[2]->unique()) != 0
? value_map[node->inputs()[2]->unique()]
- : new Double(std::numeric_limits<float>::max());
+ : new Float(std::numeric_limits<float>::max());
auto out = clamp(operand, low, high);
value_map.emplace(node->output()->unique(), out);
- },
- nullptr,
- nullptr);
+ });
}
{
auto ptr_op = getOperatorForLiteral(
"aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor");
- REGISTER_PARSE_RULE(
+ registerParseRule(
ptr_op,
- {
+ [](const Node* node,
+ std::unordered_map<size_t, CgValue>& value_map) -> void {
auto condition = value_map[node->inputs()[0]->unique()];
auto x = value_map[node->inputs()[1]->unique()];
auto y = value_map[node->inputs()[2]->unique()];
auto out = where(condition, x, y);
value_map.emplace(node->output()->unique(), out);
- },
- nullptr,
- nullptr);
+ });
}
{
"aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor"};
for (auto signature : LerpOp) {
auto ptr_op = getOperatorForLiteral(signature);
- REGISTER_PARSE_RULE(
+ registerParseRule(
ptr_op,
- {
+ [](const Node* node,
+ std::unordered_map<size_t, CgValue>& value_map) -> void {
auto self = value_map[node->inputs()[0]->unique()];
auto end = value_map[node->inputs()[1]->unique()];
auto weight = value_map[node->inputs()[2]->unique()];
auto out = lerp(self, end, weight);
value_map.emplace(node->output()->unique(), out);
- },
- nullptr,
- nullptr);
+ });
}
}
{
auto ptr_op = getOperatorForLiteral(
"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor");
- REGISTER_PARSE_RULE(
+ registerParseRule(
ptr_op,
- {
+ [](const Node* node,
+ std::unordered_map<size_t, CgValue>& value_map) -> void {
auto self = value_map[node->inputs()[0]->unique()];
auto tensor1 = value_map[node->inputs()[1]->unique()];
auto tensor2 = value_map[node->inputs()[2]->unique()];
auto out = addcmul(self, tensor1, tensor2, value);
value_map.emplace(node->output()->unique(), out);
- },
- nullptr,
- nullptr);
- }
-
- {
- auto ptr_op = getOperatorForLiteral(
- "aten::dropout(Tensor input, float p, bool train) -> Tensor");
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- auto input = value_map[node->input(0)->unique()]->as<TensorView>();
- auto train = constant_as<bool>(node->input(2));
- TORCH_INTERNAL_ASSERT(
- train.has_value(), "dropout needs constant `train` flag");
-
- if (train.value()) {
- auto prob = value_map[node->input(1)->unique()];
- auto result = dropout(input, prob);
-
- value_map.emplace(node->output()->unique(), result.output);
- } else {
- value_map.emplace(node->output()->unique(), input);
- }
- },
- nullptr,
- nullptr);
- }
-
- {
- std::array<const char*, kNumInstancenormFwd> InstanceNormFwd = {
- "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor"};
- for (auto signature : InstanceNormFwd) {
- auto ptr_op = getOperatorForLiteral(signature);
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- auto fusion = FusionGuard::getCurFusion();
-
- auto input =
- value_map[node->input(0)->unique()]->as<TensorView>();
-
- TensorView* weight = nullptr;
- if (!node->input(1)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- weight = value_map[node->input(1)->unique()]->as<TensorView>();
- }
-
- TensorView* bias = nullptr;
- if (!node->input(2)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- bias = value_map[node->input(2)->unique()]->as<TensorView>();
- }
-
- TensorView* running_mean = nullptr;
- if (!node->input(3)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- running_mean =
- value_map[node->input(3)->unique()]->as<TensorView>();
- TORCH_INTERNAL_ASSERT(
- fusion->hasInput(running_mean),
- "IO_tensor `batch_norm::running_mean` can only be input tensor to fusion");
- }
-
- TensorView* running_var = nullptr;
- if (!node->input(4)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- running_var =
- value_map[node->input(4)->unique()]->as<TensorView>();
- TORCH_INTERNAL_ASSERT(
- fusion->hasInput(running_var),
- "IO_tensor `batch_norm::running_var` can only be input tensor to fusion");
- }
-
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- auto use_input_stats = constant_as<bool>(node->input(5));
- TORCH_INTERNAL_ASSERT(
- use_input_stats.has_value(),
- "The use_input_stats (bool) parameter is required.");
- const bool kUseInputStats = use_input_stats.value();
-
- Val* momentum_ptr = nullptr;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- if (auto momentum = constant_as<float>(node->input(6))) {
- momentum_ptr = new Double(momentum.value());
- } else {
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- momentum_ptr = value_map[node->input(6)->unique()];
- }
-
- Val* eps_ptr = nullptr;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- if (auto eps = constant_as<float>(node->input(7))) {
- eps_ptr = new Double(eps.value());
- } else {
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- eps_ptr = value_map[node->input(7)->unique()];
- }
-
- auto result = instance_norm(
- input,
- weight,
- bias,
- running_mean,
- running_var,
- kUseInputStats,
- momentum_ptr,
- eps_ptr);
-
- if (node->kind() ==
- c10::Symbol::fromQualString("aten::instance_norm")) {
- value_map.emplace(node->output()->unique(), result.output);
- }
- },
- [](const Node* node) -> bool { return true; },
- [](const Node* node) -> OperatorType {
- return OperatorType::Normalization;
- });
- }
- }
-
- {
- std::array<const char*, kNumBatchnormFwd> BatchNormFwd = {
- "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)",
- "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
- "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor"};
- for (auto signature : BatchNormFwd) {
- auto ptr_op = getOperatorForLiteral(signature);
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- auto fusion = FusionGuard::getCurFusion();
-
- auto input =
- value_map[node->input(0)->unique()]->as<TensorView>();
-
- TensorView* weight = nullptr;
- if (!node->input(1)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- weight = value_map[node->input(1)->unique()]->as<TensorView>();
- }
-
- TensorView* bias = nullptr;
- if (!node->input(2)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- bias = value_map[node->input(2)->unique()]->as<TensorView>();
- }
-
- TensorView* running_mean = nullptr;
- if (!node->input(3)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- running_mean =
- value_map[node->input(3)->unique()]->as<TensorView>();
- TORCH_INTERNAL_ASSERT(
- fusion->hasInput(running_mean),
- "IO_tensor `batch_norm::running_mean` can only be input tensor to fusion");
- }
-
- TensorView* running_var = nullptr;
- if (!node->input(4)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- running_var =
- value_map[node->input(4)->unique()]->as<TensorView>();
- TORCH_INTERNAL_ASSERT(
- fusion->hasInput(running_var),
- "IO_tensor `batch_norm::running_var` can only be input tensor to fusion");
- }
-
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- auto training = constant_as<bool>(node->input(5));
- TORCH_INTERNAL_ASSERT(
- training.has_value(),
- "The training (bool) parameter is required.");
- const bool kTraining = training.value();
-
- Val* momentum_ptr = nullptr;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- if (auto momentum = constant_as<float>(node->input(6))) {
- momentum_ptr = new Double(momentum.value());
- } else {
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- momentum_ptr = value_map[node->input(6)->unique()];
- }
-
- Val* eps_ptr = nullptr;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- if (auto eps = constant_as<float>(node->input(7))) {
- eps_ptr = new Double(eps.value());
- } else {
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- eps_ptr = value_map[node->input(7)->unique()];
- }
-
- auto result = batch_norm(
- input,
- weight,
- bias,
- running_mean,
- running_var,
- kTraining,
- momentum_ptr,
- eps_ptr);
-
- if (node->kind() ==
- c10::Symbol::fromQualString("aten::native_batch_norm")) {
- value_map.emplace(node->output(0)->unique(), result.output);
-
- value_map.emplace(node->output(1)->unique(), result.mean);
-
- value_map.emplace(node->output(2)->unique(), result.invstd);
- } else if (
- node->kind() ==
- c10::Symbol::fromQualString("aten::batch_norm")) {
- value_map.emplace(node->output()->unique(), result.output);
- } else if (
- node->kind() ==
- c10::Symbol::fromQualString("aten::_batch_norm_impl_index")) {
- value_map.emplace(node->output(0)->unique(), result.output);
-
- value_map.emplace(node->output(1)->unique(), result.mean);
-
- value_map.emplace(node->output(2)->unique(), result.invstd);
-
- // TODO: output 3 & 4 are not created
- // we are not creating these outputs because codegen
- // currently lacks the support.
- }
- },
- [](const Node* node) -> bool { return true; },
- [](const Node* node) -> OperatorType {
- return OperatorType::Normalization;
- });
- }
- }
-
- {
- auto ptr_op = getOperatorForLiteral(
- "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)");
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- // discard impl_index and reservedSpace since we don't use them
-
- auto input = value_map[node->input(1)->unique()]->as<TensorView>();
-
- auto grad_out =
- value_map[node->input(2)->unique()]->as<TensorView>();
-
- TensorView* weight = nullptr;
- if (!node->input(3)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- weight = value_map[node->input(3)->unique()]->as<TensorView>();
- }
-
- TensorView* running_mean = nullptr;
- if (!node->input(4)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- running_mean =
- value_map[node->input(4)->unique()]->as<TensorView>();
- }
-
- TensorView* running_var = nullptr;
- if (!node->input(5)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- running_var =
- value_map[node->input(5)->unique()]->as<TensorView>();
- }
-
- TensorView* save_mean = nullptr;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- if (!node->input(6)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- save_mean = value_map[node->input(6)->unique()]->as<TensorView>();
- }
-
- TensorView* save_invstd = nullptr;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- if (!node->input(7)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- save_invstd =
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- value_map[node->input(7)->unique()]->as<TensorView>();
- }
-
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- auto training = constant_as<bool>(node->input(8));
- TORCH_INTERNAL_ASSERT(
- training.has_value(),
- "The training (bool) parameter is required.");
- const bool kTraining = training.value();
-
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- Val* eps_ptr = nullptr;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- if (auto eps = constant_as<float>(node->input(9))) {
- eps_ptr = new Double(eps.value());
- } else {
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- eps_ptr = value_map[node->input(7)->unique()];
- }
-
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- auto out_mask_list = constant_as<c10::List<bool>>(node->input(10));
- TORCH_INTERNAL_ASSERT(
- out_mask_list.has_value(),
- "output mask for batch_norm_backward");
- std::vector<bool> output_mask;
- for (const auto value : out_mask_list->vec()) {
- output_mask.emplace_back(static_cast<bool>(value));
- }
-
- // TODO: merge this loop below.
- if (kTraining) {
- TORCH_INTERNAL_ASSERT(
- save_mean != nullptr && save_invstd != nullptr,
- "When training=True, save_mean and save_invstd are required.");
- } else {
- // TODO: this is not a legit assumption? Can't we run with
- // track_running_stats == false && training == false
- // which should just run through the case above.
- TORCH_INTERNAL_ASSERT(
- running_mean != nullptr && running_var != nullptr,
- "When training=False, running_mean and running_invstd are required.");
- }
-
- auto grads = batch_norm_backward(
- input,
- grad_out,
- weight,
- running_mean,
- running_var,
- save_mean,
- save_invstd,
- kTraining,
- eps_ptr,
- output_mask);
-
- if (output_mask[0]) {
- TORCH_INTERNAL_ASSERT(grads.grad_input != nullptr);
- value_map.emplace(node->output(0)->unique(), grads.grad_input);
- } else {
- TORCH_INTERNAL_ASSERT(grads.grad_input == nullptr);
- value_map.emplace(
- node->output(1)->unique(), TensorViewBuilder().build());
- }
-
- if (output_mask[1]) {
- TORCH_INTERNAL_ASSERT(grads.grad_weight != nullptr);
- value_map.emplace(node->output(1)->unique(), grads.grad_weight);
- } else {
- TORCH_INTERNAL_ASSERT(grads.grad_weight == nullptr);
- value_map.emplace(
- node->output(1)->unique(), TensorViewBuilder().build());
- }
-
- if (output_mask[2]) {
- TORCH_INTERNAL_ASSERT(grads.grad_bias != nullptr);
- value_map.emplace(node->output(2)->unique(), grads.grad_bias);
- } else {
- TORCH_INTERNAL_ASSERT(grads.grad_bias == nullptr);
- value_map.emplace(
- node->output(2)->unique(), TensorViewBuilder().build());
- }
- },
- [](const Node* node) -> bool { return true; },
- [](const Node* node) -> OperatorType {
- return OperatorType::Normalization;
- });
- }
-
- {
- std::array<const char*, kNumLayernormFwd> LayerNormFwd = {
- "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)",
- "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"};
- for (auto signature : LayerNormFwd) {
- auto ptr_op = getOperatorForLiteral(signature);
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- auto input =
- value_map[node->input(0)->unique()]->as<TensorView>();
-
- auto norm_shape_optional =
- constant_as<c10::List<int64_t>>(node->input(1));
- TORCH_INTERNAL_ASSERT(
- norm_shape_optional.has_value(),
- "The Normalized_Shape list is required.");
- auto norm_shape = norm_shape_optional->vec();
-
- TensorView* weight = nullptr;
- if (!node->input(2)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- weight = value_map[node->input(2)->unique()]->as<TensorView>();
- }
-
- TensorView* bias = nullptr;
- if (!node->input(3)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- bias = value_map[node->input(3)->unique()]->as<TensorView>();
- }
-
- Val* eps_ptr = nullptr;
- if (auto eps = constant_as<float>(node->input(4))) {
- eps_ptr = new Double(eps.value());
- } else {
- eps_ptr = value_map[node->input(4)->unique()];
- }
-
- auto result =
- layer_norm(input, norm_shape, weight, bias, eps_ptr);
-
- if (node->kind() ==
- c10::Symbol::fromQualString("aten::native_layer_norm")) {
- value_map.emplace(node->output(0)->unique(), result.output);
- value_map.emplace(node->output(1)->unique(), result.mean);
- value_map.emplace(node->output(2)->unique(), result.invstd);
- } else if (
- node->kind() ==
- c10::Symbol::fromQualString("aten::layer_norm")) {
- value_map.emplace(node->output()->unique(), result.output);
- }
- },
- // TODO: #ProfileIValue List should update this
- [](const Node* node) -> bool { return true; },
- [](const Node* node) -> OperatorType {
- return OperatorType::Normalization;
- });
- }
- }
-
- {
- auto ptr_op = getOperatorForLiteral(
- "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)");
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- auto grad_out =
- value_map[node->input(0)->unique()]->as<TensorView>();
-
- auto input = value_map[node->input(1)->unique()]->as<TensorView>();
-
- auto norm_shape_optional =
- constant_as<c10::List<int64_t>>(node->input(2));
- TORCH_INTERNAL_ASSERT(
- norm_shape_optional.has_value(),
- "The Normalized_Shape list is required.");
- auto norm_shape = norm_shape_optional->vec();
-
- auto mean = value_map[node->input(3)->unique()]->as<TensorView>();
- auto rstd = value_map[node->input(4)->unique()]->as<TensorView>();
-
- TensorView* weight = nullptr;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- if (!node->input(5)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- weight = value_map[node->input(5)->unique()]->as<TensorView>();
- }
-
- TensorView* bias = nullptr;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- if (!node->input(6)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- bias = value_map[node->input(6)->unique()]->as<TensorView>();
- }
-
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- auto output_mask_optional =
- constant_as<c10::List<bool>>(node->input(7));
- TORCH_INTERNAL_ASSERT(
- output_mask_optional.has_value(),
- "output mask for layer_norm_backward");
- std::vector<bool> output_mask = output_mask_optional->vec();
-
- auto grad = layer_norm_backward(
- grad_out,
- input,
- norm_shape,
- mean,
- rstd,
- weight,
- bias,
- output_mask);
-
- if (output_mask[0]) {
- TORCH_INTERNAL_ASSERT(grad.grad_input != nullptr);
- value_map.emplace(node->output(0)->unique(), grad.grad_input);
- } else {
- TORCH_INTERNAL_ASSERT(grad.grad_input == nullptr);
- value_map.emplace(
- node->output(0)->unique(), TensorViewBuilder().build());
- }
-
- if (output_mask[1] && weight != nullptr) {
- TORCH_INTERNAL_ASSERT(grad.grad_weight != nullptr);
- value_map.emplace(node->output(1)->unique(), grad.grad_weight);
- } else {
- TORCH_INTERNAL_ASSERT(grad.grad_weight == nullptr);
- value_map.emplace(
- node->output(1)->unique(), TensorViewBuilder().build());
- }
-
- if (output_mask[2] && bias != nullptr) {
- TORCH_INTERNAL_ASSERT(grad.grad_bias != nullptr);
- value_map.emplace(node->output(2)->unique(), grad.grad_bias);
- } else {
- TORCH_INTERNAL_ASSERT(grad.grad_bias == nullptr);
- value_map.emplace(
- node->output(2)->unique(), TensorViewBuilder().build());
- }
- },
- // TODO: #ProfileIValue List should update this
- [](const Node* node) -> bool { return true; },
- [](const Node* node) -> OperatorType {
- return OperatorType::Normalization;
- });
- }
-
- {
- auto ptr_op = getOperatorForLiteral(
- "aten::softmax.int(Tensor self, int dim, int? dtype) -> Tensor");
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- auto input = value_map[node->input(0)->unique()]->as<TensorView>();
-
- auto dim_value = constant_as<int>(node->input(1));
- TORCH_INTERNAL_ASSERT(
- dim_value.has_value(), "dim in softmax is not valid");
-
- auto output = softmax(input, dim_value.value());
- value_map.emplace(node->output()->unique(), output);
- },
- [](const Node* node) -> bool {
- if (node->inputs()[1]->node()->kind() != prim::Constant) {
- return false;
- }
- if (!node->inputs()[2]->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- return false;
- }
- return true;
- },
- [](const Node* node) -> OperatorType {
- return OperatorType::Normalization;
- });
- }
-
- {
- auto ptr_op = getOperatorForLiteral(
- "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor");
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- auto grad_output =
- value_map[node->input(0)->unique()]->as<TensorView>();
-
- auto output = value_map[node->input(1)->unique()]->as<TensorView>();
-
- auto dim_value = constant_as<int>(node->input(2));
- TORCH_INTERNAL_ASSERT(
- dim_value.has_value(), "dim in softmax is not valid");
-
- auto input = value_map[node->input(3)->unique()]->as<TensorView>();
-
- auto grad_input =
- softmax_backward(grad_output, output, dim_value.value(), input);
- value_map.emplace(node->output()->unique(), grad_input);
- },
- [](const Node* node) -> bool {
- if (node->inputs()[2]->node()->kind() != prim::Constant) {
- return false;
- }
- return true;
- },
- [](const Node* node) -> OperatorType {
- return OperatorType::Normalization;
});
}
{
auto ptr_op = getOperatorForLiteral(
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
- REGISTER_PARSE_RULE(
+ registerParseRule(
ptr_op,
- {
+ [](const Node* node,
+ std::unordered_map<size_t, CgValue>& value_map) -> void {
auto self = value_map[node->input(0)->unique()];
auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
TORCH_INTERNAL_ASSERT(
- dims_list.has_value(),
- "aten::sum cannot be fused with dynamic axes");
- std::vector<int> dims;
- for (const auto dim : dims_list->vec()) {
- dims.emplace_back(static_cast<int>(dim));
- }
+ dims_list.has_value(), "requires static reduce axes");
auto keepdim = constant_as<bool>(node->input(2));
- TORCH_INTERNAL_ASSERT(
- keepdim.has_value(),
- "aten::sum cannot be fused with dynamic keepdim");
- auto out = sum(self->as<TensorView>(), dims, keepdim.value());
- value_map.emplace(node->output()->unique(), out);
- },
- [](const Node* node) -> bool {
- // TODO: support cast of output types
- if (!node->inputs()[3]->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- // We can only handle output as half, float, and double;
- if (const auto opt_ivalue = toIValue(node->input(3))) {
- const auto scalar_type = opt_ivalue->toScalarType();
- if (scalar_type == at::ScalarType::Double ||
- scalar_type == at::ScalarType::Float ||
- scalar_type == at::ScalarType::Half) {
- return true;
- }
- }
- return false;
- }
- // we don't support dynamic reduction axes;
- if (node->inputs()[1]->node()->kind() != prim::Constant) {
- return false;
- }
- // we don't support dynamic keepdim yet;
- if (node->inputs()[2]->node()->kind() != prim::Constant) {
- return false;
- }
- return true;
- },
- [](const Node* node) -> OperatorType {
- return OperatorType::Reduction;
- });
- }
-
- {
- auto ptr_op = getOperatorForLiteral(
- "aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor");
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- auto self = value_map[node->input(0)->unique()]->as<TensorView>();
- auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
- TORCH_INTERNAL_ASSERT(
- dims_list.has_value(),
- "aten::mean cannot be fused with dynamic axes");
std::vector<int> dims;
for (const auto dim : dims_list->vec()) {
dims.emplace_back(static_cast<int>(dim));
}
- auto keepdim = constant_as<bool>(node->input(2));
TORCH_INTERNAL_ASSERT(
- keepdim.has_value(),
- "aten::mean cannot be fused with dynamic keepdim");
- auto o_sum = sum(self, dims, keepdim.value());
- Val* num_features = new Double(1);
- for (const auto axis : dims) {
- num_features =
- mul(num_features, self->domain()->domain()[axis]->extent());
- }
- auto out = div(o_sum, num_features);
+ keepdim.has_value() && !keepdim.value(),
+ "Keep dim in reduction is not a const false");
+ auto out = sum(self->as<TensorView>(), dims);
value_map.emplace(node->output()->unique(), out);
},
[](const Node* node) -> bool {
- // TODO: support cast of output types
+ // TODO: support cast of output types yet;
if (!node->inputs()[3]->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
- // We can only handle output as half, float, and double;
+ // We can only handle output as half and float;
if (const auto opt_ivalue = toIValue(node->input(3))) {
const auto scalar_type = opt_ivalue->toScalarType();
- if (scalar_type == at::ScalarType::Double ||
- scalar_type == at::ScalarType::Float ||
+ if (scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half) {
return true;
}
if (node->inputs()[1]->node()->kind() != prim::Constant) {
return false;
}
- // we don't support dynamic keepdim yet;
- if (node->inputs()[2]->node()->kind() != prim::Constant) {
+ // we don't support keepdim yet;
+ if (node->inputs()[2]->node()->kind() != prim::Constant ||
+ *constant_as<bool>(node->input(2))) {
return false;
}
return true;
},
- [](const Node* node) -> OperatorType {
- return OperatorType::Reduction;
- });
- }
- {
- std::array<const char*, kNumSumToSize> SumToSize = {
- "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)",
- "aten::sum_to_size(Tensor self, int[] size) -> Tensor"};
- for (auto signature : SumToSize) {
- auto ptr_op = getOperatorForLiteral(signature);
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- auto self = value_map[node->input(0)->unique()];
- auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
- TORCH_INTERNAL_ASSERT(
- size_to.has_value(),
- "aten::sum cannot be fused with dynamic axes");
- if (!size_to->empty()) {
- auto out = sum_to(self->as<TensorView>(), size_to->vec());
- value_map.emplace(node->output()->unique(), out);
- } else {
- // We are introducing alias here!
- value_map.emplace(node->output()->unique(), self);
- }
- },
- [](const Node* node) -> bool {
- // we don't support dynamic reduction axes;
- if (node->inputs()[1]->node()->kind() != prim::Constant) {
- return false;
- }
- return true;
- // auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
- // return size_to.has_value() && !size_to->empty();
- },
- [](const Node* node) -> OperatorType {
- auto size_to = constant_as<c10::List<int64_t>>(node->input(1));
- // technically size_to->empty() should never occur, as specialized
- // _grad_sum_to_size should have been removed by optimization pass
- if (size_to->empty()) {
- return OperatorType::ElementWise;
- } else {
- return OperatorType::ReductionToSize;
- }
- });
- }
- }
-
- // Limiting aten::to implementation to only change the dtype of a tensor
- {
- auto ptr_op = getOperatorForLiteral(
- "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor");
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- const auto self = value_map[node->input(0)->unique()];
-
- // we need static type for cast
- TORCH_INTERNAL_ASSERT(
- node->input(1)->node()->kind() == prim::Constant);
- auto dtype = toIValue(node->input(1))->toScalarType();
-
- // We want to keep our internal fusion math in FP32
- // Shape Inference will continue to propagate the right
- // type to outputs unchanged.
- if (dtype == at::ScalarType::Half) {
- dtype = at::ScalarType::Float;
- }
-
- auto out = castOp(aten_to_data_type(dtype), self);
- value_map.emplace(node->output()->unique(), out);
- },
- nullptr,
- nullptr);
- }
-
- {
- auto ptr_op = getOperatorForLiteral(
- "aten::type_as(Tensor self, Tensor other) -> Tensor");
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- auto self = value_map[node->inputs()[0]->unique()];
-
- // TODO: switch to PyTorch dtype as it's closer to truth.
- // For now, reality is that PyTorch IR profiling information could
- // be missing even with profiling executor, due to upstream
- // transformations between profiling runs to fusion pass.
- auto opt_dtype =
- value_map[node->inputs()[1]->unique()]->getDataType();
- TORCH_INTERNAL_ASSERT(opt_dtype.has_value());
-
- auto out = castOp(opt_dtype.value(), self);
- value_map.emplace(node->output()->unique(), out);
- },
- nullptr,
- nullptr);
- }
-
- {
- // We are not fusing `linear` yet, because we can't codegen efficient gemm
- // However, we still need this here, so PE would insert profile node for
- // this node.
- // During fusion pass, We decompose linear into gemm + elementwise.
- auto ptr_op = getOperatorForLiteral(
- "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor");
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- // this entry is created so we do profile input tensors;
- TORCH_INTERNAL_ASSERT(false, "not implemented yet");
- },
- [](const Node* node) -> bool {
- // We only profile `linear` layer with bias.
- if (node->input(2)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- return false;
- }
- return true;
- });
- }
-
- {
- auto ptr_op = getOperatorForLiteral(
- "prim::add_optional(Tensor(a) input, Tensor? bias) -> Tensor(a)");
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- // this entry is created so we do profile input tensors;
- if (node->input(1)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- // forwarding the value;
- value_map.emplace(
- node->output()->unique(),
- value_map[node->inputs()[0]->unique()]);
- } else {
- auto lhs = value_map[node->inputs()[0]->unique()];
- auto rhs = value_map[node->inputs()[1]->unique()];
-
- auto out = binaryOp(BinaryOpType::Add, lhs, rhs);
- value_map.emplace(node->output()->unique(), out);
- }
- },
- nullptr,
- nullptr);
- }
-
- {
- auto ptr_op = getOperatorForLiteral("aten::gelu(Tensor self) -> Tensor");
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- auto self = value_map[node->inputs()[0]->unique()];
- auto output = unaryOp(UnaryOpType::Gelu, self);
- value_map.emplace(node->output()->unique(), output);
- },
- nullptr,
- nullptr);
- }
-
- {
- auto ptr_op = getOperatorForLiteral(
- "aten::gelu_backward(Tensor grad, Tensor self) -> Tensor");
- REGISTER_PARSE_RULE(
- ptr_op,
- {
- auto grad = value_map[node->inputs()[0]->unique()];
- auto self = value_map[node->inputs()[1]->unique()];
- auto grad_in = gelu_backward(grad, self);
- value_map.emplace(node->output()->unique(), grad_in);
- },
- nullptr,
- nullptr);
+ true);
}
}
*node);
}
} else {
- auto reg_entry = lookupInRegistry(node);
+ auto iter = IrParser::jit_operator_registry_.find(node->kind());
+ // make sure we have a parser for the op;
+ TORCH_INTERNAL_ASSERT(
+ iter != IrParser::jit_operator_registry_.end(),
+ "CudaFusionGroup Parser doesn't handle operator kind(): ",
+ node->kind().toDisplayString());
+ for (auto& pair_op_func : iter->second) {
+ if (node->matches(pair_op_func.first->schema())) {
+ pair_op_func.second.parse(node, value_map_);
+ return;
+ }
+ }
TORCH_INTERNAL_ASSERT(
- reg_entry != nullptr,
- "CudaFusionGroup Parser doesn't handle node: ",
+ false,
+ "CudaFusionGroup Parser doesn't recognize operator overload:",
canonicalSchemaString(node->schema()));
- reg_entry->parse(node, value_map_);
}
}
if (val->type()->isSubtypeOf(static_cast<c10::TypePtr>(FloatType::get()))) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
CgValue cg_val;
- if (auto ival = constant_as<double>(val)) {
- cg_val = new Double(ival.value());
+ // NOLINTNEXTLINE(bugprone-branch-clone)
+ if (auto ival = constant_as<float>(val)) {
+ cg_val = new Float(ival.value());
} else {
- cg_val = new Double();
+ cg_val = new Float();
}
value_map_.emplace(val->unique(), cg_val);
return true;
static_cast<c10::TypePtr>(IntType::get()))) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
CgValue cg_val;
- if (auto ival = constant_as<int64_t>(val)) {
+ // NOLINTNEXTLINE(bugprone-branch-clone)
+ if (auto ival = constant_as<int>(val)) {
cg_val = new Int(ival.value());
} else {
cg_val = new Int();
static_cast<c10::TypePtr>(BoolType::get()))) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
CgValue cg_val;
+ // NOLINTNEXTLINE(bugprone-branch-clone)
if (auto ival = constant_as<bool>(val)) {
cg_val = new Bool(ival.value());
} else {
bool registerTensor(const JitValue* val) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
CgValue cg_val;
- // Don't register if we don't support the type
- if (auto tensor_type = val->type()->cast<c10::TensorType>()) {
- if (!tensor_type->scalarType().has_value()) {
- return false;
- }
-
- if (aten_to_data_type(tensor_type->scalarType().value()) ==
- DataType::Null) {
- return false;
- }
-
+ if (auto tensor_type = val->type()->cast<TensorType>()) {
// TODO: make this a static function in Tensor class;
// create tensor;
cg_val = new TensorView(tensor_type);
// maps from JitValue::unique() to fusion Val;
std::unordered_map<size_t, CgValue> value_map_;
// parsing rule registry.
- static std::unordered_map<std::string, RegistrationEntry>
- jit_operator_registry_; // NOLINT
-
- // pointing cached entry stored in `jit_operator_registry_`
- static std::unordered_map<const FunctionSchema*, const RegistrationEntry*>
- cached_registry_lookup_; // NOLINT
-
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+ static std::unordered_map<
+ Symbol,
+ std::vector<std::pair<std::shared_ptr<Operator>, RegistrationEntry>>>
+ jit_operator_registry_;
+ static std::unordered_set<Symbol> jit_reduction_op_registry_;
static bool init_registry_;
};
-std::unordered_map<std::string, IrParser::RegistrationEntry>
- IrParser::jit_operator_registry_; // NOLINT
-std::unordered_map<const FunctionSchema*, const IrParser::RegistrationEntry*>
- IrParser::cached_registry_lookup_; // NOLINT
-
-// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+std::unordered_map<
+ Symbol,
+ std::vector<
+ std::pair<std::shared_ptr<Operator>, IrParser::RegistrationEntry>>>
+ IrParser::jit_operator_registry_;
+std::unordered_set<Symbol> IrParser::jit_reduction_op_registry_;
bool IrParser::init_registry_ = true;
-ProfileIValueOp* insertProfileIValueOp(
- Node* node,
- size_t offset,
- ProfilingRecord* pr) {
- auto in_val = node->input(offset);
- auto pn = pr->createProfileIValueNode(in_val);
- pn->insertBefore(node);
- node->replaceInput(offset, pn->output());
- return pn;
-}
-
-void profileSize(ProfilingRecord* pr, Node* node, size_t offset) {
- auto pn = insertProfileIValueOp(node, offset, pr);
-
- const auto ivalue_profiler = [pr, pn](Stack& stack) {
- std::lock_guard<std::mutex> lock(pr->mutex_);
-
- // TODO: we don't care about merging multiple profiling runs as we don't
- // support it at all;
- int64_t frame_id = 0;
- pop(stack, frame_id);
- IValue value;
- pop(stack, value);
-
- std::vector<int64_t> size_vec;
- if (value.isIntList()) {
- size_vec = value.toIntVector();
- } else if (value.isNone()) {
- size_vec.clear();
- } else {
- TORCH_INTERNAL_ASSERT(
- false, "profileSize does not support data type: ", value.tagKind());
- }
- if (!pn->hasAttribute(sizeAttr)) {
- pn->is_(sizeAttr, size_vec);
- } else {
- auto profiled_ints = pn->is(sizeAttr);
- TORCH_INTERNAL_ASSERT(
- profiled_ints.size() == size_vec.size() &&
- std::equal(
- profiled_ints.begin(), profiled_ints.end(), size_vec.begin()),
- "profiling ivalue doesn't support merge");
- }
- push(stack, value);
- };
- pn->setCallback(ivalue_profiler);
-}
-
-void profileIntList(ProfilingRecord* pr, Node* node, size_t offset) {
- auto pn = insertProfileIValueOp(node, offset, pr);
-
- const auto ivalue_profiler = [pr, pn](Stack& stack) {
- std::lock_guard<std::mutex> lock(pr->mutex_);
-
- // TODO: we don't care about merging multiple profiling runs as we don't
- // support it at all;
- int64_t frame_id = 0;
- pop(stack, frame_id);
- IValue value;
- pop(stack, value);
- TORCH_INTERNAL_ASSERT(
- value.isIntList(), "profiling seeing the wrong data type");
- if (!pn->hasAttribute(intListAttr)) {
- pn->is_(intListAttr, value.toIntVector());
- } else {
- auto profiled_ints = pn->is(intListAttr);
- auto input_ints = value.toIntList();
- TORCH_INTERNAL_ASSERT(
- profiled_ints.size() == input_ints.size() &&
- std::equal(
- profiled_ints.begin(),
- profiled_ints.end(),
- input_ints.begin()),
- "profiling ivalue doesn't support merge");
- }
- push(stack, value);
- };
-
- pn->setCallback(ivalue_profiler);
-}
-
-void profileBool(ProfilingRecord* pr, Node* node, size_t offset) {
- auto pn = insertProfileIValueOp(node, offset, pr);
-
- const auto ivalue_profiler = [pr, pn](Stack& stack) {
- std::lock_guard<std::mutex> lock(pr->mutex_);
-
- // TODO: we don't care about merging multiple profiling runs as we don't
- // support it at all;
- int64_t frame_id = 0;
- pop(stack, frame_id);
- IValue value;
- pop(stack, value);
- TORCH_INTERNAL_ASSERT(
- value.isBool(), "profiling seeing the wrong data type");
- if (!pn->hasAttribute(boolAttr)) {
- pn->i_(boolAttr, value.toBool());
- } else {
- auto profiled_bool = pn->i(boolAttr);
- auto input_bool = value.toBool();
- TORCH_INTERNAL_ASSERT(
- input_bool == profiled_bool,
- "profiling ivalue doesn't support merge");
- }
- push(stack, value);
- };
-
- pn->setCallback(ivalue_profiler);
-}
-
-void profileInt(ProfilingRecord* pr, Node* node, size_t offset) {
- auto pn = insertProfileIValueOp(node, offset, pr);
-
- const auto ivalue_profiler = [pr, pn](Stack& stack) {
- std::lock_guard<std::mutex> lock(pr->mutex_);
-
- // TODO: we don't care about merging multiple profiling runs as we don't
- // support it at all;
- int64_t frame_id = 0;
- pop(stack, frame_id);
- IValue value;
- pop(stack, value);
- TORCH_INTERNAL_ASSERT(
- value.isInt(), "profiling seeing the wrong data type");
- if (!pn->hasAttribute(intAttr)) {
- pn->i_(intAttr, value.toInt());
- } else {
- auto profiled_int = pn->i(intAttr);
- auto input_int = value.toInt();
- TORCH_INTERNAL_ASSERT(
- input_int == profiled_int, "profiling ivalue doesn't support merge");
- }
- push(stack, value);
- };
-
- pn->setCallback(ivalue_profiler);
-}
-
-void profileBoolList(ProfilingRecord* pr, Node* node, size_t offset) {
- auto pn = insertProfileIValueOp(node, offset, pr);
-
- const auto ivalue_profiler = [pr, pn](Stack& stack) {
- std::lock_guard<std::mutex> lock(pr->mutex_);
-
- // TODO: we don't care about merging multiple profiling runs as we don't
- // support it at all;
- int64_t frame_id = 0;
- pop(stack, frame_id);
- IValue value;
- pop(stack, value);
- TORCH_INTERNAL_ASSERT(
- value.isBoolList(), "profiling seeing the wrong data type");
- if (!pn->hasAttribute(boolListAttr)) {
- auto list = value.toBoolList();
- std::vector<int64_t> val(list.begin(), list.end());
- pn->is_(boolListAttr, val);
- } else {
- auto profiled_ints = pn->is(boolListAttr);
- auto input_bools = value.toBoolList();
- TORCH_INTERNAL_ASSERT(
- profiled_ints.size() == input_bools.size() &&
- std::equal(
- input_bools.begin(),
- input_bools.end(),
- profiled_ints.begin()),
- "profiling ivalue doesn't support merge");
- }
- push(stack, value);
- };
-
- pn->setCallback(ivalue_profiler);
-}
+} // namespace
-bool anyInBlock(
- const Block* block,
- const std::function<bool(const Node*)>& fn) {
+bool hasReductionNode(const Block* block) {
for (auto node : block->nodes()) {
- if (fn(node)) {
+ if (isReductionNode(node)) {
return true;
}
for (auto block : node->blocks()) {
- if (anyInBlock(block, fn)) {
+ if (hasReductionNode(block)) {
return true;
}
}
return false;
}
-} // namespace
-
-bool hasReductionNode(const Block* block) {
- return anyInBlock(block, isReductionNode);
-}
-
bool isReductionNode(const Node* node) {
return IrParser::isReductionNode(node);
}
-bool isReductionToSizeNode(const Node* node) {
- return IrParser::isReductionToSizeNode(node);
-}
-
-bool hasNormalizationNode(const Block* block) {
- return anyInBlock(block, isNormalizationNode);
-}
-
-bool isNormalizationNode(const Node* node) {
- return IrParser::isNormalizationNode(node);
-}
-
-bool isElementWiseNode(const Node* node) {
- return IrParser::isElementWiseNode(node);
-}
-
bool isNodeParsible(const Node* node) {
return IrParser::canParseNode(node);
}
-bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
- // is skip constant necessary?
- if (node->input(offset)->node()->kind() == prim::Constant) {
- return false;
- }
-
- static auto dropout_schema =
- getOperatorForLiteral(
- "aten::dropout(Tensor input, float p, bool train) -> Tensor")
- ->schema();
- if (node->matches(dropout_schema)) {
- switch (offset) {
- // argument 2: Is training?
- case 2:
- profileBool(pr, node, offset);
- break;
- default:
- return false;
- }
- return true;
- }
-
- static auto reduction_operator_schema =
- getOperatorForLiteral(
- "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)")
- ->schema();
- if (node->matches(reduction_operator_schema)) {
- switch (offset) {
- // argument 1: reduction axes;
- case 1:
- profileIntList(pr, node, offset);
- break;
- // argument 2: keepdim;
- case 2:
- profileBool(pr, node, offset);
- break;
- default:
- return false;
- }
- return true;
- }
-
- static auto sum_to_size_schema =
- getOperatorForLiteral(
- "aten::sum_to_size(Tensor self, int[] size) -> Tensor")
- ->schema();
- static auto grad_sum_to_size_schema =
- getOperatorForLiteral(
- "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)")
- ->schema();
- if (node->matches(sum_to_size_schema) ||
- node->matches(grad_sum_to_size_schema)) {
- switch (offset) {
- // argument 1: reduction sizes;
- case 1:
- // TODO(profile_size): double check optional[size]?
- profileSize(pr, node, offset);
- break;
- default:
- return false;
- }
- return true;
- }
-
- static auto batch_norm_impl_index_schema =
- getOperatorForLiteral(
- "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)")
- ->schema();
- static auto native_batch_norm_schema =
- getOperatorForLiteral(
- "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")
- ->schema();
- static auto batch_norm_schema =
- getOperatorForLiteral(
- "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")
- ->schema();
- static auto instance_norm_schema =
- getOperatorForLiteral(
- "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor")
- ->schema();
- if (node->matches(native_batch_norm_schema) ||
- node->matches(batch_norm_impl_index_schema) ||
- node->matches(batch_norm_schema) || node->matches(instance_norm_schema)) {
- switch (offset) {
- // argument 5: training;
- case 5:
- profileBool(pr, node, offset);
- break;
- default:
- return false;
- }
- return true;
- }
-
- static auto native_layer_norm_schema =
- getOperatorForLiteral(
- "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)")
- ->schema();
- static auto layer_norm_schema =
- getOperatorForLiteral(
- "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor")
- ->schema();
- if (node->matches(native_layer_norm_schema) ||
- node->matches(layer_norm_schema)) {
- switch (offset) {
- case 1:
- profileIntList(pr, node, offset);
- break;
- default:
- return false;
- }
- return true;
- }
-
- static auto native_batch_norm_backward_schema =
- getOperatorForLiteral(
- "aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)")
- ->schema();
- if (node->matches(native_batch_norm_backward_schema)) {
- switch (offset) {
- // argument 7: training;
- case 7:
- profileBool(pr, node, offset);
- break;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- case 9:
- profileBoolList(pr, node, offset);
- break;
- default:
- return false;
- }
- return true;
- }
-
- static auto batch_norm_impl_index_backward_schema =
- getOperatorForLiteral(
- "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)")
- ->schema();
- if (node->matches(batch_norm_impl_index_backward_schema)) {
- switch (offset) {
- // TODO: guard impl_index, but I think that's not needed;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- case 8: // argument 8: training;
- profileBool(pr, node, offset);
- break;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- case 10:
- profileBoolList(pr, node, offset);
- break;
- default:
- return false;
- }
- return true;
- }
-
- static auto native_layer_norm_backward_schema =
- getOperatorForLiteral(
- "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)")
- ->schema();
- if (node->matches(native_layer_norm_backward_schema)) {
- switch (offset) {
- case 2:
- profileIntList(pr, node, offset);
- break;
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- case 7:
- profileBoolList(pr, node, offset);
- break;
- default:
- return false;
- }
- return true;
- }
-
- static auto to_dtype_schema =
- getOperatorForLiteral(
- "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor")
- ->schema();
- if (node->matches(to_dtype_schema)) {
- switch (offset) {
- case 1:
- profileInt(pr, node, offset);
- return true;
- default:
- return false;
- }
- }
-
- return false;
-}
-
-void insertProfileNodesForCUDAFuser_(Block* block, ProfilingRecord* pr) {
- for (const auto& n : block->nodes()) {
- for (size_t offset = 0; offset < n->inputs().size(); offset++) {
- insertProfileIValue(pr, n, offset);
- }
-
- for (auto ib : n->blocks()) {
- insertProfileNodesForCUDAFuser_(ib, pr);
- }
- }
-}
-
-void InsertProfileNodes(ProfilingRecord* pr) {
- insertProfileNodesForCUDAFuser_(pr->profiled_graph_->block(), pr);
-}
-
std::unique_ptr<Fusion> parseJitIR(const std::shared_ptr<Graph>& graph) {
FUSER_PERF_SCOPE("parseJitIR");
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/ir/ir.h>
-#include <torch/csrc/jit/runtime/profiling_record.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
constexpr int kNonFcdReductionThreadY = 32;
TORCH_CUDA_CU_API bool hasReductionNode(const Block* block);
-TORCH_CUDA_CU_API bool isReductionToSizeNode(const Node* node);
-TORCH_CUDA_CU_API bool isReductionNode(const Node* node);
-
-TORCH_CUDA_CU_API bool hasNormalizationNode(const Block* block);
-TORCH_CUDA_CU_API bool isNormalizationNode(const Node* node);
-TORCH_CUDA_CU_API bool isElementWiseNode(const Node* node);
+TORCH_CUDA_CU_API bool isReductionNode(const Node* node);
// returns whether or not a parsing function exists for the given node type.
TORCH_CUDA_CU_API bool isNodeParsible(const Node* node);
-void InsertProfileNodes(ProfilingRecord* pr);
-
// lowers PyTorch jit graph to `Fusion`.
TORCH_CUDA_CU_API std::unique_ptr<Fusion> parseJitIR(
const std::shared_ptr<Graph>& graph);
#include <torch/csrc/jit/codegen/cuda/partition.h>
#include <ATen/core/jit_type.h>
-#include <ATen/cuda/CUDAContext.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/parser.h>
namespace {
-bool hasNonElementWiseOperation(const Node* node) {
- if (node->kind() == prim::CudaFusionGroup) {
- for (auto n : node->g(attr::Subgraph)->nodes()) {
- if (hasNonElementWiseOperation(n)) {
- return true;
- }
- }
- } else {
- // prim::Constant is not parsible, but it is also not nonElementWise
- if (node->kind() != prim::Constant && !isElementWiseNode(node)) {
- return true;
- }
- }
- return false;
-}
-
// Check all outputs are:
// 1. TensorType
// 2. on the same device;
return c10::nullopt;
}
-static bool isFusibleDevice(const Node* node, const c10::Device device) {
+static bool isFusableDevice(const Node* node, const c10::Device device) {
for (auto value : node->outputs()) {
auto output_device = getDevice(value);
if (output_device.has_value() && output_device.value() != device) {
}
// TODO: we need to check input type when we handle `to()`
-static bool isFusibleDevice(const Node* node) {
+static bool isFusableDevice(const Node* node) {
auto device = getDevice(node);
if (!device.has_value()) {
return true;
}
- return device->is_cuda() &&
- (at::cuda::getDeviceProperties(device->index())->major >= 7 ||
- !hasNonElementWiseOperation(node));
-}
-
-bool compatibleType(const torch::jit::Value* val) {
- if (auto tensor_type = val->type()->cast<c10::TensorType>()) {
- if (tensor_type->scalarType().has_value()) {
- if (aten_to_data_type(tensor_type->scalarType().value()) ==
- DataType::Null) {
- return false;
- }
- }
- }
- return true;
+ return device->is_cuda();
}
-bool checkInputTensorTypes(const Node* node) {
- for (size_t i = 0; i < node->inputs().size(); i++) {
- const auto& val = node->inputs()[i];
- if (!compatibleType(val)) {
- // special case on aten::_batch_norm_impl_index_backward, the 11th output
- // is going to be discarded, so no need to check data type there.
- if (node->kind() ==
- c10::Symbol::fromQualString(
- "aten::_batch_norm_impl_index_backward") &&
- i == 11) {
- continue;
- }
- return false;
- }
- }
- return true;
-}
-
-bool checkOutputTensorTypes(const Node* node) {
- for (size_t i = 0; i < node->outputs().size(); i++) {
- const auto& val = node->outputs()[i];
- if (!compatibleType(val)) {
- // special case on aten::_batch_norm_impl_index, the 4th output
- // is going to be discarded, so no need to check data type there.
- if (node->kind() ==
- c10::Symbol::fromQualString("aten::_batch_norm_impl_index") &&
- i == 3) {
- continue;
- }
- return false;
- }
- }
- return true;
+inline bool isFusableNode(const Node* node) {
+ // checks if node is compatible with parser:
+ // 1. if we have a parsing rule; or 2. if the node is already a fusion group.
+ return (isNodeParsible(node) || node->kind() == prim::CudaFusionGroup);
}
-inline bool isFusibleNode(const Node* node) {
- if (node->kind() == prim::CudaFusionGroup)
+bool hasReductionOperation(const Node* node) {
+ if (isReductionNode(node)) {
return true;
- // Check we have a parsing rule
- bool isFusible = isNodeParsible(node);
- // Check if we have a tensor type it's one we support
- isFusible = isFusible && checkInputTensorTypes(node);
- isFusible = isFusible && checkOutputTensorTypes(node);
- // Check if already part of a fusion group
- return isFusible;
-}
-
-bool maybeBroadcast(
- const TensorTypePtr& type,
- const std::vector<c10::optional<int64_t>>& shape) {
- if (type->dim()) {
- if (type->dim().value() < shape.size()) {
- // no broadcast for reduction operation;
- return false;
- } else if (type->dim().value() > shape.size()) {
- // increased rank means there is reduction;
- return true;
- } else {
- // same rank, we need to iterate through sizes and check if size-1
- // exists in input `shape`
- for (const auto& opt_size : shape) {
- // TODO: not sure if we need to check for output size != 1, since we
- // are currently marking all size-1 dimension as broadcast in codegen.
- if (opt_size.has_value() && opt_size.value() == 1) {
- return true;
- }
+ }
+ if (node->kind() == prim::CudaFusionGroup) {
+ for (auto n : node->g(attr::Subgraph)->nodes()) {
+ if (hasReductionOperation(n)) {
+ return true;
}
}
}
bool maybeBroadcastOnShape(
const Node* n,
const std::vector<c10::optional<int64_t>>& shape) {
- // TODO: we are only checking output 0. This means that our current check for
- // normalization is not complete.
+ TORCH_INTERNAL_ASSERT(
+ n->outputs().size() == 1,
+ "not expecting multiple outputs from a node, graph partitioning logic needs to be updated");
// assumes that if output is not a tensor type, it's not broadcasting
if (auto out_type = n->output(0)->type()->cast<TensorType>()) {
- return maybeBroadcast(out_type, shape);
- }
- return false;
-};
-
-// return true if node is pointwise operation and input tensors all have
-// identical shape.
-bool isNonBroadcastElementWise(const Node* n) {
- if (hasNonElementWiseOperation(n)) {
- return false;
- }
-
- for (const auto output : n->outputs()) {
- const auto& n_output_type = output->type()->cast<TensorType>();
-
- // TODO: we need to stay on safer side instead of "default to return true
- // when shape information is not available.", Change that when we enable
- // profiling on autodiff FW execution.
- if (n_output_type != nullptr && n_output_type->sizes().sizes()) {
- const std::vector<c10::optional<int64_t>>& n_output_shape =
- n_output_type->sizes().sizes().value();
-
- for (auto input : n->inputs()) {
- if (auto t_type = input->type()->cast<TensorType>()) {
- if (maybeBroadcast(t_type, n_output_shape)) {
- return false;
+ if (out_type->dim()) {
+ if (out_type->dim().value() < shape.size()) {
+ // no broadcast for reduction operation;
+ return false;
+ } else if (out_type->dim().value() > shape.size()) {
+ // increased rank means there is reduction;
+ return true;
+ } else {
+ // same rank, we need to iterate through sizes and check if size-1
+ // exists in input `shape`
+ for (const auto& opt_size : shape) {
+ // TODO: not sure if we need to check for output size != 1, since we
+ // are currently marking all size-1 dimension as broadcast in codegen.
+ if (opt_size.has_value() && opt_size.value() == 1) {
+ return true;
}
}
}
}
}
-
- return true;
-}
+ return false;
+};
//! [ Note - tricky broadcasting ]
//!
} // namespace
-bool isFusibleCudaFusionGroup(const Node* node) {
- FUSER_PERF_SCOPE("isFusibleCudaFusionGroup");
+bool isFusableCudaFusionGroup(const Node* node) {
+ FUSER_PERF_SCOPE("isFusableCudaFusionGroup");
- if (isFusibleNode(node)) {
- auto ret = isFusibleDevice(node);
- return ret;
+ if (isFusableNode(node)) {
+ return isFusableDevice(node);
}
return false;
}
-bool isFusibleCudaFusionGroup(const Node* fusion, const Node* node) {
- FUSER_PERF_SCOPE("isFusibleCudaFusionGroup");
- bool fused = false;
+bool isFusableCudaFusionGroup(const Node* fusion, const Node* node) {
+ FUSER_PERF_SCOPE("isFusableCudaFusionGroup");
+
// TODO: lift the restriction of not fusing producer containing reduction when
// we have proper scheduling.
- if (isFusibleCudaFusionGroup(node)) {
+ if (isFusableCudaFusionGroup(node) && !hasReductionOperation(node) &&
+ !createTrickyBroadcast(fusion, node)) {
// ensure if the node has a designated device, it's on the same device with
// fusion.
// TODO: is there a danger of us fusing operations that's supposed to be on
// separate GPUs? And is that necessarily bad?
auto device = getDevice(fusion);
- fused = (!device.has_value() || isFusibleDevice(node, device.value()));
+ return (!device.has_value() || isFusableDevice(node, device.value()));
}
-
- return fused;
+ return false;
}
} // namespace cuda
namespace fuser {
namespace cuda {
-TORCH_CUDA_CU_API bool isFusibleCudaFusionGroup(const Node* node);
+TORCH_CUDA_CU_API bool isFusableCudaFusionGroup(const Node* node);
// consider if `node` could be fused into `fusion`
-TORCH_CUDA_CU_API bool isFusibleCudaFusionGroup(
+TORCH_CUDA_CU_API bool isFusableCudaFusionGroup(
const Node* fusion,
const Node* node);
#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/index_compute.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
#include <c10/util/irange.h>
namespace fuser {
namespace cuda {
-namespace {
-
-// find the first (and only) TensorView output
-//
-// TODO(kir): same question as ir_utils::getTvOutput():
-// why do we assume a single TV output?
-//
-kir::TensorView* firstTensorViewOutput(const kir::Expr* expr) {
- TORCH_INTERNAL_ASSERT(expr != nullptr);
- for (auto out : expr->outputs()) {
- if (out->isA<kir::TensorView>()) {
- return out->as<kir::TensorView>();
- } else if (out->isA<kir::TensorIndex>()) {
- return out->as<kir::TensorIndex>()->view();
+std::vector<kir::Bool*> PredicateCompute::computePredicates(
+ const TensorView* tv,
+ const std::vector<Val*>& indices,
+ bool use_rfactor) {
+ FUSER_PERF_SCOPE("computePredicates");
+
+ const std::vector<IterDomain*>& root =
+ use_rfactor ? tv->getMaybeRFactorDomain() : tv->getRootDomain();
+
+ TORCH_INTERNAL_ASSERT(root.size() == indices.size());
+
+ bool no_pred_needed = true;
+ for (auto id : tv->domain()->domain()) {
+ if (id->getOrigin() != nullptr) {
+ no_pred_needed = false;
}
}
- TORCH_INTERNAL_ASSERT(false, "Missing kir::TensorView output");
-}
-bool isTensorIndexOp(kir::Expr* expr) {
- const auto& outputs = expr->outputs();
- return outputs.size() >= 1 && outputs[0]->isA<kir::TensorIndex>();
-}
+ if (no_pred_needed) {
+ return {};
+ }
-bool isOutputLocal(const kir::Expr* expr) {
- return std::all_of(
- expr->outputs().begin(),
- expr->outputs().end(),
- [](const kir::Val* output) {
- return !output->isA<kir::TensorView>() ||
- output->as<kir::TensorView>()->memoryType() == MemoryType::Local;
- });
-}
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
+
+ auto true_bool = ir_builder.create<kir::Bool>(true);
+ std::vector<kir::Bool*> preds(root.size(), true_bool);
+ Val* extent = nullptr;
-} // namespace
+ for (const auto i : c10::irange(indices.size())) {
+ const bool zero_ind = indices[i]->isZeroInt();
+ const bool simple_ind = indices[i]->getOrigin() == nullptr;
+
+ if (root[i]->isBroadcast()) {
+ continue;
+ } else if (simple_ind && !zero_ind) {
+ extent = nullptr;
+ continue;
+ } else if (zero_ind) {
+ if (root[i]->extent()->isOneInt()) {
+ continue;
+ }
+ const auto lowered_extent = GpuLower::lowerValue(root[i]->extent());
+ if (extent == nullptr) {
+ extent = lowered_extent;
+ } else {
+ extent = ir_builder.mulExpr(extent, lowered_extent);
+ }
+ } else {
+ auto local_extent = GpuLower::lowerValue(root[i]->extent());
+ if (extent != nullptr) {
+ local_extent = ir_builder.mulExpr(extent, local_extent);
+ }
+ auto pred = ir_builder.ltExpr(indices[i], local_extent);
+ extent = nullptr;
+ TORCH_INTERNAL_ASSERT(
+ pred->getValType().value() == ValType::KirScalar &&
+ pred->getDataType().value() == DataType::Bool);
+ preds[i] = pred->as<kir::Bool>();
+ }
+ }
+ return preds;
+}
kir::Bool* PredicateCompute::getInlinePredicate(
- const kir::Expr* expr,
+ Expr* expr,
const std::vector<kir::ForLoop*>& loops,
kir::Bool* thread_pred,
- PredicateType pred_type) {
- FUSER_PERF_SCOPE("GpuLower::Lower::getInlinePredicate");
+ bool ignore_block_grid_reductions) {
+ FUSER_PERF_SCOPE("getInlinePredicate");
- const auto gpu_lower = GpuLower::current();
- kir::IrBuilder ir_builder(gpu_lower->kernel());
-
- // If outputs are registers, no need to predicate for threads
- if (isOutputLocal(expr)) {
- thread_pred = ir_builder.trueVal();
- }
+ kir::IrBuilder ir_builder(GpuLower::current()->kernel());
if (loops.empty()) {
- TORCH_INTERNAL_ASSERT(thread_pred != nullptr);
- return thread_pred;
+ return ir_builder.create<kir::Bool>(true);
}
- auto out_tv = firstTensorViewOutput(expr);
-
- if (gpu_lower->predicateElimination().canOmitPredicate(expr)) {
- return thread_pred;
+ // Handle these elsewhere
+ if (ignore_block_grid_reductions &&
+ expr->getExprType() == ExprType::ReductionOp &&
+ (expr->as<ReductionOp>()->out()->as<TensorView>()->hasBlockReduction() ||
+ expr->as<ReductionOp>()->out()->as<TensorView>()->hasGridReduction())) {
+ return ir_builder.create<kir::Bool>(true);
}
- auto all_preds = Index::getReferenceRootPredicates(out_tv, loops);
+ TORCH_INTERNAL_ASSERT(
+ ir_utils::isTVOp(expr),
+ "Cannot generate predicate based on operation without a TensorView.");
- std::vector<kir::Bool*> preds;
+ auto out_tv = ir_utils::getTVOutput(expr);
- auto is_true = [](const kir::Bool* p) {
- return p->isConst() && p->value().value();
- };
-
- // When pred_type is ReductionWrite, filter out predicates for
- // reduction axes. For blockReduce, this is necessary when reduction
- // axes start at non-zero offsets and parallelized with TID since
- // blockReduce returns a valid output only at offset-zero
- // threads. Similarly, for gridReduce, the last block to store the
- // output may be predicated out with the read predicate, so the
- // write predicate needs to ignore the reduction axes.
- bool non_zero_start_found = false;
- for (size_t i = 0; i < all_preds.first.size(); ++i) {
- auto pred = all_preds.first[i];
- if (pred_type == PredicateType::ReductionWrite) {
- const auto& concrete_root_ids = all_preds.second[i];
- bool pred_for_reduction_axis = false;
- for (auto pred_root_id : concrete_root_ids) {
- auto kir_pred_root_id =
- gpu_lower->lowerValue(pred_root_id)->as<kir::IterDomain>();
- auto it = std::find_if(
- out_tv->domain()->rootDomain().begin(),
- out_tv->domain()->rootDomain().end(),
- [&](const auto& out_root_id) {
- return gpu_lower->caIndexMap().areMapped(
- kir_pred_root_id, out_root_id);
- });
- TORCH_INTERNAL_ASSERT(
- it != out_tv->domain()->rootDomain().end(),
- "No corresponding root ID found for ",
- pred_root_id);
- auto out_root_id = *it;
- if (out_root_id->isReduction()) {
- if (!out_root_id->start()->isZeroInt()) {
- non_zero_start_found = true;
- }
- pred_for_reduction_axis = true;
- break;
- }
- }
- // Don't add the predicate if it corresponds to a reduction axis
- if (pred_for_reduction_axis) {
- continue;
- }
+ auto pred_contiguity = out_tv->domain()->contiguity();
+
+ for (auto inp : expr->inputs()) {
+ if (!ir_utils::isTV(inp)) {
+ continue;
}
- if (!is_true(pred)) {
- preds.push_back(pred);
+ auto inp_tv = inp->as<TensorView>();
+ // NOLINTNEXTLINE(bugprone-branch-clone)
+ if (inp_tv->domain()->hasRFactor()) {
+ continue;
+ } else if (
+ inp_tv->getMemoryType() == MemoryType::Shared ||
+ inp_tv->getMemoryType() == MemoryType::Local) {
+ continue;
+ } else {
+ pred_contiguity = IndexCompute::contiguityAnd(
+ pred_contiguity,
+ IndexCompute::contiguityPasC(inp_tv->domain(), out_tv->domain()));
}
}
- // When generating a predicate for blockReduce writes and not for
- // gridReduce, if all reduction axes start with zero, we can just
- // use the same predicate for reads. nullptr is returned then.
- if (pred_type == PredicateType::ReductionWrite && !non_zero_start_found &&
- !out_tv->fuserTv()->domain()->hasGridReduction()) {
- return nullptr;
+ auto pred_inds =
+ Index::getConsumerRootPredIndices(out_tv, loops, pred_contiguity);
+ auto root_indices = pred_inds.first;
+ bool use_maybe_rfactor = pred_inds.second;
+
+ if (out_tv->getMemoryType() == MemoryType::Local && out_tv->hasReduction() &&
+ !use_maybe_rfactor) {
+ auto tv_filter_inp_view =
+ ir_utils::filterByType<TensorView>(expr->inputs());
+ auto has_tv_inputs = tv_filter_inp_view.begin() != tv_filter_inp_view.end();
+ // If predicates doesn't need maybe_rfactor, but it has reduction axes, and
+ // expr has no inputs, we're pretty confident we're intializing a reduction
+ // buffer. If we're initing a reduction buffer don't generate an inline
+ // predicate.
+ if (!has_tv_inputs) {
+ return ir_builder.create<kir::Bool>(true);
+ }
}
- if (thread_pred != nullptr && !is_true(thread_pred)) {
- preds.push_back(thread_pred);
+ auto all_preds = PredicateCompute::computePredicates(
+ out_tv, root_indices, use_maybe_rfactor);
+
+ // If we have thread predicates, add those
+ if (thread_pred != nullptr) {
+ all_preds.push_back(thread_pred);
}
+ std::vector<kir::Bool*> preds;
+
+ for (auto pred : all_preds)
+ if (!(pred->isConst()) || !(pred->isConst() && pred->value().value()))
+ preds.push_back(pred);
+
if (preds.empty()) {
- return ir_builder.trueVal();
+ return ir_builder.create<kir::Bool>(true);
}
- kir::Val* cond = preds[0];
- for (size_t i = 1; i < preds.size(); i++) {
+ Val* cond = preds[0];
+
+ for (decltype(preds.size()) i{1}; i < preds.size(); i++) {
cond = ir_builder.andExpr(cond, preds[i]);
}
+ TORCH_INTERNAL_ASSERT(
+ cond->getValType().value() == ValType::KirScalar &&
+ cond->getDataType().value() == DataType::Bool,
+ "Error computing predicate, should be returning a Bool, but returning ",
+ cond->getDataType().value());
+
return cond->as<kir::Bool>();
}
-kir::Bool* UnswitchPredicate::get(
+kir::Bool* UnrollPredicate::get(
const std::vector<kir::ForLoop*>& outer_loops,
- kir::ForLoop* unrolled_loop) {
- FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::get");
+ kir::ForLoop* unrolled_loop,
+ const std::unordered_map<IterDomain*, IterDomain*>& p2c_root_map) {
+ FUSER_PERF_SCOPE("UnrollPredicate::get");
kir::IrBuilder ir_builder(GpuLower::current()->kernel());
- UnswitchPredicate up(outer_loops, unrolled_loop);
+ UnrollPredicate up(outer_loops, unrolled_loop, p2c_root_map);
- kir::Val* unroll_pred = nullptr;
- for (auto pred : up.predicates_) {
- if (pred->isConst() && pred->value().value()) {
- continue;
- } else if (unroll_pred == nullptr) {
+ std::unordered_set<kir::Bool*> pred_set;
+ for (auto entry : up.predicates_) {
+ pred_set.emplace(entry.second);
+ }
+
+ if (up.predicates_.empty()) {
+ return ir_builder.create<kir::Bool>(true);
+ }
+
+ Val* unroll_pred = nullptr;
+ for (auto pred : pred_set) {
+ if (unroll_pred == nullptr) {
unroll_pred = pred;
} else {
unroll_pred = ir_builder.andExpr(unroll_pred, pred);
}
}
-
- return unroll_pred == nullptr ? ir_builder.trueVal()
- : unroll_pred->as<kir::Bool>();
+ TORCH_INTERNAL_ASSERT(
+ // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
+ unroll_pred->getValType().value() == ValType::KirScalar &&
+ unroll_pred->getDataType().value() == DataType::Bool);
+ return unroll_pred->as<kir::Bool>();
}
-void UnswitchPredicate::predicateOn(kir::Expr* tv_expr) {
- FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::predicateOn");
+void UnrollPredicate::predicateOn(Expr* tv_expr) {
+ FUSER_PERF_SCOPE("UnrollPredicate::predicateOn");
if (for_loops_.empty()) {
return;
}
- const auto gpu_lower = GpuLower::current();
-
- if (gpu_lower->predicateElimination().canOmitPredicate(tv_expr)) {
- return;
- }
-
- auto out_tv = firstTensorViewOutput(tv_expr);
+ auto out_tv = ir_utils::getTVOutput(tv_expr);
- auto pred_info = Index::getReferenceRootPredicates(out_tv, for_loops_, true);
+ auto pred_contiguity = out_tv->domain()->contiguity();
- for (auto i : c10::irange(pred_info.first.size())) {
- auto pred = pred_info.first[i];
- if (pred->isConst() && pred->value()) {
+ for (auto inp : tv_expr->inputs()) {
+ if (!ir_utils::isTV(inp)) {
+ continue;
+ }
+ auto inp_tv = inp->as<TensorView>();
+ // NOLINTNEXTLINE(bugprone-branch-clone)
+ if (inp_tv->domain()->hasRFactor()) {
continue;
+ } else if (
+ inp_tv->getMemoryType() == MemoryType::Shared ||
+ inp_tv->getMemoryType() == MemoryType::Local) {
+ continue;
+ } else {
+ pred_contiguity = IndexCompute::contiguityAnd(
+ pred_contiguity,
+ IndexCompute::contiguityPasC(inp_tv->domain(), out_tv->domain()));
}
+ }
- const auto& root_ids = pred_info.second[i];
+ auto pred_inds = Index::getConsumerRootPredIndices(
+ out_tv, for_loops_, pred_contiguity, true);
+ auto root_indices = pred_inds.first;
+ auto use_rfactor = pred_inds.second;
- bool add_pred = false;
+ auto all_preds =
+ PredicateCompute::computePredicates(out_tv, root_indices, use_rfactor);
- for (auto root_id : root_ids) {
- auto kir_root_id = gpu_lower->lowerValue(root_id)->as<kir::IterDomain>();
+ auto root_dom =
+ use_rfactor ? out_tv->getMaybeRFactorDomain() : out_tv->getRootDomain();
- if (kir_root_id->isBroadcast()) {
- continue;
- }
+ TORCH_INTERNAL_ASSERT(
+ all_preds.size() == root_dom.size(),
+ "Predicates should be produced for every dimension, even if it's simply set as true.");
- if (std::find(
- predicated_iter_dom_.begin(),
- predicated_iter_dom_.end(),
- kir_root_id) == predicated_iter_dom_.end()) {
- add_pred = true;
- predicated_iter_dom_.push_back(kir_root_id);
- }
- }
- if (add_pred) {
- predicates_.push_back(pred);
+ for (const auto i : c10::irange(all_preds.size())) {
+ if (all_preds[i]->isConst() && all_preds[i]->value().value()) {
+ continue;
}
+ auto term_id = loop_utils::getTermIDInMap(root_dom[i], p2c_root_map_);
+ predicates_[term_id] = all_preds[i];
}
}
-void UnswitchPredicate::openLoop(kir::ForLoop* fl) {
- FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::openLoop");
+void UnrollPredicate::openLoop(kir::ForLoop* fl) {
+ FUSER_PERF_SCOPE("UnrollPredicate::openLoop");
for_loops_.push_back(fl);
for (auto expr : fl->body().exprs()) {
- if (ir_utils::isTVOp(expr) || isTensorIndexOp(expr)) {
+ if (ir_utils::isTVOp(expr)) {
predicateOn(expr);
- } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
- openIte(ite);
- } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
- openLoop(for_loop);
+ } else if (expr->getExprType().value() == ExprType::ForLoop) {
+ openLoop(expr->as<kir::ForLoop>());
}
}
for_loops_.pop_back();
}
-void UnswitchPredicate::openIte(kir::IfThenElse* ite) {
- FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::openIte");
-
- // only expand the ite thenBody
- for (auto expr : ite->thenBody().exprs()) {
- if (ir_utils::isTVOp(expr) || isTensorIndexOp(expr)) {
- predicateOn(expr);
- } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
- openIte(ite);
- } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
- openLoop(for_loop);
- }
- }
-}
-
-UnswitchPredicate::UnswitchPredicate(
+UnrollPredicate::UnrollPredicate(
std::vector<kir::ForLoop*> outer_loops,
- kir::ForLoop* unrolled_loop)
- : for_loops_(std::move(outer_loops)) {
+ kir::ForLoop* unrolled_loop,
+ const std::unordered_map<IterDomain*, IterDomain*>& _p2c_root_map)
+ : for_loops_(std::move(outer_loops)), p2c_root_map_(_p2c_root_map) {
openLoop(unrolled_loop);
}
#pragma once
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
-#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
+#include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+
+/*
+ * Predicate compute takes a TensorView and set of indices. The number of
+ * indices and the root of the TensorView are required to have the same number
+ * of dimensions. Predicate compute should be run after index compute, and the
+ * result of index compute should be used for the indices entry.
+ *
+ * A vector of Int values are returned which are the output of the operation
+ * index[i] < get_root(TV)->domain()->axis(i)->size()
+ *
+ * It is assumed that no predicate is required if index[i] is an index directly
+ * from a for loop. This will not catch all cases if we actually have static
+ * size information for example:
+ *
+ * TV[I].split(4)
+ * would produce the code:
+ * for(i : I/4)
+ * for(j : 4)
+ * if( i * 4 + j < TV.size(0))
+ * TV[i * 4 + j]...
+ *
+ * However if we had TV.size[0] = 16 at "compile time" then we wouldn't need the
+ * predicate. However we will still generate: for(i : 4) for(j : 4) if( i * 4 +
+ * j < TV.size(0)) TV[i * 4 + j]...
+ *
+ */
namespace torch {
namespace jit {
class PredicateCompute {
public:
- // ignore_internal_syncthread_ops will prevent creation of predicates on
- // block/grid broadcast/reduce as these have syncthread calls within them
- // so all threads need to execute the function.
+ // Return the series of predicates, if an axis doesn't have a predicate
+ // reutrns 1
+ static std::vector<kir::Bool*> computePredicates(
+ const TensorView* tv,
+ const std::vector<Val*>& indices,
+ bool use_rfactor);
+
static kir::Bool* getInlinePredicate(
- const kir::Expr* expr,
+ Expr* expr,
const std::vector<kir::ForLoop*>& loops,
kir::Bool* thread_pred,
- PredicateType pred_type);
+ bool ignore_block_grid_reductions = true);
};
-class TORCH_CUDA_CU_API UnswitchPredicate {
+class TORCH_CUDA_CU_API UnrollPredicate {
public:
static kir::Bool* get(
const std::vector<kir::ForLoop*>& outer_loops,
- kir::ForLoop* unrolled_loop);
+ kir::ForLoop* unrolled_loop,
+ const std::unordered_map<IterDomain*, IterDomain*>& p2c_root_map);
private:
- UnswitchPredicate(
+ UnrollPredicate(
std::vector<kir::ForLoop*> outer_loops,
- kir::ForLoop* unrolled_loop);
+ kir::ForLoop* unrolled_loop,
+ const std::unordered_map<IterDomain*, IterDomain*>& _p2c_root_map);
- void predicateOn(kir::Expr*);
+ void predicateOn(Expr*);
void openLoop(kir::ForLoop*);
- void openIte(kir::IfThenElse*);
-
private:
- // Track which iter domains have been predicated, uses concrete_id from
- // caLoopMap.
- std::vector<kir::IterDomain*> predicated_iter_dom_;
-
- // The predicates that have been generated.
- std::vector<kir::Bool*> predicates_;
-
+ std::unordered_map<IterDomain*, kir::Bool*> predicates_;
std::vector<kir::ForLoop*> for_loops_;
+
+ const std::unordered_map<IterDomain*, IterDomain*>& p2c_root_map_;
};
} // namespace cuda
public:
RegisterInterface() {
auto ptr = getFuserInterface();
- ptr->fn_compile_n = &compileCudaFusionGroup;
- ptr->fn_run_n_s = &runCudaFusionGroup;
- ptr->fn_fuse_graph = &CudaFuseGraph;
- ptr->fn_can_fuse_n = &isFusibleCudaFusionGroup;
- ptr->fn_insert_profile_inodes = &InsertProfileNodes;
+ ptr->fn_compile_n_ = &compileCudaFusionGroup;
+ ptr->fn_run_n_s_ = &runCudaFusionGroup;
+ ptr->fn_fuse_graph_ = &CudaFuseGraph;
+ ptr->fn_can_fuse_n_ = &isFusableCudaFusionGroup;
+
+ RegisterProfilingNode(canFuseNode);
}
};
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-
-#include <sstream>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-std::unordered_map<IterDomain*, IterDomain*> RootDomainMap::
- mapProducerToConsumer(
- const TensorDomain* producer,
- const TensorDomain* consumer,
- const std::unordered_set<IterDomain*>& root_dims_to_map) const {
- return map(producer, consumer, root_dims_to_map, true);
-}
-
-std::unordered_map<IterDomain*, IterDomain*> RootDomainMap::
- mapProducerToConsumer(
- const TensorDomain* producer,
- const TensorDomain* consumer) const {
- std::unordered_set<IterDomain*> root_dims_to_map(
- producer->getMaybeRFactorDomain().begin(),
- producer->getMaybeRFactorDomain().end());
- return mapProducerToConsumer(producer, consumer, root_dims_to_map);
-}
-
-std::unordered_map<IterDomain*, IterDomain*> RootDomainMap::
- mapConsumerToProducer(
- const TensorDomain* consumer,
- const TensorDomain* producer,
- const std::unordered_set<IterDomain*>& root_dims_to_map) const {
- return map(producer, consumer, root_dims_to_map, false);
-}
-
-std::unordered_map<IterDomain*, IterDomain*> RootDomainMap::
- mapConsumerToProducer(
- const TensorDomain* consumer,
- const TensorDomain* producer) const {
- std::unordered_set<IterDomain*> root_dims_to_map(
- consumer->getRootDomain().begin(), consumer->getRootDomain().end());
- return mapConsumerToProducer(consumer, producer, root_dims_to_map);
-}
-
-PairwiseRootDomainMap::PairwiseRootDomainMap(
- const TensorView* producer,
- const TensorView* consumer)
- : producer_tv_(producer), consumer_tv_(consumer) {
- TORCH_INTERNAL_ASSERT(producer != nullptr);
- TORCH_INTERNAL_ASSERT(consumer != nullptr);
- TORCH_INTERNAL_ASSERT(producer->fusion() == consumer->fusion());
- // Make sure they are really a producer and its consumer
- TORCH_INTERNAL_ASSERT(
- producer->isConsumerOf(consumer),
- "Not a producer-consumer pair: ",
- producer,
- ", ",
- consumer);
-}
-
-std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
- const TensorDomain* producer,
- const TensorDomain* consumer,
- const std::unordered_set<IterDomain*>& root_dims_to_map,
- bool producer_to_consumer) const {
- // Sanity check that the given producer and consumer domains are
- // really the TensorDomains of the producer and consumer TensorViews
- // given to the constructor.
- TORCH_INTERNAL_ASSERT(producer_tv_->domain() == producer);
- TORCH_INTERNAL_ASSERT(consumer_tv_->domain() == consumer);
-
- if (consumer_tv_->definition()->isA<TransposeOp>()) {
- return mapTranspose(
- producer, consumer, root_dims_to_map, producer_to_consumer);
- }
-
- std::vector<bool> broadcast_flags;
- if (BroadcastOp* bop =
- dynamic_cast<BroadcastOp*>(consumer_tv_->definition())) {
- broadcast_flags = bop->getBroadcastDimFlags();
- }
-
- std::unordered_map<IterDomain*, IterDomain*> dom_map;
- const auto producer_root =
- TensorDomain::noReductions(producer->getMaybeRFactorDomain());
- const auto& consumer_root = consumer->getRootDomain();
- size_t itc = 0, itp = 0;
- while (itc < consumer_root.size() && itp < producer_root.size()) {
- IterDomain* producer_id = producer_root[itp];
- IterDomain* consumer_id = consumer_root[itc];
-
- // When the consumer ID is a new broadcast domain, there is no
- // mapping for it.
- if (!broadcast_flags.empty() && broadcast_flags.at(itc)) {
- TORCH_INTERNAL_ASSERT(consumer_id->isBroadcast());
- itc++;
- continue;
- }
-
- IterDomain* map_key_id = producer_id;
- IterDomain* map_value_id = consumer_id;
- if (!producer_to_consumer) {
- std::swap(map_key_id, map_value_id);
- }
-
- if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) {
- dom_map.insert(std::make_pair(map_key_id, map_value_id));
- }
- itc++;
- itp++;
- }
- return dom_map;
-}
-
-std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::
- mapTranspose(
- const TensorDomain* producer,
- const TensorDomain* consumer,
- const std::unordered_set<IterDomain*>& root_dims_to_map,
- bool producer_to_consumer) const {
- const auto producer_root =
- TensorDomain::noReductions(producer->getMaybeRFactorDomain());
- const auto& consumer_root = consumer->getRootDomain();
-
- std::unordered_map<IterDomain*, IterDomain*> dom_map;
-
- TransposeOp* top = dynamic_cast<TransposeOp*>(consumer_tv_->definition());
- TORCH_INTERNAL_ASSERT(top != nullptr);
-
- const auto& new2old = top->new2old();
- for (size_t i = 0; i < consumer_root.size(); ++i) {
- IterDomain* map_key_id = producer_root[new2old[i]];
- IterDomain* map_value_id = consumer_root[i];
- if (!producer_to_consumer) {
- std::swap(map_key_id, map_value_id);
- }
- if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) {
- dom_map.insert(std::make_pair(map_key_id, map_value_id));
- }
- }
- return dom_map;
-}
-
-std::string toString(const PairwiseRootDomainMap& root_map) {
- std::stringstream ss;
- ss << "{producer: " << root_map.producer()
- << ", consumer: " << root_map.consumer() << "}";
- return ss.str();
-}
-
-namespace {
-
-template <typename T>
-auto ensureMapping(
- T& m,
- const typename T::key_type& key,
- const typename T::mapped_type& init_value) {
- auto it = m.find(key);
- if (it == m.end()) {
- it = m.insert({key, init_value}).first;
- }
- return it;
-}
-
-} // namespace
-
-std::string toString(const DomainKey& key) {
- std::stringstream ss;
- ss << "{";
- if (key.td()) {
- ss << key.td() << " (root: " << key.td()->getRootDomain()
- << ", maybe rfactor: " << key.td()->getMaybeRFactorDomain() << ")";
- } else {
- ss << "null";
- }
- ss << ", ";
- if (key.id()) {
- ss << key.id();
- } else {
- ss << "null";
- }
- if (key.concreteId()) {
- ss << " (" << key.concreteId() << ")";
- }
- ss << "}";
- return ss.str();
-}
-
-UnmappableReductionDomains::UnmappableReductionDomains() {
- Fusion* fusion = FusionGuard::getCurFusion();
- traverse(fusion);
-}
-
-namespace {
-
-//! Find all domains that a given domain is depeendent on
-class FindInputDomains : BackwardVisitor {
- private:
- FindInputDomains(TensorView* tv, const IterDomain* id)
- : BackwardVisitor(false), tv_(tv) {
- input_keys.insert(DomainKey(tv_->domain(), id));
- }
-
- DomainKeySet find() {
- traverseFrom(tv_->fusion(), {tv_});
- return input_keys;
- }
-
- void handle(Expr* expr) override {
- for (auto output : expr->outputs()) {
- if (!output->isA<TensorView>()) {
- continue;
- }
- for (auto input : expr->inputs()) {
- if (!input->isA<TensorView>()) {
- continue;
- }
- propagate(input->as<TensorView>(), output->as<TensorView>());
- }
- }
- }
-
- void propagate(TensorView* in_tv, TensorView* out_tv) {
- auto c2p = PairwiseRootDomainMap(in_tv, out_tv)
- .mapConsumerToProducer(out_tv->domain(), in_tv->domain());
- for (auto root_dom : out_tv->getRootDomain()) {
- DomainKey out_key({out_tv->domain(), root_dom});
- if (input_keys.find(out_key) == input_keys.end()) {
- continue;
- }
- auto input_id_it = c2p.find(root_dom);
- if (input_id_it == c2p.end()) {
- continue;
- }
- DomainKey input_key(in_tv->domain(), input_id_it->second);
- input_keys.insert(input_key);
- }
- }
-
- private:
- TensorView* tv_ = nullptr;
- DomainKeySet input_keys;
-
- public:
- static DomainKeySet find(TensorView* tv, const IterDomain* id) {
- return FindInputDomains(tv, id).find();
- }
-};
-
-} // namespace
-
-void UnmappableReductionDomains::handleReductionOutput(TensorView* out_tv) {
- std::vector<DomainKey> reduction_keys;
- for (const auto id : out_tv->getRootDomain()) {
- if (id->isReduction()) {
- DomainKey key(out_tv->domain(), id);
- reduction_keys.push_back(key);
- reduction_domains_.insert({key, {}});
- }
- }
- auto use_chains = DependencyCheck::getAllUseChains(out_tv);
- for (const auto& chain : use_chains) {
- for (const auto& tv : ir_utils::filterByType<TensorView>(chain)) {
- const auto& root_domain = tv->getRootDomain();
- for (const auto& id : root_domain) {
- DomainKey consumer_key(tv->domain(), id);
- for (const auto& reduction_key : reduction_keys) {
- reduction_domains_.at(reduction_key).insert(consumer_key);
- }
- }
- }
- }
- for (const auto& reduction_key : reduction_keys) {
- reduction_domain_inputs_.insert(
- {reduction_key, FindInputDomains::find(out_tv, reduction_key.id())});
- }
-}
-
-void UnmappableReductionDomains::handle(ReductionOp* op) {
- // Builds a map from reduction domains to consumer domains.
- TensorView* out_tv = op->out()->as<TensorView>();
- handleReductionOutput(out_tv);
-}
-
-void UnmappableReductionDomains::handle(WelfordOp* op) {
- // Builds a map from reduction domains to consumer domains.
- handleReductionOutput(op->outAvg()->as<TensorView>());
- handleReductionOutput(op->outVar()->as<TensorView>());
- handleReductionOutput(op->outN()->as<TensorView>());
-}
-
-bool UnmappableReductionDomains::isReductionOutputMapped(
- const std::vector<DomainKey>& consumer_domains,
- const ComputeAtRootDomainMap& root_map) const {
- for (const auto& kv : reduction_domains_) {
- const DomainKey& reduction_domain = kv.first;
- const DomainKeySet& incompatible_domains = kv.second;
- DomainKey consumer_domain_with_reduction;
- bool reduction_found = false;
- const auto& input_keys = reduction_domain_inputs_.at(reduction_domain);
- for (const DomainKey& consumer_domain : consumer_domains) {
- for (const auto& input_key : input_keys) {
- if (input_key == consumer_domain) {
- consumer_domain_with_reduction = consumer_domain;
- reduction_found = true;
- break;
- }
- }
- }
- if (!reduction_found) {
- continue;
- }
- // Make sure no incompatible domains will be merged with the reduction
- // domain.
- for (const auto& consumer_domain : consumer_domains) {
- if (consumer_domain == consumer_domain_with_reduction) {
- continue;
- }
- if (std::any_of(
- incompatible_domains.begin(),
- incompatible_domains.end(),
- [&](const DomainKey& incompatible_domain) {
- return root_map.canMap(
- consumer_domain.td(),
- consumer_domain.id(),
- incompatible_domain.td(),
- incompatible_domain.id());
- })) {
- return true;
- }
- }
- }
- return false;
-}
-
-void ComputeAtRootDomainMap::build(bool map_through_reduction) {
- // Make sure we start from scratch. Throw away previous results.
- eq_set_.clear();
- bcast_map_.clear();
- new_broadcast_domains_.clear();
- ComputeAtRootDomainMapBuilder builder(*this, map_through_reduction);
-}
-
-bool ComputeAtRootDomainMap::canMap(
- const TensorDomain* td_a,
- const IterDomain* id_a,
- const TensorDomain* td_b,
- const IterDomain* id_b) const {
- TORCH_INTERNAL_ASSERT(
- id_a->definition() == nullptr || id_a->isRFactorProduct(),
- "Non-root domain is not supproted: ",
- id_a);
- TORCH_INTERNAL_ASSERT(
- id_b->definition() == nullptr || id_b->isRFactorProduct(),
- "Non-root domain is not supproted: ",
- id_b);
-
- // Forward to overloaded functions
- if (!id_a->isBroadcast() && !id_b->isBroadcast()) {
- return canMap(DomainKey(td_a, id_a), DomainKey(td_b, id_b));
- } else if (!id_a->isBroadcast()) {
- return canMap(DomainKey(td_a, id_a), td_b, id_b);
- } else if (!id_b->isBroadcast()) {
- return canMap(DomainKey(td_b, id_b), td_a, id_a);
- }
-
- // At this point, both are broadcast. Every pair of concrete IDs of
- // both id_a and id_b needs to be looked at. Whether they are
- // mappable depends on whether the concrete IDs are broadcast or
- // not. Note that a broadcast axis is used a concrete ID when it is
- // part of an output tensor domain, i.e., when it never gets
- // concretized with any non-broadcast axis.
-
- // If there exists a pair of non-broadcast concrete IDs is not
- // mappable, id_a and id_b can't be mapped together. Otherwise, they
- // can be mapped when there is any mappable pair is found.
- bool mappable_pair_found = false;
- for (const auto& key_a : getConcretizedKeys(td_a, id_a)) {
- for (const auto& key_b : getConcretizedKeys(td_b, id_b)) {
- const bool mappable = canMap(key_a, key_b);
- mappable_pair_found = mappable_pair_found || mappable;
- // If both concrete IDs are not broadcast, they must be
- // mappable. Also, if either of the concrete IDs is a reduction,
- // that means a trivial reduction (i.e., broadcast immediately
- // followed by reduction), which does not prevent any mapping.
- if (!key_a.concreteId()->isBroadcast() &&
- !key_b.concreteId()->isBroadcast() &&
- !key_a.concreteId()->isReduction() &&
- !key_b.concreteId()->isReduction() && !mappable) {
- return false;
- }
- }
- }
-
- return mappable_pair_found;
-}
-
-bool ComputeAtRootDomainMap::canMap(
- const DomainKey& key_a,
- const TensorDomain* td_b,
- const IterDomain* id_b) const {
- TORCH_INTERNAL_ASSERT(
- id_b->definition() == nullptr || id_b->isRFactorProduct(),
- "Non-root domain is not supproted: ",
- id_b);
-
- if (!id_b->isBroadcast()) {
- return canMap(key_a, DomainKey(td_b, id_b));
- }
-
- // If id_b is broadcast, look at all the concrete IDs that id_b may
- // be concretized to. Whether it is mappable with key_a depends on
- // whether key_a's concrete ID is also broadcast.
- // 1) key_a's concrete ID is also broadcast: They are mappable when
- // there is any mappable concrete ID exists in the concrete ID set
- // of id_b.
- // 2) key_a's concrete ID is not broadcast: Since key_a is indeed
- // concrete, it must be mappable with any of concrete ID of id_b,
- // except when a id_b concrete is broadcast.
- const bool key_a_bcast =
- key_a.concreteId() && key_a.concreteId()->isBroadcast();
- const bool key_a_reduction =
- (key_a.concreteId() && key_a.concreteId()->isReduction()) ||
- key_a.id()->isReduction();
- bool mappable_pair_found = false;
- for (const auto& key_b : getConcretizedKeys(td_b, id_b)) {
- const bool mappable = canMap(key_a, key_b);
- mappable_pair_found = mappable_pair_found || mappable;
- // If both concrete IDs are not broadcast, they must be mappable.
- // However, if key_b's concrete ID is a reduction, the concrete ID
- // is a result of a trivial reduction, so it should not prevent
- // any other mapping. Similarly, if key_a is a reduction, it just
- // needs to find any concrete ID of key_b that can be mapped.
- if (!key_a_bcast && !key_b.concreteId()->isBroadcast() &&
- !key_b.concreteId()->isReduction() && !key_a_reduction && !mappable) {
- return false;
- }
- }
-
- return mappable_pair_found;
-}
-
-bool ComputeAtRootDomainMap::canMap(
- const DomainKey& key_a,
- const DomainKey& key_b) const {
- return key_a == key_b || eq_set_.areEquivalent(key_a, key_b);
-}
-
-void ComputeAtRootDomainMap::setAlias(
- const TensorDomain* td,
- const TensorDomain* td_alias) {
- auto tmp_bcast_map = bcast_map_;
- for (const auto& kv : bcast_map_) {
- const auto& bcast_map_key = kv.first;
- const auto& bcast_concrete_id_set = kv.second;
- if (bcast_map_key.td() == td) {
- DomainKey alias_key(td_alias, bcast_map_key.id());
- tmp_bcast_map.insert({alias_key, bcast_concrete_id_set});
- }
- }
- bcast_map_ = tmp_bcast_map;
-
- for (const auto& key : eq_set_.getAllElements()) {
- if (key.td() == td) {
- DomainKey alias_key(td_alias, key.id(), key.concreteId());
- eq_set_.join(key, alias_key);
- }
- }
-
- auto tmp_new_broadcast_domains = new_broadcast_domains_;
- for (const auto& key : new_broadcast_domains_) {
- if (key.td() == td) {
- DomainKey alias_key(td_alias, key.id());
- tmp_new_broadcast_domains.insert(alias_key);
- }
- }
- new_broadcast_domains_ = tmp_new_broadcast_domains;
-}
-
-std::vector<DomainKey> ComputeAtRootDomainMap::getConcretizedKeys(
- const TensorDomain* td,
- const IterDomain* id) const {
- DomainKey key(td, id);
- auto it = bcast_map_.find(key);
- TORCH_INTERNAL_ASSERT(it != bcast_map_.end(), "Not found: ", toString(key));
- std::vector<DomainKey> domains;
- std::transform(
- it->second.begin(),
- it->second.end(),
- std::back_inserter(domains),
- [&](const IterDomain* concrete_id) {
- return DomainKey(td, id, concrete_id);
- });
- return domains;
-}
-
-std::unordered_set<const IterDomain*>& ComputeAtRootDomainMap::
- getConcretizedDomains(const TensorDomain* td, const IterDomain* id) {
- DomainKey key(td, id);
- auto it = bcast_map_.find(key);
- TORCH_INTERNAL_ASSERT(it != bcast_map_.end(), "Not found: ", toString(key));
- return it->second;
-}
-
-std::unordered_map<IterDomain*, IterDomain*> ComputeAtRootDomainMap::
- mapBestEffort(
- const TensorDomain* from_td,
- const std::vector<IterDomain*>& from_root,
- const TensorDomain* to_td,
- const std::vector<IterDomain*>& to_root) const {
- std::unordered_map<IterDomain*, IterDomain*> id_map;
- for (auto& from_id : from_root) {
- for (const auto& to_id : to_root) {
- if (canMap(from_td, from_id, to_td, to_id)) {
- TORCH_INTERNAL_ASSERT(
- id_map.insert({from_id, to_id}).second,
- "Multiple matching ID detected for ",
- from_id);
- }
- }
- }
- return id_map;
-}
-
-std::unordered_map<IterDomain*, IterDomain*> ComputeAtRootDomainMap::map(
- const TensorDomain* producer,
- const TensorDomain* consumer,
- const std::unordered_set<IterDomain*>& root_dims_to_map,
- bool producer_to_consumer) const {
- const auto& producer_root = producer->getMaybeRFactorDomain();
- const auto& consumer_root = consumer->getRootDomain();
- const TensorDomain* from_td = producer_to_consumer ? producer : consumer;
- const TensorDomain* to_td = producer_to_consumer ? consumer : producer;
- const auto& from_ids = producer_to_consumer ? producer_root : consumer_root;
- const auto& to_ids = producer_to_consumer ? consumer_root : producer_root;
- std::unordered_map<IterDomain*, IterDomain*> id_map =
- mapBestEffort(from_td, from_ids, to_td, to_ids);
- for (auto& from_id : from_ids) {
- if (root_dims_to_map.find(from_id) == root_dims_to_map.end()) {
- // Remove mapping if exists
- id_map.erase(from_id);
- continue;
- }
- if (id_map.find(from_id) != id_map.end()) {
- continue;
- }
- // Matching ID not found. It's an error unless: from_id is
- // reduction of a producer domain; from_id is a new broadcast of a
- // consumer domain; or from_id is a window axis of a consumer
- // domain.
- if ((producer_to_consumer && from_id->isReduction()) ||
- (!producer_to_consumer &&
- (new_broadcast_domains_.find(DomainKey(from_td, from_id)) !=
- new_broadcast_domains_.end() ||
- (window_axes_.count(from_id) > 0)))) {
- continue;
- }
- TORCH_INTERNAL_ASSERT(
- false,
- "Mapping IterDomain ",
- from_id,
- " of ",
- from_td,
- " not possible as it would require recomputing the source tensor.",
- " Producer root: ",
- producer_root,
- ". Consumer root: ",
- consumer_root,
- ". Mapping: ",
- toString(*this));
- }
- return id_map;
-}
-
-std::unordered_set<IterDomain*> ComputeAtRootDomainMap::getMappableDims(
- const TensorDomain* producer,
- const TensorDomain* consumer) const {
- const auto& producer_root = producer->getMaybeRFactorDomain();
- const auto& consumer_root = consumer->getRootDomain();
-
- std::unordered_map<IterDomain*, IterDomain*> id_map =
- mapBestEffort(producer, producer_root, consumer, consumer_root);
-
- std::unordered_set<IterDomain*> mappable_ids;
-
- for (auto& from_id : producer_root) {
- if (id_map.find(from_id) != id_map.end()) {
- mappable_ids.emplace(from_id);
- mappable_ids.emplace(id_map.at(from_id));
- }
- }
- return mappable_ids;
-}
-
-std::string toString(const ComputeAtRootDomainMap& root_map) {
- std::stringstream ss;
- root_map.eq_set_.print(ss);
- return ss.str();
-}
-
-ComputeAtRootDomainMapBuilder::ComputeAtRootDomainMapBuilder(
- ComputeAtRootDomainMap& root_map,
- bool map_through_reduction)
- : BackwardVisitor(false),
- root_map_(root_map),
- map_through_reduction_(map_through_reduction) {
- Fusion* fusion = FusionGuard::getCurFusion();
- TORCH_INTERNAL_ASSERT(fusion != nullptr);
- traverseFrom(fusion, fusion->outputs(), false);
- if (!pending_map_.empty()) {
- std::stringstream ss;
- ss << "pending map:\n";
- for (auto& kv : pending_map_) {
- ss << "\t" << toString(kv.first) << "\n";
- for (auto& dk : kv.second) {
- ss << "\t\t" << toString(dk) << "\n";
- }
- }
- std::cerr << ss.str();
- }
- TORCH_INTERNAL_ASSERT(pending_map_.empty());
-}
-
-// Set concrete domains for broadcast domains that never get joined
-// with a concrete domain. Just set its own domain as a concrete
-// domain, which is not concrete but is sufficient for this analysis.
-void ComputeAtRootDomainMapBuilder::initializeBcastMap(
- const TensorView* tv,
- const IterDomain* id) {
- TORCH_INTERNAL_ASSERT(id->isBroadcast(), "Not a broadcast axis");
- auto key = DomainKey(tv->domain(), id);
- auto it = root_map_.bcast_map_.find(key);
- if (it != root_map_.bcast_map_.end()) {
- // already initialized.
- return;
- }
-
- // This initialization should be only used for fusion output tensors and
- // outputs of multi-consumer expressions that are not fusion outputs.
- TORCH_INTERNAL_ASSERT(
- tv->isFusionOutput() || tv->definition()->outputs().size() > 1,
- "Invalid tensor to initialize bcast map: t",
- tv->name());
- root_map_.bcast_map_.insert({key, {id}});
-}
-
-void ComputeAtRootDomainMapBuilder::addToPendingList(
- const DomainKey& producer,
- const DomainKey& consumer) {
- auto it = ensureMapping(pending_map_, producer, {});
- auto& consumer_set = it->second;
- consumer_set.insert(consumer);
-}
-
-void ComputeAtRootDomainMapBuilder::setMapped(
- const DomainKey& producer,
- const DomainKey& consumer) {
- root_map_.eq_set_.join(producer, consumer);
-}
-
-void ComputeAtRootDomainMapBuilder::setMaybeMapped(
- const TensorDomain* producer_td,
- const IterDomain* producer_id,
- const TensorDomain* consumer_td,
- const IterDomain* consumer_id) {
- const DomainKey producer_key(producer_td, producer_id);
- const DomainKey consumer_key(consumer_td, consumer_id);
-
- if (producer_id->isBroadcast()) {
- ensureMapping(root_map_.bcast_map_, producer_key, {});
- }
-
- if (consumer_id->isBroadcast()) {
- TORCH_INTERNAL_ASSERT(producer_id->isBroadcast());
- // Get bcast_map_ entry for consumer_id
- const auto consumer_bcast_domains =
- root_map_.getConcretizedKeys(consumer_td, consumer_id);
- auto& producer_domains =
- root_map_.getConcretizedDomains(producer_td, producer_id);
-
- // If consumer id is broadcasted, make sure to propagate its concrete_id(s)
- // to producer
- for (const auto& consumer_bcast_key : consumer_bcast_domains) {
- const auto concrete_id = consumer_bcast_key.concreteId();
- const DomainKey producer_bcast_key(producer_td, producer_id, concrete_id);
- producer_domains.insert(concrete_id);
- addToPendingList(producer_bcast_key, consumer_bcast_key);
- }
- } else {
- TORCH_INTERNAL_ASSERT(
- !consumer_id->isBroadcast(),
- "No concrete domain found for a broadcast domain: ",
- toString(consumer_key));
- auto producer_concrete_key = producer_key;
- if (producer_id->isBroadcast()) {
- const auto concrete_id = consumer_id;
- auto& producer_domains =
- root_map_.getConcretizedDomains(producer_td, producer_id);
- producer_concrete_key = DomainKey(producer_td, producer_id, concrete_id);
- producer_domains.insert(concrete_id);
- }
- addToPendingList(producer_concrete_key, consumer_key);
- }
-}
-
-void ComputeAtRootDomainMapBuilder::handle(Expr* e) {
- // Avoid visiting expressions multiple times
- if (visited_.find(e) != visited_.end()) {
- return;
- }
- BackwardVisitor::handle(e);
- visited_.insert(e);
-}
-
-void ComputeAtRootDomainMapBuilder::mapPointwiseOrReductionOp(Expr* e) {
- if (e->output(0)->getValType() != ValType::TensorView) {
- return;
- }
-
- // Broadcast is handled separately, so e should never be BroadcastOp.
- TORCH_INTERNAL_ASSERT(e->getExprType() != ExprType::BroadcastOp);
-
- TORCH_INTERNAL_ASSERT(e->outputs().size() >= 1);
- const TensorView* out_tv = e->output(0)->as<TensorView>();
- const TensorDomain* out_td = out_tv->domain();
- const auto& out_root = out_td->getRootDomain();
-
- // Record equalities from output to all the inputs
- // ignores un-concretizable broadcasts
- for (auto* i : ir_utils::filterByType<TensorView>(e->inputs())) {
- const TensorDomain* in_td = i->domain();
- std::vector<IterDomain*> in_root =
- TensorDomain::noReductions(i->getMaybeRFactorDomain());
- TORCH_INTERNAL_ASSERT(
- in_root.size() == out_root.size(),
- "\nExpression: ",
- e,
- "\nInput root domain: ",
- in_root,
- "\nOutput root domain: ",
- out_root);
- for (size_t it = 0; it < in_root.size(); it++) {
- if (e->outputs().size() > 1) {
- TORCH_INTERNAL_ASSERT(
- e->isA<WelfordOp>(), "Only supported multioutput op is welford");
- for (auto o : e->outputs()) {
- auto o_tv = o->as<TensorView>();
- auto o_td = o_tv->domain();
- auto o_root = o_td->getRootDomain();
- setMaybeMapped(in_td, in_root[it], o_td, o_root[it]);
- }
- } else {
- setMaybeMapped(in_td, in_root[it], out_td, out_root[it]);
- }
- }
- }
-}
-
-void ComputeAtRootDomainMapBuilder::handle(BroadcastOp* op) {
- const TensorDomain* in_td = op->in()->as<TensorView>()->domain();
- const TensorDomain* out_td = op->out()->as<TensorView>()->domain();
- const auto in_root = TensorDomain::noReductions(in_td->getRootDomain());
- const auto& out_root = out_td->getRootDomain();
- const auto& bcast_dim_flags = op->getBroadcastDimFlags();
- TORCH_INTERNAL_ASSERT(
- out_root.size() == bcast_dim_flags.size(),
- "dim flags: ",
- bcast_dim_flags,
- ", out root: ",
- out_root);
- auto in_it = in_root.begin();
- auto out_it = out_root.begin();
- while (in_it != in_root.end() && out_it != out_root.end()) {
- if (bcast_dim_flags.at(std::distance(out_root.begin(), out_it))) {
- // new broadcast dim. No matching dimension in the input
- // tensor.
- root_map_.new_broadcast_domains_.insert(DomainKey(out_td, *out_it));
- ++out_it;
- continue;
- }
- setMaybeMapped(in_td, *in_it, out_td, *out_it);
- ++in_it;
- ++out_it;
- }
- // At this point, the input domain should have been scanned
- // entirely.
- TORCH_INTERNAL_ASSERT(
- in_it == in_root.end(),
- "Unmatched domain detected: ",
- *in_it,
- " of ",
- in_td);
- // On the other hand, the output may still have some domains left,
- // and they must be new broadcast domains.
- for (; out_it != out_root.end(); ++out_it) {
- TORCH_INTERNAL_ASSERT(
- bcast_dim_flags.at(std::distance(out_root.begin(), out_it)),
- "Unmatched domain detected: ",
- *out_it,
- " of ",
- out_td);
- root_map_.new_broadcast_domains_.insert(DomainKey(out_td, *out_it));
- }
-}
-
-void ComputeAtRootDomainMapBuilder::handle(TransposeOp* op) {
- const TensorDomain* in_td = op->in()->as<TensorView>()->domain();
- std::vector<IterDomain*> in_root =
- TensorDomain::noReductions(in_td->getRootDomain());
-
- const TensorDomain* out_td = op->out()->as<TensorView>()->domain();
- const auto& out_root = out_td->getRootDomain();
-
- TORCH_INTERNAL_ASSERT(in_root.size() == out_root.size());
-
- const auto& new2old = op->new2old();
-
- for (size_t it = 0; it < out_root.size(); it++) {
- setMaybeMapped(in_td, in_root[new2old[it]], out_td, out_root[it]);
- }
-}
-
-void ComputeAtRootDomainMapBuilder::handle(GatherOp* op) {
- const TensorDomain* in_td = op->in()->as<TensorView>()->domain();
- const TensorDomain* out_td = op->out()->as<TensorView>()->domain();
- const auto in_root = TensorDomain::noReductions(in_td->getRootDomain());
- const auto& out_root = out_td->getRootDomain();
-
- // Only maps the input root axes. Do not map the new window axes.
- for (size_t it = 0; it < in_root.size(); it++) {
- setMaybeMapped(in_td, in_root[it], out_td, out_root[it]);
- }
-
- // Keep track of window axes so that they can be skipped when
- // mapping root domains
- for (size_t it = in_root.size(); it < out_root.size(); it++) {
- root_map_.window_axes_.insert(out_root[it]);
- }
-}
-
-bool ComputeAtRootDomainMapBuilder::mapAllConsumers(
- const DomainKey& producer_key) {
- auto it = pending_map_.find(producer_key);
- if (it == pending_map_.end()) {
- return false;
- }
- const auto& consumer_set = it->second;
- // All entries in key_set must be equivalent with each other.
- TORCH_INTERNAL_ASSERT(consumer_set.size() > 0);
- bool consistent = safeToMap(consumer_set);
- if (consistent) {
- for (const auto pending_consumer : consumer_set) {
- setMapped(producer_key, pending_consumer);
- }
- }
- // This entry should never be used again, so remove it.
- pending_map_.erase(it);
- return consistent;
-}
-
-void ComputeAtRootDomainMapBuilder::handle(TensorView* tv) {
- const TensorDomain* td = tv->domain();
- const auto root = TensorDomain::noReductions(td->getMaybeRFactorDomain());
- for (auto id : root) {
- if (id->isBroadcast()) {
- initializeBcastMap(tv, id);
- for (const auto& key : root_map_.getConcretizedKeys(td, id)) {
- mapAllConsumers(key);
- }
- } else {
- mapAllConsumers(DomainKey(td, id));
- }
- }
-}
-
-// Checks whether all consumers of a producer can be joined without
-// introducing unsupported mappings. Specifically, if a domain of a
-// consumer has a mapped iteration domain in another consumer that
-// does not correspond to the same producer iteration domain, mapping
-// the consumer domains would result in the producer iteration domain
-// mapped to two different consumer iteration domains, requiring
-// recomputations.
-bool ComputeAtRootDomainMapBuilder::hasMatchingDomains(
- const std::vector<DomainKey>& unique_domains) {
- for (const auto& key : unique_domains) {
- for (const auto& other_key : unique_domains) {
- if (key == other_key) {
- continue;
- }
- const auto& other_root = other_key.td()->getRootDomain();
- if (std::any_of(
- other_root.begin(), other_root.end(), [&](const IterDomain* id) {
- return root_map_.canMap(key, other_key.td(), id);
- })) {
- return true;
- }
- }
- }
- return false;
-}
-
-// Checks whether all consumers of a producer can be joined without
-// introducing unsupported mappings, i.e., requiring recomputations.
-bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) {
- if (domains.size() <= 1) {
- return true;
- }
- // Filter out equivalent domains
- std::vector<DomainKey> unique_domains;
- for (const auto& domain : domains) {
- if (std::none_of(
- unique_domains.begin(),
- unique_domains.end(),
- [&](const auto& unique_dom) {
- return root_map_.canMap(domain, unique_dom);
- })) {
- unique_domains.push_back(domain);
- }
- }
- if (hasMatchingDomains(unique_domains)) {
- return false;
- }
- // Can't map if reduction output domains would be mapped
- if (incompatible_domains_.isReductionOutputMapped(
- unique_domains, root_map_) &&
- !map_through_reduction_) {
- return false;
- }
- return true;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/utils.h>
-
-#include <torch/csrc/WindowsTorchApiMacro.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-//! Generic interface for mapping root domains of a producer-consumer pair.
-class TORCH_CUDA_CU_API RootDomainMap : public PolymorphicBase {
- public:
- //! Return a map from a producer TensorDomain to a consumer
- //! TensorDomain
- //!
- //! \param producer A producer TensorDomain
- //! \param consumer A consumer TensorDomain
- //! \param root_dims_to_map Maps only producer root domains in this set
- std::unordered_map<IterDomain*, IterDomain*> mapProducerToConsumer(
- const TensorDomain* producer,
- const TensorDomain* consumer,
- const std::unordered_set<IterDomain*>& root_dims_to_map) const;
-
- //! Return a map from a producer TensorDomain to a consumer
- //! TensorDomain
- //!
- //! \param producer A producer TensorDomain
- //! \param consumer A consumer TensorDomain
- std::unordered_map<IterDomain*, IterDomain*> mapProducerToConsumer(
- const TensorDomain* producer,
- const TensorDomain* consumer) const;
-
- //! Return a map from a consumer TensorDomain to a producer
- //! TensorDomain
- //!
- //! \param consumer A consumer TensorDomain
- //! \param producer A producer TensorDomain
- //! \param root_dims_to_map Maps only consumer root domains in this set
- std::unordered_map<IterDomain*, IterDomain*> mapConsumerToProducer(
- const TensorDomain* consumer,
- const TensorDomain* producer,
- const std::unordered_set<IterDomain*>& root_dims_to_map) const;
-
- //! Return a map from a consumer TensorDomain to a producer
- //! TensorDomain
- //!
- //! \param consumer A consumer TensorDomain
- //! \param producer A producer TensorDomain
- std::unordered_map<IterDomain*, IterDomain*> mapConsumerToProducer(
- const TensorDomain* consumer,
- const TensorDomain* producer) const;
-
- protected:
- //! Return a map between root IterDomains of a producer-consumer
- //! pair.
- //!
- //! \param producer A producer TensorDomain
- //! \param consumer A consumer TensorDomain
- //! \param root_dims_to_map Maps only from IterDomains in this set
- //! \param producer_to_consumer Maps from producer to consumer if true
- virtual std::unordered_map<IterDomain*, IterDomain*> map(
- const TensorDomain* producer,
- const TensorDomain* consumer,
- const std::unordered_set<IterDomain*>& root_dims_to_map,
- bool producer_to_consumer) const = 0;
-};
-
-//! Maps root domains of a producer-consumer pair. This class only
-//! looks at the given pair of TensorViews and does not take into
-//! consideration the constraints of the computeAt transformation,
-//! i.e., unable to compute the same tensors multiple times. This
-//! should not be used for transformations implementing computeAt, but
-//! should be valid otherwise.
-class TORCH_CUDA_CU_API PairwiseRootDomainMap : public RootDomainMap {
- public:
- //! \param producer The producer tensor of a producer-consumer pair.
- //! \param consumer The consumer tensor of a producer-consumer pair.
- explicit PairwiseRootDomainMap(
- const TensorView* producer,
- const TensorView* consumer);
-
- const TensorView* producer() const {
- return producer_tv_;
- }
-
- const TensorView* consumer() const {
- return consumer_tv_;
- }
-
- protected:
- std::unordered_map<IterDomain*, IterDomain*> map(
- const TensorDomain* producer,
- const TensorDomain* consumer,
- const std::unordered_set<IterDomain*>& root_dims_to_map,
- bool producer_to_consumer) const override;
-
- std::unordered_map<IterDomain*, IterDomain*> mapTranspose(
- const TensorDomain* producer,
- const TensorDomain* consumer,
- const std::unordered_set<IterDomain*>& root_dims_to_map,
- bool producer_to_consumer) const;
-
- private:
- const TensorView* producer_tv_ = nullptr;
- const TensorView* consumer_tv_ = nullptr;
-};
-
-std::string toString(const PairwiseRootDomainMap& root_map);
-
-//! Represents an iteration domain of a TensorDomain. Only used for
-//! root domain mapping.
-//!
-//! Note that an IterDomain object may be reused
-//! across multiple TensorDomains, but an IterDomain in a
-//! TensorDomain may not be necessarily mappable to the same
-//! IterDomain used in a different TensorDomain. Thus, for the purpose
-//! of root domain mapping, an iteration domain needs to be identified
-//! with an IterDomain and its TensorDomain.
-class DomainKey {
- public:
- DomainKey() = default;
- DomainKey(
- const TensorDomain* td,
- const IterDomain* id,
- const IterDomain* concrete_id = nullptr)
- : td_(td), id_(id), concrete_id_(concrete_id) {}
- const TensorDomain* td() const {
- return td_;
- }
- const IterDomain* id() const {
- return id_;
- }
- const IterDomain* concreteId() const {
- return concrete_id_;
- }
- bool operator==(const DomainKey& other) const {
- return td() == other.td() && id() == other.id() &&
- concreteId() == other.concreteId();
- }
-
- private:
- const TensorDomain* td_ = nullptr;
- const IterDomain* id_ = nullptr;
- const IterDomain* concrete_id_ = nullptr;
-};
-
-std::string toString(const DomainKey& key);
-
-struct DomainKeyHash {
- std::size_t operator()(const DomainKey& key) const {
- return std::hash<const TensorDomain*>{}(key.td()) ^
- std::hash<const IterDomain*>{}(key.id());
- }
-};
-
-using DomainKeySet = std::unordered_set<DomainKey, DomainKeyHash>;
-
-template <typename Mapped>
-using DomainKeyMap = std::unordered_map<DomainKey, Mapped, DomainKeyHash>;
-
-class ComputeAtRootDomainMap;
-
-//! A helper class to find all DomainKeys that are consumers of
-//! reduction outputs. Such consumer IterDomains may not be mapped to
-//! the producer reduction domain since the corresponding reduction
-//! loop must be closed before any of the consumers can appear.
-class TORCH_CUDA_CU_API UnmappableReductionDomains : private IterVisitor {
- public:
- UnmappableReductionDomains();
- virtual ~UnmappableReductionDomains() = default;
-
- //! Returns true when mapping consumer domains would cause a
- //! reduction output domain to be mapped with a consumer domain of
- //! the redution. It needs to be avoided as computing consumers of
- //! reduction outputs within the corresponding reduction loop is not
- //! possible. This routine is used to build root domain mappings.
- bool isReductionOutputMapped(
- const std::vector<DomainKey>& consumer_domains,
- const ComputeAtRootDomainMap& root_map) const;
-
- private:
- using IterVisitor::handle;
- void handle(ReductionOp* op) override;
- void handle(WelfordOp* op) override;
-
- void handleReductionOutput(TensorView* out_tv);
-
- private:
- //! Map from Reduction output DomainKeys to consumer DomainKeys
- DomainKeyMap<DomainKeySet> reduction_domains_;
- //! Map from Reduction output DomainKeys to producer DomainKeys
- DomainKeyMap<DomainKeySet> reduction_domain_inputs_;
-};
-
-//! Models root-domain mappings for computeAt
-//!
-//! Two iteration domains are mapped when computeAt of one iteration
-//! domain is possible at another iteration domain. Consider a simple
-//! example:
-//! T2 [i0,i1] = T1[i2,i3] + T0[i4,i5]
-//! This will create mappings between i0, i2 and i4.
-class TORCH_CUDA_CU_API ComputeAtRootDomainMap : public RootDomainMap {
- friend class ComputeAtRootDomainMapBuilder;
- friend std::string toString(const ComputeAtRootDomainMap&);
-
- public:
- //! Builds a mapping table by analyzing the current
- //! fusion. Overwrite a previous table if any.
- //!
- //! \param map_through_reduction If set
- //! true, will disable UnmappableReductionDomains check.
- //! This is only for re-using logic in detecting
- //! normalization fusions, which deviates slightly from
- //! intended use of this class. Should always be true
- //! in compute_at use cases.
- void build(bool map_through_reduction = false);
-
- //! Returns if key(td_a, id_a) and key(td_b, id_b) are mapped to eachother
- //! (equivalent), or are the same key.
- //!
- //! \param td_a A TensorDomain
- //! \param id_a An IterDomain in td_a
- //! \param td_b Another TensorDomain
- //! \param id_b An IterDomain in td_b
- //! \returns Boolean representing if they are mapped
- bool canMap(
- const TensorDomain* td_a,
- const IterDomain* id_a,
- const TensorDomain* td_b,
- const IterDomain* id_b) const;
-
- //! Make a TensorDomain an alias of another TensorDomain
- //!
- //! This is for the computeAt transformation, where TensorViews are
- //! updated with new TensorDomains. Since they keep using the same
- //! root doamins, the root mapping remains valid but needs to
- //! reflect the use of new TensorDomains as aliases of the existing
- //! ones.
- //!
- //! \param td An existing TensorDomain
- //! \param td_alias An alias of td
- void setAlias(const TensorDomain* td, const TensorDomain* td_alias);
-
- //! Return a map between TensorDomains
- //!
- //! Unlike the other map functions, two TensorDomains do not need to
- //! be a producer-consumer pair. Since they may not be a
- //! producer-consumer pair, this function requires proper root
- //! domains, which may be root or rfactor domains. Also, no error
- //! check is done as we do not assume producer-consumer relationship.
- //!
- //! \param from_td A TensorDomain from which a map is created
- //! \param from_root A root domain of from_td
- //! \param to_td A TensorDomain to which a map is created
- //! \param to_root A root domain of to_td
- std::unordered_map<IterDomain*, IterDomain*> mapBestEffort(
- const TensorDomain* from_td,
- const std::vector<IterDomain*>& from_root,
- const TensorDomain* to_td,
- const std::vector<IterDomain*>& to_root) const;
-
- // Returns an unordered set of all iter domains in producer and consumer that
- // can map to eachother
- std::unordered_set<IterDomain*> getMappableDims(
- const TensorDomain* producer,
- const TensorDomain* consumer) const;
-
- private:
- //! Returns if key_a and key(td_b, id_b) are mapped to eachother (equivalent),
- //! or are the same key.
- //!
- //! \param key_a A DomainKey
- //! \param td_b Another TensorDomain
- //! \param id_b An IterDomain in td_b
- //! \returns Boolean representing if they are mapped
- bool canMap(
- const DomainKey& key_a,
- const TensorDomain* td_b,
- const IterDomain* id_b) const;
-
- //! Returns if key_a and key_b are mapped to eachother (equivalent), or are
- //! the same key.
- bool canMap(const DomainKey& key_a, const DomainKey& key_b) const;
-
- //! Returns the set of (non-broadcast) DomainKeys that id in td is
- //! broadcasted to. Can result in more than one "concrete" DomainKey.
- std::vector<DomainKey> getConcretizedKeys(
- const TensorDomain* td,
- const IterDomain* id) const;
-
- //! Returns the set of (non-broadcast) iter domains that id in td is
- //! broadcasted to. Can result in more than one "concrete" iter domain.
- std::unordered_set<const IterDomain*>& getConcretizedDomains(
- const TensorDomain* td,
- const IterDomain* id);
-
- //! Return a map between root IterDomains of a producer-consumer
- //! pair.
- //!
- //! \param producer A producer TensorDomain
- //! \param consumer A consumer TensorDomain
- //! \param root_dims_to_map Maps only from IterDomains in this set
- //! \param producer_to_consumer Maps from producer to consumer if true
- std::unordered_map<IterDomain*, IterDomain*> map(
- const TensorDomain* producer,
- const TensorDomain* consumer,
- const std::unordered_set<IterDomain*>& root_dims_to_map,
- bool producer_to_consumer) const override;
-
- private:
- //! Disjoint set of all mapped <TD, ID> keys to determine axes equivalency
- DisjointSet<DomainKey, DomainKeyHash> eq_set_;
-
- //! All IterDomains in the mapping that are a broadcast ID
- DomainKeyMap<std::unordered_set<const IterDomain*>> bcast_map_;
-
- //! Broadcast iter domain that does not match dimensions in its produer,
- //! meaning it is a brand new domain in its TensorDomain.
- DomainKeySet new_broadcast_domains_;
-
- //! Keep track of window axes so that the map function can ignore them.
- std::unordered_set<IterDomain*> window_axes_;
-};
-
-std::string toString(const ComputeAtRootDomainMap& root_map);
-
-//! Create a DisjointSet of root IterDomains by traversing the
-//! current fusion entirely. IterDomains that can be mapped each
-//! other with computeAt are grouped into the same subset in the
-//! DisjointSet.
-class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder
- : private BackwardVisitor {
- public:
- explicit ComputeAtRootDomainMapBuilder(
- ComputeAtRootDomainMap& root_map,
- bool map_through_reduction = false);
-
- private:
- //! Initialize the bcast map for fusion outputs
- void initializeBcastMap(const TensorView* tv, const IterDomain* id);
-
- //! Set a pair of producer-consumer domain keys as mappable
- void setMapped(const DomainKey& producer, const DomainKey& consumer);
-
- //! Track a pair of producer-consumer domains as potentially mappable. Inserts
- //! entries into pending_map_, but does not add anything into the root_map_
- //! (added when handle is called on a TensorView). Maybe mapped will, however,
- //! immediately propagate broadcast iter domains.
- void setMaybeMapped(
- const TensorDomain* producer_td,
- const IterDomain* producer_id,
- const TensorDomain* consumer_td,
- const IterDomain* consumer_id);
-
- void addToPendingList(const DomainKey& producer, const DomainKey& consumer);
-
- //! Map pointwise IterDomains from inputs of expressions to outputs.
- //! Do not map reduction IterDomains in inputs.
- void mapPointwiseOrReductionOp(Expr* e);
-
- using BackwardVisitor::handle;
-
- void handle(Expr* e) override;
-
- void handle(UnaryOp* uop) override {
- mapPointwiseOrReductionOp(uop);
- }
-
- void handle(BinaryOp* bop) override {
- mapPointwiseOrReductionOp(bop);
- }
-
- void handle(TernaryOp* top) override {
- mapPointwiseOrReductionOp(top);
- }
-
- void handle(ReductionOp* op) override {
- mapPointwiseOrReductionOp(op);
- }
-
- void handle(WelfordOp* wop) override {
- mapPointwiseOrReductionOp(wop);
- }
-
- void handle(ShiftOp* op) override {
- mapPointwiseOrReductionOp(op);
- }
-
- void handle(BroadcastOp* op) override;
-
- void handle(TransposeOp* op) override;
-
- void handle(GatherOp* op) override;
-
- void handle(TensorView* tv) override;
-
- //! Maps all consumers with a producer.
- //! This is called for each of TensorViews in a backward traversal,
- //! recursively building mappings from the output tensors to the
- //! input tensors.
- bool mapAllConsumers(const DomainKey& producer_key);
-
- bool hasMatchingDomains(const std::vector<DomainKey>& unique_domains);
-
- bool safeToMap(const DomainKeySet& domains);
-
- private:
- ComputeAtRootDomainMap& root_map_;
- //! Keep track of what we want to try and map. Set in attemptToProveId.
- DomainKeyMap<DomainKeySet> pending_map_;
- std::unordered_set<Expr*> visited_;
- UnmappableReductionDomains incompatible_domains_;
-
- //! Disable UnmappableReductions check, should
- //! always be false for compute_at use cases
- bool map_through_reduction_ = false;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
typename _dim3bd>
__device__ void blockReduce(
T& out,
- const T& inp_val,
+ const T inp_val,
Func reduction_op,
const _dim3ti& thread_idx,
const _dim3bd& block_dim,
T* shared_mem,
- bool read_pred,
- bool write_pred,
+ bool read_write_pred,
T init_val) {
unsigned int reduction_size = (X_REDUCE ? block_dim.x : 1) *
(Y_REDUCE ? block_dim.y : 1) * (Z_REDUCE ? block_dim.z : 1);
assert(reduction_stride != 0);
- if (read_pred) {
+ if (read_write_pred) {
shared_mem[linear_tid] = inp_val;
} else {
shared_mem[linear_tid] = init_val;
}
- block_sync::sync();
+ __syncthreads();
// Reduce down to nearest power of 2:
int np2 = 1 << (31 - __clz(reduction_size));
shared_mem[linear_tid + np2 * reduction_stride]);
}
}
- block_sync::sync();
- // loop peel the final iteration to save one syncthread for the end
- for (int factor = np2 / 2; factor > 1; factor >>= 1) {
+ __syncthreads();
+ // for (int factor = np2/2; factor > contig_threads / 2; factor>>=1) {
+ for (int factor = np2 / 2; factor > 0; factor >>= 1) {
if (reduction_tid < factor) {
reduction_op(
shared_mem[linear_tid],
shared_mem[linear_tid + factor * reduction_stride]);
}
- block_sync::sync();
- }
-
- if (should_write && write_pred) {
- T result = out;
- reduction_op(result, shared_mem[linear_tid]);
- if (reduction_size > 1) {
- reduction_op(result, shared_mem[linear_tid + 1 * reduction_stride]);
- }
- out = result;
+ __syncthreads();
}
- block_sync::sync();
-}
-// Use the same pred for both reads and writes
-template <
- bool X_REDUCE,
- bool Y_REDUCE,
- bool Z_REDUCE,
- typename T,
- typename Func,
- typename _dim3ti,
- typename _dim3bd>
-__device__ void blockReduce(
- T& out,
- const T& inp_val,
- Func reduction_op,
- const _dim3ti& thread_idx,
- const _dim3bd& block_dim,
- T* shared_mem,
- bool read_write_pred,
- T init_val) {
- blockReduce<X_REDUCE, Y_REDUCE, Z_REDUCE, T, Func, _dim3ti, _dim3bd>(
- out,
- inp_val,
- reduction_op,
- thread_idx,
- block_dim,
- shared_mem,
- read_write_pred,
- read_write_pred,
- init_val);
+ if (should_write && read_write_pred)
+ out = shared_mem[linear_tid];
}
+++ /dev/null
-
-// Counter-based block synchronization. Only meant to be used for
-// debugging and validating synchronization. This should be replaced
-// with cuda::barrier::arrive_and_wait as that should be more robust.
-
-namespace block_sync {
-
-using CounterType = unsigned int;
-static constexpr CounterType COUNTER_TYPE_MAX = ~(CounterType)0;
-__shared__ CounterType sync_counter;
-
-__device__ void init() {
- const unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x +
- threadIdx.z * blockDim.x * blockDim.y;
- if (tid == 0) {
- sync_counter = 0;
- }
- __syncthreads();
-}
-
-// Emulate __syncthreads() with a synchronization counter
-__device__ void sync() {
- unsigned int backoff = 8;
- const unsigned int backoff_max = 256;
- const unsigned int num_threads = blockDim.x * blockDim.y * blockDim.z;
-
- __threadfence_block();
-
- // Use counter range only up to a limit so that the next val won't
- // overflow.
-
- const auto counter_max = (COUNTER_TYPE_MAX / num_threads) * num_threads;
- const auto old = atomicInc(&sync_counter, counter_max - 1);
-
- const auto next = (old / num_threads) * num_threads + num_threads;
-
- auto local_sync_counter = *(volatile CounterType*)(&sync_counter);
-
- // sync_counter may wrap around, which means local_sync_counter
- // becomes smaller than old. In that case, it's guaranteed that all
- // threads have incremented the counter.
- while (local_sync_counter < next && old < local_sync_counter) {
- __nanosleep(backoff);
- if (backoff < backoff_max) {
- backoff *= 2;
- }
- local_sync_counter = *(volatile CounterType*)(&sync_counter);
- }
-}
-
-} // namespace block_sync
+++ /dev/null
-
-// Default block synchronization. Just use __barrier_sync
-namespace block_sync {
-
-__forceinline__ __device__ void init() {}
-
-// Thread-block synchronization
-__forceinline__ __device__ void sync() {
- __barrier_sync(0);
-}
-
-} // namespace block_sync
-
namespace broadcast {
template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD>
// out: Per-thread output location
//
template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename T>
-__device__ void blockBroadcast(
- T& out,
- const T& inp_val,
- T* shared_mem,
- bool read_write_pred) {
+__device__ void blockBroadcast(T& out, T inp_val, T* shared_mem) {
const bool has_valid_data = (!X_THREAD || threadIdx.x == 0) &&
(!Y_THREAD || threadIdx.y == 0) && (!Z_THREAD || threadIdx.z == 0);
const auto shared_offset =
offset_of_source<X_THREAD, Y_THREAD, Z_THREAD>(blockDim, threadIdx);
- if (has_valid_data && read_write_pred) {
+ if (has_valid_data)
shared_mem[shared_offset] = inp_val;
- }
-
- block_sync::sync();
- if (read_write_pred) {
- out = shared_mem[shared_offset];
- }
+ __syncthreads();
- block_sync::sync();
+ out = shared_mem[shared_offset];
}
} // namespace broadcast
-
-#define __NVFUSER_HALF_TO_US(var) *(reinterpret_cast<unsigned short*>(&(var)))
-#define __NVFUSER_HALF_TO_CUS(var) \
- *(reinterpret_cast<const unsigned short*>(&(var)))
-
-struct __half;
-__device__ __half __float2half(const float);
+#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short*>(&(var)))
+#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short*>(&(var)))
struct __align__(2) __half {
- __half() = default;
-
- __device__ __half(const float f) {
- __x = __float2half(f).__x;
- }
+ __host__ __device__ __half() {}
protected:
unsigned short __x;
__device__ __half __float2half(const float f) {
__half val;
- asm("{ cvt.rn.f16.f32 %0, %1;}\n"
- : "=h"(__NVFUSER_HALF_TO_US(val))
- : "f"(f));
+ asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__HALF_TO_US(val)) : "f"(f));
return val;
}
__device__ float __half2float(const __half h) {
float val;
- asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__NVFUSER_HALF_TO_CUS(h)));
+ asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(h)));
return val;
}
-
-// aligned vector generates vectorized load/store on CUDA
-template <typename scalar_t, int vec_size>
-struct alignas(sizeof(scalar_t) * vec_size) Array {
- scalar_t val[vec_size];
- __device__ void set(scalar_t v) {
- for (int i = 0; i < vec_size; ++i) {
- val[i] = v;
- }
- }
-};
// Utility functions
template <typename _dim3>
-__device__ __forceinline__ nvfuser_index_t size(const _dim3& d) {
- return (nvfuser_index_t)d.x * (nvfuser_index_t)d.y * (nvfuser_index_t)d.z;
+__device__ __forceinline__ size_t size(const _dim3& d) {
+ return (size_t)d.x * (size_t)d.y * (size_t)d.z;
}
-#define isize(d) ((d).x * (d).y * (d).z)
+#define isize(d) d.x* d.y* d.z
template <typename _dim3pos, typename _dim3dim>
-__device__ __forceinline__ nvfuser_index_t
+__device__ __forceinline__ size_t
offset(const _dim3pos& pos, const _dim3dim& dim) {
- return (nvfuser_index_t)pos.x +
- (nvfuser_index_t)pos.y * (nvfuser_index_t)dim.x +
- (nvfuser_index_t)pos.z * (nvfuser_index_t)dim.x * (nvfuser_index_t)dim.y;
+ return (size_t)pos.x + (size_t)pos.y * (size_t)dim.x +
+ (size_t)pos.z * (size_t)dim.x * (size_t)dim.y;
}
-#define ioffset(pos, dim) \
- ((pos).x + (pos).y * (dim).x + (pos).z * (dim).x * (dim).y)
+#define ioffset(pos, dim) pos.x + pos.y* dim.x + pos.z* dim.x* dim.y
// Returns dim3 of each reduction segment.
template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
// Returns the number of blocks in each reduction segment.
template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
-__device__ nvfuser_index_t size_of_reduction_segment(const _dim3& grid_dim) {
+__device__ size_t size_of_reduction_segment(const _dim3& grid_dim) {
return size(
dimension_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(grid_dim));
}
// Returns the total number of reduction segments.
template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
-__device__ nvfuser_index_t number_of_reduction_segments(const _dim3& grid_dim) {
+__device__ size_t number_of_reduction_segments(const _dim3& grid_dim) {
return (X_BLOCK ? 1 : grid_dim.x) * (Y_BLOCK ? 1 : grid_dim.y) *
(Z_BLOCK ? 1 : grid_dim.z);
}
bool Z_BLOCK,
typename _dim3bi,
typename _dim3gd>
-__device__ nvfuser_index_t
+__device__ size_t
index_of_reduction_segment(const _dim3bi& block_idx, const _dim3gd& grid_dim) {
- nvfuser_index_t seg_idx = 0;
+ size_t seg_idx = 0;
if (!Z_BLOCK)
seg_idx += block_idx.z;
if (!Y_BLOCK)
bool Z_BLOCK,
typename _dim3bi,
typename _dim3gd>
-__device__ nvfuser_index_t
+__device__ size_t
offset_in_reduction_segment(const _dim3bi& block_idx, const _dim3gd& grid_dim) {
- nvfuser_index_t offset = 0;
+ size_t offset = 0;
if (Z_BLOCK)
offset = offset * grid_dim.z + block_idx.z;
if (Y_BLOCK)
__device__ void gridReduceLastBlock(
T& out,
const T* in,
- const nvfuser_index_t in_size,
+ const size_t in_size,
Func reduction_op,
T* shared_buf,
- bool write_pred,
+ bool read_write_pred,
T init_val) {
const int tid = ioffset(threadIdx, blockDim);
const int block_size = isize(blockDim);
if (tid < in_size) {
inp = in[tid];
}
- for (nvfuser_index_t i = tid + block_size; i < in_size; i += block_size) {
+ for (size_t i = tid + block_size; i < in_size; i += block_size) {
reduction_op(inp, in[i]);
}
if (rem_size > 1) {
const int rblock_offset = tid % rblock_size;
const int rblock_idx = tid / rblock_size;
- T inp_tmp = init_val;
blockReduce<false, true, false>(
- inp_tmp,
+ inp,
inp,
reduction_op,
dim3{(unsigned)rblock_offset, (unsigned)rblock_idx, 0},
shared_buf,
true,
init_val);
- block_sync::sync();
- inp = inp_tmp;
+ __syncthreads();
if (tid < rblock_size) {
shared_buf[tid] = inp;
}
- block_sync::sync();
+ __syncthreads();
if (should_write) {
inp = shared_buf[offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(
threadIdx, blockDim)];
}
}
- if (should_write && write_pred) {
- reduction_op(out, inp);
+ if (should_write && read_write_pred) {
+ out = inp;
}
}
typename Func>
__device__ bool gridReduce(
T& out,
- const T& inp_val,
+ T inp_val,
Func reduction_op,
volatile T* work_buf,
Tensor<int64_t, 1> sync_flags,
T* shared_buf,
- bool read_pred,
- bool write_pred,
+ bool read_write_pred,
T init_val) {
// Number of values to reduce in the grid dimensions
const auto seg_size =
offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(
threadIdx, blockDim);
auto work_buf_offset = rblock_size * rblock_offset + thread_offset;
- if (read_pred) {
+ if (read_write_pred) {
work_buf[work_buf_offset] = inp_val;
} else {
work_buf[work_buf_offset] = init_val;
}
}
- block_sync::sync();
+ __syncthreads();
__shared__ bool last_block;
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
last_block = old + 1 == seg_size;
// printf("Last_block = %d + 1 == %d\n", (int)old, (int)seg_size);
}
- block_sync::sync();
+ __syncthreads();
if (last_block) {
// printf("Last block %d %d %d %d\n", blockIdx.x, blockIdx.y, blockIdx.z);
seg_size * rblock_size,
reduction_op,
shared_buf,
- write_pred,
+ read_write_pred,
init_val);
return true;
} else {
}
} // namespace reduction
-
-#undef isize
-#undef ioffset
-#define NVFUSER_DEFINE_MAGIC_ZERO \
- __shared__ int nvfuser_zero_s; \
- if (threadIdx.x == 0) \
- nvfuser_zero_s = 0; \
- __syncthreads(); \
- atomicMin(&nvfuser_zero_s, threadIdx.x); \
- int nvfuser_zero = nvfuser_zero_s;
-
-#define NVFUSER_UPDATE_MAGIC_ZERO \
- do { \
- nvfuser_zero <<= 1; \
- } while (0);
-
__device__ constexpr int ceilDiv(int a, int b) {
return (a + b - 1) / b;
}
-__device__ constexpr int64_t ceilDiv(int64_t a, int64_t b) {
- return (a + b - 1) / b;
-}
-
-__device__ constexpr int64_t ceilDiv(int64_t a, int b) {
- return ceilDiv(a, (int64_t)b);
-}
-
-__device__ constexpr int64_t ceilDiv(int a, int64_t b) {
- return ceilDiv((int64_t)a, b);
-}
-
__device__ constexpr int alignBufferSize(int buffer, int size) {
return (buffer + (size - 1)) & ~(size - 1);
}
-__device__ double clamp(double x, double minv, double maxv) {
- return x < minv ? minv : (x > maxv ? maxv : x);
-}
-
-__device__ float clamp(float x, double minv, double maxv) {
+__device__ float clamp(float x, float minv, float maxv) {
return x < minv ? minv : (x > maxv ? maxv : x);
}
-__device__ double frac(double x) {
- return x - trunc(x);
-}
-
__device__ float frac(float x) {
- return x - trunc(x);
-}
-
-__device__ double gelu(double x) {
- return x * normcdf(x);
+ return x - truncf(x);
}
__device__ float gelu(float x) {
return x * normcdf(x);
}
-__device__ double reciprocal(double x) {
- return 1 / x;
-}
-
__device__ float reciprocal(float x) {
- return 1 / x;
-}
-
-__device__ double relu(double x) {
- return x <= 0 ? 0 : x;
+ return 1.f / x;
}
__device__ float relu(float x) {
- return x <= 0 ? 0 : x;
-}
-
-__device__ double remainder(double a, double b) {
- auto mod = ::fmod(a, b);
- if ((mod != 0) && ((b < 0) != (mod < 0)))
- mod += b;
- return mod;
+ return x <= 0.f ? 0.f : x;
}
__device__ float remainder(float a, float b) {
- auto mod = ::fmod(a, b);
- if ((mod != 0) && ((b < 0) != (mod < 0)))
- mod += b;
- return mod;
-}
-
-__device__ double sigmoid(double x) {
- return 1 / (1 + exp(-x));
+ return a - b * floorf(a / b);
}
__device__ float sigmoid(float x) {
- return 1 / (1 + exp(-x));
+ return 1.f / (1.f + expf(-x));
}
-__device__ double silu(double x) {
- return x * sigmoid(x);
-}
-
-__device__ float silu(float x) {
- return x * sigmoid(x);
-}
-
-__device__ double threshold(double x, double t, double v) {
+__device__ float threshold(float x, float t, float v) {
return x <= t ? v : x;
}
-__device__ float threshold(float x, double t, double v) {
- return x <= t ? v : x;
-}
-
-__device__ double where(bool c, double a, double b) {
- return c ? a : b;
-}
-
__device__ float where(bool c, float a, float b) {
return c ? a : b;
}
-__device__ int64_t where(bool c, int64_t a, int64_t b) {
- return c ? a : b;
-}
-
-__device__ double randLike(Philox rnd) {
- return uniform(rnd(), rnd());
-}
-
-__device__ float randLikef(Philox rnd) {
- return uniformf(rnd());
-}
-
-__device__ constexpr int64_t remainder(int64_t a, int64_t b) {
- auto mod = a % b;
- if ((mod != 0) && ((b < 0) != (mod < 0)))
- mod += b;
- return mod;
-}
-
-__device__ constexpr int remainder(int a, int b) {
- auto mod = a % b;
- if ((mod != 0) && ((b < 0) != (mod < 0)))
- mod += b;
- return mod;
+__device__ float randLike(Philox rnd) {
+ return uniform(rnd());
}
unsigned int STATE = 0;
};
-__device__ float uniformf(unsigned int x) {
+__device__ float uniform(unsigned int x) {
constexpr float kRanInvM32 = 2.3283064e-10f; // Inverse of 2^32.
return x * kRanInvM32;
}
-
-__device__ double uniform(unsigned int x, unsigned int y) {
- constexpr double kRan2Pow53Inv = 1.1102230246251565e-16;
- const unsigned long long z =
- (unsigned long long)x ^ ((unsigned long long)y << (53 - 32));
- return z * kRan2Pow53Inv + (kRan2Pow53Inv / 2.0);
-}
+typedef unsigned char uint8_t;
+typedef signed char int8_t;
+typedef short int int16_t;
+typedef long long int int64_t;
+
template <typename T, int N>
struct Tensor {
- __device__ T& operator[](nvfuser_index_t ind) {
+ __device__ T& operator[](int64_t ind) {
return data[ind];
};
T* data;
- nvfuser_index_t size[N];
- nvfuser_index_t stride[N];
+ int64_t size[N];
+ int64_t stride[N];
};
// Specialization for 0-dim case as it does not need size and stride arrays.
// They will be an error as well since zero-length arrays are not allowed.
template <typename T>
struct Tensor<T, 0> {
- __device__ T& operator[](nvfuser_index_t) {
+ __device__ T& operator[](int64_t) {
return *data;
};
+++ /dev/null
-// -----------------------------------------------------------------------------------------------
-// Block Welford Primitives
-// -----------------------------------------------------------------------------------------------
-// Basic utility for welford update. Can be used to scan one value, or two merge
-// two welford results
-template <typename T, typename TN>
-__inline__ __device__ void welfordCombine(
- T& a_avg,
- T& a_M2,
- TN& a_N,
- const T& b_avg,
- const T& b_M2,
- TN b_N) {
- if (b_N == 0) {
- return;
- }
- TN ab_N = a_N + b_N;
- T delta = b_avg - a_avg;
- a_avg += delta * b_N / ab_N;
- a_M2 += b_M2 + delta * delta * a_N * b_N / ab_N;
- a_N = ab_N;
-}
-
-// [Z,Y,X]_THREADS is the number of participating threads in the z, y, x
-// dimension of the block.
-template <
- bool X_REDUCE,
- bool Y_REDUCE,
- bool Z_REDUCE,
- typename T,
- typename TN,
- typename _dim3ti,
- typename _dim3bd>
-__inline__ __device__ void blockWelford(
- T& out_avg,
- T& out_M2,
- TN& out_N,
- const T& in_avg,
- const T& in_M2,
- const TN& in_N,
- const _dim3ti& thread_idx,
- const _dim3bd& block_dim,
- T* shared_mem_avg,
- T* shared_mem_M2,
- TN* shared_mem_N,
- bool read_pred,
- bool write_pred,
- T init_val) {
- unsigned int reduction_size = (X_REDUCE ? block_dim.x : 1) *
- (Y_REDUCE ? block_dim.y : 1) * (Z_REDUCE ? block_dim.z : 1);
- // If this thread will output a final result
- bool should_write = true;
- if (X_REDUCE)
- should_write = should_write && thread_idx.x == 0;
- if (Y_REDUCE)
- should_write = should_write && thread_idx.y == 0;
- if (Z_REDUCE)
- should_write = should_write && thread_idx.z == 0;
- unsigned int reduction_stride;
- unsigned int reduction_tid;
- unsigned int linear_tid;
- if (X_REDUCE && !Y_REDUCE && Z_REDUCE) {
- // Transpose Z and Y in the shared memory so Z and X dims are contiguous in
- // smem
- reduction_stride = 1;
- linear_tid = threadIdx.y * blockDim.z * blockDim.x +
- threadIdx.z * blockDim.x + threadIdx.x;
- reduction_tid = threadIdx.z * blockDim.x + threadIdx.x;
- } else {
- // Normal reduction in order
- reduction_stride =
- (X_REDUCE ? 1
- : (Y_REDUCE ? block_dim.x
- : (Z_REDUCE ? block_dim.x * block_dim.y : 0)));
- linear_tid = thread_idx.z * block_dim.y * block_dim.x +
- thread_idx.y * block_dim.x + thread_idx.x;
- reduction_tid = (Z_REDUCE ? thread_idx.z : 0) *
- (Y_REDUCE ? block_dim.y : 1) * (X_REDUCE ? block_dim.x : 1) +
- (Y_REDUCE ? thread_idx.y : 0) * (X_REDUCE ? block_dim.x : 1) +
- (X_REDUCE ? thread_idx.x : 0);
- }
- assert(reduction_stride != 0);
- if (read_pred) {
- shared_mem_avg[linear_tid] = in_avg;
- shared_mem_M2[linear_tid] = in_M2;
- shared_mem_N[linear_tid] = in_N;
- } else {
- shared_mem_avg[linear_tid] = init_val;
- shared_mem_M2[linear_tid] = init_val;
- shared_mem_N[linear_tid] = 0;
- }
- block_sync::sync();
- // Reduce down to nearest power of 2:
- int np2 = 1 << (31 - __clz(reduction_size));
- if (reduction_tid < np2) {
- if (reduction_tid + np2 < reduction_size) {
- welfordCombine(
- shared_mem_avg[linear_tid],
- shared_mem_M2[linear_tid],
- shared_mem_N[linear_tid],
- shared_mem_avg[linear_tid + np2 * reduction_stride],
- shared_mem_M2[linear_tid + np2 * reduction_stride],
- shared_mem_N[linear_tid + np2 * reduction_stride]);
- }
- }
- block_sync::sync();
-
- // loop peel the final iteration to save one syncthread for the end
- for (int factor = np2 / 2; factor > 1; factor >>= 1) {
- if (reduction_tid < factor) {
- welfordCombine(
- shared_mem_avg[linear_tid],
- shared_mem_M2[linear_tid],
- shared_mem_N[linear_tid],
- shared_mem_avg[linear_tid + factor * reduction_stride],
- shared_mem_M2[linear_tid + factor * reduction_stride],
- shared_mem_N[linear_tid + factor * reduction_stride]);
- }
- block_sync::sync();
- }
- if (should_write && write_pred) {
- T res_avg = out_avg;
- T res_M2 = out_M2;
- TN res_N = out_N;
- welfordCombine(
- res_avg,
- res_M2,
- res_N,
- shared_mem_avg[linear_tid],
- shared_mem_M2[linear_tid],
- shared_mem_N[linear_tid]);
- if (reduction_size > 1) {
- welfordCombine(
- res_avg,
- res_M2,
- res_N,
- shared_mem_avg[linear_tid + reduction_stride],
- shared_mem_M2[linear_tid + reduction_stride],
- shared_mem_N[linear_tid + reduction_stride]);
- }
- out_avg = res_avg;
- out_M2 = res_M2;
- out_N = res_N;
- }
- block_sync::sync();
-}
-
-// Use the same pred for both reads and writes
-template <
- bool X_REDUCE,
- bool Y_REDUCE,
- bool Z_REDUCE,
- typename T,
- typename TN,
- typename _dim3ti,
- typename _dim3bd>
-__inline__ __device__ void blockWelford(
- T& out_avg,
- T& out_M2,
- TN& out_N,
- const T& in_avg,
- const T& in_M2,
- const TN& in_N,
- const _dim3ti& thread_idx,
- const _dim3bd& block_dim,
- T* shared_mem_avg,
- T* shared_mem_M2,
- TN* shared_mem_N,
- bool read_write_pred,
- T init_val) {
- blockWelford<X_REDUCE, Y_REDUCE, Z_REDUCE, T, TN, _dim3ti, _dim3bd>(
- out_avg,
- out_M2,
- out_N,
- in_avg,
- in_M2,
- in_N,
- thread_idx,
- block_dim,
- shared_mem_avg,
- shared_mem_M2,
- shared_mem_N,
- read_write_pred,
- read_write_pred,
- init_val);
-}
-// -----------------------------------------------------------------------------------------------
-// Grid Welford Prototype
-// -----------------------------------------------------------------------------------------------
-namespace welford {
-// Utility functions
-template <typename _dim3>
-__host__ __device__ __forceinline__ nvfuser_index_t size(const _dim3& d) {
- return (nvfuser_index_t)d.x * (nvfuser_index_t)d.y * (nvfuser_index_t)d.z;
-}
-
-#define isize(d) ((d).x * (d).y * (d).z)
-
-template <typename _dim3pos, typename _dim3dim>
-__host__ __device__ __forceinline__ nvfuser_index_t
-offset(const _dim3pos& pos, const _dim3dim& dim) {
- return (nvfuser_index_t)pos.x +
- (nvfuser_index_t)pos.y * (nvfuser_index_t)dim.x +
- (nvfuser_index_t)pos.z * (nvfuser_index_t)dim.x * (nvfuser_index_t)dim.y;
-}
-
-#define ioffset(pos, dim) \
- ((pos).x + (pos).y * (dim).x + (pos).z * (dim).x * (dim).y)
-
-// Returns dim3 of each reduction segment.
-template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
-__host__ __device__ dim3 dimension_of_reduction_segment(const _dim3& grid_dim) {
- return dim3{
- X_BLOCK ? grid_dim.x : 1,
- Y_BLOCK ? grid_dim.y : 1,
- Z_BLOCK ? grid_dim.z : 1};
-}
-
-// Returns the number of blocks in each reduction segment.
-template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
-__host__ __device__ nvfuser_index_t
-size_of_reduction_segment(const _dim3& grid_dim) {
- return size(
- dimension_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(grid_dim));
-}
-
-// Returns the total number of reduction segments.
-template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
-__host__ __device__ nvfuser_index_t
-number_of_reduction_segments(const _dim3& grid_dim) {
- return (X_BLOCK ? 1 : grid_dim.x) * (Y_BLOCK ? 1 : grid_dim.y) *
- (Z_BLOCK ? 1 : grid_dim.z);
-}
-
-// Returns the 1-D index of the segment of thread block of block_idx.
-template <
- bool X_BLOCK,
- bool Y_BLOCK,
- bool Z_BLOCK,
- typename _dim3bi,
- typename _dim3gd>
-__host__ __device__ nvfuser_index_t
-index_of_reduction_segment(const _dim3bi& block_idx, const _dim3gd& grid_dim) {
- nvfuser_index_t seg_idx = 0;
- if (!Z_BLOCK)
- seg_idx += block_idx.z;
- if (!Y_BLOCK)
- seg_idx = seg_idx * grid_dim.y + block_idx.y;
- if (!X_BLOCK)
- seg_idx = seg_idx * grid_dim.x + block_idx.x;
- return seg_idx;
-}
-
-// Returns the offset of thread block in its reduction segment.
-template <
- bool X_BLOCK,
- bool Y_BLOCK,
- bool Z_BLOCK,
- typename _dim3bi,
- typename _dim3gd>
-__host__ __device__ nvfuser_index_t
-offset_in_reduction_segment(const _dim3bi& block_idx, const _dim3gd& grid_dim) {
- nvfuser_index_t offset = 0;
- if (Z_BLOCK)
- offset = offset * grid_dim.z + block_idx.z;
- if (Y_BLOCK)
- offset = offset * grid_dim.y + block_idx.y;
- if (X_BLOCK)
- offset = offset * grid_dim.x + block_idx.x;
- return offset;
-}
-
-// Returns dim3 of each reduction block.
-template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename _dim3>
-__host__ __device__ dim3 dimension_of_reduction_block(const _dim3& block_dim) {
- return dim3{
- X_THREAD ? block_dim.x : 1,
- Y_THREAD ? block_dim.y : 1,
- Z_THREAD ? block_dim.z : 1};
-}
-
-// Returns the number of threads of each reduction block.
-template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename _dim3>
-__host__ __device__ int size_of_reduction_block(const _dim3& block_dim) {
- auto tmp_dim =
- dimension_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(block_dim);
- return isize(tmp_dim);
-}
-
-// Returns the linear offset of a thread in a reduction block.
-template <
- bool X_THREAD,
- bool Y_THREAD,
- bool Z_THREAD,
- typename _dim3ti,
- typename _dim3bd>
-__host__ __device__ int offset_in_reduction_block(
- const _dim3ti& thread_idx,
- const _dim3bd& block_dim) {
- int offset = 0;
- if (Z_THREAD)
- offset += thread_idx.z;
- if (Y_THREAD)
- offset = offset * block_dim.y + thread_idx.y;
- if (X_THREAD)
- offset = offset * block_dim.x + thread_idx.x;
- return offset;
-}
-
-template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename T, typename TN>
-__device__ void gridWelfordLastBlock(
- T& out_avg,
- T& out_M2,
- TN& out_N,
- const T* in_avg,
- const T* in_M2,
- const TN* in_N,
- const nvfuser_index_t in_size,
- T* shared_buf_avg,
- T* shared_buf_M2,
- TN* shared_buf_N,
- bool write_pred,
- T init_val) {
- const int tid = ioffset(threadIdx, blockDim);
- const int block_size = isize(blockDim);
- const int rblock_size =
- size_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(blockDim);
-
- T inp_avg = init_val;
- T inp_M2 = init_val;
- TN inp_N = 0;
- if (tid < in_size) {
- inp_avg = in_avg[tid];
- inp_M2 = in_M2[tid];
- inp_N = in_N[tid];
- }
- for (nvfuser_index_t i = tid + block_size; i < in_size; i += block_size) {
- welfordCombine(inp_avg, inp_M2, inp_N, in_avg[i], in_M2[i], in_N[i]);
- }
- const auto should_write = (X_THREAD || threadIdx.x == 0) &&
- (Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0);
-
- auto rem_size = block_size / rblock_size;
-
- if (rem_size > 1) {
- const int rblock_offset = tid % rblock_size;
- const int rblock_idx = tid / rblock_size;
- T inp_avg_tmp = init_val;
- T inp_M2_tmp = init_val;
- TN inp_N_tmp = 0;
- blockWelford<false, true, false>(
- inp_avg_tmp,
- inp_M2_tmp,
- inp_N_tmp,
- inp_avg,
- inp_M2,
- inp_N,
- dim3{(unsigned)rblock_offset, (unsigned)rblock_idx, 0},
- dim3{(unsigned)rblock_size, (unsigned)rem_size},
- shared_buf_avg,
- shared_buf_M2,
- shared_buf_N,
- true,
- init_val);
- block_sync::sync();
- if (tid < rblock_size) {
- shared_buf_avg[tid] = inp_avg_tmp;
- shared_buf_M2[tid] = inp_M2_tmp;
- shared_buf_N[tid] = inp_N_tmp;
- }
- block_sync::sync();
- if (should_write) {
- nvfuser_index_t offset_write =
- offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(
- threadIdx, blockDim);
- inp_avg = shared_buf_avg[offset_write];
- inp_M2 = shared_buf_M2[offset_write];
- inp_N = shared_buf_N[offset_write];
- }
- }
-
- if (should_write && write_pred) {
- welfordCombine(out_avg, out_M2, out_N, inp_avg, inp_M2, inp_N);
- }
-}
-
-// Grid welford combine
-template <
- bool X_BLOCK,
- bool Y_BLOCK,
- bool Z_BLOCK,
- bool X_THREAD,
- bool Y_THREAD,
- bool Z_THREAD,
- typename T,
- typename TN>
-__device__ bool gridWelford(
- T& out_avg,
- T& out_M2,
- TN& out_N,
- const T& inp_avg,
- const T& inp_M2,
- const TN& inp_N,
- volatile T* work_buf_avg,
- volatile T* work_buf_M2,
- volatile TN* work_buf_N,
- Tensor<int64_t, 1> sync_flags,
- T* shared_buf_avg,
- T* shared_buf_M2,
- TN* shared_buf_N,
- bool read_pred,
- bool write_pred,
- T init_val) {
- // Number of values to reduce in the grid dimensions
- const auto seg_size =
- size_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim);
-
- // Index of the reduction we're performing out of the seg_size
- const auto seg_idx =
- index_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
-
- // Number of threads we can use in final reduction, Seems to assume all
- // threads in the block participate
- const auto rblock_size =
- size_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(blockDim);
-
- work_buf_avg += seg_idx * seg_size * rblock_size;
- work_buf_M2 += seg_idx * seg_size * rblock_size;
- work_buf_N += seg_idx * seg_size * rblock_size;
-
- if ((X_THREAD || threadIdx.x == 0) && (Y_THREAD || threadIdx.y == 0) &&
- (Z_THREAD || threadIdx.z == 0)) {
- auto rblock_offset = offset_in_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(
- blockIdx, gridDim);
- auto thread_offset =
- offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(
- threadIdx, blockDim);
- auto work_buf_offset = rblock_size * rblock_offset + thread_offset;
- if (read_pred) {
- work_buf_avg[work_buf_offset] = inp_avg;
- work_buf_M2[work_buf_offset] = inp_M2;
- work_buf_N[work_buf_offset] = inp_N;
- } else {
- work_buf_avg[work_buf_offset] = init_val;
- work_buf_M2[work_buf_offset] = init_val;
- work_buf_N[work_buf_offset] = 0;
- }
- }
- block_sync::sync();
-
- __shared__ bool last_block;
- if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
- __threadfence();
- auto old = (int64_t)atomicAdd((unsigned long long*)&sync_flags[seg_idx], 1);
- last_block = old + 1 == seg_size;
- }
- block_sync::sync();
-
- if (last_block) {
- // final reduction
- gridWelfordLastBlock<X_THREAD, Y_THREAD, Z_THREAD>(
- out_avg,
- out_M2,
- out_N,
- (T*)work_buf_avg,
- (T*)work_buf_M2,
- (TN*)work_buf_N,
- seg_size * rblock_size,
- shared_buf_avg,
- shared_buf_M2,
- shared_buf_N,
- write_pred,
- init_val);
- return true;
- } else {
- return false;
- }
-}
-} // namespace welford
-
-#undef isize
-#undef ioffset
--- /dev/null
+#include <torch/csrc/jit/codegen/cuda/scheduler.h>
+
+#include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
+#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
+#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
+#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
+#include <torch/csrc/jit/codegen/cuda/parser.h>
+
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/util/irange.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
+
+constexpr int kUnrollFactor = 1;
+
+namespace {
+
+std::vector<int> reductionAxes(TensorView* tv) {
+ size_t n_dims = tv->nDims();
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ std::vector<int> reduction_axes;
+ for (const auto i : c10::irange(n_dims)) {
+ if (tv->axis(i)->isReduction()) {
+ reduction_axes.emplace_back(i);
+ }
+ }
+ return reduction_axes;
+}
+
+// Merge all reduction to the right side and returns total number of
+// reduction axes
+size_t mergeReduction(TensorView* tv) {
+ int prev_i = -1;
+ size_t num_merged = 0;
+ for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
+ if (!tv->axis(i)->isReduction()) {
+ continue;
+ }
+ if (prev_i == -1) {
+ prev_i = i;
+ } else {
+ tv->merge(i, prev_i);
+ prev_i = i;
+ num_merged++;
+ }
+ }
+ if (prev_i == 0) {
+ tv->reorder({{prev_i, -1}});
+ }
+
+ return prev_i == -1 ? 0 : num_merged + 1;
+}
+
+// merge all non-reduction axes to the left side and returns total number of
+// iteration axes
+size_t mergeNonReduction(TensorView* tv) {
+ int prev_i = -1;
+ size_t num_merged = 0;
+ for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
+ if (tv->axis(i)->isReduction()) {
+ continue;
+ }
+ if (prev_i == -1) {
+ prev_i = i;
+ } else {
+ tv->merge(i, prev_i);
+ prev_i = i;
+ num_merged++;
+ }
+ }
+ if (prev_i != 0) {
+ tv->reorder({{prev_i, 0}});
+ }
+
+ return prev_i == -1 ? 0 : num_merged + 1;
+}
+
+} // namespace
+
+// This one is a total mess and it should go.
+bool scheduleFusion(Fusion* fusion, const at::ArrayRef<c10::IValue> inputs) {
+ FUSER_PERF_SCOPE("scheduleFusion");
+
+ FusionGuard fg(fusion);
+ // maybe has_reduction for scheudling should be done on a per output tensor
+ // basis.
+ TORCH_INTERNAL_ASSERT(
+ !fusion->hasReduction(), "This scheduler only handles pointwise ops.");
+ const bool disable_unroll = fusion->isStochastic();
+
+ for (auto out_val : fusion->outputs()) {
+ auto out = out_val->as<TensorView>();
+
+ // Merge all dimensions because we're only supporting pointwise
+ while (out->nDims() > 1) {
+ out->merge(-2, -1);
+ }
+ }
+
+ // Run through outputs, grab all inputs of outputs
+ // squeeze with computeAt to set overall structure.
+ for (auto output : fusion->outputs()) {
+ if (output->getValType() != ValType::TensorView)
+ continue;
+ TensorView* out_tv = output->as<TensorView>();
+
+ // Split into 128 which will be bockDim.x
+ out_tv->split(0, kPwThreadX);
+ // Split by another 4 which will be our unroll factor
+ auto ur_factor = disable_unroll ? 1 : kUnrollFactor;
+ out_tv->split(0, ur_factor);
+ }
+
+ for (auto output : fusion->outputs()) {
+ if (output->getValType() != ValType::TensorView)
+ continue;
+ TensorView* out_tv = output->as<TensorView>();
+ for (Val* inp : fusion->inputsOf(output)) {
+ if (inp->getValType().value() == ValType::TensorView)
+ inp->as<TensorView>()->computeAt(out_tv, -1);
+ }
+ out_tv->axis(0)->parallelize(ParallelType::BIDx);
+ out_tv->axis(1)->parallelize(ParallelType::Unroll);
+ out_tv->axis(2)->parallelize(ParallelType::TIDx);
+ }
+
+ return true;
+}
+
+namespace {
+// Largest Power of 2 less-than n
+constexpr int lastPow2(int n) {
+ n |= (n >> 1);
+ n |= (n >> 2);
+ n |= (n >> 4);
+ n |= (n >> 8); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
+ n |= (n >> 16); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
+ return std::max(1, n - (n >> 1));
+}
+
+ReductionParams reductionHeuristic(
+ int red_elems,
+ int red_outputs,
+ bool red_on_fastest_dim) {
+ ReductionParams rparams;
+ rparams.fastest_dim = red_on_fastest_dim;
+
+ int gdimx = LaunchParams::UNINITIALIZED_VAL;
+ int gdimy = LaunchParams::UNINITIALIZED_VAL;
+ int bdimx = LaunchParams::UNINITIALIZED_VAL;
+ int bdimy = LaunchParams::UNINITIALIZED_VAL;
+
+ // 1. Initial Assumptions
+
+ // Evaluate Dimensions of Reduction TensorView
+ TORCH_INTERNAL_ASSERT(red_elems > 0 && red_outputs > 0);
+
+ // 2. Initial Definition of Block Dimensions
+
+ // Is fastest dimension a reduction dimension?
+ if (rparams.fastest_dim) {
+ if (red_elems < rparams.loop_unroll) {
+ rparams.loop_unroll = 1;
+ }
+ bdimx = ceilDiv(red_elems, rparams.loop_unroll);
+ bdimy = red_outputs;
+ } else {
+ bdimx = red_outputs;
+ bdimy = red_elems;
+ }
+
+ // 3. Applying Power of 2 Blocking based on the Maximum Number of threads
+
+ constexpr int kMaxNumThreads = 512;
+ int num_threads = kMaxNumThreads;
+ int device_warp_size = at::cuda::warp_size();
+
+ if (bdimx < num_threads) {
+ bdimx = lastPow2(bdimx);
+ } else {
+ bdimx = num_threads;
+ }
+
+ if (bdimy < num_threads) {
+ bdimy = lastPow2(bdimy);
+ } else {
+ bdimy = num_threads;
+ }
+
+ int bdimx_prev = bdimx;
+ bdimx = std::min(bdimx, device_warp_size);
+ bdimy = std::min(bdimy, num_threads / bdimx);
+ bdimx = std::min(bdimx_prev, num_threads / bdimy);
+
+ // 4. Distributing work across a block
+
+ // Magic numbers of calculations allowed per thread.
+ constexpr int kMinValuesPerThread = 16;
+ constexpr int kMaxValuesPerThread = 256;
+
+ int inputs_consumed_per_block_iter = 1;
+ int red_elems_per_thread = red_elems;
+
+ int outputs_produced_per_block_iter = 1;
+
+ // Reduction is performed across warp threads (cross-thread reduction)
+ if (rparams.fastest_dim) {
+ inputs_consumed_per_block_iter *= bdimx;
+ red_elems_per_thread =
+ ceilDiv(red_elems_per_thread, inputs_consumed_per_block_iter);
+ // Warp threads are applied across the output
+ } else {
+ outputs_produced_per_block_iter *= bdimx;
+ }
+
+ // Decision to do a cross-warp reduction per block
+ if (red_elems_per_thread >= (bdimy * kMinValuesPerThread) ||
+ red_elems_per_thread >= kMaxValuesPerThread || !rparams.fastest_dim) {
+ // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
+ inputs_consumed_per_block_iter *= bdimy;
+ red_elems_per_thread = ceilDiv(red_elems_per_thread, bdimy);
+ rparams.cross_block = true;
+ rparams.mul_reds_per_blk = false;
+ // Do multiple reductions per block
+ } else {
+ rparams.cross_block = false;
+ rparams.mul_reds_per_blk = true;
+ outputs_produced_per_block_iter *= bdimy;
+ }
+
+ // 5. Distributing work across blocks
+
+ // WARNING: Current device for codegen may not be the target device
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ int device_max_threads_per_multiprocessor =
+ at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor;
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ int device_multiprocessor_count =
+ at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
+
+ int blocks_per_sm = device_max_threads_per_multiprocessor / (bdimx * bdimy);
+ int target_grid_size = device_multiprocessor_count * blocks_per_sm;
+
+ // Setting the number of blocks based on the number of outputs
+ gdimx = ceilDiv(red_outputs, outputs_produced_per_block_iter);
+
+ // Cross-block reductions (if necessary)
+ if (rparams.cross_block && red_elems_per_thread >= kMaxValuesPerThread &&
+ gdimx <= target_grid_size) {
+ int blks_per_out_1 = ceilDiv(target_grid_size, gdimx);
+ int blks_per_out_2 = ceilDiv(red_elems_per_thread, kMinValuesPerThread);
+ int blks_per_out_3 = ceilDiv(red_elems_per_thread, kMaxValuesPerThread);
+ int blks_per_output =
+ std::max(std::min(blks_per_out_1, blks_per_out_2), blks_per_out_3);
+
+ gdimy = std::max(1, blks_per_output);
+ // If a cross-block reduction was generated
+ if (blks_per_output > 1) {
+ rparams.cross_grid = true;
+ }
+ }
+
+ const char* debug_env = getenv("PYTORCH_CUDA_FUSER_RED_SCHED_DEBUG");
+ if (debug_env && atoi(debug_env)) {
+ std::cout << "\n===== Reduction Parameters ========" << std::endl
+ << "Inputs:" << std::endl
+ << "\tRed Elems: " << red_elems << " Red Outputs: " << red_outputs
+ << " Red On Fastest Dim? " << red_on_fastest_dim << std::endl
+ << "Reduction Characteristics:" << std::endl
+ << "\tMultiple Reds Per Block? " << rparams.mul_reds_per_blk
+ << " Cross Block? " << rparams.cross_block << " Cross Grid? "
+ << rparams.cross_grid << std::endl
+ << "Recommended Blocking:" << std::endl
+ << "\tGridX: " << gdimx << " GridY: " << gdimy
+ << " BlckX: " << bdimx << " BlckY: " << bdimy << std::endl
+ << "====================================" << std::endl;
+ }
+
+ rparams.lparams = LaunchParams(
+ LaunchParams::UNINITIALIZED_VAL,
+ gdimy,
+ LaunchParams::UNINITIALIZED_VAL,
+ bdimx,
+ bdimy,
+ LaunchParams::UNINITIALIZED_VAL);
+ return rparams;
+}
+} // anonymous namespace
+
+TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
+ Fusion* fusion,
+ const at::ArrayRef<c10::IValue>& fusion_inputs,
+ TensorView* red_tv) {
+ FUSER_PERF_SCOPE("scheduleReduction");
+
+ FusionGuard fg(fusion);
+
+ if (!fusion->hasReduction()) {
+ return c10::nullopt;
+ }
+
+ auto red_root_dom = red_tv->getRootDomain();
+ const bool red_on_fastest_dim =
+ red_root_dom[red_root_dom.size() - 1]->isReduction();
+
+ TORCH_INTERNAL_ASSERT(
+ red_tv != nullptr, "Reduction TensorView wasn't found.");
+
+ if (!fusion->hasReduction()) {
+ return c10::nullopt;
+ }
+
+ TORCH_INTERNAL_ASSERT(
+ red_tv->hasReduction(), "TensorView doesn't have a reduction.");
+ const auto red_expr = fusion->origin(red_tv);
+
+ TORCH_INTERNAL_ASSERT(
+ red_expr->getExprType() != c10::nullopt &&
+ red_expr->getExprType().value() == ExprType::ReductionOp,
+ "TensorView doesn't have a reduction.");
+
+ StatefulExpressionEvaluator evaluator(
+ executor_utils::statefulBindInputs(fusion_inputs, fusion));
+
+ int64_t red_outputs = 1;
+ int64_t red_elements = 1;
+
+ for (auto id : red_tv->getRootDomain()) {
+ auto inferred_val = evaluator.inferValue(id->rawExtent());
+ TORCH_INTERNAL_ASSERT(
+ inferred_val.has_value(), "Error inferring reduction size.");
+ if (id->isReduction()) {
+ red_elements *= inferred_val.value();
+ } else {
+ red_outputs *= inferred_val.value();
+ }
+ }
+
+ return reductionHeuristic(red_elements, red_outputs, red_on_fastest_dim);
+}
+
+// fusion is the input IR that will be modified by this function
+void scheduleReduction(
+ Fusion* fusion,
+ const ReductionParams& rparams,
+ TensorView* red_tv,
+ std::vector<TensorView*> outs_of_red) {
+ FusionGuard fg(fusion);
+
+ // We coalesc all reduction axes to the right;
+ mergeReduction(red_tv);
+
+ // Merge all iteration dimensions
+ mergeNonReduction(red_tv);
+ for (auto iter_tv : outs_of_red) {
+ mergeNonReduction(iter_tv);
+ }
+
+ // Evaluate Dimensions of Reduction TensorView
+ auto red_ids = red_tv->domain()->domain();
+
+ TORCH_INTERNAL_ASSERT(
+ red_ids.size() == 2, "We coalesced all dimensions into 2 previously.");
+
+ constexpr int kLoopUnrollSplit = 4;
+
+ // Scheduling the Reduction
+ if (rparams.fastest_dim) {
+ // Do multiple reductions per block
+ if (rparams.mul_reds_per_blk) {
+ // Reduction Splits
+ // [outputs, |rF-Leftover, X-Warp, rf-Unroll|]
+ // Idx: 0 | 1(-1) 2(-2) 3(-1) |
+ // --------------------------------
+ // Reduction Dimensions
+ red_tv->split(1, rparams.loop_unroll);
+ red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
+
+ // Output Splits
+ // [|Out-Leftover, Out-PerBlock|, <Reduction Dims>]
+ // Idx: | 0 1 | 2(-2) -- 3(-1)
+ // ----------------------------
+ // Output Dimensions
+ red_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDy));
+ for (auto iter_tv : outs_of_red) {
+ iter_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDy));
+ }
+
+ auto red_tv_rf = red_tv->rFactor({-3, -1});
+
+ // WARNING: computeAt will coalesce the rFactored dimensions
+ // rFactored Reduction Tensor after computeAt():
+ // [<output dims>, | rF-Leftover, X-Warp, rF-Unroll|]
+ // Idx: 0 -- 1 | 2(-3) 3(-2) 4(-1) |
+ // ---------------------------------
+ // Reduction Dimensions
+ red_tv_rf->computeAt(red_tv, -1);
+
+ // After the Reduction Tensor has rFactoring applied
+ // Reduction Output Tensor:
+ // [Out-Leftover, Out-PerBlock, X-Warp]
+ // Idx: 0 1 2(-1)
+ if (!outs_of_red.empty()) {
+ red_tv->computeAt(outs_of_red[0], -1);
+ }
+
+ red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
+
+ red_tv->axis(0)->parallelize(ParallelType::BIDx);
+ for (auto iter_tv : outs_of_red) {
+ iter_tv->axis(0)->parallelize(ParallelType::BIDx);
+ }
+ red_tv->axis(1)->parallelize(ParallelType::TIDy);
+ for (auto iter_tv : outs_of_red) {
+ iter_tv->axis(1)->parallelize(ParallelType::TIDy);
+ }
+ red_tv->axis(-1)->parallelize(ParallelType::TIDx);
+
+ // Bind Inputs to Reduction
+ for (auto input : fusion->inputsOf(red_tv_rf)) {
+ if (input->getValType().value() == ValType::TensorView) {
+ input->as<TensorView>()->computeAt(red_tv_rf, -1);
+ }
+ }
+ // Do a cross-warp reduction per block
+ } else {
+ if (rparams.cross_grid) {
+ // Reduction Splits
+ // [outputs, |rF-Leftover, X-Grid, X-Block, X-Warp, rf-Unroll|]
+ // Idx: 0 | 1(-5) 2(-4) 3(-3) 4(-2) 5(-1) |
+ // -------------------------------------------------
+ // Reduction Dimensions
+ red_tv->split(1, rparams.loop_unroll);
+ red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
+ red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
+ red_tv->split(1, NamedScalar::getParallelDim(ParallelType::BIDy));
+
+ auto red_tv_rf = red_tv->rFactor(
+ {-5, -1}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
+
+ // WARNING: computeAt will coalesce the rFactored dimensions
+ // rFactored Reduction Tensor after computeAt():
+ // [Outputs, |X-Grid, X-Block, X-Warp, rF-Leftover, rF-Unroll|]
+ // Idx: 0 | 1(-5) 2(-4) 3(-3) 4(-2) 5(-1) |
+ // -------------------------------------------------
+ // Reduction Dimensions
+ red_tv_rf->computeAt(red_tv, -1);
+
+ // After the Reduction Tensor has rFactoring applied
+ // Reduction Output Tensor:
+ // [Outputs, X-Grid, X-Block, X-Warp]
+ // Idx: 0 1(-3) 2(-2) 3(-1)
+
+ if (!outs_of_red.empty()) {
+ red_tv->computeAt(outs_of_red[0], -1);
+ }
+
+ red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
+
+ red_tv->axis(0)->parallelize(ParallelType::BIDx);
+ for (auto iter_tv : outs_of_red) {
+ iter_tv->axis(0)->parallelize(ParallelType::BIDx);
+ }
+ red_tv->axis(-1)->parallelize(ParallelType::TIDx);
+ red_tv->axis(-2)->parallelize(ParallelType::TIDy);
+ red_tv->axis(-3)->parallelize(ParallelType::BIDy);
+
+ // Bind Inputs to Reduction
+ for (auto input : fusion->inputsOf(red_tv_rf)) {
+ if (input->getValType().value() == ValType::TensorView) {
+ input->as<TensorView>()->computeAt(red_tv_rf, -1);
+ }
+ }
+ } else {
+ // Reduction Splits
+ // [outputs, |rF-Leftover, X-Block, X-Warp, rf-Unroll|]
+ // Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) |
+ // -----------------------------------------
+ // Reduction Dimensions
+ red_tv->split(1, rparams.loop_unroll);
+ red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
+ red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
+
+ auto red_tv_rf = red_tv->rFactor({-4, -1});
+
+ // WARNING: computeAt will coalesce the rFactored dimensions
+ // rFactored Reduction Tensor after computeAt():
+ // [Outputs, |X-Block, X-Warp, rF-Leftover, rF-Unroll|]
+ // Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) |
+ // -----------------------------------------
+ // Reduction Dimensions
+ red_tv_rf->computeAt(red_tv, -1);
+
+ // After the Reduction Tensor has rFactoring applied
+ // Reduction Output Tensor:
+ // [Outputs, X-Block, X-Warp]
+ // Idx: 0 1(-2) 2(-1)
+
+ if (!outs_of_red.empty()) {
+ red_tv->computeAt(outs_of_red[0], -1);
+ }
+
+ red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
+
+ red_tv->axis(0)->parallelize(ParallelType::BIDx);
+ for (auto iter_tv : outs_of_red) {
+ iter_tv->axis(0)->parallelize(ParallelType::BIDx);
+ }
+ red_tv->axis(-1)->parallelize(ParallelType::TIDx);
+ red_tv->axis(-2)->parallelize(ParallelType::TIDy);
+
+ // Bind Inputs to Reduction
+ for (auto input : fusion->inputsOf(red_tv_rf)) {
+ if (input->getValType().value() == ValType::TensorView) {
+ input->as<TensorView>()->computeAt(red_tv_rf, -1);
+ }
+ }
+ }
+ }
+ } else {
+ if (rparams.cross_block) {
+ if (rparams.cross_grid) {
+ // Reduction Splits
+ // [outputs, |rF-Leftover, rf-Unroll, X-Grid, X-Block|]
+ // Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) |
+ // -----------------------------------------
+ // Reduction Dimensions
+ red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
+ red_tv->split(1, NamedScalar::getParallelDim(ParallelType::BIDy));
+ red_tv->split(1, kLoopUnrollSplit);
+
+ // Reordering the Unroll dimension eases applying computeAt()
+ // for preceeding operations and the rFactored Tensor.
+ // |--- Reordered ----|
+ // V V
+ // [outputs, |rF-Leftover, X-Block, X-Grid, rF-Unroll|]
+ // Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) |
+ // -----------------------------------------
+ // Reduction Dimensions
+ red_tv->reorder({{-1, -3}, {-3, -1}});
+
+ // Output Splits
+ // [|Out-Leftover, Out-PerBlock|, <Reduction Dims>]
+ // Idx: | 0 1 | 2(-4) -- 5(-1)
+ // ----------------------------
+ // Output Dimensions
+ red_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
+ for (auto iter_tv : outs_of_red) {
+ iter_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
+ }
+
+ auto red_tv_rf = red_tv->rFactor({-4, -1});
+
+ // WARNING: computeAt will coalesce the rFactored dimensions
+ // rFactored Reduction Tensor after computeAt():
+ // [<output dims>, |X-Block, X-Grid, rF-Leftover, rF-Unroll|]
+ // Idx: 0 -- 1 | 2(-4) 3(-3) 4(-2) 5(-1) |
+ // -----------------------------------------
+ // Reduction Dimensions
+ red_tv_rf->computeAt(red_tv, -1);
+
+ // After the Reduction Tensor has rFactoring applied
+ // Reduction Output Tensor:
+ // [Out-Leftover, Out-PerBlock, X-Block, X-Grid]
+ // Idx: 0 1 2(-2) 3(-1)
+
+ if (!outs_of_red.empty()) {
+ red_tv->computeAt(outs_of_red[0], -1);
+ }
+
+ red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
+
+ red_tv->axis(0)->parallelize(ParallelType::BIDx);
+ for (auto iter_tv : outs_of_red) {
+ iter_tv->axis(0)->parallelize(ParallelType::BIDx);
+ iter_tv->axis(1)->parallelize(ParallelType::TIDx);
+ }
+
+ red_tv->axis(-3)->parallelize(ParallelType::TIDx);
+ red_tv->axis(-2)->parallelize(ParallelType::TIDy);
+ red_tv->axis(-1)->parallelize(ParallelType::BIDy);
+
+ // Bind Inputs to Reduction
+ for (auto input : fusion->inputsOf(red_tv_rf)) {
+ if (input->getValType().value() == ValType::TensorView) {
+ input->as<TensorView>()->computeAt(red_tv_rf, -1);
+ }
+ }
+ } else {
+ // Reduction Splits
+ // [outputs, |rF-Leftover, rf-Unroll, X-Block|]
+ // Idx: 0 | 1(-3) 2(-2) 3(-1) |
+ // ---------------------------------
+ // Reduction Dimensions
+ red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
+ red_tv->split(1, kLoopUnrollSplit);
+
+ // Reordering the Unroll dimension eases applying computeAt()
+ // for preceeding operations and the rFactored Tensor.
+ // |- Reordered -|
+ // V V
+ // [outputs, |rF-Leftover, X-Block, rF-Unroll|]
+ // Idx: 0 | 1(-3) 2(-2) 3(-1) |
+ // ---------------------------------
+ // Reduction Dimensions
+ red_tv->reorder({{-1, -2}, {-2, -1}});
+
+ // Output Splits
+ // [|Out-Leftover, Out-PerBlock|, <Reduction Dims>]
+ // Idx: | 0 1 | 2(-3) -- 4(-1)
+ // ----------------------------
+ // Output Dimensions
+ red_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
+ for (auto iter_tv : outs_of_red) {
+ iter_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
+ }
+
+ auto red_tv_rf = red_tv->rFactor({-3, -1});
+
+ // WARNING: computeAt will coalesce the rFactored dimensions
+ // rFactored Reduction Tensor after computeAt():
+ // [<output dims>, |X-Block, rF-Leftover, rF-Unroll|]
+ // Idx: 0 -- 1 | 2(-3) 3(-2) 4(-1) |
+ // ---------------------------------
+ // Reduction Dimensions
+ red_tv_rf->computeAt(red_tv, -1);
+
+ // After the Reduction Tensor has rFactoring applied
+ // Reduction Output Tensor:
+ // [Out-Leftover, Out-PerBlock, X-Block]
+ // Idx: 0 1 2(-1)
+
+ if (!outs_of_red.empty()) {
+ red_tv->computeAt(outs_of_red[0], -1);
+ }
+
+ red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
+
+ red_tv->axis(0)->parallelize(ParallelType::BIDx);
+ for (auto iter_tv : outs_of_red) {
+ iter_tv->axis(0)->parallelize(ParallelType::BIDx);
+ iter_tv->axis(1)->parallelize(ParallelType::TIDx);
+ }
+ red_tv->axis(-2)->parallelize(ParallelType::TIDx);
+ red_tv->axis(-1)->parallelize(ParallelType::TIDy);
+
+ // Bind Inputs to Reduction
+ for (auto input : fusion->inputsOf(red_tv_rf)) {
+ if (input->getValType().value() == ValType::TensorView) {
+ input->as<TensorView>()->computeAt(red_tv_rf, -1);
+ }
+ }
+ }
+ } else {
+ red_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
+ for (auto iter_tv : outs_of_red) {
+ iter_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
+ }
+
+ if (!outs_of_red.empty()) {
+ red_tv->computeAt(outs_of_red[0], -1);
+ }
+
+ red_tv->axis(0)->parallelize(ParallelType::BIDx);
+ red_tv->axis(1)->parallelize(ParallelType::TIDx);
+ for (auto iter_tv : outs_of_red) {
+ iter_tv->axis(0)->parallelize(ParallelType::BIDx);
+ iter_tv->axis(1)->parallelize(ParallelType::TIDx);
+ }
+
+ for (auto input : fusion->inputsOf(red_tv)) {
+ if (input->getValType().value() == ValType::TensorView) {
+ input->as<TensorView>()->computeAt(red_tv, -1);
+ }
+ }
+ }
+ }
+}
+
+} // namespace cuda
+} // namespace fuser
+} // namespace jit
+} // namespace torch
--- /dev/null
+#pragma once
+
+#include <ATen/core/ivalue.h>
+#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
+#include <torch/csrc/jit/codegen/cuda/fusion.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
+
+// return true or false on whether given fusion could be scheduled;
+TORCH_CUDA_CU_API bool scheduleFusion(
+ Fusion* fusion,
+ const at::ArrayRef<c10::IValue> inputs);
+
+// Parameters the Reduction Heuristic Generates to describe the optimial
+// schedule. Warning: equal operator is intended for use in caching the kernel
+// associated with these reduction parameteres. It does not check if the launch
+// parameters are equivelent!
+struct ReductionParams {
+ // Reducing inner most dimension?
+ bool fastest_dim = true;
+ // Reduce across the block?
+ bool cross_block = false;
+ // Reduce across the grid?
+ bool cross_grid = false;
+ // Perform multiple reductions per block?
+ bool mul_reds_per_blk = false;
+ // Unrolling factor
+ int loop_unroll = 4;
+
+ LaunchParams lparams;
+
+ // Warning: Does not check launch parameters!
+ bool operator==(const ReductionParams& other) const {
+ bool attr_equal = other.fastest_dim == fastest_dim &&
+ other.cross_block == cross_block && other.cross_grid == cross_grid &&
+ other.mul_reds_per_blk == mul_reds_per_blk &&
+ other.loop_unroll == loop_unroll;
+ return attr_equal;
+ }
+};
+
+// Warning: Hash is not based on launch parameters!
+class ReductionParamsHash {
+ public:
+ size_t operator()(const ReductionParams& rp) const {
+ constexpr size_t bits = sizeof(std::size_t) * 8;
+ size_t attr_hash = static_cast<size_t>(rp.fastest_dim) << (bits - 1) |
+ static_cast<size_t>(rp.cross_block) << (bits - 2) |
+ static_cast<size_t>(rp.cross_grid) << (bits - 3) |
+ static_cast<size_t>(rp.mul_reds_per_blk) << (bits - 4);
+ return attr_hash;
+ }
+};
+
+TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
+ Fusion* fusion,
+ const at::ArrayRef<c10::IValue>& fusion_inputs,
+ TensorView* red_tv);
+
+TORCH_CUDA_CU_API void scheduleReduction(
+ Fusion* fusion,
+ const ReductionParams& rparams,
+ TensorView* red_tv,
+ std::vector<TensorView*> outs_of_red);
+
+} // namespace cuda
+} // namespace fuser
+} // namespace jit
+} // namespace torch
+++ /dev/null
-#pragma once
-#include <torch/csrc/jit/codegen/cuda/scheduler/normalization.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/reduction.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-enum class TORCH_CUDA_CU_API ScheduleHeuristic {
- PointWise,
- Reduction,
- Normalization
-};
-
-}
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/scheduler/reduction.h>
-
-#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-
-#include <ATen/cuda/CUDAContext.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-// Copied from reduction scheduler, should generalize. Simply needed to take out
-// grid reductions.
-ReductionParams innerNormalizationHeuristic(
- const int64_t num_elems_in_reduction,
- const int64_t num_outputs_for_reduction,
- const int64_t n_tensor_inputs,
- const int64_t max_input_dtype_size,
- bool persistence_required,
- const int64_t max_persistent_buffer_size,
- size_t vectorize_factor) {
- // Set some targets for parallelization
- const int64_t n_elems = num_elems_in_reduction * num_outputs_for_reduction;
-
- // WARNING: Current device for codegen may not be the target device
- const int64_t device_max_threads_per_multiprocessor =
- (int64_t)at::cuda::getCurrentDeviceProperties()
- ->maxThreadsPerMultiProcessor;
-
- const int64_t device_multiprocessor_count =
- (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
-
- auto const max_unroll = ceilDiv(
- // Available unrolling based on size of data type
- (int64_t)16 / (int64_t)max_input_dtype_size,
- // Reduce unrolling if we have many inputs, start reduction at 4 inputs
- std::max(
- (scheduler_utils::lastPow2((int64_t)n_tensor_inputs - 1) >> 1),
- (int64_t)1));
-
- // Conservative value, could be set to larger based on arch if necessary.
- constexpr int64_t l1_cache = 32 * 1024;
- // Could change per generation, but for l1 we want to consider active threads,
- // not resident
- constexpr int64_t active_threads = 1024;
-
- // if data fits in l2 and we need more parallelization in the reduction dim,
- // we can use a smaller warp size. While thread local data fits in l1, and
- // reduction dim is really small, we can use <32 threads per warp.
- const bool fits_in_l2 = n_elems * max_input_dtype_size * n_tensor_inputs <
- at::cuda::getCurrentDeviceProperties()->l2CacheSize;
-
- // If it fits in l2, we just want to make sure each warp uses 32Bytes. Set
- // minimum warp as 16 threads instead of 32 as if we have a small reduction
- // dim going a bit smaller than 32 usually helps.
- const int64_t warp_size_based_on_l2 =
- fits_in_l2 ? (int64_t)32 / max_input_dtype_size : 16;
-
- // Check how many elements it would take per thread to start thrashing l1
- // set that to minimum number we want to reduce per thread.
- const int64_t warp_size_based_on_l1 = std::min(
- ceilDiv(
- num_elems_in_reduction,
- std::max(
- l1_cache /
- (n_tensor_inputs * max_input_dtype_size * active_threads),
- (int64_t)1)),
- (int64_t)16);
-
- const int64_t warp_size =
- std::min(warp_size_based_on_l1, warp_size_based_on_l2);
-
- // Initialization
- int64_t target_blocks = 1;
- int64_t target_unroll = 1;
- int64_t target_iterations = 1;
-
- // Try to set a minmum amount of work for each thread, as cross thread
- // communication is slow so it shouldn't be done for every element in the
- // reduction.
- int64_t min_target_iterations =
- std::max((int64_t)32 / (int64_t)max_input_dtype_size, (int64_t)1);
-
- // Start trying to break parallelization up across threads,
- // unrolling/iterations, and blocks.
-
- // max_threads_in_block is the cap on a thread block, the minimum is based on
- // warp_size
- int64_t max_threads_in_block = std::max(
- warp_size, ceilDiv(num_elems_in_reduction, min_target_iterations));
-
- // If we have one warp per block, check if that's enough to saturate the SMs
- target_blocks = ceilDiv(n_elems, warp_size);
-
- // If we have more than a wave of blocks, put parallelism into unrolling and
- // target iterations
- if (target_blocks > device_multiprocessor_count) {
- auto available_unroll = std::max(
- n_elems / (warp_size * device_multiprocessor_count), (int64_t)1);
-
- // Spread across unrolling and iterations, want a balance of the two so flip
- // back and forth to alternate adding to them.
- bool flip = true;
-
- while (available_unroll > 1 &&
- (target_unroll < max_unroll ||
- // Prefer unrolling
- target_iterations < ceilDiv(min_target_iterations, max_unroll))) {
- if (target_unroll * 2 <= max_unroll && flip) {
- target_unroll *= 2;
- }
-
- if (target_iterations * 2 <= ceilDiv(min_target_iterations, max_unroll) &&
- !flip) {
- target_iterations *= 2;
- }
-
- available_unroll = std::max(
- n_elems /
- (warp_size * device_multiprocessor_count * target_unroll *
- target_iterations),
- (int64_t)1);
-
- flip = !flip;
- }
-
- // Recompute target blocks
- target_blocks =
- ceilDiv(n_elems, warp_size * target_unroll * target_iterations);
- }
-
- // Cap target blocks to 4 waves
- target_blocks = std::min(target_blocks, device_multiprocessor_count * 4);
-
- if (target_blocks * target_unroll * target_iterations < n_elems) {
- // targetting 4 waves, so try to use a quarter of available threads
- max_threads_in_block = std::min(
- ceilDiv(n_elems, target_blocks * target_unroll),
- ceilDiv(device_max_threads_per_multiprocessor, (int64_t)4));
- }
-
- // Compute maximum number of reductions we could do in the same kernel based
- // on persistent buffer size
- const int64_t max_multi_reduction_factor = std::max(
- (persistence_required ? (scheduler_utils::register_file_size * 3) /
- (max_persistent_buffer_size * 4)
- : std::numeric_limits<int64_t>::max()),
- (int64_t)1);
-
- // To get to target threads:
- // Prioritize
- // (1) x dim in reduction
- // (2) unrolling in reduction
- // (3) y in output
- // To get target blocks:
- // Prioritize
- // (1) x dim in multiple outputs
- // (2) y dim in multiple reductions
-
- // Blocks for outputs
- int64_t godim = 1;
-
- // Threads for outputs
- int64_t bdimy = 1;
- // Threads for reduction
- int64_t bdimx = 1;
-
- // Should we unroll from reduction axis, or outs axis
- bool unroll_reduction = true;
-
- // Unroll amount
- int64_t unroll_factor = 1;
-
- // Grab what we can out of reduction domain, but don't go over a warp size yet
- bdimx = std::min(num_elems_in_reduction, (int64_t)warp_size);
-
- // Put everything else in bdimy for now
- bdimy = std::min(
- std::max(max_threads_in_block / bdimx, (int64_t)1),
- max_multi_reduction_factor);
-
- int64_t remainder_in_reduction = ceilDiv(num_elems_in_reduction, bdimx);
- int64_t remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimy);
-
- // Adjust blocking and setup unrolling
- // Disable unrolling on iteration domain for persistent kernels for now.
- // TODO: Re-enable.
- if (remainder_in_reduction == 1 && !persistence_required) {
- // Small number of reduction elements, try unrolling output dimension
- unroll_factor = std::min(target_unroll, remainder_in_output);
-
- if (unroll_factor > 1) {
- unroll_reduction = false;
- remainder_in_output =
- ceilDiv(num_outputs_for_reduction, unroll_factor * bdimy);
- }
- } else {
- // If there are reduction elements left after unrolling a warp, re-adjust
- // the block dims to put more threads into the reduction
- bdimx = std::min(
- std::max(
- ceilDiv(num_elems_in_reduction, target_iterations * target_unroll),
- warp_size),
- max_threads_in_block);
-
- // Don't exceed target threads in a block.
- bdimy = std::min(
- std::max(max_threads_in_block / bdimx, (int64_t)1),
- max_multi_reduction_factor);
- remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimy);
-
- remainder_in_reduction = ceilDiv(num_elems_in_reduction, bdimx);
- unroll_factor = std::min(remainder_in_reduction, target_unroll);
-
- // If there's no longer any space for unrolling the reduction dimension, try
- // unrolling the iteration (output) dimension.
- // Disable unrolling on iteration domain for persistent kernels for now.
- // TODO: Re-enable.
- if (unroll_factor == 1 && !persistence_required) {
- // If we can't unroll reduction dim, unroll output dim
- unroll_factor = std::min(remainder_in_output, target_unroll);
- if (unroll_factor > 1) {
- unroll_reduction = false;
- }
- remainder_in_output =
- ceilDiv(num_outputs_for_reduction, bdimy * unroll_factor);
- // Clang-tidy
- // remainder_in_reduction =
- // ceilDiv(num_elems_in_reduction, bdimx *
- // target_iterations);
- }
- // else {
- // remainder_in_reduction = ceilDiv(
- // num_elems_in_reduction,
- // bdimx * std::max(unroll_factor, target_iterations));
- // }
- }
-
- godim = remainder_in_output;
-
- bool vectorize = false;
-
- // Move unrolling factor into vectorization upto vectorization limit.
- if (vectorize_factor > 1 && unroll_factor > 1 && unroll_reduction) {
- vectorize = true;
- unroll_factor = std::min(
- scheduler_utils::lastPow2(unroll_factor), (int64_t)vectorize_factor);
- }
-
- // Set size of persistent per thread buffer
- int64_t batches_per_block = ceilDiv(
- num_elems_in_reduction,
- bdimx * (unroll_reduction ? unroll_factor : (int64_t)1));
- // round up to multiple of 8 or pow2 whichever smaller
- auto round_up_pow2 = scheduler_utils::lastPow2(batches_per_block);
- if (round_up_pow2 < batches_per_block) {
- round_up_pow2 *= 2;
- }
-
- constexpr int64_t kEight = 8; // clang tidy
-
- auto round_up_8 = batches_per_block % kEight == 0
- ? batches_per_block
- : batches_per_block + (kEight - batches_per_block % kEight);
-
- batches_per_block = std::min(round_up_8, round_up_pow2);
-
- // Prefer putting iterations into unrolling over having a very large
- // persistent buffer. Likely this should be more carefully adjusted to not
- // blow out registers, but can revisit if we see any kernels with local memory
- // use.
- while (persistence_required && !vectorize && unroll_factor < max_unroll &&
- batches_per_block % 2 == 0) {
- batches_per_block /= 2;
- unroll_factor *= 2;
- }
-
- ReductionParams rparams;
- rparams.fastest_dim = true;
- rparams.cross_block = true;
- rparams.cross_grid = false;
- rparams.multiple_reds_per_blk =
- bdimy > 1 || (!unroll_reduction && unroll_factor);
- rparams.loop_unroll = unroll_factor;
- rparams.vectorize = vectorize;
- rparams.reduction_unroll = unroll_reduction;
- rparams.batches_per_block = batches_per_block;
- rparams.persistent_kernel = persistence_required;
-
- // Check if we need to split grid-x binding
- rparams.split_grid_dim = godim > scheduler_utils::x_grid_limit;
-
- rparams.lparams = LaunchParams(
- LaunchParams::UNINITIALIZED_VAL,
- LaunchParams::UNINITIALIZED_VAL,
- LaunchParams::UNINITIALIZED_VAL,
- persistence_required ? LaunchParams::UNINITIALIZED_VAL : bdimx,
- bdimy,
- LaunchParams::UNINITIALIZED_VAL);
-
- rparams.tag = persistence_required ? "Inner normalization heuristic.\n"
- : "Multi inner reduction (norm heuristic)";
-
- if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
- std::cerr << "\n===== Reduction Stats ========\n"
- << "num_elems_in_reduction: " << num_elems_in_reduction << "\n"
- << "num_outputs_for_reduction: " << num_outputs_for_reduction
- << "\n"
- << "n_tensor_inputs: " << n_tensor_inputs << "\n"
- << "max_input_dtype_size: " << max_input_dtype_size << "\n"
- << "persistence_required: " << persistence_required << "\n"
- << "max_persistent_buffer_size: " << max_persistent_buffer_size
- << std::endl;
- std::cerr << rparams.toString() << std::endl;
- }
-
- return rparams;
-}
-
-// Copied from reduction scheduler, should generalize. Simply needed to take out
-// grid reductions.
-ReductionParams OuterNormalizationHeuristic(
- const int64_t num_elems_in_reduction,
- const int64_t num_outputs_for_reduction,
- const int64_t n_tensor_inputs,
- const int64_t max_input_dtype_size,
- bool persistence_required,
- const int64_t max_persistent_buffer_size,
- size_t vectorize_factor) {
- // Set some targets for parallelization
- const int64_t n_elems = num_elems_in_reduction * num_outputs_for_reduction;
-
- // WARNING: Current device for codegen may not be the target device
- const int64_t device_max_threads_per_multiprocessor =
- (int64_t)at::cuda::getCurrentDeviceProperties()
- ->maxThreadsPerMultiProcessor;
-
- const int64_t device_multiprocessor_count =
- (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
-
- auto const max_unroll = ceilDiv(
- // Available unrolling based on size of data type
- (int64_t)16 / (int64_t)max_input_dtype_size,
- // Reduce unrolling if we have many inputs, start reduction at 4 inputs
- std::max(
- (scheduler_utils::lastPow2((int64_t)n_tensor_inputs - 1) >> 1),
- (int64_t)1));
-
- // If it fits in l2, we just want to make sure each warp uses 32Bytes. Set
- // minimum warp as 16 threads instead of 32 as if we have a small reduction
- // dim going a bit smaller than 32 usually helps.
- const int64_t warp_size = n_elems * max_input_dtype_size * n_tensor_inputs <
- at::cuda::getCurrentDeviceProperties()->l2CacheSize
- ? (int64_t)32 / max_input_dtype_size
- : 16;
-
- // Initialization
- int64_t target_blocks = 1;
- int64_t target_unroll = 1;
- int64_t max_threads_in_block = warp_size;
-
- // If we have one warp per block, check if that's enough to saturate the SMs
- target_blocks = ceilDiv(n_elems, (int64_t)warp_size);
-
- // If we have more than a wave of blocks, put parallelism into unrolling
- if (target_blocks > device_multiprocessor_count) {
- target_unroll = std::min(
- max_unroll, ceilDiv(target_blocks, device_multiprocessor_count));
- target_blocks = ceilDiv(target_blocks, target_unroll);
- }
-
- // Cap target blocks to 4 waves
- target_blocks = std::min(target_blocks, device_multiprocessor_count * 4);
-
- if (target_blocks * target_unroll * max_threads_in_block < n_elems) {
- // targetting 4 waves, so try to use a quarter of available threads
- max_threads_in_block = std::min(
- ceilDiv(n_elems, target_blocks * target_unroll),
- ceilDiv(device_max_threads_per_multiprocessor, (int64_t)4));
- }
-
- // Compute maximum number of reductions we could do in the same kernel based
- // on persistent buffer size
-
- const int64_t max_multi_reduction_factor = std::max(
- (persistence_required ? (scheduler_utils::register_file_size * 3) /
- (max_persistent_buffer_size * 4)
- : std::numeric_limits<int64_t>::max()),
- (int64_t)1);
-
- // To get to target threads:
- // Prioritize
- // (1) x dim in iter domain
- // (2) unrolling in iter domain
- // (3) y in reduction domain
- // To get target blocks:
- // Prioritize
- // (1) x dim in multiple outputs
- // (2) y dim in multiple reductions - need to flip unrolling to reduction
- // domain for this
-
- // Blocks for outputs
- // int64_t gdimx = 1; // unused at this time, comment for clang tidy
-
- // Threads for reduction
- int64_t bdimy = 1;
- // Threads for output
- int64_t bdimx = 1;
-
- // Should we unroll from reduction axis, or outs axis
- bool unroll_reduction = false;
-
- // Unroll amount
- int64_t unroll_factor = 1;
-
- int64_t remainder_in_reduction = num_elems_in_reduction;
- int64_t remainder_in_output = num_outputs_for_reduction;
-
- if (ceilDiv(num_outputs_for_reduction, warp_size) <
- device_multiprocessor_count) {
- // If we can't hit a full wave, leave bdimx as warp_size, and prioritize
- // bdimy.
- bdimx = std::min(
- std::min(num_outputs_for_reduction, warp_size),
- max_multi_reduction_factor);
- } else {
- bdimx = std::min(
- max_threads_in_block,
- ceilDiv(num_outputs_for_reduction, target_blocks));
- bdimx = std::min(std::max(bdimx, warp_size), max_multi_reduction_factor);
- }
-
- // Fill bdimy with left over threads
- bdimy = std::min(
- std::max(max_threads_in_block / bdimx, (int64_t)1),
- num_elems_in_reduction);
-
- // Clang tidy
- // remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimx);
- remainder_in_reduction = ceilDiv(remainder_in_reduction, bdimy);
-
- if (num_outputs_for_reduction >=
- device_multiprocessor_count * max_threads_in_block) {
- // If we easily saturate the GPU, don't use block dim y and unroll output
- // dimension TODO: this could be a more gentle transition starting earlier
- bdimx = std::min(max_threads_in_block, max_multi_reduction_factor);
- remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimx);
-
- // TODO: This should probably still be based on max threads in a block
- // especially if we're limited by max_multi_reduction_factor
- bdimy = 1;
- remainder_in_reduction = num_elems_in_reduction;
-
- // Assume unroll in output, switch to remainder if cross grid
- // Don't unroll if we don't have 2 full waves
- //
- // Disable unrolling on iteration domain for persistent kernels for now.
- // TODO: Re-enable.
- unroll_factor = persistence_required
- ? 1
- : std::min(
- ceilDiv(remainder_in_output, device_multiprocessor_count * 2),
- target_unroll);
- if (unroll_factor == 1 && remainder_in_reduction > 1) {
- // Try unrolling in reduction dimension
- unroll_factor = std::min(remainder_in_reduction, unroll_factor);
- // Clang tidy
- // remainder_in_reduction = ceilDiv(remainder_in_reduction,
- // unroll_factor);
- if (unroll_factor > 1) {
- unroll_reduction = true;
- }
- }
- // else {
- // remainder_in_output =
- // ceilDiv(num_outputs_for_reduction, bdimx * unroll_factor);
- // unused, comment for clang tidy
- // }
- } else {
- // Not many output elements, try unrolling reduction dimension, would
- // typically go cross grid, but can't for multi-reduction and normalization
- // kernels.
- // TODO: Enable cross reduction for multi-reduction cases
- unroll_factor = std::min(max_unroll, remainder_in_reduction);
- if (unroll_factor > 1) {
- unroll_reduction = true;
- }
- }
-
- if (unroll_factor == 1) {
- unroll_reduction = true;
- }
-
- // Persistence size from buffers
- int64_t batches_per_block = 1;
- if (persistence_required) {
- batches_per_block = ceilDiv(
- num_elems_in_reduction,
- bdimy * (unroll_reduction ? unroll_factor : (int64_t)1));
- // round up to multiple of 8 or pow2 whichever smaller
- }
-
- auto round_up_pow2 = scheduler_utils::lastPow2(batches_per_block);
- if (round_up_pow2 < batches_per_block) {
- round_up_pow2 *= 2;
- }
-
- constexpr int64_t kEight = 8; // clang tidy
-
- auto round_up_8 = batches_per_block % kEight == 0
- ? batches_per_block
- : batches_per_block + (kEight - batches_per_block % kEight);
-
- batches_per_block = std::min(round_up_8, round_up_pow2);
-
- bool vectorize = false;
-
- if (vectorize_factor > 1 && unroll_factor > 1 && !unroll_reduction) {
- vectorize = true;
- unroll_factor = std::min(
- scheduler_utils::lastPow2(unroll_factor), (int64_t)vectorize_factor);
- }
-
- ReductionParams rparams;
- rparams.fastest_dim = false;
- rparams.cross_block = bdimy > 1;
- rparams.cross_grid = false;
- rparams.multiple_reds_per_blk =
- bdimx > 1 || (!unroll_reduction && unroll_factor);
- rparams.loop_unroll = unroll_factor;
- rparams.vectorize = vectorize;
- rparams.reduction_unroll = unroll_reduction;
- rparams.batches_per_block = batches_per_block;
- rparams.persistent_kernel = persistence_required;
-
- rparams.lparams = LaunchParams(
- LaunchParams::UNINITIALIZED_VAL,
- LaunchParams::UNINITIALIZED_VAL,
- LaunchParams::UNINITIALIZED_VAL,
- bdimx,
- persistence_required ? LaunchParams::UNINITIALIZED_VAL : bdimy,
- LaunchParams::UNINITIALIZED_VAL);
-
- rparams.tag = persistence_required ? "Outer normalization heuristic.\n"
- : "Multi outer reduction (norm heuristic)";
-
- if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
- std::cerr << "\n===== Reduction Stats ========\n"
- << "num_elems_in_reduction: " << num_elems_in_reduction << "\n"
- << "num_outputs_for_reduction: " << num_outputs_for_reduction
- << "\n"
- << "n_tensor_inputs: " << n_tensor_inputs << "\n"
- << "max_input_dtype_size: " << max_input_dtype_size << "\n"
- << "persistence_required: " << persistence_required << "\n"
- << "max_persistent_buffer_size: " << max_persistent_buffer_size
- << std::endl;
- std::cerr << rparams.toString() << std::endl;
- }
-
- return rparams;
-}
-
-} // namespace
-
-ReductionParams NormalizationHeuristic(
- int64_t num_elems_in_reduction,
- int64_t num_outputs_for_reduction,
- bool fastest_dim_reduction,
- size_t n_tensor_inputs,
- size_t max_input_dtype_size,
- bool persistence_required,
- const int64_t max_persistent_buffer_size,
- size_t vectorize_factor) {
- if (fastest_dim_reduction) {
- return innerNormalizationHeuristic(
- num_elems_in_reduction,
- num_outputs_for_reduction,
- n_tensor_inputs,
- max_input_dtype_size,
- persistence_required,
- max_persistent_buffer_size,
- vectorize_factor);
- } else {
- return OuterNormalizationHeuristic(
- num_elems_in_reduction,
- num_outputs_for_reduction,
- n_tensor_inputs,
- max_input_dtype_size,
- persistence_required,
- max_persistent_buffer_size,
- vectorize_factor);
- }
-}
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getNormalizationHeuristics(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache) {
- FUSER_PERF_SCOPE("getNormalizationHeuristics");
-
- FusionGuard fg(fusion);
-
- HeuristicCacheAccessor<std::vector<TensorView*>> reduction_tv_data;
- // TODO: move all these boilerplate code into the accessor class
- // (follow up)
- if (data_cache && !data_cache->isRecording()) {
- reduction_tv_data.writeTemporary(data_cache->getReductionTVs());
- } else {
- reduction_tv_data.writeNew(scheduler_utils::getReductionTvs(fusion));
- if (data_cache && data_cache->isRecording()) {
- data_cache->setReductionTVs(reduction_tv_data.read());
- }
- }
-
- auto& reduction_tvs = reduction_tv_data.read();
-
- TORCH_INTERNAL_ASSERT(
- !reduction_tvs.empty(), "Need reduction tensor views to schedule.");
-
- auto first_red_tv = reduction_tvs[0];
-
- TORCH_INTERNAL_ASSERT(
- first_red_tv != nullptr, "Reduction TensorView wasn't found.");
-
- TORCH_INTERNAL_ASSERT(
- first_red_tv->hasReduction(), "TensorView doesn't have a reduction.");
- const auto red_expr = first_red_tv->definition();
-
- TORCH_INTERNAL_ASSERT(
- red_expr->getExprType() != c10::nullopt &&
- (red_expr->getExprType().value() == ExprType::ReductionOp ||
- red_expr->getExprType().value() == ExprType::WelfordOp),
- "TensorView doesn't have a reduction.");
-
- size_t max_dtype_size = 1;
- size_t n_tensor_inputs = 0;
- for (auto inp : fusion->inputs()) {
- if (inp->isA<TensorView>()) {
- max_dtype_size =
- std::max(max_dtype_size, dataTypeSize(inp->getDataType().value()));
- n_tensor_inputs++;
- }
- }
-
- TORCH_INTERNAL_ASSERT(
- n_tensor_inputs > 0,
- "Tried to schedule a fusion with no tensor inputs, currently not supported.");
-
- HeuristicCacheAccessor<scheduler_utils::PersistentBufferInfo>
- persistent_buffer_data;
-
- // TODO: move all these boilerplate code into the accessor class
- // (follow up)
- if (data_cache && !data_cache->isRecording()) {
- persistent_buffer_data.writeTemporary(
- data_cache->getPersistentBufferInfo());
- } else {
- persistent_buffer_data.writeNew(scheduler_utils::persistentBuffers(fusion));
- if (data_cache && data_cache->isRecording()) {
- data_cache->setPersistentBufferInfo(persistent_buffer_data.read());
- }
- }
-
- auto& persistent_buffers = persistent_buffer_data.read();
- bool requires_persistence = !persistent_buffers.buffers.empty();
-
- auto properties =
- scheduler_utils::getProperties(fusion, runtime_info, first_red_tv);
-
- auto max_persistent_size = scheduler_utils::persistentBufferSize(
- fusion, runtime_info, persistent_buffers, data_cache);
-
- HeuristicCacheAccessor<std::vector<TensorView*>>
- vectorizable_inputs_outputs_data;
-
- // TODO: move all these boilerplate code into the accessor class
- // (follow up)
- if (data_cache && !data_cache->isRecording()) {
- vectorizable_inputs_outputs_data.writeTemporary(
- data_cache->getVectorizableInputsOutputs());
- } else {
- vectorizable_inputs_outputs_data.writeNew(
- scheduler_utils::getVectorizableInputsOutputs(first_red_tv));
- if (data_cache && data_cache->isRecording()) {
- data_cache->setVectorizableInputsOutputs(
- vectorizable_inputs_outputs_data.read());
- }
- }
-
- auto& vectorizable_inputs_outputs = vectorizable_inputs_outputs_data.read();
-
- // Vectorize as much as we can
- size_t vectorize_factor = std::numeric_limits<size_t>::max();
-
- for (auto tv : vectorizable_inputs_outputs) {
- const auto tv_vectorize_factor = runtime_info.getVectorizableWidth(tv);
- vectorize_factor = std::min(vectorize_factor, tv_vectorize_factor);
- }
-
- if (vectorize_factor == std::numeric_limits<size_t>::max()) {
- vectorize_factor = 1;
- }
-
- return NormalizationHeuristic(
- properties.reduction_numel,
- properties.iteration_numel,
- properties.fastest_dim_reduction,
- n_tensor_inputs,
- max_dtype_size,
- requires_persistence,
- max_persistent_size,
- vectorize_factor);
-}
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getNormalizationHeuristics(
- Fusion* fusion,
- const at::ArrayRef<c10::IValue>& runtime_inputs,
- HeuristicSummary* data_cache) {
- FUSER_PERF_SCOPE("getNormalizationHeuristicsFromIValue");
- SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true);
- return getNormalizationHeuristics(fusion, runtime_info, data_cache);
-}
-
-namespace {
-
-void schedulePersistentNormalization(
- Fusion* fusion,
- const ReductionParams& rparams) {
- FUSER_PERF_SCOPE("schedulePersistentNormalization");
- FusionGuard fg(fusion);
- // Cache tensors before grabbing any references to reductions as cache_before
- // can invalidate the references since when applied to a reduction tensor view
- // the new tensor view contains the reduction and original doesn't.
-
- // Cache inputs if unrolled
- auto cached_inputs =
- scheduler_utils::cacheInputs(fusion, rparams.loop_unroll > 1);
-
- // Cache and fork outputs
- std::vector<std::pair<TensorView*, TensorView*>> cached_outputs =
- scheduler_utils::cacheAndForkOutputs(fusion, rparams.loop_unroll > 1);
-
- // Make sure we don't have global memory set on intermediate tensors from
- // fusion segmentation
- scheduler_utils::clearMemorySpace(fusion);
-
- auto reduction_tvs = scheduler_utils::getReductionTvs(fusion);
-
- TORCH_INTERNAL_ASSERT(reduction_tvs.size());
- auto reduction_tv = reduction_tvs[0];
-
- auto dim_analysis =
- scheduler_utils::canonicalDimReduction(fusion, reduction_tv);
- bool has_iter_axis = dim_analysis.first;
- bool has_red_axis = dim_analysis.second;
-
- TORCH_INTERNAL_ASSERT(
- has_red_axis,
- "Could not find reduction axis in tensor used for reduction scheduler.");
-
- if (!has_iter_axis) {
- TORCH_INTERNAL_ASSERT(
- rparams.fastest_dim,
- "If all dims are reduction, should be sending it to fastest dim scheduler.");
- }
-
- TensorView* reference_tv = scheduler_utils::scheduleReductionTV(
- rparams, reduction_tv, has_iter_axis);
-
- // Reduction tensor views and rfactor tensor views are setup. Let's finish off
- // the scheduling, particularly inlining and unrolling.
- TORCH_INTERNAL_ASSERT(
- reference_tv != nullptr && reduction_tv != nullptr,
- "Need these two tensor views to finish the scheduling.");
-
- scheduler_utils::multiReductionInliner(
- fusion,
- rparams,
- reduction_tv,
- reference_tv,
- reduction_tvs,
- cached_inputs,
- cached_outputs);
-}
-
-void scheduleMultiReduction(Fusion* fusion, const ReductionParams& rparams) {
- FUSER_PERF_SCOPE("scheduleMultiReduction");
- FusionGuard fg(fusion);
- // Cache tensors before grabbing any references to reductions as cache_before
- // can invalidate the references since when applied to a reduction tensor view
- // the new tensor view contains the reduction and original doesn't.
-
- // Cache inputs if unrolled
- auto cached_inputs =
- scheduler_utils::cacheInputs(fusion, rparams.loop_unroll > 1);
-
- // Cache and fork outputs
- std::vector<std::pair<TensorView*, TensorView*>> cached_outputs =
- scheduler_utils::cacheAndForkOutputs(fusion, rparams.loop_unroll > 1);
-
- // Make sure we don't have global memory set on intermediate tensors from
- // fusion segmentation
- scheduler_utils::clearMemorySpace(fusion);
-
- auto reduction_tvs = scheduler_utils::getReductionTvs(fusion);
-
- TORCH_INTERNAL_ASSERT(reduction_tvs.size());
- auto reduction_tv = reduction_tvs[0];
-
- auto dim_analysis =
- scheduler_utils::canonicalDimReduction(fusion, reduction_tv);
- bool has_iter_axis = dim_analysis.first;
- bool has_red_axis = dim_analysis.second;
-
- TORCH_INTERNAL_ASSERT(
- has_red_axis,
- "Could not find reduction axis in tensor used for reduction scheduler.");
-
- if (!has_iter_axis) {
- TORCH_INTERNAL_ASSERT(
- rparams.fastest_dim,
- "If all dims are reduction, should be sending it to fastest dim scheduler.");
- }
-
- TensorView* reference_tv = scheduler_utils::scheduleReductionTV(
- rparams, reduction_tv, has_iter_axis);
-
- // Reduction tensor views and rfactor tensor views are setup. Let's finish off
- // the scheduling, particularly inlining and unrolling.
- TORCH_INTERNAL_ASSERT(
- reference_tv != nullptr && reduction_tv != nullptr,
- "Need these two tensor views to finish the scheduling.");
-
- scheduler_utils::multiReductionInliner(
- fusion,
- rparams,
- reduction_tv,
- reference_tv,
- reduction_tvs,
- cached_inputs,
- cached_outputs);
-}
-} // namespace
-
-// fusion is the input IR that will be modified by this function
-TORCH_CUDA_CU_API void scheduleNormalization(
- Fusion* fusion,
- const ReductionParams& rparams) {
- if (rparams.persistent_kernel) {
- schedulePersistentNormalization(fusion, rparams);
- } else {
- scheduleMultiReduction(fusion, rparams);
- }
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <ATen/core/ivalue.h>
-
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h>
-
-// TODO: If caching inputs would require persistence we are sending it to the
-// persistent kerenl scheduler. This isn't necessary if the only persistent
-// buffers are inputs as we could re-read them from global memory. Need to
-// consider if this is worth implementing.
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class SchedulerRuntimeInfo;
-class HeuristicSummary;
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getNormalizationHeuristics(
- Fusion* fusion,
- const at::ArrayRef<c10::IValue>& runtime_inputs,
- HeuristicSummary* data_cache = nullptr);
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getNormalizationHeuristics(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr);
-
-TORCH_CUDA_CU_API void scheduleNormalization(
- Fusion* fusion,
- const ReductionParams& rparams);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise.h>
-
-#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-#include <torch/csrc/jit/codegen/cuda/utils.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-
-#include <ATen/cuda/CUDAContext.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-// constexpr int64_t x_grid_limit = ((int64_t)1 << (int64_t)31) - (int64_t)1;
-// Unused at the moment, commenting for clang tidy
-constexpr int64_t kThreadX = 128;
-} // namespace
-
-c10::optional<PointwiseParams> getPointwiseHeuristics(
- Fusion* fusion,
- const at::ArrayRef<c10::IValue>& runtime_inputs,
- HeuristicSummary* data_cache) {
- SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true);
- return getPointwiseHeuristics(fusion, runtime_info, data_cache);
-}
-
-c10::optional<PointwiseParams> getPointwiseHeuristics(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache) {
- FUSER_PERF_SCOPE("getPointwiseHeuristics");
-
- FusionGuard fg(fusion);
- TensorView* largest_out = nullptr;
- int max_dims = -1;
-
- auto in_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
- auto out_tvs_it = ir_utils::filterByType<TensorView>(fusion->outputs());
- // Will want to access this with direct indexing later, convert now.
- std::vector<TensorView*> out_tvs(out_tvs_it.begin(), out_tvs_it.end());
-
- for (auto out_tv : out_tvs) {
- int n_dims = 0;
- for (auto id : out_tv->getMaybeRFactorDomain()) {
- if (id->isReduction() || id->isBroadcast()) {
- continue;
- }
- n_dims++;
- }
- if (n_dims > max_dims) {
- largest_out = out_tv;
- max_dims = n_dims;
- }
- }
-
- TORCH_INTERNAL_ASSERT(largest_out != nullptr);
-
- // If zero dimensional, return default parameters
- if (TensorDomain::noReductions(
- TensorDomain::noBroadcasts(largest_out->domain()->domain()))
- .size() == 0) {
- if (data_cache && data_cache->isRecording()) {
- data_cache->setVectorizableInputsOutputs(std::vector<TensorView*>());
- data_cache->setMappedInputOutputDims(std::vector<int64_t>());
- }
- return PointwiseParams();
- }
-
- auto ref_root = largest_out->getMaybeRFactorDomain();
-
- std::vector<int64_t> elem_counts(ref_root.size(), 1);
- int64_t n_elems = 1;
- for (size_t ref_i = 0; ref_i < ref_root.size(); ref_i++) {
- auto inferred_val =
- runtime_info.expressionEvaluator().evaluate(ref_root[ref_i]->extent());
- TORCH_INTERNAL_ASSERT(
- inferred_val.has_value(),
- "Error inferring size for pointwise scheduler.");
- elem_counts[ref_i] = inferred_val.value();
- n_elems *= inferred_val.value();
- }
-
- const int64_t device_multiprocessor_count =
- (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
-
- // TODO: Set to 1?
- int64_t max_input_dtype_size = 2;
- size_t n_tensors = 0;
-
- for (auto inp : in_tvs) {
- max_input_dtype_size = std::max(
- max_input_dtype_size,
- (int64_t)dataTypeSize(inp->getDataType().value()));
- n_tensors++;
- }
- n_tensors += std::distance(out_tvs.begin(), out_tvs.end());
-
- constexpr int64_t kSixteen = 16; // clang tidy
-
- auto max_unroll_factor = ceilDiv(
- // Available unrolling based on size of data type
- (int64_t)kSixteen / max_input_dtype_size,
- // Reduce unrolling if we have many inputs, start reduction at 4 inputs
- std::max(
- (scheduler_utils::lastPow2((int64_t)n_tensors) >> 2), (int64_t)1));
-
- // Don't unroll at the cost of getting a full wave on the GPU
- if (n_elems < device_multiprocessor_count * kThreadX &&
- max_unroll_factor > 1) {
- max_unroll_factor = std::min(
- max_unroll_factor,
- ceilDiv(n_elems, device_multiprocessor_count * kThreadX));
- }
-
- // If we use RNG don't unroll so we can do correctness testing
- if (fusion->isStochastic() && disableRNGUnrolling()) {
- max_unroll_factor = 1;
- }
-
- PointwiseParams params;
- params.tag = "Pointwise heuristics";
-
- // Don't try to vectorize if it's not recommended
- params.inner_factor = 1;
-
- // Vectorize as much as we can
- size_t vectorize_factor = max_unroll_factor;
-
- HeuristicCacheAccessor<std::vector<TensorView*>>
- vectorizable_inputs_outputs_data;
-
- // TODO: move all these boilerplate code into the accessor class
- // (follow up)
- if (data_cache && !data_cache->isRecording()) {
- vectorizable_inputs_outputs_data.writeTemporary(
- data_cache->getVectorizableInputsOutputs());
- } else {
- vectorizable_inputs_outputs_data.writeNew(
- scheduler_utils::getVectorizableInputsOutputs(largest_out));
- if (data_cache && data_cache->isRecording()) {
- data_cache->setVectorizableInputsOutputs(
- vectorizable_inputs_outputs_data.read());
- }
- }
-
- auto& vectorizable_inputs_outputs = vectorizable_inputs_outputs_data.read();
-
- for (auto tv : vectorizable_inputs_outputs) {
- const auto tv_vectorize_factor = runtime_info.getVectorizableWidth(tv);
- vectorize_factor = std::min(vectorize_factor, tv_vectorize_factor);
- }
-
- if (vectorize_factor == 1) {
- params.vectorize = false;
- params.inner_factor = max_unroll_factor;
- } else {
- params.vectorize = true;
- params.inner_factor = vectorize_factor;
- }
- /*
- * 2D pointwise scheduling logic. What is expected is there's some
- * broadcasting pattern which would make scheduling as a 2D problem more
- * efficient than scheduling simply as a 1D problem.
- *
- * Mapping count holds how many bytes are in each dimension for both inputs
- * and outputs relative to the reference tensor. What we're looking for is a
- * break point in reference_tvs dimensions which separates the outer dimension
- * and inner dimension of the problem mapped to 2D.
- *
- * break_point is computed assuming no reuse, ignoring parallelization
- * limitations, and simply figures out which point best separates broadcasted
- * dimensions. In other words, where's the point where we isolate the most
- * broadcasted elements to one side.
- *
- * Once a break point is found, simply schedule the pointwise op as 2D
- * balancing parallelization as best as possible.
- */
-
- // Ideal break point location
- int64_t break_point = 0;
-
- // Elements on the right of break point (without break point all are on the
- // right)
- int64_t right_elem_count = 0;
-
- int64_t bdimx = kThreadX;
-
- // bdimy may be used if the right side of the break point is not large and we
- // need to expand block level parallelism into the left side of the break
- // point.
- int64_t bdimy = 1;
-
- // In 2D scheduler gdimx is used to parallelize the left side of the break
- // point.
- int64_t gdimx = 1;
-
- // gdimy is used if there's too much parallelization in the right side of the
- // break point. We will expand grid parallelization into the right side of the
- // break point with gdimx and use gdimy for the left side of the break point.
- int64_t gdimy = 1;
-
- HeuristicCacheAccessor<std::vector<int64_t>> mapping_count_accessor;
- // TODO: move all these boilerplate code into the accessor class
- // (follow up)
- if (data_cache && !data_cache->isRecording()) {
- mapping_count_accessor.writeTemporary(
- data_cache->getMappedInputOutputDims());
- } else {
- mapping_count_accessor.writeNew(
- scheduler_utils::mappedInputsOutputs(largest_out));
- if (data_cache && data_cache->isRecording()) {
- data_cache->setMappedInputOutputDims(mapping_count_accessor.read());
- }
- }
-
- auto mapping_count = mapping_count_accessor.read();
-
- {
- // How much would this transfer cost if it was done as a 1-D schedule
- int64_t transfer_size_1d = 1;
-
- auto max_dims =
- std::max_element(mapping_count.begin(), mapping_count.end());
-
- for (int64_t i = 0; i < (int64_t)ref_root.size(); i++) {
- transfer_size_1d = transfer_size_1d * elem_counts[i] * (*max_dims);
- }
-
- // If there isn't very much parallelism available, just use 1D scheduler
- if (true || n_elems * 2 > device_multiprocessor_count * kThreadX) {
- int64_t min_total_transfer = std::numeric_limits<int64_t>::max();
-
- for (int64_t break_point_i = 0; break_point_i < (int64_t)ref_root.size();
- break_point_i++) {
- // Number of elements in the right side of reference tv with
- // break_point_i
- int64_t cur_right_elem_count = 1;
- for (int64_t right_i = break_point_i;
- right_i < (int64_t)ref_root.size();
- right_i++) {
- cur_right_elem_count = cur_right_elem_count * elem_counts[right_i];
- }
-
- if (cur_right_elem_count <= 1) {
- continue;
- }
-
- auto cur_left_elem_count = n_elems / cur_right_elem_count;
- if (cur_left_elem_count <= 1) {
- continue;
- }
-
- auto left_max_dims = std::max_element(
- mapping_count.begin(), mapping_count.begin() + break_point_i);
-
- auto right_max_dims = std::max_element(
- mapping_count.begin() + break_point_i, mapping_count.end());
-
- // Estimate transfer cost with this break point
- int64_t cur_transfer_size = 1;
-
- for (int64_t left_i = 0; left_i < break_point_i; left_i++) {
- cur_transfer_size =
- cur_transfer_size * elem_counts[left_i] * (*left_max_dims);
- }
-
- for (int64_t right_i = break_point_i;
- right_i < (int64_t)ref_root.size();
- right_i++) {
- cur_transfer_size =
- cur_transfer_size * elem_counts[right_i] * (*right_max_dims);
- }
-
- // Continue if this break point doesn't save at least 10% of 1D
- // scheduling.
- if (cur_transfer_size >= min_total_transfer ||
- cur_transfer_size * 10 >= transfer_size_1d * 9) {
- continue;
- }
-
- // Don't limit unroll factor with break point
- if (cur_right_elem_count < max_unroll_factor) {
- continue;
- }
-
- bdimx = std::min(
- ceilDiv(cur_right_elem_count, max_unroll_factor), kThreadX);
- bdimy = 1;
- gdimy = 1;
- // Put remainder in bdimy if there's at least a wave of grid level
- // parallelism.
- if (cur_left_elem_count > device_multiprocessor_count) {
- bdimy = kThreadX / bdimx;
- }
- auto remainder_left = ceilDiv(cur_left_elem_count, bdimy);
- auto remainder_right =
- ceilDiv(cur_right_elem_count, bdimy * bdimx * max_unroll_factor);
-
- // Use this break point
- break_point = break_point_i;
- min_total_transfer = cur_transfer_size;
- right_elem_count = cur_right_elem_count;
-
- gdimx = remainder_left;
- if (remainder_right > 1 && bdimy <= 1) {
- gdimy = remainder_right;
- }
- }
- }
- }
-
- TORCH_INTERNAL_ASSERT(right_elem_count > 0 || params.break_point == 0);
-
- TORCH_INTERNAL_ASSERT(!(bdimy > 1 && gdimy > 1));
- params.break_point = break_point;
- params.split_block = bdimy > 1;
-
- params.lparams.bind(bdimx, ParallelType::TIDx);
- if (params.split_block) {
- params.lparams.bind(bdimy, ParallelType::TIDy);
- }
- if (gdimy > 65535) {
- params.split_grid_y_dim = true;
- }
-
- if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
- std::cerr << "\n===== Pointwise Stats ========\n"
- << "num_elems: " << n_elems << "\n"
- << "mapping_count: " << mapping_count << "\n"
- << "elem_counts: " << elem_counts << "\n"
- << "n_tensor_inputs: " << n_tensors << "\n"
- << "max_input_dtype_size: " << max_input_dtype_size << "\n"
- << "vectorize_factor: " << vectorize_factor << std::endl;
- std::cerr << params.toString() << std::endl;
- }
-
- return params;
-}
-
-// TODO: remove or return launch parameters
-LaunchParams schedulePointwise(
- Fusion* fusion,
- const at::ArrayRef<c10::IValue>& runtime_inputs) {
- FUSER_PERF_SCOPE("scheduleFusion");
- auto params = getPointwiseHeuristics(fusion, runtime_inputs);
- TORCH_INTERNAL_ASSERT(
- params.has_value(), "Could not schedule pointwise operation.");
- schedulePointwise(fusion, params.value());
- return params.value().lparams;
-}
-
-namespace {
-// Returns number of non-reduction/non-broadcast dims in rfactor domain
-size_t nRootDims(const TensorView* tv) {
- auto root_dom = tv->getMaybeRFactorDomain();
- size_t tv_n_dims = 0;
- for (auto dim : root_dom) {
- if (!dim->isReduction() && !dim->isBroadcast()) {
- tv_n_dims++;
- }
- }
- return tv_n_dims;
-}
-} // namespace
-
-// TODO: Inline intermediate operations (avoid inlining unrolled/vectorized
-// input/output caches)
-void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
- FusionGuard fg(fusion);
- // fusion->printMath();
- // Make sure we don't have global memory set on intermediate tensors from
- // fusion segmentation
- scheduler_utils::clearMemorySpace(fusion);
-
- // maybe has_reduction for scheduling should be done on a per output tensor
- // basis.
- TORCH_INTERNAL_ASSERT(
- !fusion->hasReduction(), "This scheduler only handles pointwise ops.");
-
- // For intermediate outputs, apply cache_fork
- auto outs = fusion->outputs();
- for (const auto output : outs) {
- if (!output->uses().empty() && output->definition() != nullptr) {
- if (output->getValType().value() == ValType::TensorView) {
- output->as<TensorView>()->cache_fork();
- }
- }
- }
-
- std::vector<TensorView*> input_tvs;
- {
- auto filtered_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
- // Remove hanging tensor views
- for (auto tv : filtered_tvs) {
- if (tv->uses().empty()) {
- continue;
- }
- input_tvs.push_back(tv);
- }
- }
- auto output_tvs = ir_utils::filterByType<TensorView>(fusion->outputs());
-
- size_t max_dims = 0;
- for (auto inp : input_tvs) {
- max_dims = std::max(nRootDims(inp), max_dims);
- }
-
- for (auto out : output_tvs) {
- max_dims = std::max(nRootDims(out), max_dims);
- }
-
- // If everything is zero dim tensors, just return.
- if (max_dims == 0) {
- return;
- }
-
- TensorView* reference_tv = nullptr;
- for (auto out : output_tvs) {
- if (out->definition() == nullptr) {
- continue;
- }
- if (nRootDims(out) == max_dims) {
- reference_tv = out;
- break;
- }
- }
-
- TORCH_INTERNAL_ASSERT(
- reference_tv != nullptr,
- "Could not find a fully broadcasted output to reference schedule on.");
-
- IterDomain* inner_most_id = nullptr;
- for (auto it = reference_tv->domain()->domain().rbegin();
- it != reference_tv->domain()->domain().rend();
- it++) {
- if ((*it)->isReduction()) {
- continue;
- }
- if ((*it)->isBroadcast() && inner_most_id == nullptr) {
- inner_most_id = *it;
- }
- inner_most_id = *it;
- break;
- }
-
- TORCH_INTERNAL_ASSERT(inner_most_id != nullptr);
- auto vectorizable_dims =
- scheduler_utils::FindAllMappedDims::from(reference_tv, inner_most_id);
-
- // Caches of inputs
- std::vector<TensorView*> cached_inputs;
-
- // Output, cache_before of output
- std::vector<std::pair<TensorView*, TensorView*>> cached_outputs;
-
- // Track what should be vectorized versus unrolled
- std::unordered_set<TensorView*> vectorized_tensor;
-
- // Figure out which inputs to cache for unrolling or vectorization
- for (auto inp : input_tvs) {
- if (inp->uses().empty()) {
- continue;
- }
- // Need to check before caching.
- bool vectorize = params.vectorize &&
- scheduler_utils::shouldVectorize(inp, vectorizable_dims);
- cached_inputs.emplace_back(inp->cache_after());
- if (vectorize) {
- vectorized_tensor.emplace(cached_inputs.back());
- }
- }
-
- // Figure out which outputs to cache for unrolling or vectorization
- for (auto out : output_tvs) {
- if (out->definition() == nullptr) {
- continue;
- }
- // Need to check before caching.
- bool vectorize = params.vectorize &&
- scheduler_utils::shouldVectorize(out, vectorizable_dims);
- cached_outputs.emplace_back(std::make_pair(out, out->cache_before()));
- if (vectorize) {
- vectorized_tensor.emplace(out);
- }
- }
-
- auto all_tvs = ir_utils::allTvs(fusion);
-
- // Merge right side of break point
- int rhs_i = -1;
- for (int i = (int)reference_tv->nDims(); i > (int)params.break_point; i--) {
- auto axis_i = i - 1;
- if (reference_tv->axis(axis_i)->isBroadcast() ||
- reference_tv->axis(axis_i)->isReduction()) {
- continue;
- }
- if (rhs_i == -1) {
- rhs_i = axis_i;
- } else {
- reference_tv->merge(axis_i, rhs_i);
- rhs_i = axis_i;
- }
- }
- if (rhs_i >= 0) {
- // If there's an rhs
- reference_tv->reorder({{rhs_i, -1}});
- }
-
- // Merge left side of break point
- int lhs_i = -1;
- for (int i = (int)params.break_point; i > 0; i--) {
- auto axis_i = i - 1;
- if (reference_tv->axis(axis_i)->isBroadcast() ||
- reference_tv->axis(axis_i)->isReduction()) {
- continue;
- }
- if (lhs_i == -1) {
- lhs_i = axis_i;
- } else {
- reference_tv->merge(axis_i, lhs_i);
- lhs_i = axis_i;
- }
- }
-
- // Right (inner merged) dimension is at inner most position, left (outer
- // merged) dimension is at lhs_i. Order as [lhs_i, rhs_i, unmerged...]
- reference_tv->reorder({{lhs_i, 0}, {-1, 1}});
-
- if (params.break_point) {
- // 2D parallelization scheme
- TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i >= 0);
-
- if (params.vectorize) {
- reference_tv->split(1, params.inner_factor);
- reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
- reference_tv->split(0, 1);
- // [outer, Unswitch | i-remainder, TIDx, Vectorization]
- reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
- reference_tv->axis(3)->parallelize(ParallelType::TIDx);
-
- // Aggressively mark with vectorized and cleanup later. That way we
- // don't have to manually specify parallelization outside the reference.
- reference_tv->axis(4)->parallelize(ParallelType::Vectorize);
-
- // [outer, Unswitch | i-remainder, TIDx, Vectorization]
- // To make consistent with unrolling:
- reference_tv->reorder({{1, 2}, {2, 1}, {3, 4}, {4, 3}});
- //[outer | i-remainder, Unswitch, Vectorization, TIDx]
- } else {
- reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
- reference_tv->split(1, params.inner_factor);
-
- reference_tv->split(0, 1);
- // [outer, unswitch | i-remainder, unroll, TIDx ]
- reference_tv->reorder({{1, 2}});
- // [outer, i-remainder, unswitch, unroll, TIDx ]
- reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
- reference_tv->axis(4)->parallelize(ParallelType::TIDx);
-
- //[outer | i-remainder, Unswitch, Unroll, TIDx]
- }
-
- // Move out of the way to furthest left point
- reference_tv->reorder({{1, 0}});
-
- //[i-remainder | outer | Unswitch, Unroll, TIDx]
- if (params.split_block) {
- reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
- // [i-remainder | BIDx TIDy | Unswitch, Unroll, TIDx]
- reference_tv->axis(1)->parallelize(ParallelType::BIDx);
- reference_tv->axis(2)->parallelize(ParallelType::TIDy);
- } else {
- // [BIDy | BIDx | Unswitch, Unroll, TIDx]
- reference_tv->axis(1)->parallelize(ParallelType::BIDx);
- if (params.split_grid_y_dim) {
- reference_tv->split(0, 65535);
- reference_tv->axis(1)->parallelize(ParallelType::BIDy);
- } else {
- reference_tv->axis(0)->parallelize(ParallelType::BIDy);
- }
- }
-
- } else {
- // 1D Scheduler
- TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i == -1);
- // right hand side exists and is the only axis we care to schedule, move it
- // from the inner most position to left most.
- reference_tv->reorder({{-1, 0}});
-
- if (params.vectorize) {
- // Vectorize
- reference_tv->split(0, params.inner_factor);
- // Unswitch
- reference_tv->split(0, 1);
- // Threads
- reference_tv->split(0, kThreadX);
-
- reference_tv->axis(0)->parallelize(ParallelType::BIDx);
- reference_tv->axis(1)->parallelize(ParallelType::TIDx);
- reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
- // Aggressively mark with vectorized and cleanup later. That way we don't
- // have to manually specify parallelization outside the reference.
- reference_tv->axis(-1)->parallelize(ParallelType::Vectorize);
-
- //[BIDx, TIDx, Unswitch, Vectorization]
- // To make consistent with unrolling:
- reference_tv->reorder({{1, 3}, {2, 1}, {3, 2}});
- //[BIDx, Unswitch, Vectorization, TIDx]
- } else {
- // Threads
- reference_tv->split(0, kThreadX);
- // Unroll
- reference_tv->split(0, params.inner_factor);
- // Unswitch
- reference_tv->split(0, 1);
-
- // [BIDx, Unswitch, Unroll, TIDx]
- reference_tv->axis(0)->parallelize(ParallelType::BIDx);
- reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
- reference_tv->axis(3)->parallelize(ParallelType::TIDx);
- }
- }
- TransformPropagator::from(reference_tv);
- scheduler_utils::parallelizeAllLike(reference_tv, all_tvs);
-
- if (params.vectorize) {
- // Clear vectorize on tensors that shouldn't have it
- for (auto tv : all_tvs) {
- if (!vectorized_tensor.count(tv)) {
- for (auto id : tv->domain()->domain()) {
- if (id->getParallelType() == ParallelType::Vectorize) {
- id->parallelize(ParallelType::Serial);
- }
- }
- }
- }
- }
-
- // Compute at into cached inputs
- std::vector<TensorView*> consumers_of_cached_inputs;
- // Cache of input, and one of its consumers
- std::vector<std::pair<TensorView*, TensorView*>> input_cache_and_consumer;
- {
- // Avoid duplicate additions, so track what we add
- std::unordered_set<TensorView*> added;
- for (auto cached_input : cached_inputs) {
- auto consumer_tvs = ir_utils::consumerTvsOf(cached_input);
- TORCH_INTERNAL_ASSERT(
- consumer_tvs.size(),
- "Input was not succesfully filtered out for scheduling but wasn't used.");
-
- // Grab a consumer which will be used for computeAt structure of cached
- // input into a consumer
- input_cache_and_consumer.emplace_back(
- std::make_pair(cached_input, consumer_tvs[0]));
-
- // Grab all consumers which will be used for inlining computeAt for the
- // body of the computation (excluding caching inputs/outputs)
- for (auto consumer_tv : consumer_tvs) {
- // Don't duplicate
- if (added.insert(consumer_tv).second) {
- consumers_of_cached_inputs.emplace_back(consumer_tv);
- }
- }
- }
- }
-
- for (auto entry : input_cache_and_consumer) {
- // Compute at inside unswitch position:
- auto input_cache = entry.first;
- auto input_cache_consumer = entry.second;
-
- auto unswitch_it = std::find_if(
- input_cache_consumer->domain()->domain().begin(),
- input_cache_consumer->domain()->domain().end(),
- [](IterDomain* id) {
- return id->getParallelType() == ParallelType::Unswitch;
- });
- auto unswitch_pos =
- unswitch_it == input_cache_consumer->domain()->domain().end()
- ? -1
- : std::distance(
- input_cache_consumer->domain()->domain().begin(), unswitch_it) +
- 1;
-
- input_cache->computeAt(
- input_cache_consumer, unswitch_pos, ComputeAtMode::BestEffort);
- }
-
- // Producers for inlined computeAt
- std::vector<TensorView*> compute_from = consumers_of_cached_inputs;
-
- // Consumers for inlined computeAt
- std::vector<TensorView*> compute_to;
- // Compute at cached outputs
- //[BIDx, Unswitch, Vectorization, TIDx]
- for (auto entry : cached_outputs) {
- auto cached_output = entry.second;
- auto output = entry.first;
-
- auto unswitch_it = std::find_if(
- output->domain()->domain().begin(),
- output->domain()->domain().end(),
- [](IterDomain* id) {
- return id->getParallelType() == ParallelType::Unswitch;
- });
- auto unswitch_pos = unswitch_it == output->domain()->domain().end()
- ? -1
- : std::distance(output->domain()->domain().begin(), unswitch_it) + 1;
-
- cached_output->computeAt(output, unswitch_pos, ComputeAtMode::BestEffort);
- compute_to.push_back(cached_output);
- }
-
- scheduler_utils::computeAtBetween(
- compute_from, compute_to, -1, ComputeAtMode::BestEffort);
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <ATen/core/ivalue.h>
-
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class SchedulerRuntimeInfo;
-class HeuristicSummary;
-
-TORCH_CUDA_CU_API c10::optional<PointwiseParams> getPointwiseHeuristics(
- Fusion* fusion,
- const at::ArrayRef<c10::IValue>& runtime_inputs,
- HeuristicSummary* data_cache = nullptr);
-
-TORCH_CUDA_CU_API c10::optional<PointwiseParams> getPointwiseHeuristics(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr);
-
-TORCH_CUDA_CU_API void schedulePointwise(
- Fusion* fusion,
- const PointwiseParams& params);
-
-TORCH_CUDA_CU_API LaunchParams schedulePointwise(
- Fusion* fusion,
- const at::ArrayRef<c10::IValue>& runtime_inputs);
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
-
-#include <sstream>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-// Parameters the Reduction Heuristic Generates to describe the optimial
-// schedule. Warning: equal operator is intended for use in caching the kernel
-// associated with these reduction parameters. It does not check if the launch
-// parameters are equivelent!
-class PointwiseParams {
- public:
- // vectorize if true, otherwise unroll
- bool vectorize = false;
-
- // Treat pointwise operation as 2-Dimensional, this is the location where we
- // split from left side of the domain to right. i.e. 0 means problem is
- // treated as 1-D, 1 of 3 would mean we treat the first dimension as the outer
- // dimension, and all the others as an inner dimension.
- int break_point = 0;
-
- // Split block across left and right dimension
- bool split_block = false;
-
- // Split grid y dimension, if otherwise it would be too large
- bool split_grid_y_dim = false;
-
- // Unroll or vectorization factor
- int64_t inner_factor = 1;
-
- std::string tag = "";
-
- LaunchParams lparams;
-
- // Warning: Does not check launch parameters!
- bool operator==(const PointwiseParams& other) const {
- bool attr_equal = other.vectorize == vectorize &&
- other.break_point == break_point && other.split_block == split_block &&
- other.split_grid_y_dim == split_grid_y_dim &&
- other.inner_factor == inner_factor;
- return attr_equal;
- }
-
- std::string toString() const {
- std::stringstream ss;
- ss << "\n===== Pointwise Parameters ========\n"
- << (tag == "" ? "" : "Tag: ") << tag << " Pointwise Characteristics:\n"
- << " Gridx: " << lparams.gdimx() << " BlckY: " << lparams.bdimy()
- << " BlckX: " << lparams.bdimx() << "\n";
- if (break_point) {
- ss << "2D Schedule\n"
- << " Bcast break point: " << break_point << "\n";
- if (split_block) {
- ss << "Split block into y-dim\n";
- }
- if (split_grid_y_dim) {
- ss << " Split y grid dim\n";
- }
- }
- if (inner_factor > 1) {
- if (vectorize) {
- ss << "Vectorize, Factor: " << inner_factor << "\n";
- } else {
- ss << "Unroll, Factor: " << inner_factor << "\n";
- }
- }
- ss << "====================================\n";
- return ss.str();
- }
-};
-
-// Warning: Hash is not based on launch parameters!
-class PointwiseParamsHash {
- public:
- size_t operator()(const PointwiseParams& pp) const {
- size_t attr_hash = static_cast<size_t>(pp.vectorize) ^
- static_cast<size_t>(pp.break_point) << 4 ^
- static_cast<size_t>(pp.split_block) << 5 ^
- static_cast<size_t>(pp.split_grid_y_dim) << 6 ^
- static_cast<size_t>(pp.inner_factor) << 9;
- return attr_hash;
- }
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/scheduler/reduction.h>
-
-#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-
-#include <ATen/cuda/CUDAContext.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-ReductionParams innerReductionHeuristic(
- const int64_t num_elems_in_reduction,
- const int64_t num_outputs_for_reduction,
- const int64_t n_tensor_inputs,
- const int64_t max_input_dtype_size,
- size_t vectorize_factor) {
- // Set some targets for parallelization
-
- const int64_t n_elems = num_elems_in_reduction * num_outputs_for_reduction;
-
- // WARNING: Current device for codegen may not be the target device
- const int64_t device_max_threads_per_multiprocessor =
- (int64_t)at::cuda::getCurrentDeviceProperties()
- ->maxThreadsPerMultiProcessor;
-
- const int64_t device_multiprocessor_count =
- (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
-
- auto const max_unroll = ceilDiv(
- // Available unrolling based on size of data type
- (int64_t)16 / (int64_t)max_input_dtype_size,
- // Reduce unrolling if we have many inputs, start reduction at 4 inputs
- std::max(
- (scheduler_utils::lastPow2((int64_t)n_tensor_inputs) >> 2),
- (int64_t)1));
-
- // Conservative value, could be set to larger based on arch if necessary.
- constexpr int64_t l1_cache = 32 * 1024;
- // Could change per generation, but for l1 we want to consider active threads,
- // not resident
- constexpr int64_t active_threads = 1024;
-
- // if data fits in l2 and we need more parallelization in the reduction dim,
- // we can use a smaller warp size. While thread local data fits in l1, and
- // reduction dim is really small, we can use <32 threads per warp.
- const bool fits_in_l2 = n_elems * max_input_dtype_size * n_tensor_inputs <
- at::cuda::getCurrentDeviceProperties()->l2CacheSize;
-
- // If it fits in l2, we just want to make sure each thread uses 32Bytes.
- const int64_t warp_size_based_on_l2 =
- fits_in_l2 ? (int64_t)32 / max_input_dtype_size : 32;
-
- // Check how many elements it would take per thread to start thrashing l1
- // set that to minimum number we want to reduce per thread.
- const int64_t warp_size_based_on_l1 = std::min(
- ceilDiv(
- num_elems_in_reduction,
- std::max(
- l1_cache /
- (n_tensor_inputs * max_input_dtype_size * active_threads),
- (int64_t)1)),
- (int64_t)16);
-
- // Take the smaller
- const int64_t warp_size =
- std::min(warp_size_based_on_l1, warp_size_based_on_l2);
-
- // Initialization
- int64_t target_blocks = 1;
- int64_t target_unroll = 1;
- int64_t target_iterations = 1;
-
- // Try to set a minmum amount of work for each thread, as cross thread
- // communication is slow so it shouldn't be done for every element in the
- // reduction.
- int64_t min_target_iterations =
- std::max((int64_t)32 / (int64_t)max_input_dtype_size, (int64_t)1);
-
- // Start trying to break parallelization up across threads,
- // unrolling/iterations, and blocks.
-
- // max_threads_in_block is the cap on a thread block, the minimum is based on
- // warp_size
- int64_t max_threads_in_block = std::max(
- warp_size, ceilDiv(num_elems_in_reduction, min_target_iterations));
-
- // If we have one warp per block, check if that's enough to saturate the SMs
- target_blocks = ceilDiv(n_elems, warp_size);
-
- // If we have more than a wave of blocks, put parallelism into unrolling and
- // target iterations
- if (target_blocks > device_multiprocessor_count) {
- auto available_unroll = std::max(
- n_elems / (warp_size * device_multiprocessor_count), (int64_t)1);
-
- // Spread across unrolling and iterations, want a balance of the two so flip
- // back and forth to alternate adding to them.
- bool flip = true;
-
- while (available_unroll > 1 &&
- (target_unroll < max_unroll ||
- // Prefer unrolling
- target_iterations < ceilDiv(min_target_iterations, max_unroll))) {
- if (target_unroll * 2 <= max_unroll && flip) {
- target_unroll *= 2;
- }
-
- if (target_iterations * 2 <= ceilDiv(min_target_iterations, max_unroll) &&
- !flip) {
- target_iterations *= 2;
- }
-
- available_unroll = std::max(
- n_elems /
- (warp_size * device_multiprocessor_count * target_unroll *
- target_iterations),
- (int64_t)1);
-
- flip = !flip;
- }
-
- // Recompute target blocks
- target_blocks =
- ceilDiv(n_elems, warp_size * target_unroll * target_iterations);
- }
-
- // Cap target blocks to 4 waves
- target_blocks = std::min(target_blocks, device_multiprocessor_count * 4);
-
- if (target_blocks * target_unroll * target_iterations < n_elems) {
- // targetting 4 waves, so try to use a quarter of available threads
- max_threads_in_block = std::min(
- ceilDiv(n_elems, target_blocks * target_unroll),
- ceilDiv(device_max_threads_per_multiprocessor, (int64_t)4));
- }
-
- // To get to target threads:
- // Prioritize
- // (1) x dim in reduction
- // (2) unrolling in reduction
- // (3) y in output
- // To get target blocks:
- // Prioritize
- // (1) x dim in multiple outputs
- // (2) y dim in multiple reductions
-
- // Blocks for reductions
- int64_t grdim = 1;
- // Blocks for outputs
- int64_t godim = 1;
-
- // Threads for outputs
- int64_t bdimy = 1;
- // Threads for reduction
- int64_t bdimx = 1;
-
- // Should we unroll from reduction axis, or outs axis
- bool unroll_reduction = true;
-
- // Unroll amount
- int64_t unroll_factor = 1;
-
- // Grab what we can out of reduction domain, but don't go over a warp size yet
- bdimx = std::min(num_elems_in_reduction, (int64_t)warp_size);
- // Put everything else in bdimy for now
- bdimy = std::max(max_threads_in_block / bdimx, (int64_t)1);
- int64_t remainder_in_reduction = ceilDiv(num_elems_in_reduction, bdimx);
- int64_t remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimy);
-
- // Adjust blocking and setup unrolling
- if (remainder_in_reduction == 1) {
- // Small number of reduction elements, try unrolling output dimension
- unroll_factor = std::min(target_unroll, remainder_in_output);
- if (unroll_factor > 1) {
- unroll_reduction = false;
- remainder_in_output =
- ceilDiv(num_outputs_for_reduction, unroll_factor * bdimy);
- }
- } else {
- // If there are reduction elements left after unrolling a warp, re-adjust
- // the block dims to put more threads into the reduction
- bdimx = std::min(
- std::max(
- ceilDiv(num_elems_in_reduction, target_iterations * target_unroll),
- warp_size),
- max_threads_in_block);
-
- // Don't exceed target.
- bdimy = std::max(max_threads_in_block / bdimx, (int64_t)1);
- remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimy);
-
- remainder_in_reduction = ceilDiv(num_elems_in_reduction, bdimx);
- unroll_factor = std::min(remainder_in_reduction, target_unroll);
- if (unroll_factor == 1) {
- // If we can't unroll reduction dim, unroll output dim
- unroll_factor = std::min(remainder_in_output, target_unroll);
- if (unroll_factor > 1) {
- unroll_reduction = false;
- }
- remainder_in_output =
- ceilDiv(num_outputs_for_reduction, bdimy * unroll_factor);
- remainder_in_reduction =
- ceilDiv(num_elems_in_reduction, bdimx * target_iterations);
- } else {
- remainder_in_reduction = ceilDiv(
- num_elems_in_reduction,
- bdimx * std::max(unroll_factor, target_iterations));
- }
- }
-
- godim = remainder_in_output;
-
- // Clang tidy
- constexpr int64_t kEight = 8;
- constexpr int64_t kThirtyTwo = 32;
-
- // Cross grid reduction if we haven't hit our target blocks, and we have many
- // reduction elements.
- if ((godim < target_blocks && remainder_in_reduction > kEight &&
- remainder_in_reduction < kThirtyTwo) ||
- (remainder_in_reduction >= kThirtyTwo)) {
- // Grid reductions do not support unrolling iteration dimension, revert if
- // set.
- if (!unroll_reduction) {
- unroll_reduction = true;
- unroll_factor = 1;
- remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimy);
- remainder_in_reduction =
- ceilDiv(num_elems_in_reduction, bdimx * target_iterations);
- }
- if (remainder_in_reduction >= kThirtyTwo) {
- // Do at least 2 iterations of unrolling per thread before we go cross
- // grid. Limit cross grid to a multiple of the block size so cleanup on
- // the last block doesn't take too long.
- grdim = std::min(
- ceilDiv(remainder_in_reduction, (int64_t)2), bdimx * bdimy * kEight);
- // Clang tidy
- // remainder_in_reduction = ceilDiv(remainder_in_reduction, grdim);
- } else {
- grdim = ceilDiv(remainder_in_reduction, (int64_t)4);
- }
- // Clang tidy
- //
- // remainder_in_reduction = ceilDiv(
- // num_elems_in_reduction,
- // bdimx *
- // std::max(
- // unroll_reduction ? unroll_factor : 1,
- // min_red_elems_per_thread) *
- // grdim);
- }
-
- // Try to do some cleanup of ragged waves on device
- // godim is a remainder of a split, so can only control bdimy
- if (
- // If we have less than 8 waves of blocks
- grdim * godim < device_multiprocessor_count * kEight &&
- // And we don't have an even divisible number of blocks
- (grdim * godim) % device_multiprocessor_count != 0 &&
- // And we have more than one wave
- grdim * godim > device_multiprocessor_count) {
- // round waves down
- auto waves =
- std::max((godim * grdim) / device_multiprocessor_count, (int64_t)1);
- auto new_grdim =
- std::max((waves * device_multiprocessor_count) / godim, (int64_t)1);
- if (
- // If difference is less than 25% of the original grdim
- (new_grdim - grdim) * 4 < grdim &&
- // and difference is less than 25% of the original number of blocks
- ((new_grdim * godim) - (grdim * godim)) * 4 < grdim * godim) {
- grdim = new_grdim;
- }
- }
-
- bool vectorize = false;
-
- if (vectorize_factor > 1 && unroll_factor > 1 && unroll_reduction) {
- vectorize = true;
- unroll_factor = std::min(
- scheduler_utils::lastPow2(unroll_factor), (int64_t)vectorize_factor);
- }
-
- ReductionParams rparams;
- rparams.fastest_dim = true;
- rparams.cross_block = true;
- rparams.cross_grid = grdim > 1;
- rparams.multiple_reds_per_blk = bdimy > 1;
- rparams.loop_unroll = unroll_factor;
- rparams.vectorize = vectorize;
- rparams.reduction_unroll = unroll_reduction;
-
- // If we have a cross grid case we want to have gdimy assigned to godim and
- // gdimx assigned to grdim. Otherwise it's helpful to pull godim into gdimx in
- // case it's larger than gdimy can hold, as not doing so can thrash the cache.
- int64_t gdimx = LaunchParams::UNINITIALIZED_VAL;
- int64_t gdimy = LaunchParams::UNINITIALIZED_VAL;
-
- if (rparams.cross_grid) {
- gdimx = grdim;
- rparams.split_grid_dim = gdimy > scheduler_utils::y_grid_limit;
- } else {
- rparams.split_grid_dim = gdimx > scheduler_utils::x_grid_limit;
- }
-
- rparams.lparams = LaunchParams(
- gdimx,
- gdimy,
- LaunchParams::UNINITIALIZED_VAL,
- bdimx,
- bdimy,
- LaunchParams::UNINITIALIZED_VAL);
- if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
- std::cerr << "\n===== Reduction Stats ========\n"
- << "num_elems_in_reduction: " << num_elems_in_reduction << "\n"
- << "num_outputs_for_reduction: " << num_outputs_for_reduction
- << "\n"
- << "n_tensor_inputs: " << n_tensor_inputs << "\n"
- << "max_input_dtype_size: " << max_input_dtype_size << std::endl;
- std::cerr << rparams.toString() << std::endl;
- }
-
- return rparams;
-}
-
-ReductionParams OuterReductionHeuristic(
- const int64_t num_elems_in_reduction,
- const int64_t num_outputs_for_reduction,
- const int64_t n_tensor_inputs,
- const int64_t max_input_dtype_size,
- size_t vectorize_factor) {
- // Set some targets for parallelization
-
- const int64_t n_elems = num_elems_in_reduction * num_outputs_for_reduction;
- const int64_t l2_cache_size =
- at::cuda::getCurrentDeviceProperties()->l2CacheSize;
-
- const int64_t warp_size =
- n_elems * max_input_dtype_size * n_tensor_inputs < l2_cache_size
- ? (int64_t)32 / max_input_dtype_size
- : 32;
-
- int64_t target_blocks = 1;
- int64_t target_unroll = 1;
- int64_t max_threads_in_block = warp_size;
-
- // WARNING: Current device for codegen may not be the target device
- const int64_t device_max_threads_per_multiprocessor =
- (int64_t)at::cuda::getCurrentDeviceProperties()
- ->maxThreadsPerMultiProcessor;
-
- const int64_t device_multiprocessor_count =
- (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
-
- auto const max_unroll = ceilDiv(
- // Available unrolling based on size of data type
- (int64_t)16 / (int64_t)max_input_dtype_size,
- // Reduce unrolling if we have many inputs, start reduction at 4 inputs
- std::max(
- (scheduler_utils::lastPow2((int64_t)n_tensor_inputs) >> 2),
- (int64_t)1));
-
- // If we have one warp per block, how many blocks would that be?
- target_blocks = ceilDiv(n_elems, (int64_t)warp_size);
-
- // If we have more than a wave, put parallelism into unrolling
- if (target_blocks > device_multiprocessor_count) {
- target_unroll = std::min(
- max_unroll, ceilDiv(target_blocks, device_multiprocessor_count));
- target_blocks = ceilDiv(target_blocks, target_unroll);
- }
-
- // Cap target blocks to 4 waves
- target_blocks = std::min(target_blocks, device_multiprocessor_count * 4);
-
- if (target_blocks * target_unroll * max_threads_in_block < n_elems) {
- // targetting 4 waves, so try to use a quarter of available threads
- max_threads_in_block = std::min(
- ceilDiv(n_elems, target_blocks * target_unroll),
- ceilDiv(device_max_threads_per_multiprocessor, (int64_t)4));
- }
-
- // To get to target threads:
- // Prioritize
- // (1) x dim in iter domain
- // (2) unrolling in iter domain
- // (3) y in reduction domain
- // To get target blocks:
- // Prioritize
- // (1) x dim in multiple outputs
- // (2) y dim in multiple reductions - need to flip unrolling to reduction
- // domain for this
-
- // Blocks for reductions
- int64_t gdimy = 1;
- // Blocks for outputs
- int64_t gdimx = 1;
-
- // Threads for reduction
- int64_t bdimy = 1;
- // Threads for output
- int64_t bdimx = 1;
-
- // Should we unroll from reduction axis, or outs axis
- bool unroll_reduction = false;
-
- // Unroll amount
- int64_t unroll_factor = 1;
-
- int64_t remainder_in_reduction = num_elems_in_reduction;
- int64_t remainder_in_output = num_outputs_for_reduction;
-
- if (ceilDiv(num_outputs_for_reduction, warp_size) <
- device_multiprocessor_count) {
- // If we can't hit a full wave, leave bdimx as warp_size, and prioritize
- // bdimy. TODO: Re-evaluate, should it be bdimx = warp_size?
- bdimx = std::min(num_outputs_for_reduction, warp_size);
- } else {
- bdimx = std::min(
- max_threads_in_block,
- ceilDiv(num_outputs_for_reduction, target_blocks));
- bdimx = std::max(bdimx, warp_size);
- }
-
- bdimy = std::min(
- std::max(max_threads_in_block / bdimx, (int64_t)1),
- num_elems_in_reduction);
-
- // Clang tidy
- // remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimx);
- remainder_in_reduction = ceilDiv(remainder_in_reduction, bdimy);
-
- if (num_outputs_for_reduction >=
- device_multiprocessor_count * max_threads_in_block) {
- // If we easily saturate the GPU, don't use block dim y and unroll output
- // dimension, this could be a more gentle transition starting earlier
- bdimx = max_threads_in_block;
- remainder_in_output = ceilDiv(num_outputs_for_reduction, bdimx);
-
- bdimy = 1;
- remainder_in_reduction = num_elems_in_reduction;
-
- // Assume unroll in output, switch to remainder if cross grid
- // Don't unroll if we don't have 2 full waves
- unroll_factor = std::min(
- ceilDiv(remainder_in_output, device_multiprocessor_count * 2),
- target_unroll);
-
- if (unroll_factor == 1 && remainder_in_reduction > 1) {
- // Try unrolling in reduction dimension
- unroll_factor = std::min(remainder_in_reduction, unroll_factor);
- // Clang tidy
- // remainder_in_reduction = ceilDiv(remainder_in_reduction,
- // unroll_factor);
- if (unroll_factor > 1) {
- unroll_reduction = true;
- }
- }
- // Clang tidy
- // else {
- // remainder_in_output =
- // ceilDiv(num_outputs_for_reduction, bdimx * unroll_factor);
- // }
- } else {
- // Not many output elements, so we want to try expand grid level parallelism
- // first go after unrolling
- unroll_factor = std::min(max_unroll, remainder_in_reduction);
- if (unroll_factor > 1) {
- unroll_reduction = true;
- }
-
- remainder_in_reduction =
- ceilDiv(num_elems_in_reduction, bdimy * unroll_factor);
-
- // Go cross grid
- gdimy = ceilDiv(remainder_in_reduction, (int64_t)4);
- // Clang tidy
- // remainder_in_reduction =
- // ceilDiv(num_elems_in_reduction, bdimy * unroll_factor * gdimy);
- }
-
- // Clang tidy
- constexpr int64_t kEight = 8;
- constexpr int64_t kSixteen = 16;
- constexpr int64_t kThirtyTwo = 32;
-
- if (ceilDiv(num_elems_in_reduction, bdimy * unroll_factor) >= kThirtyTwo) {
- // Many reduction elements, go cross grid
- int64_t min_gdimy = 1;
- if (gdimy > 1) {
- // already cross grid, don't go below target or what was already set
- min_gdimy = std::min(gdimy, ceilDiv(target_blocks, gdimx));
- }
- gdimy = std::max(
- min_gdimy,
- ceilDiv(
- ceilDiv(num_elems_in_reduction, bdimy * unroll_factor),
- (int64_t)kSixteen));
- // Don't go too far above number of threads in a block since that's how many
- // threads are available to do final reduction iteration
- // This is good!
- gdimy = std::min(gdimy, bdimx * bdimy * kEight);
- }
-
- // Try to do some cleanup of ragged waves on device
- if (
- // If we have less than 8 waves of blocks
- gdimy * gdimx < device_multiprocessor_count * kEight &&
- // And we don't have an even divisible number of blocks
- (gdimy * gdimx) % device_multiprocessor_count != 0 &&
- // And we have more than one wave
- gdimy * gdimx > device_multiprocessor_count) {
- // round waves down
- auto waves =
- std::max((gdimx * gdimy) / device_multiprocessor_count, (int64_t)1);
- auto new_gdimy =
- std::max((waves * device_multiprocessor_count) / gdimx, (int64_t)1);
- if (
- // If difference is less than 25% of the original gdimy
- (new_gdimy - gdimy) * 4 < gdimy &&
- // and difference is less than 25% of the original number of blocks
- ((new_gdimy * gdimx) - (gdimy * gdimx)) * 4 < gdimy * gdimx) {
- gdimy = new_gdimy;
- }
- }
-
- // Cannot unroll with cross grid reductions
- if (gdimy > 1 && !unroll_reduction) {
- unroll_reduction = true;
- unroll_factor = 1;
- }
-
- bool vectorize = false;
-
- if (vectorize_factor > 1 && unroll_factor > 1 && !unroll_reduction) {
- vectorize = true;
- unroll_factor = std::min(
- scheduler_utils::lastPow2(unroll_factor), (int64_t)vectorize_factor);
- }
-
- ReductionParams rparams;
- rparams.fastest_dim = false;
- // cross grid implies cross block
- rparams.cross_block = bdimy > 1 || gdimy > 1;
- rparams.cross_grid = gdimy > 1;
- rparams.multiple_reds_per_blk = bdimx > 1;
- rparams.loop_unroll = unroll_factor;
- rparams.vectorize = vectorize;
- rparams.reduction_unroll = unroll_reduction;
-
- rparams.lparams = LaunchParams(
- LaunchParams::UNINITIALIZED_VAL,
- gdimy,
- LaunchParams::UNINITIALIZED_VAL,
- bdimx,
- bdimy,
- LaunchParams::UNINITIALIZED_VAL);
-
- if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
- std::cerr << "\n===== Reduction Stats ========\n"
- << "num_elems_in_reduction: " << num_elems_in_reduction << "\n"
- << "num_outputs_for_reduction: " << num_outputs_for_reduction
- << "\n"
- << "n_tensor_inputs: " << n_tensor_inputs << "\n"
- << "max_input_dtype_size: " << max_input_dtype_size << std::endl;
- std::cerr << rparams.toString() << std::endl;
- }
- return rparams;
-}
-
-} // namespace
-
-ReductionParams reductionHeuristic(
- int64_t num_elems_in_reduction,
- int64_t num_outputs_for_reduction,
- bool fastest_dim_reduction,
- size_t n_tensor_inputs,
- size_t max_input_dtype_size,
- size_t vectorize_factor) {
- if (fastest_dim_reduction) {
- return innerReductionHeuristic(
- num_elems_in_reduction,
- num_outputs_for_reduction,
- n_tensor_inputs,
- max_input_dtype_size,
- vectorize_factor);
- } else {
- return OuterReductionHeuristic(
- num_elems_in_reduction,
- num_outputs_for_reduction,
- n_tensor_inputs,
- max_input_dtype_size,
- vectorize_factor);
- }
-}
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
- Fusion* fusion,
- const at::ArrayRef<c10::IValue>& runtime_inputs,
- HeuristicSummary* data_cache) {
- FUSER_PERF_SCOPE("getReductionHeuristics");
-
- SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true);
-
- return getReductionHeuristics(fusion, runtime_info, data_cache);
-}
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache) {
- FUSER_PERF_SCOPE("getReductionHeuristics");
-
- FusionGuard fg(fusion);
-
- HeuristicCacheAccessor<std::vector<TensorView*>> reduction_tv_data;
- // TODO: move all these boilerplate code into the accessor class
- // (follow up)
- if (data_cache && !data_cache->isRecording()) {
- reduction_tv_data.writeTemporary(data_cache->getReductionTVs());
- } else {
- reduction_tv_data.writeNew(scheduler_utils::getReductionTvs(fusion));
- if (data_cache && data_cache->isRecording()) {
- data_cache->setReductionTVs(reduction_tv_data.read());
- }
- }
-
- auto& reduction_tvs = reduction_tv_data.read();
-
- TORCH_INTERNAL_ASSERT(
- reduction_tvs.size() == 1, "Need reduction tensor views to schedule.");
-
- auto reduction_tv = reduction_tvs[0];
-
- TORCH_INTERNAL_ASSERT(reduction_tv != nullptr);
-
- auto red_root_dom = reduction_tv->getRootDomain();
- bool fastest_dim_reduction = true;
- for (size_t i = red_root_dom.size(); i > 0; i--) {
- if (red_root_dom[i - 1]->isBroadcast() ||
- red_root_dom[i - 1]->isTrivialReduction()) {
- continue;
- } else if (red_root_dom[i - 1]->isReduction()) {
- fastest_dim_reduction = true;
- break;
- } else {
- fastest_dim_reduction = false;
- break;
- }
- }
-
- TORCH_INTERNAL_ASSERT(
- reduction_tv != nullptr, "Reduction TensorView wasn't found.");
-
- TORCH_INTERNAL_ASSERT(
- reduction_tv->hasReduction(), "TensorView doesn't have a reduction.");
- const auto red_expr = reduction_tv->definition();
-
- TORCH_INTERNAL_ASSERT(
- red_expr->getExprType() != c10::nullopt &&
- (red_expr->getExprType().value() == ExprType::ReductionOp ||
- red_expr->getExprType().value() == ExprType::WelfordOp),
- "TensorView doesn't have a reduction.");
-
- int64_t num_outputs_for_reduction = 1;
- int64_t red_elements = 1;
-
- for (auto id : reduction_tv->getRootDomain()) {
- auto inferred_val =
- runtime_info.expressionEvaluator().evaluate(id->extent());
- TORCH_INTERNAL_ASSERT(
- inferred_val.has_value(), "Error inferring reduction size.");
- if (id->isReduction()) {
- red_elements *= inferred_val.value();
- } else {
- num_outputs_for_reduction *= inferred_val.value();
- }
- }
-
- size_t max_dtype_size = 1;
- size_t n_tensor_inputs = 0;
- for (auto inp : fusion->inputs()) {
- if (inp->isA<TensorView>()) {
- max_dtype_size =
- std::max(max_dtype_size, dataTypeSize(inp->getDataType().value()));
- n_tensor_inputs++;
- }
- }
-
- TORCH_INTERNAL_ASSERT(
- n_tensor_inputs > 0,
- "Tried to schedule a fusion with no tensor inputs, currently not supported.");
-
- auto vectorizable_inputs_outputs =
- scheduler_utils::getVectorizableInputsOutputs(reduction_tv);
-
- // Vectorize as much as we can
- size_t vectorize_factor = std::numeric_limits<size_t>::max();
-
- for (auto tv : vectorizable_inputs_outputs) {
- const auto tv_vectorize_factor = runtime_info.getVectorizableWidth(tv);
- vectorize_factor = std::min(vectorize_factor, tv_vectorize_factor);
- }
-
- if (vectorize_factor == std::numeric_limits<size_t>::max()) {
- vectorize_factor = 1;
- }
-
- return reductionHeuristic(
- red_elements,
- num_outputs_for_reduction,
- fastest_dim_reduction,
- n_tensor_inputs,
- max_dtype_size,
- vectorize_factor);
-}
-
-// fusion is the input IR that will be modified by this function
-void scheduleReduction(Fusion* fusion, const ReductionParams& rparams) {
- FUSER_PERF_SCOPE("scheduleReduction");
- FusionGuard fg(fusion);
-
- // Cache inputs if unrolled
- auto cached_inputs =
- scheduler_utils::cacheInputs(fusion, rparams.loop_unroll > 1);
-
- // Cache and fork outputs
- std::vector<std::pair<TensorView*, TensorView*>> cached_outputs =
- scheduler_utils::cacheAndForkOutputs(fusion, rparams.loop_unroll > 1);
-
- // Make sure we don't have global memory set on intermediate tensors from
- // fusion segmentation
- scheduler_utils::clearMemorySpace(fusion);
-
- auto reduction_tvs = scheduler_utils::getReductionTvs(fusion);
-
- TORCH_INTERNAL_ASSERT(
- reduction_tvs.size() <= 1,
- "Found multiple reductions sent to reduction heuristics",
- " (and reductions are not from a multi-output expr).");
- TORCH_INTERNAL_ASSERT(reduction_tvs.size());
-
- auto reduction_tv = reduction_tvs[0];
-
- auto dim_analysis =
- scheduler_utils::canonicalDimReduction(fusion, reduction_tv);
- bool has_iter_axis = dim_analysis.first;
- bool has_red_axis = dim_analysis.second;
-
- TORCH_INTERNAL_ASSERT(
- has_red_axis,
- "Could not find reduction axis in tensor used for reduction scheduler.");
-
- if (!has_iter_axis) {
- TORCH_INTERNAL_ASSERT(
- rparams.fastest_dim,
- "If all dims are reduction, should be sending it to fastest dim scheduler.");
- }
-
- TensorView* reference_tv = scheduler_utils::scheduleReductionTV(
- rparams, reduction_tv, has_iter_axis);
-
- // Reduction tensor views and rfactor tensor views are setup. Let's finish off
- // the scheduling, particularly inlining and unrolling.
- TORCH_INTERNAL_ASSERT(
- reference_tv != nullptr && reduction_tv != nullptr,
- "Need these two tensor views to finish the scheduling.");
- scheduler_utils::multiReductionInliner(
- fusion,
- rparams,
- reduction_tv,
- reference_tv,
- reduction_tvs,
- cached_inputs,
- cached_outputs);
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <ATen/core/ivalue.h>
-
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class SchedulerRuntimeInfo;
-class HeuristicSummary;
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
- Fusion* fusion,
- const at::ArrayRef<c10::IValue>& runtime_inputs,
- HeuristicSummary* data_cache = nullptr);
-
-TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr);
-
-TORCH_CUDA_CU_API void scheduleReduction(
- Fusion* fusion,
- const ReductionParams& rparams);
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
-
-#include <sstream>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-// Parameters the Reduction Heuristic Generates to describe the optimial
-// schedule. Warning: equal operator is intended for use in caching the kernel
-// associated with these reduction parameters. It does not check if the launch
-// parameters are equivelent!
-class ReductionParams {
- public:
- // Reducing inner most dimension?
- bool fastest_dim = true;
- // Reduce across the block?
- bool cross_block = false;
- // Reduce across the grid?
- bool cross_grid = false;
- // Perform multiple reductions per block?
- bool multiple_reds_per_blk = false;
- // Unrolling factor
- int64_t loop_unroll = 1;
- // Should unrolling be done on reduction dimension
- bool reduction_unroll = true;
- // vectorize instead of unroll
- bool vectorize = false;
- // Number of batches for each block
- int64_t batches_per_block = 1;
- // Number of warps per block
- // TODO: Remove or repurpose
- int64_t num_warps = 1;
- // Store input in shared memory or registers to reduce global memory reads
- bool persistent_kernel = false;
-
- // Split grid dim in case it's too large for cuda
- bool split_grid_dim = false;
-
- std::string tag = "";
-
- LaunchParams lparams;
-
- public:
- // Warning: Does not check launch parameters!
- bool operator==(const ReductionParams& other) const {
- bool attr_equal = other.fastest_dim == fastest_dim &&
- other.cross_block == cross_block && other.cross_grid == cross_grid &&
- other.multiple_reds_per_blk == multiple_reds_per_blk &&
- other.loop_unroll == loop_unroll && other.vectorize == vectorize &&
- other.batches_per_block == batches_per_block &&
- other.num_warps == num_warps &&
- other.persistent_kernel == persistent_kernel &&
- other.reduction_unroll == reduction_unroll &&
- other.split_grid_dim == split_grid_dim;
- return attr_equal;
- }
-
- std::string toString() const {
- std::stringstream ss;
- ss << "\n===== Reduction Parameters ========\n"
- << (tag == "" ? "" : "Tag: ") << tag
- << (fastest_dim ? "Red On Fastest Dim\n" : "Red On Slow Dim\n")
- << "Reduction Characteristics:\n"
- << (multiple_reds_per_blk ? "Multiple Reds Per Block\n" : "")
- << (cross_block ? "Cross block reduction\n" : "")
- << (cross_grid ? "Cross grid reduction\n" : "");
- if (persistent_kernel) {
- ss << "Persistent Kernel\n"
- << "Batches per block: " << batches_per_block << "\n";
- }
- ss << "Blocking:\n"
- << " GridY: " << lparams.gdimy() << " BlckY: " << lparams.bdimy()
- << " BlckX: " << lparams.bdimx() << "\n";
- if (loop_unroll > 1) {
- ss << (vectorize ? "Vectorize " : "Unroll ")
- << (reduction_unroll ? " reduction dim, " : " iter dim, ")
- << "Factor: " << loop_unroll << "\n";
- }
- ss << "====================================\n";
- return ss.str();
- }
-};
-
-// Warning: Hash is not based on launch parameters!
-class ReductionParamsHash {
- public:
- size_t operator()(const ReductionParams& rp) const {
- constexpr size_t bits = sizeof(std::size_t) * 8;
- size_t attr_hash = static_cast<size_t>(rp.fastest_dim) << (bits - 1) ^
- static_cast<size_t>(rp.cross_block) << (bits - 2) ^
- static_cast<size_t>(rp.cross_grid) << (bits - 3) ^
- static_cast<size_t>(rp.multiple_reds_per_blk) << (bits - 4) ^
- static_cast<size_t>(rp.loop_unroll) ^
- static_cast<size_t>(rp.reduction_unroll) << (bits - 5) ^
- static_cast<size_t>(rp.vectorize) << (bits - 6) ^
- static_cast<size_t>(rp.batches_per_block) ^
- static_cast<size_t>(rp.num_warps) ^
- static_cast<size_t>(rp.persistent_kernel) << (bits - 7) ^
- static_cast<size_t>(rp.split_grid_dim) << (bits - 8);
- return attr_hash;
- }
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#include <c10/util/irange.h>
-#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-
-#include <limits>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-// TODO: Deduplicate from compute_at.cpp
-std::deque<std::deque<TensorView*>> tvChains(
- std::deque<std::deque<Val*>> val_chains) {
- std::deque<std::deque<TensorView*>> tv_chains(val_chains.size());
- for (size_t i = 0; i < val_chains.size(); i++) {
- auto tv_iterable = ir_utils::filterByType<TensorView>(val_chains[i]);
- tv_chains[i] =
- std::deque<TensorView*>(tv_iterable.begin(), tv_iterable.end());
- }
- return tv_chains;
-}
-
-class SchedulerTopologyChecker {
- public:
- // Checks if any broadcasts are resolved after a reduction that don't follow
- // the normalization pattern
- static bool hasNonNormalizePostReductionBCast(Fusion* fusion) {
- auto all_vals = fusion->usedMathVals();
- std::vector<TensorView*> reduction_tvs;
- for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
- if (tv->hasReduction() && !fusion->hasInput(tv)) {
- reduction_tvs.push_back(tv);
- }
- }
-
- // All tensor views that are eventually consumed to produce a reduction,
- // includes reduction tensor views.
- std::unordered_set<TensorView*> pre_reduction_tvs;
-
- {
- auto pre_reduction_vals = DependencyCheck::getAllValsBetween(
- {fusion->inputs().begin(), fusion->inputs().end()},
- {reduction_tvs.begin(), reduction_tvs.end()});
- auto pre_reduction_tv_vector =
- ir_utils::filterByType<TensorView>(pre_reduction_vals);
- pre_reduction_tvs = std::unordered_set<TensorView*>(
- pre_reduction_tv_vector.begin(), pre_reduction_tv_vector.end());
- }
-
- // Track which tensor views we've validated so we don't do it again.
- std::unordered_set<TensorView*> validated_resolved_tvs;
-
- // Run forward (towards outputs) from reductions on any path that isn't
- // before another reduction. Look for resolved broadcasts. If a resolved
- // broadcast is found, start there and propagate backwards. Track the id's
- // that were resolved and make sure there's a mapping to a TensorView before
- // a reduction.
- for (auto red_tv : reduction_tvs) {
- auto forward_tv_chains =
- tvChains(DependencyCheck::getAllUseChains(red_tv));
- // Propagate forward from reduction through all uses of the reduction
- for (auto forward_tv_dep_chain : forward_tv_chains) {
- TensorView* forward_running_producer = nullptr;
- TensorView* forward_running_consumer = forward_tv_dep_chain.front();
- forward_tv_dep_chain.pop_front();
- while (!forward_tv_dep_chain.empty()) {
- forward_running_producer = forward_running_consumer;
- forward_running_consumer = forward_tv_dep_chain.front();
- forward_tv_dep_chain.pop_front();
-
- if (std::none_of(
- forward_running_producer->getMaybeRFactorDomain().begin(),
- forward_running_producer->getMaybeRFactorDomain().end(),
- [](IterDomain* id) { return id->isBroadcast(); })) {
- // If there's no broadcast axes in producer it doesn't need to be
- // checked
- continue;
- }
-
- // If consumer is before another reduction it doesn't need to be
- // checked
- if (pre_reduction_tvs.count(forward_running_consumer)) {
- break;
- }
-
- // If consumer was already validated it doesn't need to be checked
- if (validated_resolved_tvs.count(forward_running_consumer)) {
- continue;
- }
-
- auto forward_pairwise_root_map = PairwiseRootDomainMap(
- forward_running_producer, forward_running_consumer);
- auto forward_p2c_root_map =
- forward_pairwise_root_map.mapProducerToConsumer(
- forward_running_producer->domain(),
- forward_running_consumer->domain());
-
- // These are the ids we will have to resolve. As we resolve them we'll
- // remove them from this vector. If this vector ends up empty, then
- // we've resolved everything we need to. This is a pair so as we
- // traverse we can map the id through the traversal. The first entry
- // in the pair will be the original id so we can reset it if it's not
- // resolved before the next traversal. The second ID will be
- // propagated as we map the IDs through the backward traversal.
- std::vector<std::pair<IterDomain*, IterDomain*>> ids_to_resolve;
-
- // Check if any TensorViews have a resolved broadcast
- for (auto entry : forward_p2c_root_map) {
- auto p_id = entry.first;
- auto c_id = entry.second;
- if (p_id->isBroadcast() &&
- (!c_id->isBroadcast() && !c_id->isTrivialReduction())) {
- ids_to_resolve.emplace_back(std::make_pair(c_id, c_id));
- }
- }
-
- if (ids_to_resolve.empty()) {
- continue;
- }
-
- // Only because of api limitations in getAllDependencyChains
- auto inputs_of_forward_running_consumer =
- IterVisitor::getInputsTo({forward_running_consumer});
- auto tv_inputs_of_forward_running_consumer =
- ir_utils::filterByType<TensorView>(
- inputs_of_forward_running_consumer);
-
- for (auto input_of_forward_running_consumer :
- tv_inputs_of_forward_running_consumer) {
- if (pre_reduction_tvs.find(input_of_forward_running_consumer) ==
- pre_reduction_tvs.end()) {
- // If this input isn't an input to a reduction, no point
- // traversing the dependency chains as we know we can't validate
- // this broadcast through chains to this input
- continue;
- }
-
- auto backward_tv_chains =
- tvChains(DependencyCheck::getAllDependencyChains(
- input_of_forward_running_consumer,
- forward_running_consumer));
-
- for (auto backward_tv_chain : backward_tv_chains) {
- if (ids_to_resolve.empty()) {
- break;
- }
-
- for (auto& pair : ids_to_resolve) {
- pair.second = pair.first;
- }
-
- TensorView* backward_running_producer = backward_tv_chain.back();
- TensorView* backward_running_consumer = nullptr;
- backward_tv_chain.pop_back();
-
- TORCH_INTERNAL_ASSERT(
- backward_running_producer == forward_running_consumer);
-
- while (!backward_tv_chain.empty()) {
- backward_running_consumer = backward_running_producer;
- backward_running_producer = backward_tv_chain.back();
- backward_tv_chain.pop_back();
-
- std::vector<IterDomain*> running_resolved_ids;
-
- auto backward_pairwise_root_map = PairwiseRootDomainMap(
- backward_running_producer, backward_running_consumer);
-
- auto backward_c2p_root_map =
- backward_pairwise_root_map.mapConsumerToProducer(
- backward_running_consumer->domain(),
- backward_running_producer->domain());
-
- // Mark if producer is a producer of a reduction
- bool producer_resolves =
- pre_reduction_tvs.count(backward_running_producer);
-
- bool at_leat_one_id_mapped = false;
- for (size_t entry_i = ids_to_resolve.size(); entry_i > 0;
- entry_i--) {
- auto orig_id = ids_to_resolve[entry_i - 1].first;
- auto running_id = ids_to_resolve[entry_i - 1].second;
- if (backward_c2p_root_map.find(running_id) !=
- backward_c2p_root_map.end()) {
- at_leat_one_id_mapped = true;
- if (producer_resolves &&
- !backward_c2p_root_map.at(running_id)->isBroadcast()) {
- // If mapped, and producer is a producer of a reduction,
- // we can resolve this id
- ids_to_resolve.erase(
- ids_to_resolve.begin() + (entry_i - 1));
- } else {
- ids_to_resolve[entry_i - 1] = std::make_pair(
- orig_id, backward_c2p_root_map.at(running_id));
- }
- }
- }
- if (!at_leat_one_id_mapped) {
- // If no id's map any more, go to the next chain
- break;
- }
-
- if (ids_to_resolve.empty()) {
- break;
- }
- }
- }
- } // for(auto input_of_forward_running_consumer :
- // tv_inputs_of_forward_running_consumer){
-
- // if all ids were not resolved, then we've found an instance of a
- // bad broadcast resolution after reduction
- if (ids_to_resolve.size()) {
- return true;
- }
-
- } // while (!forward_tv_dep_chain.empty()) {
- } // for (auto forward_tv_dep_chain : forward_tv_chains) {
- } // for (auto red_tv : reduction_tvs)
- return false;
- }
-
- // Checks if any broadcasts are resolved after a reduction, this shouldn't be
- // accepted in the single reduction or multi-reduction scheduler
- static bool hasPostReductionBCast(Fusion* fusion) {
- auto all_vals = fusion->usedMathVals();
- for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
- // Welford can have 2 outputs, so do this on all found reduction tensor
- // views
- if (tv->hasReduction() && !tv->isFusionInput()) {
- auto tv_chains = tvChains(DependencyCheck::getAllUseChains(tv));
- // Propagate forward from reduction through all uses of the reduction
- for (auto tv_dep_chain : tv_chains) {
- TensorView* running_producer = nullptr;
- TensorView* running_consumer = tv_dep_chain.front();
- tv_dep_chain.pop_front();
- while (!tv_dep_chain.empty()) {
- running_producer = running_consumer;
- running_consumer = tv_dep_chain.front();
- tv_dep_chain.pop_front();
-
- auto pairwise_root_map =
- PairwiseRootDomainMap(running_producer, running_consumer);
- auto p2c_root_map = pairwise_root_map.mapProducerToConsumer(
- running_producer->domain(), running_consumer->domain());
-
- // Check if any TensorViews have a resolved broadcast
- for (auto entry : p2c_root_map) {
- auto p_id = entry.first;
- auto c_id = entry.second;
- if (p_id->isBroadcast() &&
- (!c_id->isBroadcast() && !c_id->isTrivialReduction())) {
- return true;
- }
- }
- }
- }
- }
- }
- return false;
- }
-
- // Checks if there's any unsupported operations post reduction. If outer
- // reduction we can fuse some pointwise ops if they don't require
- // broadcasting (checked in hasPostReductionBCast). For inner reductions we
- // cannot fuse any binary like operation (includes operations like shift that
- // we're not fusing right now) involving "new" inputs (not going through a
- // reduction).
- static bool supportedPostReductionFusion(
- Fusion* fusion,
- std::vector<TensorView*> reduction_tvs) {
- TORCH_INTERNAL_ASSERT(reduction_tvs.size());
- bool fastest_dim_reduction = true;
- auto red_root_dom = reduction_tvs[0]->getRootDomain();
- for (size_t i = red_root_dom.size(); i > 0; i--) {
- if (red_root_dom[i - 1]->isBroadcast() ||
- red_root_dom[i - 1]->isTrivialReduction()) {
- continue;
- } else if (red_root_dom[i - 1]->isReduction()) {
- fastest_dim_reduction = true;
- break;
- } else {
- fastest_dim_reduction = false;
- break;
- }
- }
-
- // If reductions are on fastest dim, don't fuse any operations (after
- // reductions) that requires an input that is not an input to the
- // reductions.
- if (fastest_dim_reduction) {
- auto post_reduction_vals = DependencyCheck::getAllValsBetween(
- {reduction_tvs.begin(), reduction_tvs.end()},
- {fusion->outputs().begin(), fusion->outputs().end()});
-
- if (post_reduction_vals.empty()) {
- return true;
- }
-
- auto reduction_inputs = IterVisitor::getInputsTo(
- {reduction_tvs.begin(), reduction_tvs.end()});
-
- for (auto tv : ir_utils::filterByType<TensorView>(
- post_reduction_vals.begin(), post_reduction_vals.end())) {
- if (tv->definition() == nullptr) {
- continue;
- }
-
- auto tv_inputs = IterVisitor::getInputsTo({tv});
-
- if (std::any_of(
- tv_inputs.begin(),
- tv_inputs.end(),
- [&reduction_inputs](Val* inp) {
- return inp->isA<TensorView>() &&
- std::find(
- reduction_inputs.begin(),
- reduction_inputs.end(),
- inp) == reduction_inputs.end();
- })) {
- return false;
- }
- }
- }
-
- return true;
- }
-};
-} // namespace
-
-SchedulerRuntimeInfo::SchedulerRuntimeInfo(
- Fusion* complete_fusion,
- const at::ArrayRef<IValue>& inputs,
- bool create_expr_evaluator)
- : complete_fusion_(complete_fusion) {
- collectVectorizationInfo(inputs);
- if (create_expr_evaluator) {
- initializeExpressionEvaluator(inputs);
- }
- collectIndexModeInfo(inputs);
-}
-
-SchedulerRuntimeInfo::SchedulerRuntimeInfo(
- const SchedulerRuntimeInfo& copy_from)
- : complete_fusion_(copy_from.complete_fusion_),
- alignment_map_(copy_from.alignment_map_),
- common_alignment_size_(copy_from.common_alignment_size_) {
- expression_evaluator_ =
- std::make_unique<ExpressionEvaluator>(complete_fusion_);
-}
-
-size_t SchedulerRuntimeInfo::getAlignmentSize(TensorView* tv) {
- auto alignment_entry = alignment_map_.find(tv);
- if (alignment_entry == alignment_map_.end()) {
- return max_alignment_size_in_byte;
- } else {
- return alignment_entry->second;
- }
-}
-
-void SchedulerRuntimeInfo::initializeExpressionEvaluator(
- const at::ArrayRef<IValue>& inputs) {
- // TODO: refactor bindFusionInputs to better support this
- // use case, i.e. support construct and bind input.
- expression_evaluator_ =
- std::make_unique<ExpressionEvaluator>(complete_fusion_);
- *expression_evaluator_ =
- executor_utils::bindFusionInputs(inputs, complete_fusion_);
-}
-
-size_t SchedulerRuntimeInfo::collectAlignmentSize(
- const at::Tensor& tensor) const {
- const size_t address = reinterpret_cast<size_t>(tensor.data_ptr());
- size_t alignment_size = 1;
- size_t next_alignment_size = 2;
-
- while (alignment_size <= max_alignment_size_in_byte &&
- address % next_alignment_size == 0) {
- alignment_size = next_alignment_size;
- next_alignment_size *= 2;
- }
-
- return alignment_size;
-}
-
-void SchedulerRuntimeInfo::collectVectorizationInfo(
- const at::ArrayRef<IValue>& inputs) {
- common_alignment_size_ = max_alignment_size_in_byte;
- size_t number_of_inputs = complete_fusion_->inputs().size();
- std::unordered_map<TensorView*, size_t> cg_tensor_to_at_tensor_index;
-
- for (auto input_index : c10::irange(number_of_inputs)) {
- if (auto input_tensor = dynamic_cast<TensorView*>(
- complete_fusion_->inputs()[input_index])) {
- if (input_tensor->nDims() == 0) {
- // A 0-dim tensor input would not need vectorization
- continue;
- }
- if (input_tensor->domain()
- ->domain()[input_tensor->nDims() - 1]
- ->isBroadcast()) {
- // skip the tensors with innermost iterdomain broadcasted,
- // as we will not vectorize these.
- continue;
- }
-
- // Collect strides of the input tensor
- TORCH_INTERNAL_ASSERT(inputs[input_index].isTensor());
- const auto& at_tensor = inputs[input_index].toTensor();
-
- cg_tensor_to_at_tensor_index.emplace(
- std::make_pair(input_tensor, input_index));
-
- // Collect alignment of the input tensor
- auto alignment_size = collectAlignmentSize(at_tensor);
- common_alignment_size_ = std::min(alignment_size, common_alignment_size_);
- alignment_map_[input_tensor] = alignment_size;
- }
- }
-
- // Compute max vector word size for each input,
- // tensors with inner most broadcast already
- // filtered out. common_alignment_size_ is
- // computed up to this point.
- for (auto it : cg_tensor_to_at_tensor_index) {
- vectorword_map_[it.first] = collectMaxVectorizeSize(
- inputs[it.second].toTensor(), common_alignment_size_);
- }
-}
-
-size_t SchedulerRuntimeInfo::collectMaxVectorizeSize(
- const at::Tensor& tensor,
- size_t max_vector_size_in_byte) {
- size_t vector_size = 1;
- size_t next_vector_size = 2;
- bool next_size_compatible = true;
-
- while (next_size_compatible &&
- next_vector_size * tensor.itemsize() <= max_vector_size_in_byte) {
- // If inner most dimension size is not divisible by new word size
- // then we cannot vectorize with this width. But we do not
- // care if all dimensions of this tensor is 1, i.e.
- // input is actually a un-squeezed 0-dim tensor.
- for (size_t i = tensor.ndimension(); i > 0; i--) {
- if (tensor.size(i - 1) != 1) {
- if (tensor.size(tensor.ndimension() - 1) % next_vector_size != 0 ||
- tensor.stride(tensor.ndimension() - 1) != 1) {
- next_size_compatible = false;
- }
- break;
- }
- }
-
- if (!next_size_compatible) {
- break;
- }
-
- // If any stride is not divisible by the next word size,
- // we cannot vectorize with this width.
- for (auto stride : tensor.strides()) {
- if (stride != 1 && stride % next_vector_size != 0) {
- next_size_compatible = false;
- break;
- }
- }
-
- if (next_size_compatible) {
- vector_size = next_vector_size;
- next_vector_size *= 2;
- }
- }
-
- return vector_size;
-}
-
-size_t SchedulerRuntimeInfo::getVectorizableWidth(TensorView* tv) {
- auto recorded_size_it = vectorword_map_.find(tv);
- if (recorded_size_it != vectorword_map_.end()) {
- return recorded_size_it->second;
- }
-
- // If we don't have an record, either it is a tv with innermost
- // broadcast, or it is an intermediate tensor allocated by fuser
- auto tv_root = TensorDomain::noReductions(tv->getRootDomain());
- auto tv_root_size = tv_root.size();
-
- // Filter out 0-dim tensors
- if (tv_root_size < 1) {
- return 1;
- }
-
- // Filter out mismatched contiguity info
- if (tv_root_size != tv->domain()->contiguity().size()) {
- return 1;
- }
-
- // Filter out innermost broadcast tensors
- auto inner_dimension = tv_root[tv_root_size - 1];
- if (inner_dimension->isBroadcast()) {
- return 1;
- }
-
- // Handle intermediate or output tensors that
- // will be allocated by fuser
- auto maybe_data_type = tv->getDataType();
-
- // Do not vectorize on data with unknown type
- if (!maybe_data_type.has_value()) {
- return 1;
- }
-
- size_t item_size = dataTypeSize(maybe_data_type.value());
- // Assume we don't have non-divisible types for now.
- TORCH_INTERNAL_ASSERT(max_alignment_size_in_byte % item_size == 0);
- size_t max_vector_size = max_alignment_size_in_byte / item_size;
-
- // Assuming intermediate tensors have friendly alignment, and
- // all contiguity true. Determine the largest power of 2 below
- // innermost dimension size for the word size of vectorizaiton
- size_t vector_size = 1;
- size_t next_vector_size = 2;
- auto maybe_inner_dimension_size =
- expression_evaluator_->evaluate(inner_dimension->extent());
- TORCH_INTERNAL_ASSERT(maybe_inner_dimension_size.has_value());
- size_t inner_dimension_size = maybe_inner_dimension_size.value();
-
- while (next_vector_size <= max_vector_size &&
- next_vector_size <= inner_dimension_size &&
- inner_dimension_size % next_vector_size == 0) {
- vector_size = next_vector_size;
- next_vector_size *= 2;
- }
-
- // save output to avoid re-compute
- vectorword_map_[tv] = vector_size;
-
- return vector_size;
-}
-
-void SchedulerRuntimeInfo::collectIndexModeInfo(
- const at::ArrayRef<at::IValue>& inputs) {
- // Save 1 more bit besides the sign bit to be conservative
- constexpr int64_t most_positive_int32_index =
- std::numeric_limits<int>::max() / 2;
- constexpr int64_t most_negative_int32_index =
- std::numeric_limits<int>::min() / 2;
-
- // Start by setting index mode to int32
- index_mode_ = KernelIndexMode::INT32;
-
- // Check all runtime inputs, and if any one of
- // the input's index exceeds max_int32 will
- // fall back to int64 indexing
- for (auto ivalue_input : inputs) {
- if (ivalue_input.isTensor()) {
- auto tensor_input = ivalue_input.toTensor();
- int64_t tensor_most_positive_index = 0;
- int64_t tensor_most_negative_index = 0;
- for (auto dim_i = 0; dim_i < tensor_input.ndimension(); dim_i++) {
- // Ignore broadcast dimensions
- if (tensor_input.size(dim_i) > 1) {
- // accumulate based on the sign of stride
- if (tensor_input.stride(dim_i) > 0) {
- // Acuumulate positive stride
- tensor_most_positive_index +=
- (tensor_input.size(dim_i) - 1) * tensor_input.stride(dim_i);
- } else {
- // Acuumulate negative stride
- tensor_most_negative_index +=
- (tensor_input.size(dim_i) - 1) * tensor_input.stride(dim_i);
- }
- }
- }
-
- // Fall back to int64 if it can be either too positive
- // or too negative.
- if (tensor_most_positive_index > most_positive_int32_index ||
- tensor_most_negative_index < most_negative_int32_index) {
- index_mode_ = KernelIndexMode::INT64;
- return;
- }
- }
- }
-}
-
-bool SchedulerEntry::sameAs(const SchedulerEntry* other) {
- if (heuristc_ != other->heuristc_) {
- return false;
- }
- if (index_mode_ != other->index_mode_) {
- return false;
- }
- // Heuristic equal should imply has_reduction_param_ equal,
- // need to double check if it is the case before removing
- // the below one.
- if (has_reduction_param_ != other->has_reduction_param_) {
- return false;
- }
- if (has_reduction_param_) {
- return rparams_ == other->rparams_;
- } else {
- return pparams_ == other->pparams_;
- }
- return true;
-}
-
-namespace {
-template <typename REDUCTION_OP = ReductionOp>
-inline bool isTrivialReduction(REDUCTION_OP* red) {
- auto o_tv = red->out()->template as<TensorView>();
- // Assuming graph unscheduled at this point.
- for (auto id : o_tv->getRootDomain()) {
- if (id->isReduction() && !id->extent()->isOneInt()) {
- return false;
- }
- }
- return true;
-}
-
-template <typename REDUCTION_OP = ReductionOp>
-std::vector<REDUCTION_OP*> findReductionOps(Fusion* fusion) {
- std::vector<REDUCTION_OP*> red_ops;
- for (auto expr : fusion->exprs()) {
- if (auto red = dynamic_cast<REDUCTION_OP*>(expr)) {
- if (!isTrivialReduction(red)) {
- red_ops.push_back(red);
- }
- }
- }
- return red_ops;
-}
-
-class SingleReductionScheduler : public SchedulerEntry {
- public:
- explicit SingleReductionScheduler(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr)
- : SchedulerEntry(ScheduleHeuristic::Reduction, true) {
- computeHeuristics(fusion, runtime_info, data_cache);
- }
-
- //! Check if the reduction heuristics apply in given fusion
- static bool canSchedule(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr) {
- if (data_cache) {
- return true;
- }
-
- auto red_ops = findReductionOps(fusion);
- auto welford_ops = findReductionOps<WelfordOp>(fusion);
- if (red_ops.size() + welford_ops.size() != 1) {
- return false;
- }
-
- bool is_welford = welford_ops.size() > 0;
-
- if (SchedulerTopologyChecker::hasPostReductionBCast(fusion)) {
- return false;
- }
-
- auto reduction_tv = is_welford ? welford_ops[0]->out()->as<TensorView>()
- : red_ops[0]->out()->as<TensorView>();
-
- if (!SchedulerTopologyChecker::supportedPostReductionFusion(
- fusion, {reduction_tv})) {
- return false;
- }
-
- return true;
- }
-
- void schedule(Fusion* fusion) override {
- FUSER_PERF_SCOPE("Schedule Single Reduction");
- scheduleReduction(fusion, rparams_);
- }
-
- private:
- void computeHeuristics(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr) {
- auto param = getReductionHeuristics(fusion, runtime_info, data_cache);
- TORCH_INTERNAL_ASSERT(param.has_value());
- rparams_ = param.value();
- }
-};
-
-class PointWiseScheduler : public SchedulerEntry {
- public:
- explicit PointWiseScheduler(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr)
- : SchedulerEntry(ScheduleHeuristic::PointWise, false) {
- computeHeuristics(fusion, runtime_info, data_cache);
- }
-
- static bool canSchedule(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr) {
- if (data_cache) {
- return true;
- }
- auto red_ops = findReductionOps(fusion);
- auto welford_ops = findReductionOps<WelfordOp>(fusion);
- return red_ops.empty() && welford_ops.empty();
- }
-
- void schedule(Fusion* fusion) override {
- FUSER_PERF_SCOPE("Schedule PointWise Fusion");
- schedulePointwise(fusion, pparams_);
- }
-
- void computeHeuristics(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr) {
- auto pparam = getPointwiseHeuristics(fusion, runtime_info, data_cache);
- TORCH_INTERNAL_ASSERT(pparam.has_value());
- pparams_ = pparam.value();
- }
-};
-
-class NormalizationScheduler : public SchedulerEntry {
- public:
- explicit NormalizationScheduler(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr)
- : SchedulerEntry(ScheduleHeuristic::Normalization, true) {
- computeHeuristics(fusion, runtime_info, data_cache);
- }
-
- void schedule(Fusion* fusion) override {
- FUSER_PERF_SCOPE("Schedule Normalization Fusion");
- scheduleNormalization(fusion, rparams_);
- }
-
- static bool canSchedule(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr) {
- FUSER_PERF_SCOPE("NormalizationScheduler::canSchedule");
-
- HeuristicCacheAccessor<std::vector<TensorView*>> reduction_tv_data;
- // TODO: move all these boilerplate code into the accessor class
- // (follow up)
- if (data_cache && !data_cache->isRecording()) {
- reduction_tv_data.writeTemporary(data_cache->getReductionTVs());
- } else {
- reduction_tv_data.writeNew(scheduler_utils::getReductionTvs(fusion));
- if (data_cache && data_cache->isRecording()) {
- data_cache->setReductionTVs(reduction_tv_data.read());
- }
- }
-
- auto& reduction_tvs = reduction_tv_data.read();
-
- if (!data_cache) {
- if (reduction_tvs.size() == 0) {
- // Use single reduction or pointwise logic
- return false;
- }
-
- if (SchedulerTopologyChecker::hasNonNormalizePostReductionBCast(fusion)) {
- return false;
- }
-
- // Before examining the reduction axes want to quickly
- // check the reductions have the same axis width
- // to avoid building root domain map in easier cases
- bool valid_axis_count = false;
- size_t axis_count = 0;
- auto reduction_root_size = [](TensorView* red_tv) {
- size_t count = 0;
- for (auto id : red_tv->getRootDomain()) {
- if (!id->isBroadcast()) {
- count++;
- }
- }
- return count;
- };
-
- for (auto red : reduction_tvs) {
- if (!valid_axis_count) {
- valid_axis_count = true;
- axis_count = reduction_root_size(red);
- } else {
- if (reduction_root_size(red) != axis_count) {
- return false;
- }
- }
- }
-
- // Use root domain map to check the reduction ops have the same axes
- FusionGuard fg(fusion);
- ComputeAtRootDomainMap root_map;
- root_map.build(true);
-
- // red_ops.size()>1 checked before
- for (size_t it = 1; it < reduction_tvs.size(); it++) {
- if (!checkEquivalence(
- reduction_tvs[it - 1], reduction_tvs[it], root_map)) {
- return false;
- }
- }
- }
-
- // TODO: move all these boilerplate code into the accessor class
- // (follow up)
- // Note: this persistent buffer is actually cached from
- // getNormalizationHeuristics. Will need to create a separate
- // cache entry if they are not the same.
- HeuristicCacheAccessor<scheduler_utils::PersistentBufferInfo>
- persistent_buffer_data;
-
- if (data_cache && !data_cache->isRecording()) {
- persistent_buffer_data.writeTemporary(
- data_cache->getPersistentBufferInfo());
- } else {
- persistent_buffer_data.writeNew(
- scheduler_utils::persistentBuffers(fusion));
- if (data_cache && data_cache->isRecording()) {
- data_cache->setPersistentBufferInfo(persistent_buffer_data.read());
- }
- }
- auto& persistent_buffers = persistent_buffer_data.read();
-
- auto persistent_buffer_size = scheduler_utils::persistentBufferSize(
- fusion, runtime_info, persistent_buffers, data_cache);
- if (persistent_buffer_size * 4 > scheduler_utils::register_file_size * 3) {
- return false;
- }
-
- // TODO: really need to make inserting an entry into data_cache easier to do
- HeuristicCacheAccessor<bool> has_post_reduction_bcast_data;
-
- if (data_cache && !data_cache->isRecording()) {
- has_post_reduction_bcast_data.writeTemporary(
- data_cache->getHasPostReductionBCast());
- } else {
- has_post_reduction_bcast_data.writeNew(
- SchedulerTopologyChecker::hasPostReductionBCast(fusion));
- if (data_cache && data_cache->isRecording()) {
- data_cache->setHasPostReductionBCast(
- has_post_reduction_bcast_data.read());
- }
- }
-
- HeuristicCacheAccessor<bool> supported_post_reduction_fusion_data;
-
- if (data_cache && !data_cache->isRecording()) {
- supported_post_reduction_fusion_data.writeTemporary(
- data_cache->getSupportedPostReductionFusion());
- } else {
- supported_post_reduction_fusion_data.writeNew(
- SchedulerTopologyChecker::supportedPostReductionFusion(
- fusion, reduction_tvs));
- if (data_cache && data_cache->isRecording()) {
- data_cache->setSupportedPostReductionFusion(
- supported_post_reduction_fusion_data.read());
- }
- }
-
- auto has_post_reduction_bcast = has_post_reduction_bcast_data.read();
- auto supported_post_reduction_fusion =
- supported_post_reduction_fusion_data.read();
-
- // Multi reduction scheduler has the same limitations as single reduction
- // scheduler here
- if (persistent_buffer_size <= 1) {
- if (has_post_reduction_bcast) {
- return false;
- }
-
- if (!supported_post_reduction_fusion) {
- return false;
- }
- }
-
- return true;
- }
-
- private:
- void computeHeuristics(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr) {
- auto rparams = getNormalizationHeuristics(fusion, runtime_info, data_cache);
- TORCH_INTERNAL_ASSERT(rparams.has_value());
- rparams_ = rparams.value();
- }
-
- static bool checkEquivalence(
- TensorView* out_tv0,
- TensorView* out_tv1,
- const ComputeAtRootDomainMap& root_map) {
- const auto& out_root0 = out_tv0->getRootDomain();
- const auto& out_root1 = out_tv1->getRootDomain();
- const auto domain0 = out_tv0->domain();
- const auto domain1 = out_tv1->domain();
-
- auto it0 = out_root0.begin();
- auto it1 = out_root1.begin();
-
- auto skip_broadcast = [&]() {
- while (it0 != out_root0.end() && (*it0)->isBroadcast()) {
- it0++;
- }
- while (it1 != out_root1.end() && (*it1)->isBroadcast()) {
- it1++;
- }
- };
-
- skip_broadcast();
- while (it0 != out_root0.end() && it1 != out_root1.end()) {
- if ((*it0)->isReduction() != (*it1)->isReduction()) {
- return false;
- }
- if (!root_map.canMap(domain0, (*it0), domain1, (*it1))) {
- return false;
- }
- it0++;
- it1++;
- skip_broadcast();
- }
-
- return it0 == out_root0.end() && it1 == out_root1.end();
- }
-};
-
-// Schedule Table
-const std::vector<ScheduleHeuristic>& all_heuristics() {
- static const std::vector<ScheduleHeuristic> hlist = {
- ScheduleHeuristic::Reduction,
- ScheduleHeuristic::PointWise,
- ScheduleHeuristic::Normalization};
- return hlist;
-}
-
-} // namespace
-
-// Simple dispatcher interface
-bool SchedulerEntry::canSchedule(
- ScheduleHeuristic sh,
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache) {
- switch (sh) {
- case ScheduleHeuristic::PointWise:
- return PointWiseScheduler::canSchedule(fusion, runtime_info, data_cache);
- case ScheduleHeuristic::Reduction:
- return SingleReductionScheduler::canSchedule(
- fusion, runtime_info, data_cache);
- case ScheduleHeuristic::Normalization:
- return NormalizationScheduler::canSchedule(
- fusion, runtime_info, data_cache);
- default:
- TORCH_INTERNAL_ASSERT(false, "unreachable");
- return false;
- }
- return false;
-}
-
-std::unique_ptr<SchedulerEntry> SchedulerEntry::makeEntry(
- ScheduleHeuristic sh,
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache) {
- std::unique_ptr<SchedulerEntry> scheduler_entry = nullptr;
- switch (sh) {
- case ScheduleHeuristic::PointWise:
- scheduler_entry = std::make_unique<PointWiseScheduler>(
- fusion, runtime_info, data_cache);
- break;
- case ScheduleHeuristic::Reduction:
- scheduler_entry = std::make_unique<SingleReductionScheduler>(
- fusion, runtime_info, data_cache);
- break;
- case ScheduleHeuristic::Normalization:
- scheduler_entry = std::make_unique<NormalizationScheduler>(
- fusion, runtime_info, data_cache);
- break;
- default:
- TORCH_INTERNAL_ASSERT(false, "unreachable");
- }
-
- scheduler_entry->index_mode_ = runtime_info.getIndexMode();
- return scheduler_entry;
-}
-
-// Simply loop through the list as baseline strategy
-c10::optional<ScheduleHeuristic> SchedulerEntry::proposeHeuristics(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info) {
- for (auto sh : all_heuristics()) {
- if (canSchedule(sh, fusion, runtime_info)) {
- return sh;
- }
- }
- return c10::nullopt;
-}
-
-size_t SchedulerEntryHash::operator()(const SchedulerEntry& se) const {
- if (se.hasReductionParam()) {
- return ReductionParamsHash()(se.reductionParams());
- } else {
- return PointwiseParamsHash()(se.pointwiseParams());
- }
-}
-
-std::string toString(ScheduleHeuristic sh) {
- switch (sh) {
- case ScheduleHeuristic::PointWise:
- return "pointwise";
- case ScheduleHeuristic::Reduction:
- return "reduction";
- case ScheduleHeuristic::Normalization:
- return "normalization";
- default:
- TORCH_INTERNAL_ASSERT(false, "undefined schedule");
- }
- return "";
-}
-
-HeuristicSummary::HeuristicSummary(
- Fusion* fusion,
- ScheduleHeuristic heuristic,
- SchedulerRuntimeInfo& runtime_info)
- : heuristic_(heuristic) {
- recording_ = true;
- switch (heuristic) {
- case ScheduleHeuristic::PointWise:
- getPointwiseHeuristics(fusion, runtime_info, this);
- PointWiseScheduler::canSchedule(fusion, runtime_info, this);
- break;
- case ScheduleHeuristic::Reduction:
- getReductionHeuristics(fusion, runtime_info, this);
- SingleReductionScheduler::canSchedule(fusion, runtime_info, this);
- break;
- case ScheduleHeuristic::Normalization:
- getNormalizationHeuristics(fusion, runtime_info, this);
- NormalizationScheduler::canSchedule(fusion, runtime_info, this);
- break;
- default:
- TORCH_INTERNAL_ASSERT(false, "unknown heuristic");
- }
- validate();
- recording_ = false;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-#include <torch/csrc/jit/codegen/cuda/utils.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class SegmentedGroup;
-class ExpressionEvaluator;
-
-//! SchedulerRuntimeInfo is the abstraction introduced in
-//! this PR for passing runtime input dependent information
-//! to the schedulers and kernel caches.
-//!
-//! Note:
-//! if any additional info needed, or maybe just the inputs themselves it
-//! could just be added to this class, and they will be distributed to the
-//! segmenter and schedulers.
-//! It is important that input id encoding should be up to date with any change
-//! of this class to avoid launching compiled kernels with illegal inputs.
-class TORCH_CUDA_CU_API SchedulerRuntimeInfo {
- public:
- // Max vector size we will consider, in bytes,
- // currently set to 16B = 128b
- const size_t max_alignment_size_in_byte = 16;
-
- //! Create runtime info for given fusion and input. Creating and binding
- //! evaluator is optional. The evaluator is used to manage intermediate
- //! integers in the fusion. We need them for segmenter and schedulers,
- //! but we don't need them when we are just using this class to provide
- //! additional encoding for kernel cache lookup.
- SchedulerRuntimeInfo(
- Fusion* complete_fusion,
- const at::ArrayRef<at::IValue>& inputs,
- bool create_expr_evaluator = false);
-
- //! Create runtime info by copying all the global
- //! input meta data (i.e. alignment), but not the
- //! expression evaluator.
- SchedulerRuntimeInfo(const SchedulerRuntimeInfo& global_runtime_info);
-
- //! Lookup for the alignment sizes of the given tv. Currently only returns
- //! actual alignment info for input tensors to the complete fusion,
- //! and for other intermediate/fuser-allocated tensors will
- //! return max_alignment_size_in_byte.
- size_t getAlignmentSize(TensorView* tv);
-
- //! Take the minimum of input tv alignment sizes. This is both information for
- //! vectorization and
- //! a signature for kernel cache id lookup. May need to be updated with
- //! vectorization logic.
- size_t getCommonAlignmentSize() const {
- return common_alignment_size_;
- }
-
- //! Returns the max width the given tensor view can be vectorized,
- //! for input tensors will use the pre-computed value based on
- //! the given tensor alignment and strides. For intermediate tensors
- //! will assume it is contiguous and aligned to 128bit/16Byte
- size_t getVectorizableWidth(TensorView* tv);
-
- KernelIndexMode getIndexMode() {
- return index_mode_;
- }
-
- Fusion* fusion() {
- return complete_fusion_;
- }
-
- ExpressionEvaluator& expressionEvaluator() {
- TORCH_INTERNAL_ASSERT(expression_evaluator_ != nullptr);
- return *expression_evaluator_;
- }
-
- private:
- // Bind full fusion inputs to the internal expression evaluator
- void initializeExpressionEvaluator(const at::ArrayRef<at::IValue>& inputs);
-
- // Compute alignment data for all input tensors of full fusion
- void collectVectorizationInfo(const at::ArrayRef<at::IValue>& inputs);
-
- // Compute alignment data for given tensor
- size_t collectAlignmentSize(const at::Tensor& tensor) const;
-
- // Compute max vectorization word size for each an input tensor
- size_t collectMaxVectorizeSize(
- const at::Tensor& tensor,
- size_t max_word_size_in_byte);
-
- // check if input is compatible with 32b index mode
- void collectIndexModeInfo(const at::ArrayRef<at::IValue>& inputs);
-
- private:
- std::unique_ptr<ExpressionEvaluator> expression_evaluator_ = nullptr;
- Fusion* complete_fusion_;
- std::unordered_map<TensorView*, size_t> alignment_map_;
- std::unordered_map<TensorView*, size_t> vectorword_map_;
- size_t common_alignment_size_;
- KernelIndexMode index_mode_ = KernelIndexMode::INT64;
-};
-
-class HeuristicSummary;
-
-//! Virtual base class for schedule heuristics
-//! heuristic implementations derive from this
-//! class and implement a schedule(Fusion*)
-//! and a bool canSchedule(Fusion*) interface
-class TORCH_CUDA_CU_API SchedulerEntry {
- public:
- //! Fusion runtime facing API,
- //! builds a new entry with the given heuristics
- //! corresponding to the given fusion
- static std::unique_ptr<SchedulerEntry> makeEntry(
- ScheduleHeuristic sh,
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr);
-
- virtual ~SchedulerEntry() = default;
-
- //! External access for canSchedule utilities through SchedulerEntry
- //! to avoid exposing a single function to the namespace
- static bool canSchedule(
- ScheduleHeuristic sh,
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- HeuristicSummary* data_cache = nullptr);
-
- //! Fusion segmenter facing API,
- //! returns a schedule that applies in the given fusion, returns a nullopt
- //! if no schedule in the registry can handle.
- static c10::optional<ScheduleHeuristic> proposeHeuristics(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info);
-
- //! Fusion runtime facing API,
- //! schedule the given fusion with heuristics owned
- //! by this entry, for actual heuristics to override
- virtual void schedule(Fusion* fusion) = 0;
-
- //! Heuristic comparison
- bool sameAs(const SchedulerEntry* other);
-
- bool hasReductionParam() const {
- return has_reduction_param_;
- }
-
- ScheduleHeuristic heuristc() const {
- return heuristc_;
- }
-
- KernelIndexMode indexMode() const {
- return index_mode_;
- }
-
- const ReductionParams& reductionParams() const {
- TORCH_INTERNAL_ASSERT(
- has_reduction_param_, "This schedule heuristic is not reduction.");
- return rparams_;
- }
-
- const PointwiseParams& pointwiseParams() const {
- TORCH_INTERNAL_ASSERT(
- !has_reduction_param_, "This schedule heuristic is not pointwise.");
- return pparams_;
- }
-
- void updateLaunchConstraint(const LaunchParams& launch_params) {
- if (hasReductionParam()) {
- rparams_.lparams = launch_params;
- } else {
- pparams_.lparams = launch_params;
- }
- }
-
- protected:
- explicit SchedulerEntry(ScheduleHeuristic heuristic, bool has_reduction_param)
- : heuristc_(heuristic), has_reduction_param_(has_reduction_param) {}
-
- //! What kind of heuristics does this entry have?
- const ScheduleHeuristic heuristc_;
-
- //! Has reduction params if true, else has pointwise params
- const bool has_reduction_param_;
-
- //! Reduction parameters if applicable
- ReductionParams rparams_;
-
- //! Pointwise parameters if applicable
- PointwiseParams pparams_;
-
- //! Kernel Index Mode
- KernelIndexMode index_mode_;
-};
-
-//! Hash function for a scheduler entry
-class TORCH_CUDA_CU_API SchedulerEntryHash {
- public:
- size_t operator()(const SchedulerEntry& se) const;
-};
-
-//! Debug print function for heuristics
-std::string toString(ScheduleHeuristic sh);
-
-class TORCH_CUDA_CU_API HeuristicSummary {
- using ValToFactorMap = std::unordered_map<Val*, int>;
- using ValToFactorMapPtr = std::unique_ptr<ValToFactorMap>;
- using ScopedPersistenceFactorMap =
- std::unordered_map<Val*, ValToFactorMapPtr>;
-
- public:
- HeuristicSummary(
- Fusion* fusion,
- ScheduleHeuristic heuristic,
- SchedulerRuntimeInfo& runtime_info);
- // Recording scheme:
- bool isRecording() {
- return recording_;
- }
-
- // Validate post recording:
- // make sure we have collected all the needed fields
- void validate() {
- switch (heuristic_) {
- case ScheduleHeuristic::PointWise:
- TORCH_INTERNAL_ASSERT(vectorizable_inputs_outputs_);
- TORCH_INTERNAL_ASSERT(mapped_input_output_dims_);
- break;
- case ScheduleHeuristic::Reduction:
- TORCH_INTERNAL_ASSERT(reduction_tvs_);
- break;
- case ScheduleHeuristic::Normalization:
- TORCH_INTERNAL_ASSERT(vectorizable_inputs_outputs_);
- TORCH_INTERNAL_ASSERT(reduction_tvs_);
- TORCH_INTERNAL_ASSERT(persistent_buffer_info_);
- TORCH_INTERNAL_ASSERT(has_post_reduction_bcast_);
- TORCH_INTERNAL_ASSERT(supported_post_reduction_fusion_);
- break;
- }
- }
-
- // Accessors (un-protected for now)
- void setVectorizableInputsOutputs(const std::vector<TensorView*>& input) {
- TORCH_INTERNAL_ASSERT(recording_);
-
- if (!vectorizable_inputs_outputs_) {
- vectorizable_inputs_outputs_ =
- std::make_unique<std::vector<TensorView*>>(input);
- }
- }
-
- auto* getVectorizableInputsOutputs() {
- return vectorizable_inputs_outputs_.get();
- }
-
- void setReductionTVs(const std::vector<TensorView*>& input) {
- TORCH_INTERNAL_ASSERT(recording_);
-
- if (!reduction_tvs_) {
- reduction_tvs_ = std::make_unique<std::vector<TensorView*>>(input);
- }
- }
-
- auto* getReductionTVs() {
- return reduction_tvs_.get();
- }
-
- void setPersistentBufferInfo(
- const scheduler_utils::PersistentBufferInfo& input) {
- TORCH_INTERNAL_ASSERT(recording_);
-
- if (!persistent_buffer_info_) {
- persistent_buffer_info_ =
- std::make_unique<scheduler_utils::PersistentBufferInfo>(input);
- }
- }
-
- auto* getPersistentBufferInfo() {
- return persistent_buffer_info_.get();
- }
-
- void setSupportedPostReductionFusion(bool input) {
- TORCH_INTERNAL_ASSERT(recording_);
-
- if (!supported_post_reduction_fusion_) {
- supported_post_reduction_fusion_ = std::make_unique<bool>(input);
- }
- }
-
- auto* getSupportedPostReductionFusion() {
- return supported_post_reduction_fusion_.get();
- }
-
- void setHasPostReductionBCast(bool input) {
- TORCH_INTERNAL_ASSERT(recording_);
-
- if (!has_post_reduction_bcast_) {
- has_post_reduction_bcast_ = std::make_unique<bool>(input);
- }
- }
-
- auto* getHasPostReductionBCast() {
- return has_post_reduction_bcast_.get();
- }
-
- void setScopedPersistenceFactorMap(const ScopedPersistenceFactorMap& input) {
- TORCH_INTERNAL_ASSERT(recording_);
-
- scope_persistence_factor_map_ =
- std::make_unique<ScopedPersistenceFactorMap>();
- for (const auto& it : input) {
- ValToFactorMap& to_copy = *(it.second);
- scope_persistence_factor_map_->operator[](it.first) =
- std::make_unique<ValToFactorMap>(to_copy);
- }
- }
-
- auto* getScopedPersistenceFactorMap() {
- return scope_persistence_factor_map_.get();
- }
-
- void setMappedInputOutputDims(const std::vector<int64_t>& input) {
- TORCH_INTERNAL_ASSERT(recording_);
-
- if (!mapped_input_output_dims_) {
- mapped_input_output_dims_ = std::make_unique<std::vector<int64_t>>(input);
- }
- }
-
- auto* getMappedInputOutputDims() {
- return mapped_input_output_dims_.get();
- }
-
- private:
- ScheduleHeuristic heuristic_;
- bool recording_ = true;
-
- // Actual data payload, could be folded into subclasses later.
- std::unique_ptr<std::vector<TensorView*>> vectorizable_inputs_outputs_;
- std::unique_ptr<std::vector<TensorView*>> reduction_tvs_;
- std::unique_ptr<scheduler_utils::PersistentBufferInfo>
- persistent_buffer_info_;
- std::unique_ptr<bool> has_post_reduction_bcast_;
- std::unique_ptr<bool> supported_post_reduction_fusion_;
- std::unique_ptr<ScopedPersistenceFactorMap> scope_persistence_factor_map_;
- std::unique_ptr<std::vector<int64_t>> mapped_input_output_dims_;
-};
-
-// A temporary utility class to save some boilerplate code when
-// using HeuristicSummary. Can be significantly improved in a follow up.
-template <typename T>
-class HeuristicCacheAccessor {
- public:
- HeuristicCacheAccessor() = default;
-
- T& read() {
- if (temporary_data_) {
- return *temporary_data_;
- } else {
- return *owned_data_;
- }
- }
-
- void writeNew(T data) {
- owned_data_ = std::make_unique<T>(std::move(data));
- }
-
- void takeNew(std::unique_ptr<T>& data) {
- owned_data_ = std::move(data);
- }
-
- void writeTemporary(T* data) {
- temporary_data_ = data;
- }
-
- private:
- std::unique_ptr<T> owned_data_ = nullptr;
- T* temporary_data_ = nullptr;
-};
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
-
-#include <torch/csrc/jit/codegen/cuda/arith.h>
-#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
-#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
-#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
-#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-namespace scheduler_utils {
-size_t mergeReduction(
- TensorView* tv,
- const std::unordered_set<IterDomain*>& dont_merge) {
- int prev_i = -1;
- size_t num_merged = 0;
- for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
- if (!tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
- continue;
- }
- if (prev_i == -1) {
- prev_i = i;
- } else {
- tv->merge(i, prev_i);
- prev_i = i;
- num_merged++;
- }
- }
- if (prev_i != 0) {
- tv->reorder({{prev_i, 0}});
- }
-
- return prev_i == -1 ? 0 : num_merged + 1;
-}
-
-size_t mergeNonReduction(
- TensorView* tv,
- const std::unordered_set<IterDomain*>& dont_merge) {
- int prev_i = -1;
- size_t num_merged = 0;
- if (tv->nDims() == 0) {
- return 0;
- }
- for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
- if (tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
- continue;
- }
- if (prev_i == -1) {
- prev_i = i;
- } else {
- tv->merge(i, prev_i);
- prev_i = i;
- num_merged++;
- }
- }
- if (prev_i != 0) {
- tv->reorder({{prev_i, 0}});
- }
-
- return prev_i == -1 ? 0 : num_merged + 1;
-}
-
-void parallelizeAllLike(
- TensorView* reference_tv,
- const std::vector<TensorView*>& all_tvs) {
- FusionGuard fg(reference_tv->fusion());
-
- auto ca_loop_map = ComputeAtMap(ComputeAtMap::MappingMode::LOOP);
- ca_loop_map.build(FusionGuard::getCurFusion());
- for (auto id : reference_tv->domain()->domain()) {
- ca_loop_map.getConcreteMappedID(id)->parallelize(id->getParallelType());
- }
-
- for (auto tv : all_tvs) {
- if (tv->isFusionInput()) {
- continue;
- }
- for (size_t i = 0; i < tv->domain()->domain().size(); i++) {
- tv->axis(i)->parallelize(
- ca_loop_map.getConcreteMappedID(tv->axis(i))->getParallelType());
- }
- }
-}
-
-void computeAtInputs(TensorView* consumer, int pos, ComputeAtMode mode) {
- for (auto inp_tv : ir_utils::inputTvsOf(consumer)) {
- inp_tv->computeAt(consumer, pos, mode);
- }
-}
-
-void computeWithOutputs(TensorView* producer, int pos, ComputeAtMode mode) {
- for (auto out_tv : ir_utils::outputTvsOf(producer)) {
- producer->computeWith(out_tv, pos, mode);
- }
-}
-
-PersistentBufferInfo persistentBuffers(Fusion* fusion) {
- FusionGuard fg(fusion);
-
- PersistentBufferInfo info;
-
- ComputeAtRootDomainMap root_map;
- root_map.build();
-
- auto all_tvs = ir_utils::allTvs(fusion);
-
- for (auto producer : all_tvs) {
- bool mappable = true;
- auto consumers = ir_utils::consumerTvsOf(producer);
- if (consumers.empty()) {
- continue;
- }
-
- auto mappable_roots =
- root_map.getMappableDims(producer->domain(), consumers[0]->domain());
-
- auto p_root = producer->getMaybeRFactorDomain();
-
- for (auto p_root_id : p_root) {
- if (p_root_id->isReduction()) {
- continue;
- }
- if (!mappable_roots.count(p_root_id)) {
- mappable = false;
- info.unmappable_dims.emplace(p_root_id);
- }
- }
-
- if (!mappable) {
- info.buffers.push_back(producer);
- }
- }
- return info;
-}
-
-TvProperties getProperties(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- TensorView* tv) {
- TvProperties properties;
- FusionGuard fg(fusion);
-
- auto red_root_dom = tv->getRootDomain();
- for (size_t i = red_root_dom.size(); i > 0; i--) {
- if (red_root_dom[i - 1]->isBroadcast()) {
- continue;
- } else if (red_root_dom[i - 1]->isReduction()) {
- break;
- } else {
- properties.fastest_dim_reduction = false;
- break;
- }
- }
-
- bool hit_reduction = false;
- auto root_dom = tv->getMaybeRFactorDomain();
- for (auto it = root_dom.rbegin(); it != root_dom.rend(); ++it) {
- auto id = *it;
-
- auto inferred_val =
- runtime_info.expressionEvaluator().evaluate(id->extent());
- TORCH_INTERNAL_ASSERT(
- inferred_val.has_value(), "Error inferring reduction size.");
- if (id->isReduction()) {
- hit_reduction = true;
- properties.reduction_numel *= inferred_val.value();
- } else {
- auto dim_size = inferred_val.value();
- properties.iteration_numel *= dim_size;
- if (hit_reduction) {
- properties.iter_outside_red *= dim_size;
- } else {
- properties.iter_inside_red *= dim_size;
- }
- }
- }
-
- if (properties.reduction_numel == 1) {
- properties.iter_outside_red =
- properties.iter_outside_red * properties.iter_inside_red;
- properties.iter_inside_red = 1;
- properties.fastest_dim_reduction = true;
- }
-
- return properties;
-}
-
-void computeAtBetween(
- const std::vector<TensorView*>& producers,
- const std::vector<TensorView*>& overall_consumers,
- int pos,
- ComputeAtMode mode,
- std::unordered_set<IterDomain*> mapped_to_trivial_reduction) {
- for (auto producer : producers) {
- // Figure out what's between producer and overall_consumers, will not give
- // back any consumers that are not downstream from producer
- auto all_vals_between = DependencyCheck::getAllValsBetween(
- {producer}, {overall_consumers.begin(), overall_consumers.end()});
-
- std::unordered_set<Val*> all_vals_between_set(
- all_vals_between.begin(), all_vals_between.end());
-
- for (auto consumer : overall_consumers) {
- if (all_vals_between_set.count(consumer)) {
- // The way we generate producers and consumers is that we inch away from
- // inputs/outputs. There's a chance we could meet in the middle.
- if (producer == consumer) {
- continue;
- }
-
- auto pos_it = std::find_if(
- consumer->domain()->domain().begin(),
- consumer->domain()->domain().end(),
- [&mapped_to_trivial_reduction](IterDomain* id) {
- return mapped_to_trivial_reduction.count(id);
- });
-
- pos = pos_it == consumer->domain()->domain().end()
- ? pos
- : std::min(
- (int)std::distance(
- consumer->domain()->domain().begin(), pos_it) +
- 1,
- (pos < 0 ? pos + (int)consumer->nDims() : pos));
- // Assume we don't want to reset computeAt on tensors that have already
- // performed it.
- producer->computeAt(consumer, pos, mode);
- }
- }
- }
-}
-
-int64_t persistentBufferSize(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- PersistentBufferInfo& persistent_buffers,
- HeuristicSummary* data_cache) {
- FUSER_PERF_SCOPE("scheduler_utils::persistentBufferSize");
-
- if (persistent_buffers.buffers.empty()) {
- return 0;
- }
-
- int64_t persistent_buffer_size = 0;
-
- using ValToFactorMap = std::unordered_map<Val*, int>;
- using ValToFactorMapPtr = std::unique_ptr<ValToFactorMap>;
- using ScopedPersistenceFactorMap =
- std::unordered_map<Val*, ValToFactorMapPtr>;
-
- HeuristicCacheAccessor<ScopedPersistenceFactorMap>
- scoped_persistent_factor_data;
- // TODO: move all these boilerplate code into the accessor class
- // (follow up)
-
- // Caching traversal result in this case.
- // This one is slightly more involving. The end result we want is all the
- // concrete
- // int values in scoped_persistence. Essentially:
- // scoped_persistence [val] = sum_over_all_persistent_tv (
- // contrubution_from_tv_to_val * persistent_size_of_tv )
- // Here contrubution_from_tv_to_val can be determined at compile time.
- // persistent_size_of_tv is a runtime value but
- // doesn't require heavy graph traversal.
- // So in this cache entry we try to save a matrix of contribution factors,
- // i.e.
- //
- // new_persistent_factor_map[tv][val] = contribution_from_tv_to_val, from
- // compile time and we combine the factor
- //
- // with runtime persistent buffer sizes at runtime.
- if (data_cache && !data_cache->isRecording()) {
- scoped_persistent_factor_data.writeTemporary(
- data_cache->getScopedPersistenceFactorMap());
- } else {
- // Compute new scoped persisitence factor:
- auto new_persistent_factor_map_ptr =
- std::make_unique<ScopedPersistenceFactorMap>();
- auto& new_persistent_factor_map = *new_persistent_factor_map_ptr;
-
- for (auto tv : persistent_buffers.buffers) {
- auto& consumer_tv_to_factor_map_ptr = new_persistent_factor_map[tv];
- consumer_tv_to_factor_map_ptr = std::make_unique<ValToFactorMap>();
- auto& consumer_tv_to_factor_map = *consumer_tv_to_factor_map_ptr;
-
- // All expressions between tv and its consumers must have tv's persistent
- // buffer allocated. This is an optimistic view on how many registers we
- // need allocated in the kernel, since if we ordered two persistent
- // buffers that are completely independent to somehow overlap with
- // eachother we would assume we wouldn't need those two buffers active at
- // the same time, even though they would be.
- //
- // Unfortunately this limitation is hard to work around as we would have
- // to actually generate the kernel before we know if it would fit
- // persistently in registers. In practice, though, this should not happen
- // as inlining loop structures where the persistent buffer is used should
- // prevent muiltiple persistent buffers from being merged togther if not
- // necessary.
- auto consumers_of_tv = ir_utils::consumerTvsOf(tv);
- for (auto val : DependencyCheck::getAllValsBetween(
- {tv}, {consumers_of_tv.begin(), consumers_of_tv.end()})) {
- // Persistent normalization kernels imply that all persistent buffers
- // have the same dimensionality. Assume if a persistent buffer is
- // consumed by another we can alias and reuse the memory.
- if (val == tv) {
- continue;
- }
-
- if (consumer_tv_to_factor_map.count(val)) {
- consumer_tv_to_factor_map.at(val) += 1;
- } else {
- consumer_tv_to_factor_map[val] = 1;
- }
- }
- }
-
- // Caching boilerplate (TO be cleaned up in a follow up)
- scoped_persistent_factor_data.takeNew(new_persistent_factor_map_ptr);
- if (data_cache && data_cache->isRecording()) {
- data_cache->setScopedPersistenceFactorMap(
- scoped_persistent_factor_data.read());
- }
- }
-
- auto& scoped_persistence_factor = scoped_persistent_factor_data.read();
-
- // Runtime: convert the persistent factor to actual values
- std::unordered_map<Val*, int64_t> scoped_persistence;
-
- for (auto tv : persistent_buffers.buffers) {
- int64_t tv_persistent_numel = -1;
- for (auto id : tv->getMaybeRFactorDomain()) {
- if (id->isReduction() || id->isBroadcast()) {
- continue;
- }
- // Unmappable dimensions are those that we cannot inline into other
- // tensor views. So they're the ones that need to be persistent.
- if (!persistent_buffers.unmappable_dims.count(id)) {
- continue;
- }
-
- auto id_size = runtime_info.expressionEvaluator().evaluate(id->extent());
- TORCH_INTERNAL_ASSERT(
- id_size.has_value(),
- "Cannot generate heuristics if we don't have input information.");
- if (tv_persistent_numel == -1) {
- tv_persistent_numel = id_size.value();
- } else {
- tv_persistent_numel *= id_size.value();
- }
- }
-
- persistent_buffer_size =
- tv_persistent_numel * dataTypeSize(tv->getDataType().value());
-
- // Look up the contribution part from the cached matrix:
- auto scoped_factor_it = scoped_persistence_factor.find(tv);
- if (scoped_factor_it != scoped_persistence_factor.end()) {
- // now looking at scoped_persistence_factor[tv]
- for (auto val_to_factor_it : *(scoped_factor_it->second)) {
- // (val_to_factor_it) is (val, factor)
- int64_t persistent_buffer_size_contribution =
- persistent_buffer_size * val_to_factor_it.second;
-
- // try to write factor * persistent_buffer_size into
- // scoped_persistence[val]
- auto val_it = scoped_persistence.find(val_to_factor_it.first);
- if (val_it == scoped_persistence.end()) {
- scoped_persistence[val_to_factor_it.first] =
- persistent_buffer_size_contribution;
- } else {
- val_it->second += persistent_buffer_size_contribution;
- }
- }
- }
- }
-
- // Find the maximum persistent buffer use
- int64_t max_persistence_size = 0;
- for (auto persistent_entry : scoped_persistence) {
- max_persistence_size =
- std::max(max_persistence_size, persistent_entry.second);
- }
-
- return max_persistence_size;
-}
-
-std::unordered_set<IterDomain*> getTrivialReductionMap(Fusion* fusion) {
- auto all_tvs = ir_utils::allTvs(fusion);
- std::unordered_set<IterDomain*> mapped_to_trivial_reduction;
- for (auto tv : all_tvs) {
- // root domain vs domain shouldn't matter as at this point we shouldn't have
- // any transformations.
- for (auto id : tv->getRootDomain()) {
- if (id->isTrivialReduction()) {
- mapped_to_trivial_reduction.emplace(id);
- }
- }
- }
-
- if (!mapped_to_trivial_reduction.empty()) {
- // Shouldn't matter which compute at map we use
- auto ca_index_map = ComputeAtMap(ComputeAtMap::MappingMode::INDEX);
- ca_index_map.build(fusion);
- // Make a copy we need to check mappings of all
- auto trivial_ids = mapped_to_trivial_reduction;
- for (auto tv : all_tvs) {
- for (auto id : tv->getRootDomain()) {
- if (!id->extent()->isOneInt()) {
- continue;
- }
- if (std::any_of(
- trivial_ids.begin(),
- trivial_ids.end(),
- [&ca_index_map, &id](IterDomain* trivial_id) {
- return ca_index_map.areMapped(id, trivial_id);
- })) {
- mapped_to_trivial_reduction.emplace(id);
- }
- }
- }
- }
- return mapped_to_trivial_reduction;
-}
-
-std::pair<bool, bool> canonicalDimReduction(Fusion* fusion, TensorView* tv) {
- std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
- getTrivialReductionMap(fusion);
-
- TORCH_INTERNAL_ASSERT(tv != nullptr);
-
- // We coalesce all reduction axes to the right;
- bool has_red_axis = mergeReduction(tv, mapped_to_trivial_reduction) > 0;
-
- bool has_iter_axis = mergeNonReduction(tv, mapped_to_trivial_reduction) > 0;
- return {has_iter_axis, has_red_axis};
-}
-
-std::vector<TensorView*> getReductionTvs(Fusion* fusion) {
- auto all_tvs = ir_utils::allTvs(fusion);
- std::vector<TensorView*> reduction_tvs;
- for (auto tv : all_tvs) {
- if (!tv->isFusionInput() &&
- std::any_of(
- tv->domain()->domain().begin(),
- tv->domain()->domain().end(),
- [](IterDomain* id) {
- return id->isReduction() && !id->isTrivialReduction();
- })) {
- reduction_tvs.emplace_back(tv);
- }
- }
-
- // Remove multi outputs from reduction tensor views
- std::unordered_set<Expr*> seen_reduction_exprs;
- reduction_tvs.erase(
- std::remove_if(
- reduction_tvs.begin(),
- reduction_tvs.end(),
- [&seen_reduction_exprs](TensorView* tv) {
- TORCH_INTERNAL_ASSERT(
- tv->definition() != nullptr,
- "Somehow a tensor view without a definition but a reduction snuck into the scheduler reduction list.");
- if (!seen_reduction_exprs.emplace(tv->definition()).second) {
- return true;
- }
- return false;
- }),
- reduction_tvs.end());
- return reduction_tvs;
-}
-
-TensorView* scheduleReductionTV(
- const ReductionParams& rparams,
- TensorView* reduction_tv,
- bool has_iter_axis) {
- TensorView* reference_tv = nullptr;
- if (rparams.fastest_dim) {
- const int iter_axis = 0;
- const int reduce_axis = has_iter_axis ? 1 : 0;
-
- // Do multiple reductions per block
- if (rparams.multiple_reds_per_blk) {
- if (rparams.reduction_unroll) {
- // Fastest dim, multiple reductions per block
- // Output Dimensions
- // [x-BIDx, x-TIDy
- // 0 1
- //
- // Reduction Dimensions
- // rF-Remain, rf-Unswitch, rf-Unroll, X-TIDx]
- // 2(r) 3(r+1) 4(r+2) 5(r+3)
- // Reduction Dimensions
- // rF-Remain, rf-Unswitch, X-TIDx, rf-Vectorize]
- // 2(r) 3(r+1) 4(r+2) 5(r+3)
-
- // X-TIDx, rF-Remain, rf-Unswitch, rf-Unroll/Vect]
- // 2(r) 3(r+1) 4(r+2) 5(r+3)
-
- if (!rparams.persistent_kernel) {
- if (rparams.vectorize) {
- reduction_tv->split(reduce_axis, rparams.loop_unroll);
- reduction_tv->split(
- reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
- } else {
- reduction_tv->split(
- reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
- reduction_tv->split(reduce_axis, rparams.loop_unroll);
- }
- // Unswitch axis which gives us finer control on allocations with
- // unrolling
- reduction_tv->split(reduce_axis, 1);
- } else {
- if (rparams.vectorize) {
- reduction_tv->split(reduce_axis, rparams.batches_per_block, false);
- reduction_tv->split(reduce_axis + 1, rparams.loop_unroll);
- } else {
- reduction_tv->split(
- reduce_axis,
- rparams.batches_per_block * rparams.loop_unroll,
- false);
- reduction_tv->split(reduce_axis, rparams.loop_unroll);
- }
- // Unswitch axis which gives us finer control on allocations with
- // unrolling
- reduction_tv->split(reduce_axis, 1);
- }
-
- if (rparams.vectorize) {
- reduction_tv->reorder(
- {{reduce_axis, reduce_axis + 1},
- {reduce_axis + 1, reduce_axis + 2},
- {reduce_axis + 2, reduce_axis}});
- } else {
- reduction_tv->reorder(
- {{reduce_axis + 3, reduce_axis},
- {reduce_axis, reduce_axis + 1},
- {reduce_axis + 1, reduce_axis + 2},
- {reduce_axis + 2, reduce_axis + 3}});
- }
-
- reference_tv = ir_utils::rfactorHelper(
- reduction_tv, {reduce_axis + 1, reduce_axis + 2, reduce_axis + 3});
-
- reference_tv->axis(reduce_axis)->parallelize(ParallelType::TIDx);
-
- if (rparams.vectorize) {
- reference_tv->axis(reduce_axis + 3)
- ->parallelize(ParallelType::Vectorize);
- } else {
- reference_tv->axis(reduce_axis + 3)
- ->parallelize(ParallelType::Unroll);
- }
- reference_tv->axis(reduce_axis + 2)
- ->parallelize(ParallelType::Unswitch);
-
- if (has_iter_axis) {
- reference_tv->split(
- iter_axis, NamedScalar::getParallelDim(ParallelType::TIDy));
- reference_tv->axis(iter_axis + 1)->parallelize(ParallelType::TIDy);
- if (rparams.split_grid_dim) {
- reference_tv->split(iter_axis, x_grid_limit);
- reference_tv->axis(iter_axis + 1)->parallelize(ParallelType::BIDx);
- } else {
- reference_tv->axis(iter_axis)->parallelize(ParallelType::BIDx);
- }
- }
- } else {
- TORCH_INTERNAL_ASSERT(
- has_iter_axis,
- "This scheduler requires an outer dim to the reduction.");
- // Fastest dim, Multiple reductions per block iter unroll
- // Output Dimensions
- // [x-BIDx, x-Unswitch, x-Unroll, x-TIDy
- // 0 1 2 3
- //
- // Reduction Dimensions
- // rF-Remain, r-TIDx]
- // 4(r) 5(r+1)
- if (!rparams.persistent_kernel) {
- reduction_tv->split(
- 1, NamedScalar::getParallelDim(ParallelType::TIDx));
- } else {
- reduction_tv->split(1, rparams.batches_per_block, false);
- }
-
- reference_tv = ir_utils::rfactorHelper(reduction_tv, {1});
-
- reference_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDy));
- reference_tv->split(0, rparams.loop_unroll);
- // Unswitch axis which gives us finer control on allocations with
- // unrolling
- reference_tv->split(0, 1);
-
- // [x-BIDx, x-Unswitch, x-Unroll, x-TIDy, rF-Remain, r-TIDx]
- // 0 1 2 3 4 5
- // -> [x-BIDx, x-TIDy, rF-Remain, x-Unswitch, x-Unroll, r-TIDx]
- // 0 1 2 3 4 5
-
- reference_tv->reorder({{1, 3}, {2, 4}, {3, 1}, {4, 2}});
-
- reference_tv->axis(1)->parallelize(ParallelType::TIDy);
- reference_tv->axis(3)->parallelize(ParallelType::Unswitch);
- reference_tv->axis(4)->parallelize(ParallelType::Unroll);
- reference_tv->axis(5)->parallelize(ParallelType::TIDx);
-
- if (rparams.split_grid_dim) {
- reference_tv->split(0, x_grid_limit);
- reference_tv->axis(1)->parallelize(ParallelType::BIDx);
- } else {
- reference_tv->axis(0)->parallelize(ParallelType::BIDx);
- }
- }
- } else {
- // Not multiple reductions per block
- if (rparams.cross_grid) {
- TORCH_INTERNAL_ASSERT(
- rparams.reduction_unroll,
- "Unrolling on iter domain not supported in this scheduler.");
-
- TORCH_INTERNAL_ASSERT(
- !rparams.persistent_kernel,
- "Grid reductions not implemented yet for persistent kernels.");
-
- // Fastest dim, cross grid, cross block
- // [outputs,
- // Idx: 0
- // | rf-Remain, r-BIDx, r-TIDy, rf-Unswitch, rf-Unroll, r-TIDx]
- // 1(r) 2(r+1) 3(r+2) 4(r+3) 5(r+4) 6(r+5)|
- // | rf-Remain, r-BIDx, r-TIDy, rf-Unswitch, r-TIDx, r-Vectorize]
- // 1(r) 2(r+1) 3(r+2) 4(r+3) 5(r+4) 6(r+5)|
- // Reduction Dimensions
-
- // | r-BIDx, r-TIDy, r-TIDx, rf-Remain, rf-Unswitch, rf-Unroll/Vect]
- // 1(r) 2(r+1) 3(r+2) 4(r+3) 5(r+4) 6(r+5) |
- // Reduction Dimensions
-
- if (rparams.vectorize) {
- reduction_tv->split(reduce_axis, rparams.loop_unroll);
- reduction_tv->split(
- reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
- } else {
- reduction_tv->split(
- reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
- reduction_tv->split(reduce_axis, rparams.loop_unroll);
- }
- reduction_tv->split(reduce_axis, 1);
- // Unswitch axis which gives us finer control on allocations with
- // unrolling
- reduction_tv->split(
- reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDy));
- reduction_tv->split(
- reduce_axis, NamedScalar::getParallelDim(ParallelType::BIDx));
-
- if (rparams.vectorize) {
- reduction_tv->reorder(
- {{reduce_axis, reduce_axis + 3},
- {reduce_axis + 1, reduce_axis},
- {reduce_axis + 2, reduce_axis + 1},
- {reduce_axis + 3, reduce_axis + 4},
- {reduce_axis + 4, reduce_axis + 2}});
- } else {
- reduction_tv->reorder(
- {{reduce_axis, reduce_axis + 3},
- {reduce_axis + 1, reduce_axis},
- {reduce_axis + 2, reduce_axis + 1},
- {reduce_axis + 3, reduce_axis + 4},
- {reduce_axis + 4, reduce_axis + 5},
- {reduce_axis + 5, reduce_axis + 2}});
- }
-
- reference_tv = ir_utils::rfactorHelper(
- reduction_tv, {reduce_axis + 3, reduce_axis + 4, reduce_axis + 5});
-
- if (rparams.vectorize) {
- reference_tv->axis(reduce_axis + 5)
- ->parallelize(ParallelType::Vectorize);
- } else {
- reference_tv->axis(reduce_axis + 5)
- ->parallelize(ParallelType::Unroll);
- }
- reference_tv->axis(reduce_axis + 4)
- ->parallelize(ParallelType::Unswitch);
-
- reference_tv->axis(reduce_axis + 2)->parallelize(ParallelType::TIDx);
- reference_tv->axis(reduce_axis + 1)->parallelize(ParallelType::TIDy);
- reference_tv->axis(reduce_axis)->parallelize(ParallelType::BIDx);
-
- if (has_iter_axis) {
- if (rparams.split_grid_dim) {
- reference_tv->split(iter_axis, y_grid_limit);
- reference_tv->axis(iter_axis + 1)->parallelize(ParallelType::BIDy);
- } else {
- reference_tv->axis(iter_axis)->parallelize(ParallelType::BIDy);
- }
- }
-
- } else {
- // Not cross grid
- if (rparams.reduction_unroll) {
- // Fastest dim, Reduction unroll
- // Output Dimensions
- // [BIDx
- // 0
- //
- // Reduction Dimensions
- // rF-Remain, rf-Unswitch, rf-Unroll, r-TIDx]
- // 1(r) 2(r+1) 3(r+2) 4(r+3)
- // rF-Remain, rf-Unswitch, r-TIDx, rf-Vectorize]
- // 1(r) 2(r+1) 3(r+2) 4(r+3)
-
- // r-TIDx, rF-Leftover, rf-Unswitch, rf-Unroll]
- // 1(r) 2(r+1) 3(r+2) 4(r+3)
-
- if (!rparams.persistent_kernel) {
- if (rparams.vectorize) {
- reduction_tv->split(reduce_axis, rparams.loop_unroll);
- reduction_tv->split(
- reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
- } else {
- reduction_tv->split(
- reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
- reduction_tv->split(reduce_axis, rparams.loop_unroll);
- }
- // Unswitch axis which gives us finer control on allocations with
- // unrolling
- reduction_tv->split(reduce_axis, 1);
- } else {
- if (rparams.vectorize) {
- reduction_tv->split(
- reduce_axis, rparams.batches_per_block, false);
- reduction_tv->split(reduce_axis + 1, rparams.loop_unroll);
- } else {
- reduction_tv->split(
- reduce_axis,
- rparams.batches_per_block * rparams.loop_unroll,
- false);
- reduction_tv->split(reduce_axis, rparams.loop_unroll);
- }
- // Unswitch axis which gives us finer control on allocations with
- // unrolling
- reduction_tv->split(reduce_axis, 1);
- }
-
- if (rparams.vectorize) {
- reduction_tv->reorder(
- {{reduce_axis + 2, reduce_axis},
- {reduce_axis, reduce_axis + 1},
- {reduce_axis + 1, reduce_axis + 2}});
- } else {
- reduction_tv->reorder(
- {{reduce_axis + 3, reduce_axis},
- {reduce_axis, reduce_axis + 1},
- {reduce_axis + 1, reduce_axis + 2},
- {reduce_axis + 2, reduce_axis + 3}});
- }
-
- reference_tv = ir_utils::rfactorHelper(
- reduction_tv,
- {reduce_axis + 1, reduce_axis + 2, reduce_axis + 3});
-
- reference_tv->axis(reduce_axis)->parallelize(ParallelType::TIDx);
- if (rparams.vectorize) {
- reference_tv->axis(reduce_axis + 3)
- ->parallelize(ParallelType::Vectorize);
- } else {
- reference_tv->axis(reduce_axis + 3)
- ->parallelize(ParallelType::Unroll);
- }
- reference_tv->axis(reduce_axis + 2)
- ->parallelize(ParallelType::Unswitch);
-
- if (has_iter_axis) {
- if (rparams.split_grid_dim) {
- reference_tv->split(iter_axis, x_grid_limit);
- reference_tv->axis(iter_axis + 1)
- ->parallelize(ParallelType::BIDx);
- } else {
- reference_tv->axis(iter_axis)->parallelize(ParallelType::BIDx);
- }
- }
- } else {
- TORCH_INTERNAL_ASSERT(
- has_iter_axis, "Need iteration axis for iteration unroll.");
- // Fastest dim, Reduction Splits
- // Output Dimensions
- // [BIDx, x-Unswitch, x-Unroll
- // 0
- //
- // Reduction Dimensions
- // rF-Remain, r-TIDx]
- // 1(r) 2(r+1)
-
- if (!rparams.persistent_kernel) {
- reduction_tv->split(
- reduce_axis, NamedScalar::getParallelDim(ParallelType::TIDx));
- } else {
- reduction_tv->split(reduce_axis, rparams.batches_per_block, false);
- }
-
- reduction_tv->split(iter_axis, rparams.loop_unroll);
- // Unswitch axis which gives us finer control on allocations with
- // unrolling
- reduction_tv->split(iter_axis, 1);
-
- // [x-BIDx, x-Unswitch, x-Unroll, rF-Remain, r-TIDx]
- // 0 1 2 3 4
- // -> [x-BIDx, rF-Remain, x-Unswitch, x-Unroll, r-TIDx]
- // 0 1 2 3 4
-
- reduction_tv->reorder({{1, 2}, {2, 3}, {3, 1}});
-
- reference_tv = ir_utils::rfactorHelper(reduction_tv, {1});
-
- reference_tv->axis(4)->parallelize(ParallelType::TIDx);
- reference_tv->axis(3)->parallelize(ParallelType::Unroll);
- reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
-
- if (rparams.split_grid_dim) {
- reference_tv->split(0, x_grid_limit);
- reference_tv->axis(1)->parallelize(ParallelType::BIDx);
- } else {
- reference_tv->axis(0)->parallelize(ParallelType::BIDx);
- }
- }
- }
- }
- } else {
- if (rparams.cross_block) {
- if (rparams.cross_grid) {
- TORCH_INTERNAL_ASSERT(
- rparams.reduction_unroll,
- "Unrolling on iter domain not supported in this scheduler.");
-
- TORCH_INTERNAL_ASSERT(
- !rparams.persistent_kernel,
- "Grid reductions not implemented yet for persistent kernels.");
-
- // Outer Dim, cross grid, cross block
-
- // Unrolling in this case can only be applied to the reduction
- // dimension since currently, grid reductions cannot be called
- // multiple times
- //
- // Output Dimensions
- // [x-BIDx, x-TIDx,
- // 0 1
- //
- // Reduction Dimensions
- // rF-Leftover, r-BIDy, r-TIDy, rf-Unswitch, rf-Unroll]
- // 2(-5) 3(-4) 4(-3) 5(-2) 6(-1)
-
- // r-BIDy, r-TIDy, rF-Leftover, rf-Unswitch, rf-Unroll]
- // 2(-5) 3(-4) 4(-3) 5(-2) 6(-1)
-
- reduction_tv->split(1, rparams.loop_unroll);
- // Unswitch axis which gives us finer control on allocations with
- // unrolling
- reduction_tv->split(1, 1);
- reduction_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
- reduction_tv->split(1, NamedScalar::getParallelDim(ParallelType::BIDy));
-
- reduction_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
-
- reduction_tv->reorder({{2, 4}, {3, 2}, {4, 3}});
-
- reference_tv = ir_utils::rfactorHelper(
- reduction_tv,
- {4, 5, 6}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-
- reference_tv->axis(6)->parallelize(ParallelType::Unroll);
- reference_tv->axis(5)->parallelize(ParallelType::Unswitch);
- reference_tv->axis(3)->parallelize(ParallelType::TIDy);
- reference_tv->axis(2)->parallelize(ParallelType::BIDy);
- reference_tv->axis(1)->parallelize(ParallelType::TIDx);
- reference_tv->axis(0)->parallelize(ParallelType::BIDx);
- } else {
- if (rparams.reduction_unroll || rparams.loop_unroll == 1) {
- // Outer Dim, cross block, unroll reduction dimension
-
- // Reduction Splits
- // Output Dimensions
- // [x-BIDx, x-TIDx
- // 0 1
- //
- // Reduction Dimensions
- // rF-Leftover, r-TIDy, rf-Unswitch, rf-Unroll]
- // 2(-4) 3(-3) 4(-2) 5(-1)
-
- // r-TIDy, rF-Leftover, rf-Unswitch, rf-Unroll]
- // 2(-4) 3(-3) 4(-2) 5(-1)
- if (!rparams.persistent_kernel) {
- reduction_tv->split(1, rparams.loop_unroll);
- // Unswitch axis which gives us finer control on allocations with
- // unrolling
- reduction_tv->split(1, 1);
- reduction_tv->split(
- 1, NamedScalar::getParallelDim(ParallelType::TIDy));
- } else {
- reduction_tv->split(1, rparams.batches_per_block, false);
- reduction_tv->split(2, rparams.loop_unroll);
- reduction_tv->split(2, 1);
- }
-
- reduction_tv->split(
- 0, NamedScalar::getParallelDim(ParallelType::TIDx));
-
- reduction_tv->reorder({{2, 3}, {3, 2}});
-
- reference_tv = ir_utils::rfactorHelper(
- reduction_tv,
- {3, 4, 5}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-
- reference_tv->axis(5)->parallelize(ParallelType::Unroll);
- reference_tv->axis(4)->parallelize(ParallelType::Unswitch);
- reference_tv->axis(2)->parallelize(ParallelType::TIDy);
- reference_tv->axis(1)->parallelize(ParallelType::TIDx);
- reference_tv->axis(0)->parallelize(ParallelType::BIDx);
- } else {
- // Outer Dim, cross block, unroll iter dimension
-
- // Output Dimensions
- // [x-BIDx, x-Unswitch, x-Unroll, x-TIDx
- // 0 1 2 3
- // [x-BIDx, x-Unswitch, x-TIDx, x-Vectorize
- // 0 1 2 3
- //
- // Reduction Dimensions
- // rF-Leftover, r-TIDy]
- // 4(-2) 5(-1)
-
- // The unroll/unswitch dimension needs to be within the rF-Leftover
- // dimension
- // [x-BIDx, x-Unswitch, x-Unroll, x-TIDx, rF-Leftover, r-TIDy]
- // 0(-6) 1(-5) 2(-4) 3(-3) 4(-2) 5(-1)
- // [x-BIDx, x-Unswitch, x-TIDx, x-Vectorize, rF-Leftover, r-TIDy]
- // 0(-6) 1(-5) 2(-4) 3(-3) 4(-2) 5(-1)
- // -> [x-BIDx, x-TIDx, rF-Leftover, x-Unswitch, x-Unroll/Vect,
- // r-TIDy]
- // 0(-6) 1(-5) 2(-4) 3(-3) 4(-2) 5(-1)
-
- if (!rparams.persistent_kernel) {
- reduction_tv->split(
- 1, NamedScalar::getParallelDim(ParallelType::TIDy));
- } else {
- reduction_tv->split(1, rparams.batches_per_block, false);
- }
- if (rparams.vectorize) {
- reduction_tv->split(0, rparams.loop_unroll);
- reduction_tv->split(
- 0, NamedScalar::getParallelDim(ParallelType::TIDx));
-
- } else {
- reduction_tv->split(
- 0, NamedScalar::getParallelDim(ParallelType::TIDx));
- reduction_tv->split(0, rparams.loop_unroll);
- }
- // Unswitch axis which gives us finer control on allocations with
- // unrolling
- reduction_tv->split(0, 1);
-
- if (rparams.vectorize) {
- reduction_tv->reorder({{1, 3}, {2, 1}, {3, 4}, {4, 2}});
- } else {
- reduction_tv->reorder({{1, 3}, {2, 4}, {3, 1}, {4, 2}});
- }
-
- reference_tv = ir_utils::rfactorHelper(
- reduction_tv,
- {2}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
-
- reference_tv->axis(5)->parallelize(ParallelType::TIDy);
- reference_tv->axis(1)->parallelize(ParallelType::TIDx);
- if (rparams.vectorize) {
- reference_tv->axis(4)->parallelize(ParallelType::Vectorize);
- } else {
- reference_tv->axis(4)->parallelize(ParallelType::Unroll);
- }
- reference_tv->axis(3)->parallelize(ParallelType::Unswitch);
- reference_tv->axis(0)->parallelize(ParallelType::BIDx);
- }
- }
- } else {
- if (rparams.reduction_unroll) {
- // Outer Dim, no parallelization on reduction, unroll reduction axis
- // Output Dimensions
- // [x-BIDx, x-TIDx
- // 0 1
- //
- // Reduction Dimensions
- // rf-Leftover, rf-Unswitch, r-Unroll]
- // 2 3 4
- if (rparams.persistent_kernel) {
- reduction_tv->split(1, rparams.batches_per_block, false);
- reduction_tv->split(2, rparams.loop_unroll);
- // Reduction Dimensions
- // rf-Leftover, r-TIDy, rf-Unroll]
- // 2 3 4
- } else {
- reduction_tv->split(1, rparams.loop_unroll);
- // Unswitch axis which gives us finer control on allocations with
- // unrolling
- reduction_tv->split(1, 1);
- }
-
- reduction_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
-
- if (rparams.persistent_kernel) {
- // [x-BIDx, x-TIDx, rf-Leftover, r-TIDy, rf-Unroll]
- // 0 1 2 3 4
- reduction_tv->reorder({{3, 2}, {2, 3}});
- // [x-BIDx, x-TIDx, r-TIDy, rf-Leftover, rf-Unroll]
- // 0 1 2 3 4
- reference_tv = ir_utils::rfactorHelper(
- reduction_tv,
- {3, 4}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- reference_tv->axis(0)->parallelize(ParallelType::BIDx);
- reference_tv->axis(1)->parallelize(ParallelType::TIDx);
- reference_tv->axis(2)->parallelize(ParallelType::TIDy);
- reference_tv->axis(3)->parallelize(ParallelType::Unswitch);
- reference_tv->axis(4)->parallelize(ParallelType::Unroll);
- } else {
- reference_tv = ir_utils::rfactorHelper(
- reduction_tv,
- {2, 3}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- reference_tv->axis(0)->parallelize(ParallelType::BIDx);
- reference_tv->axis(1)->parallelize(ParallelType::TIDx);
- reference_tv->axis(3)->parallelize(ParallelType::Unswitch);
- reference_tv->axis(4)->parallelize(ParallelType::Unroll);
- }
- } else {
- // No parallelization on reduction, unroll iter axis
- // Output Dimensions
- // [x-BIDx, x-Unswitch, x-Unroll, x-TIDx
- // 0 1 2 3
- // [x-BIDx, x-Unswitch, x-TIDx, x-Vectorize
- // 0 1 2 3
- //
- // Reduction Dimensions
- // rf-Leftover, r-{1}]
- // 4(-1)
- //
- // Fake an rfactor to make scheduling more consistent.
- //
- // The unroll/unswitch dimension needs to be within the rF-Leftover
- // dimension
- if (rparams.persistent_kernel) {
- reduction_tv->split(1, rparams.batches_per_block, false);
- } else {
- reduction_tv->split(1, 1);
- }
-
- if (rparams.vectorize) {
- reduction_tv->split(0, rparams.loop_unroll);
- reduction_tv->split(
- 0, NamedScalar::getParallelDim(ParallelType::TIDx));
- } else {
- reduction_tv->split(
- 0, NamedScalar::getParallelDim(ParallelType::TIDx));
- reduction_tv->split(0, rparams.loop_unroll);
- }
-
- reduction_tv->split(0, 1);
-
- // [x-BIDx, x-Unswitch, x-Unroll, x-TIDx, rf-Leftover, r-1]
- // 0 1 2 3 4 5
- // [x-BIDx, x-Unswitch, x-TIDx, x-Vectorize, rf-Leftover, r-1]
- // 0 1 2 3 4 5
-
- if (rparams.vectorize) {
- reduction_tv->reorder({{1, 3}, {2, 1}, {3, 4}, {4, 2}});
- } else {
- reduction_tv->reorder({{1, 3}, {2, 4}, {3, 1}, {4, 2}});
- }
-
- // [x-BIDx, x-TIDx, rf-Leftover, x-Unswitch, x-Unroll, r-1(TIDy)]
- // 0 1 2 3 4 5
-
- reference_tv = ir_utils::rfactorHelper(reduction_tv, {2});
- if (rparams.persistent_kernel) {
- reference_tv->axis(5)->parallelize(ParallelType::TIDy);
- }
-
- reference_tv->axis(0)->parallelize(ParallelType::BIDx);
- reference_tv->axis(1)->parallelize(ParallelType::TIDx);
- reference_tv->axis(3)->parallelize(ParallelType::Unswitch);
- if (rparams.vectorize) {
- reference_tv->axis(4)->parallelize(ParallelType::Vectorize);
- } else {
- reference_tv->axis(4)->parallelize(ParallelType::Unroll);
- }
- }
- }
- }
- return reference_tv;
-}
-
-// Reset inputs and outputs to global memory, everything else to local.
-void clearMemorySpace(Fusion* fusion) {
- for (auto tv : ir_utils::allTvs(fusion)) {
- if (tv->isFusionInput() || tv->isFusionOutput()) {
- tv->setMemoryType(MemoryType::Global);
- } else {
- tv->setMemoryType(MemoryType::Local);
- }
- }
-}
-
-// Returns cached after tensors of the fusion inputs if unrolled. Otherwise
-// return empty vector.
-std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll) {
- if (!unroll) {
- return {};
- }
-
- std::vector<TensorView*> cached_inputs;
- // If we're going to unroll, make a cache of the inputs
- auto in_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
- for (auto tv : in_tvs) {
- if (tv->uses().empty()) {
- continue;
- }
- auto cached_tv = tv->cache_after();
- cached_inputs.emplace_back(cached_tv);
- }
- return cached_inputs;
-}
-
-// Returns the pairs of <cache of each fusion output, corresponding output> for
-// all outputs.
-std::vector<std::pair<TensorView*, TensorView*>> cacheAndForkOutputs(
- Fusion* fusion,
- bool unroll) {
- std::vector<std::pair<TensorView*, TensorView*>> cached_outputs;
- // For intermediate outputs, apply cache_fork
- for (const auto output :
- ir_utils::filterByType<TensorView>(fusion->outputs())) {
- if (output->definition() == nullptr) {
- continue;
- }
- if (!output->uses().empty()) {
- auto cached_output = output->as<TensorView>()->cache_fork();
- cached_outputs.emplace_back(std::make_pair(output, cached_output));
- } else if (unroll) {
- auto cached_output = output->as<TensorView>()->cache_before();
- cached_outputs.emplace_back(std::make_pair(cached_output, output));
- }
- }
- return cached_outputs;
-}
-
-void multiReductionInliner(
- Fusion* fusion,
- const ReductionParams& rparams,
- TensorView* reduction_tv,
- TensorView* reference_tv,
- std::vector<TensorView*> reduction_tvs,
- std::vector<TensorView*> cached_inputs,
- std::vector<std::pair<TensorView*, TensorView*>> cached_outputs) {
- TransformPropagator::from(reference_tv);
-
- // Apply rfactor to all reductions if applicable
- std::vector<TensorView*> rfactor_tvs;
-
- if (reference_tv != reduction_tv) {
- std::vector<int> rfactor_axes;
- for (size_t i = 0; i < reference_tv->nDims(); i++) {
- if (reference_tv->axis((int)i)->isReduction() &&
- reference_tv->axis((int)i)->isRFactorProduct()) {
- rfactor_axes.push_back((int)i);
- }
- }
-
- for (auto reduction_tv_ : reduction_tvs) {
- if (reduction_tv_ == reduction_tv) {
- // The reduction tv
- rfactor_tvs.push_back(reference_tv);
- continue;
- } else {
- rfactor_tvs.push_back(
- ir_utils::rfactorHelper(reduction_tv_, rfactor_axes));
- }
- }
-
- TORCH_INTERNAL_ASSERT(
- reduction_tvs.size() == rfactor_tvs.size(),
- "Expected all reductions to contain rfactor.");
- }
-
- // Propagate parallelization
- parallelizeAllLike(reference_tv, ir_utils::allTvs(fusion));
-
- // Find iter domains that are mapped to a trivial reduction, these should
- // never be inlined.
- std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
- getTrivialReductionMap(fusion);
-
- if (rparams.loop_unroll > 1) {
- // Inline Input caches to their consumers outside unswitched/vectorization
- // position Inline consumers of input caches to rfactor tensors
-
- // Mark which tensor views are actual input caches to leave vectorization on
- // them
- std::unordered_set<TensorView*> keep_unrolled;
-
- std::vector<TensorView*> compute_from;
-
- // Grab all tensor views that should be vectorized
- auto vecotrizable_inputs_outputs =
- getVectorizableInputsOutputs(reference_tv);
-
- // Inputs to cache
- for (auto cached_input : cached_inputs) {
- auto consumers_of_input_cache = ir_utils::consumerTvsOf(cached_input);
- for (auto consumer : consumers_of_input_cache) {
- auto unswitch_it = std::find_if(
- consumer->domain()->domain().begin(),
- consumer->domain()->domain().end(),
- [&mapped_to_trivial_reduction](IterDomain* id) {
- return id->getParallelType() == ParallelType::Unswitch ||
- id->getParallelType() == ParallelType::Unroll ||
- id->getParallelType() == ParallelType::Vectorize ||
- id->getParallelType() == ParallelType::MisalignedVectorize ||
- mapped_to_trivial_reduction.count(id);
- });
- auto unswitch_pos = unswitch_it == consumer->domain()->domain().end()
- ? -1
- : std::distance(consumer->domain()->domain().begin(), unswitch_it) +
- 1;
-
- cached_input->computeAt(
- consumer, unswitch_pos, ComputeAtMode::BestEffort);
- compute_from.push_back(consumer);
-
- if (rparams.vectorize) {
- auto producer_tvs = ir_utils::producerTvsOf(cached_input);
- if (producer_tvs.size() == 1 &&
- std::find(
- vecotrizable_inputs_outputs.begin(),
- vecotrizable_inputs_outputs.end(),
- producer_tvs[0]) != vecotrizable_inputs_outputs.end()) {
- keep_unrolled.emplace(cached_input);
- }
- } else {
- keep_unrolled.emplace(cached_input);
- }
- }
- }
-
- // Inline output caches into outputs
- std::vector<TensorView*> compute_to;
- for (auto cached_output_pair : cached_outputs) {
- auto cached_output = cached_output_pair.first;
- auto output = cached_output_pair.second;
-
- // If an output has multiple consumers don't process here, we want only
- // terminating outputs
- if (cached_output->uses().size() > 1) {
- continue;
- }
-
- auto pos_it = std::find_if(
- output->domain()->domain().begin(),
- output->domain()->domain().end(),
- [&mapped_to_trivial_reduction](IterDomain* id) {
- return id->getParallelType() == ParallelType::Unswitch ||
- id->getParallelType() == ParallelType::Unroll ||
- id->getParallelType() == ParallelType::Vectorize ||
- id->getParallelType() == ParallelType::MisalignedVectorize ||
- mapped_to_trivial_reduction.count(id);
- });
- auto pos = pos_it == output->domain()->domain().end()
- ? -1
- : std::distance(output->domain()->domain().begin(), pos_it) + 1;
-
- cached_output->computeAt(output, pos, ComputeAtMode::BestEffort);
-
- compute_to.push_back(cached_output);
- if (rparams.vectorize) {
- if (std::find(
- vecotrizable_inputs_outputs.begin(),
- vecotrizable_inputs_outputs.end(),
- output) != vecotrizable_inputs_outputs.end()) {
- keep_unrolled.emplace(output);
- }
- } else {
- keep_unrolled.emplace(output);
- }
- }
-
- // Before compute at-ing the internal structure, remove vectorization
- // anywhere it doesn't belong. Otherwise it will mess up our inlining. Clear
- // explicit unroll or vectorization when not for input or output GMEM
- // transfers.
- for (auto tv : ir_utils::allTvs(fusion)) {
- if (!keep_unrolled.count(tv)) {
- for (size_t i = 0; i < tv->nDims(); i++) {
- auto id = tv->axis((int)i);
- if (id->getParallelType() == ParallelType::Unroll ||
- id->getParallelType() == ParallelType::Vectorize ||
- id->getParallelType() == ParallelType::MisalignedVectorize) {
- tv->axis((int)i)->parallelize(ParallelType::Serial);
- }
- }
- }
- }
-
- // Make sure not to completely inline if there's trivial reductions in the
- // fusion
- auto pos_it = std::find_if(
- reference_tv->domain()->domain().begin(),
- reference_tv->domain()->domain().end(),
- [&mapped_to_trivial_reduction](IterDomain* id) {
- return mapped_to_trivial_reduction.count(id);
- });
-
- auto pos = pos_it == reference_tv->domain()->domain().end()
- ? -1
- : std::distance(reference_tv->domain()->domain().begin(), pos_it) + 1;
-
- // Compute at inputs to rfactor dimensions
- computeAtBetween(
- compute_from, rfactor_tvs, pos, ComputeAtMode::MostInlined);
-
- // Inline rfactor into reduction
- if (reference_tv != reduction_tv) {
- // Compute at rfactor into following reduction, keep outside first
- // reduction iter domain in the rfactor tensor view
- for (size_t i = 0; i < rfactor_tvs.size(); i++) {
- if (!rparams.reduction_unroll) {
- auto rfactor_tv = rfactor_tvs[i];
- auto rfactor_tv_dom = rfactor_tv->domain()->domain();
- auto reduction_it = std::find_if(
- rfactor_tv_dom.begin(), rfactor_tv_dom.end(), [](IterDomain* id) {
- return id->isReduction();
- });
- TORCH_INTERNAL_ASSERT(
- reduction_it != rfactor_tv_dom.end(),
- "Expected reduction axis in ",
- rfactor_tv);
- auto pos = std::distance(rfactor_tv_dom.begin(), reduction_it);
- rfactor_tv->computeWith(
- reduction_tvs[i], pos, ComputeAtMode::Standard);
- } else {
- rfactor_tvs[i]->computeWith(
- reduction_tvs[i], -1, ComputeAtMode::BestEffort);
- }
- }
- }
-
- // Remove anything before a reduction from compute_from
- {
- auto producers_of_reductions = DependencyCheck::getAllValsBetween(
- {fusion->inputs().begin(), fusion->inputs().end()},
- {reduction_tvs.begin(), reduction_tvs.end()});
-
- auto producer_tvs_of_reductions =
- ir_utils::filterByType<TensorView>(producers_of_reductions);
- compute_from.erase(
- std::remove_if(
- compute_from.begin(),
- compute_from.end(),
- [&producer_tvs_of_reductions](TensorView* compute_from_tv) {
- return std::find(
- producer_tvs_of_reductions.begin(),
- producer_tvs_of_reductions.end(),
- compute_from_tv) != producer_tvs_of_reductions.end();
- }),
- compute_from.end());
- }
-
- // Add reduction tensor views to compute from
- compute_from.insert(
- compute_from.end(), reduction_tvs.begin(), reduction_tvs.end());
-
- // Compute between reductions and output caches
- computeAtBetween(
- compute_from,
- compute_to,
- -1,
- ComputeAtMode::BestEffort,
- mapped_to_trivial_reduction);
-
- } else {
- // Want to inline, especially backwards based on reduction_tv, otherwise
- // rfactor tv may not be inlined correctly
- auto ref_tvs = rfactor_tvs.size() ? rfactor_tvs : reduction_tvs;
- for (auto red_tv : ref_tvs) {
- auto pos_it = std::find_if(
- red_tv->domain()->domain().begin(),
- red_tv->domain()->domain().end(),
- [&mapped_to_trivial_reduction](IterDomain* id) {
- return id->getParallelType() == ParallelType::Unswitch ||
- id->getParallelType() == ParallelType::Unroll ||
- id->getParallelType() == ParallelType::Vectorize ||
- id->getParallelType() == ParallelType::MisalignedVectorize ||
- mapped_to_trivial_reduction.count(id);
- });
- auto pos = pos_it == red_tv->domain()->domain().end()
- ? -1
- : std::distance(red_tv->domain()->domain().begin(), pos_it) + 1;
-
- computeAtInputs(red_tv, pos, ComputeAtMode::MostInlined);
- computeWithOutputs(red_tv, pos, ComputeAtMode::BestEffort);
- }
- }
-}
-
-FindAllMappedDims::FindAllMappedDims(TensorView* from, IterDomain* id)
- : starting_tv(from), starting_id(id) {
- std::deque<TensorView*> to_visit{starting_tv};
- std::unordered_set<TensorView*> visited;
- mapped_ids.emplace(std::make_pair(starting_tv, starting_id));
-
- // Propagate mapping of id
- while (!to_visit.empty()) {
- auto tv = to_visit.front();
- to_visit.pop_front();
-
- if (!visited.emplace(tv).second) {
- continue;
- }
-
- auto tv_id = mapped_ids.at(tv);
-
- for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
- if (visited.find(consumer_tv) != visited.end()) {
- continue;
- }
-
- if (mapped_ids.find(consumer_tv) != mapped_ids.end()) {
- continue;
- }
-
- PairwiseRootDomainMap root_map(tv, consumer_tv);
- auto p2c_map =
- root_map.mapProducerToConsumer(tv->domain(), consumer_tv->domain());
-
- auto c_it = p2c_map.find(tv_id);
- if (c_it != p2c_map.end()) {
- mapped_ids.emplace(std::make_pair(consumer_tv, c_it->second));
- to_visit.emplace_back(consumer_tv);
- }
- }
-
- for (auto producer_tv : ir_utils::producerTvsOf(tv)) {
- if (visited.find(producer_tv) != visited.end()) {
- continue;
- }
-
- if (mapped_ids.find(producer_tv) != mapped_ids.end()) {
- continue;
- }
-
- PairwiseRootDomainMap root_map(producer_tv, tv);
- auto c2p_map =
- root_map.mapConsumerToProducer(tv->domain(), producer_tv->domain());
- auto p_it = c2p_map.find(tv_id);
- if (p_it != c2p_map.end()) {
- mapped_ids.emplace(std::make_pair(producer_tv, p_it->second));
- to_visit.emplace_back(producer_tv);
- }
- }
- }
-}
-
-std::unordered_set<IterDomain*> FindAllMappedDims::from(
- TensorView* tv,
- IterDomain* id) {
- TORCH_INTERNAL_ASSERT(
- std::find_if(
- tv->getRootDomain().begin(),
- tv->getRootDomain().end(),
- [&id](IterDomain* root_id) { return root_id == id; }) !=
- tv->getRootDomain().end(),
- "Tried to map out ",
- id,
- " from TV ",
- tv,
- " to the rest of the fusion, but id does not belong to this tv.");
-
- FindAllMappedDims mapped_dims(tv, id);
-
- std::unordered_set<IterDomain*> mapped_id_set;
- for (auto entry : mapped_dims.mapped_ids) {
- mapped_id_set.emplace(entry.second);
- }
- return mapped_id_set;
-}
-
-bool shouldVectorize(
- TensorView* tv,
- std::unordered_set<IterDomain*> vector_dims) {
- const auto& root_dom = TensorDomain::noBroadcasts(
- TensorDomain::noReductions(tv->getRootDomain()));
-
- // Don't vectorize 0-dim tensors
- if (root_dom.size() == 0) {
- return false;
- }
-
- auto inner_most_dim = root_dom[root_dom.size() - 1];
-
- // Make sure inner most dimension is in the vector_dim set
- if (vector_dims.count(inner_most_dim) == 0) {
- return false;
- }
-
- auto root_pos_it = std::find_if(
- tv->getRootDomain().begin(),
- tv->getRootDomain().end(),
- [&inner_most_dim](IterDomain* id) { return inner_most_dim == id; });
-
- TORCH_INTERNAL_ASSERT(root_pos_it != tv->getRootDomain().end());
- auto inner_most_dim_pos =
- std::distance(tv->getRootDomain().begin(), root_pos_it);
-
- const auto& contiguity = tv->domain()->contiguity();
-
- TORCH_INTERNAL_ASSERT(contiguity.size() == tv->getRootDomain().size());
-
- // Don't vectorize if inner most dimension is not contiguous
- if (!contiguity[inner_most_dim_pos]) {
- return false;
- }
-
- return true;
-}
-
-std::vector<TensorView*> getVectorizableInputsOutputs(
- TensorView* reference_tv) {
- if (reference_tv->nDims() == 0) {
- return {};
- }
-
- IterDomain* inner_most_id = nullptr;
- for (auto it = reference_tv->getRootDomain().rbegin();
- it != reference_tv->getRootDomain().rend();
- it++) {
- if ((*it)->isReduction() && reference_tv->isFusionInput()) {
- continue;
- }
- if ((*it)->isBroadcast() && inner_most_id == nullptr) {
- inner_most_id = *it;
- }
- inner_most_id = *it;
- break;
- }
-
- if (inner_most_id == nullptr) {
- return {};
- }
-
- auto vectorizable_dims = FindAllMappedDims::from(reference_tv, inner_most_id);
-
- std::vector<TensorView*> vectorizable_tensors;
-
- for (auto input_tv :
- ir_utils::filterByType<TensorView>(reference_tv->fusion()->inputs())) {
- if (shouldVectorize(input_tv, vectorizable_dims)) {
- vectorizable_tensors.push_back(input_tv);
- }
- }
-
- for (auto output_tv :
- ir_utils::filterByType<TensorView>(reference_tv->fusion()->outputs())) {
- if (shouldVectorize(output_tv, vectorizable_dims)) {
- vectorizable_tensors.push_back(output_tv);
- }
- }
-
- return vectorizable_tensors;
-}
-
-std::vector<int64_t> mappedInputsOutputs(TensorView* reference_tv) {
- auto fusion = reference_tv->fusion();
- FusionGuard fg(fusion);
-
- // All input or output tensor views
- std::vector<TensorView*> in_out_tvs;
- {
- auto inp_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
- in_out_tvs.insert(in_out_tvs.end(), inp_tvs.begin(), inp_tvs.end());
- auto out_tvs = ir_utils::filterByType<TensorView>(fusion->outputs());
- in_out_tvs.insert(in_out_tvs.end(), out_tvs.begin(), out_tvs.end());
- }
-
- // Shouldn't matter which compute at map we use
- auto ca_index_map = ComputeAtMap(ComputeAtMap::MappingMode::INDEX);
- ca_index_map.build(fusion);
-
- auto ref_root_domain = reference_tv->getMaybeRFactorDomain();
- std::vector<int64_t> mapping_count(ref_root_domain.size(), 0);
-
- // Map all inputs and output domains to reference tv domains
- for (auto in_out_tv : in_out_tvs) {
- auto in_out_tv_domain = in_out_tv->getRootDomain();
- auto in_out_tv_domain_list = std::list<IterDomain*>(
- in_out_tv_domain.begin(), in_out_tv_domain.end());
- auto in_out_dtype_size = dataTypeSize(in_out_tv->getDataType().value());
-
- for (size_t ref_i = 0; ref_i < ref_root_domain.size(); ref_i++) {
- auto ref_id = ref_root_domain[ref_i];
-
- // If reference id is broadcast or reduction
- if (ref_id->isBroadcast() || ref_id->isReduction()) {
- continue;
- }
- auto map_it = std::find_if(
- in_out_tv_domain_list.begin(),
- in_out_tv_domain_list.end(),
- [&ref_id, &ca_index_map](IterDomain* in_out_tv_id) {
- return ca_index_map.areMapped(in_out_tv_id, ref_id);
- });
-
- if (map_it == in_out_tv_domain_list.end()) {
- continue;
- }
-
- // If input/output id is broadcast or reduction
- if ((*map_it)->isBroadcast() || (*map_it)->isReduction()) {
- continue;
- }
-
- mapping_count[ref_i] = mapping_count[ref_i] + (int64_t)in_out_dtype_size;
- in_out_tv_domain_list.erase(map_it);
- }
- }
- return mapping_count;
-}
-
-} // namespace scheduler_utils
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-class SchedulerRuntimeInfo;
-
-namespace scheduler_utils {
-
-constexpr int64_t register_file_size = 256 * 1024;
-constexpr int64_t x_grid_limit = ((int64_t)1 << (int64_t)31) - (int64_t)1;
-constexpr int64_t y_grid_limit = 65535;
-
-// Largest Power of 2 less-than n
-constexpr int64_t lastPow2(int64_t n) {
- TORCH_INTERNAL_ASSERT(n >= 0);
- n |= (n >> 1);
- n |= (n >> 2);
- n |= (n >> 4);
- n |= (n >> 8); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- n |= (n >> 16); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- n |= (n >> 32); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
- return std::max((int64_t)1, n - (n >> 1));
-}
-
-// Merge all reduction to the right side and returns total number of
-// reduction axes. Don't merge is typically used for trivial reductions.
-size_t mergeReduction(
- TensorView* tv,
- const std::unordered_set<IterDomain*>& dont_merge = {});
-
-// merge all non-reduction axes to the left side and returns total number of
-// iteration axes. Don't merge is typically used for trivial reductions.
-size_t mergeNonReduction(
- TensorView* tv,
- const std::unordered_set<IterDomain*>& dont_merge = {});
-
-TORCH_CUDA_CU_API void parallelizeAllLike(
- TensorView* reference_tv,
- const std::vector<TensorView*>& all_tvs);
-
-TORCH_CUDA_CU_API void computeAtInputs(
- TensorView* consumer,
- int pos,
- ComputeAtMode mode = ComputeAtMode::Standard);
-
-TORCH_CUDA_CU_API void computeWithOutputs(
- TensorView* producer,
- int pos,
- ComputeAtMode mode = ComputeAtMode::Standard);
-
-struct PersistentBufferInfo {
- std::vector<TensorView*> buffers;
- std::unordered_set<IterDomain*> unmappable_dims;
-};
-
-// Buffers whos roots can't map to all producer roots based on compute at. These
-// are the buffers we would make persistent in a persistent kerenl or would have
-// to recompute if we can't make a persistent kernel. This function will also
-// return inputs as being marked persistent if they follow this pattern. It is
-// important to note however inputs don't strictly have to be persistent as they
-// can simply be read multiple times from GMEM in the same kernel.
-PersistentBufferInfo persistentBuffers(Fusion* fusion);
-
-struct TvProperties {
- // How many elements in tensor view are there to reduce
- int64_t reduction_numel = 1;
- // How many reductions do we need to perform, i.e. how many iter dimension
- // elements are there
- int64_t iteration_numel = 1;
- // Do we reduce the fastest dimension, if no reduction mark true
- bool fastest_dim_reduction = true;
- // What's the iter numel to the left of the reduction (if there is one)
- int64_t iter_outside_red = 1;
- // What's the iter numel to the right of the reduction (if this is or isn't
- // one)
- int64_t iter_inside_red = 1;
-};
-
-// Fill TvProperties structure about tv
-TvProperties getProperties(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- TensorView* tv);
-
-// Will call computeAt once on each producer, with the first consumer found that
-// is a consumer of the individual producer
-void computeAtBetween(
- const std::vector<TensorView*>& producers,
- const std::vector<TensorView*>& consumers,
- int pos,
- ComputeAtMode mode,
- std::unordered_set<IterDomain*> mapped_to_trivial_reduction = {});
-
-// Compute the amount of register space would be needed to perform this kernel
-// persistently, only based on buffers that must be persistent, and based on the
-// maximum of all minimum size requirement. i.e. if must be persistent, only
-// hold persistent dimension.
-int64_t persistentBufferSize(
- Fusion* fusion,
- SchedulerRuntimeInfo& runtime_info,
- PersistentBufferInfo& persistent_buffers,
- HeuristicSummary* data_cache = nullptr);
-
-// Returns a set of all iteration domains (in roots of tensors) that map to a
-// trivial reduction
-std::unordered_set<IterDomain*> getTrivialReductionMap(Fusion* fusion);
-
-// Merges tensor view to the form:
-// [IterationDomain, ReductionDomain, TrivialReductionDim0,
-// TrivialReductionDim1, ...] Returns if <iteration dimensions, reduction
-// dimensions>
-std::pair<bool, bool> canonicalDimReduction(Fusion* fusion, TensorView* tv);
-
-// Return a list of tensor views that are outputs of reduction operations. If
-// multiple outputs of an expression are found, only include one in the list
-// (WelfordOp)
-std::vector<TensorView*> getReductionTvs(Fusion* fusion);
-
-// Consistent parallelization based on provided reduction parameters. Provided
-// tensor is expected to be reduced by canonicalDimReduction before sending
-// here. reduction_tv should be provided as the tensorview to reduce.
-// RFactor of reduction_tv will be returned if applicable otherwise reduction_tv
-// is returned
-TensorView* scheduleReductionTV(
- const ReductionParams& rparams,
- TensorView* reduction_tv,
- bool has_iter_axis);
-
-// Reset inputs and outputs to global memory, everything else to local.
-void clearMemorySpace(Fusion* fusion);
-
-// Returns cached after tensors of the fusion inputs if unrolled. Otherwise
-// return empty vector.
-std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll);
-
-// Returns the pairs of <cache of each fusion output, corresponding output> for
-// all outputs.
-std::vector<std::pair<TensorView*, TensorView*>> cacheAndForkOutputs(
- Fusion* fusion,
- bool unroll);
-
-// Inlining function intended for single or multi reduction fusions.
-void multiReductionInliner(
- Fusion* fusion,
- const ReductionParams& rparams,
- TensorView* reduction_tv,
- TensorView* reference_tv,
- std::vector<TensorView*> reduction_tvs,
- std::vector<TensorView*> cached_inputs,
- std::vector<std::pair<TensorView*, TensorView*>> cached_outputs);
-
-// Uses a lot of logic from TransformPropagator in the implementation
-class FindAllMappedDims {
- private:
- FindAllMappedDims(TensorView* from, IterDomain* starting_id);
-
- private:
- std::unordered_map<TensorView*, IterDomain*> mapped_ids;
- TensorView* starting_tv = nullptr;
- IterDomain* starting_id = nullptr;
-
- public:
- // Looks through fusion and finds all dims that match to the one provided in
- // the tensorview provided. Iter domain must be a root domain.
- static std::unordered_set<IterDomain*> from(TensorView* tv, IterDomain* id);
-};
-
-// Checks if tensor view has an iteration domain in vector dims in its inner
-// most root position (excluding broadcast and reduction), and checks if it is a
-// contiguous dimension
-bool shouldVectorize(
- TensorView* tv,
- std::unordered_set<IterDomain*> vector_dims);
-
-// Returns all inputs and outputs that share the inner most dimension of the
-// provided reference. If reference is an input it ignores reduction axes, will
-// ignore all broadcast axes.
-std::vector<TensorView*> getVectorizableInputsOutputs(TensorView* reference_tv);
-
-// Returns a vector of counts, size = reference_tv->getRootDomain().size(), each
-// entry [i] is the number of inputs/outputs that have a non-broadcast dimension
-// mapped to the corresponding dimension in reference_tv. Count includes
-// reference_tv if reference_tv is an input or output. Count is multiplied by
-// data type size.
-std::vector<int64_t> mappedInputsOutputs(TensorView* reference_tv);
-
-} // namespace scheduler_utils
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
namespace {
-bool hasTypeAndDevice(const TensorTypePtr& op) {
- return op->device().has_value() && op->scalarType().has_value();
+bool hasTypeAndDim(const TensorTypePtr& op) {
+ return op->sizes().size().has_value() && op->scalarType().has_value();
}
/* NaiveTypePropagator
}
// unary operations that forward meta info:
case aten::neg:
- case aten::bitwise_not:
case aten::abs:
case aten::log:
case aten::log10:
case aten::relu:
case aten::sigmoid:
case aten::threshold:
- case aten::softplus:
case aten::clamp:
case aten::gelu:
- case aten::gelu_backward:
- case aten::silu:
case aten::tanh: {
TORCH_CHECK(
- hasTypeAndDevice(node->input(0)->type()->cast<TensorType>()),
- "Type and device propagation has failed, or was not provided enough information.");
+ hasTypeAndDim(node->input(0)->type()->cast<TensorType>()),
+ "Type, device, and dimensionality propagation has failed, or was not provided enough information.");
node->output()->setType(node->input(0)->type()->cast<TensorType>());
break;
}
// TODO: rand_like should support cast.
case aten::rand_like: {
TORCH_CHECK(
- hasTypeAndDevice(node->input(0)->type()->cast<TensorType>()),
- "Type and device propagation has failed, or was not provided enough information.");
+ hasTypeAndDim(node->input(0)->type()->cast<TensorType>()),
+ "Type, device, and dimensionality propagation has failed, or was not provided enough information.");
node->output()->setType(node->input(0)->type()->cast<TensorType>());
break;
}
node->output()->setType(promoted_type);
break;
}
- // Type can be int or bool for "and" and "or", if both are bool should be
- // bool, if both int should be int, otherwise would have errored
- case aten::__and__:
- case aten::__or__: {
- const auto promoted_type = binary_broadcast_type(
- node->input(0)->type()->cast<TensorType>(),
- node->input(1)->type()->cast<TensorType>(),
- node->input(0)->type()->cast<TensorType>()->scalarType() ==
- at::ScalarType::Bool
- ? at::ScalarType::Bool
- : at::ScalarType::Int);
- break;
- }
- // Real int ops
- case aten::__xor__:
- case aten::__lshift__:
- case aten::__rshift__: {
- const auto promoted_type = binary_broadcast_type(
- node->input(0)->type()->cast<TensorType>(),
- node->input(1)->type()->cast<TensorType>(),
- at::ScalarType::Int);
- node->output()->setType(promoted_type);
- break;
- }
+ // TODO: double check type casting logic for operations commented out.
case aten::lt:
case aten::le:
case aten::gt:
node->output()->setType(promoted_type);
break;
}
- case aten::dropout: {
- auto out_type = node->input(0)->type()->cast<TensorType>();
- node->output()->setType(out_type);
- break;
- }
- case aten::instance_norm:
- case aten::batch_norm: {
- auto out_type = node->input(0)->type()->cast<TensorType>();
- node->output()->setType(out_type);
- break;
- }
- case aten::_batch_norm_impl_index_backward: {
- auto grad_input_type = node->input(1)->type()->cast<TensorType>();
- TORCH_CHECK(
- hasTypeAndDevice(grad_input_type),
- "Type and device propagation has failed, or was not provided enough information.");
- node->output(0)->setType(grad_input_type);
-
- // TODO: double check with type promotion
- auto mean_rstd_type = TensorType::create(
- *grad_input_type->scalarType(),
- *grad_input_type->device(),
- c10::nullopt,
- c10::nullopt);
-
- node->output(1)->setType(mean_rstd_type);
- node->output(2)->setType(mean_rstd_type);
-
- break;
- }
- case aten::_batch_norm_impl_index: {
- auto out_type = node->input(0)->type()->cast<TensorType>();
- TORCH_CHECK(
- hasTypeAndDevice(out_type),
- "Type and device propagation has failed, or was not provided enough information.");
- node->output(0)->setType(out_type);
-
- auto mean_rstd_type = TensorType::create(
- *out_type->scalarType(),
- *out_type->device(),
- c10::nullopt,
- c10::nullopt);
-
- node->output(1)->setType(mean_rstd_type);
- node->output(2)->setType(mean_rstd_type);
- // TODO: not that it matters, but mark the right type here;
- // node->output(3)->setType(out_type->withScalarType());
- node->output(3)->setType(out_type);
- node->output(4)->setType(IntType::get());
-
- break;
- }
- case aten::native_batch_norm: {
- auto out_type = node->input(0)->type()->cast<TensorType>();
- TORCH_CHECK(
- hasTypeAndDevice(out_type),
- "Type and device propagation has failed, or was not provided enough information.");
- node->output(0)->setType(out_type);
-
- auto mean_rstd_type = TensorType::create(
- *out_type->scalarType(),
- *out_type->device(),
- c10::nullopt,
- c10::nullopt);
-
- node->output(1)->setType(mean_rstd_type);
- node->output(2)->setType(mean_rstd_type);
-
- break;
- }
- case aten::native_batch_norm_backward: {
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- auto out_mask_list = constant_as<c10::List<bool>>(node->input(9));
- TORCH_INTERNAL_ASSERT(
- out_mask_list.has_value(), "output mask for batch_norm_backward");
- std::vector<int> output_mask;
- for (const auto value : out_mask_list->vec()) {
- output_mask.emplace_back(static_cast<int>(value));
- }
-
- if (output_mask[0]) {
- auto in_type = node->input(1)->type()->cast<TensorType>();
- node->output(0)->setType(in_type);
- }
-
- if (output_mask[1]) {
- auto weight_type = node->input(2)->type()->cast<TensorType>();
- node->output(1)->setType(weight_type);
- }
-
- if (output_mask[2]) {
- auto weight_type = node->input(2)->type()->cast<TensorType>();
- auto bias_type = TensorType::create(
- *weight_type->scalarType(),
- *weight_type->device(),
- *weight_type->dim(),
- output_mask[2]);
- node->output(2)->setType(bias_type);
- }
- break;
- }
- case aten::layer_norm: {
- auto out_type = node->input(0)->type()->cast<TensorType>();
- node->output()->setType(out_type);
- break;
- }
- case aten::native_layer_norm: {
- auto out_type = node->input(0)->type()->cast<TensorType>();
- TORCH_CHECK(
- hasTypeAndDevice(out_type),
- "Type and device propagation has failed, or was not provided enough information.");
- node->output(0)->setType(out_type);
-
- auto mean_rstd_type = TensorType::create(
- *out_type->scalarType(), *out_type->device(), c10::nullopt, false);
-
- node->output(1)->setType(mean_rstd_type);
- node->output(2)->setType(mean_rstd_type);
-
- break;
- }
- case aten::native_layer_norm_backward: {
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- auto out_mask_list = constant_as<c10::List<bool>>(node->input(7));
- TORCH_INTERNAL_ASSERT(
- out_mask_list.has_value(), "output mask for layer_norm_backward");
- std::vector<int> output_mask;
- for (const auto value : out_mask_list->vec()) {
- output_mask.emplace_back(static_cast<int>(value));
- }
-
- if (output_mask[0]) {
- auto out_type = node->input(0)->type()->cast<TensorType>();
- node->output(0)->setType(out_type);
- }
-
- if (output_mask[1] &&
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- !node->input(5)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- auto weight_type = node->input(5)->type()->cast<TensorType>();
- node->output(1)->setType(weight_type);
- }
-
- if (output_mask[2] &&
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- !node->input(6)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
- auto bias_type = node->input(6)->type()->cast<TensorType>();
- node->output(2)->setType(bias_type);
- }
- break;
- }
- case aten::softmax: {
- auto out_type = node->input(0)->type()->cast<TensorType>();
-
- // accept dtype input to `aten::softmax` node
- if (!node->input(2)->type()->isSubtypeOf(
- static_cast<c10::TypePtr>(NoneType::get()))) {
- if (auto opt_ivalue = toIValue(node->input(2))) {
- out_type = out_type->withScalarType(opt_ivalue->toScalarType());
- }
- }
- node->output()->setType(out_type);
- break;
- }
- case aten::_softmax_backward_data: {
- auto out_type = node->input(0)->type()->cast<TensorType>();
- node->output()->setType(out_type);
- break;
- }
- case aten::mean:
case aten::sum: {
auto out_type = node->input(0)->type()->cast<TensorType>();
const auto dims = constant_as<c10::List<int64_t>>(node->input(1));
const auto keepdim = constant_as<bool>(node->input(2));
TORCH_CHECK(
- dims.has_value() && keepdim.has_value(),
+ dims.has_value() && keepdim.has_value() && !keepdim.value(),
"Shape inference cannot handle options.");
node->output()->setType(
unary_reduce_type(out_type, dims->vec(), keepdim.value()));
break;
}
- case aten::sum_to_size:
- case aten::_grad_sum_to_size: {
- auto out_type = node->input(0)->type()->cast<TensorType>();
- node->output()->setType(out_type->withDim(c10::nullopt));
- break;
- }
- case aten::type_as: {
- const auto type0 = node->input(0)->type()->cast<TensorType>();
- const auto type1 = node->input(1)->type()->cast<TensorType>();
- TORCH_CHECK(
- type0 != nullptr && type1 != nullptr &&
- type1->scalarType().has_value(),
- "input to type_as needs to be a tensor");
- node->output()->setType(type0->withScalarType(type1->scalarType()));
- break;
- }
- case aten::to: {
- const auto type0 = node->input(0)->type()->cast<TensorType>();
- const auto out_dtype = toIValue(node->input(1));
- TORCH_CHECK(out_dtype, "No output type specified");
- node->output()->setType(
- type0->withScalarType(out_dtype->toScalarType()));
- break;
- }
- case prim::add_optional: {
- const auto type0 = node->input(0)->type()->cast<TensorType>();
- const auto type1 = node->input(1)->type()->cast<TensorType>();
- TORCH_CHECK(type0 != nullptr);
- if (type1 != nullptr) {
- node->output()->setType(type0);
- } else {
- const auto promoted_type = binary_broadcast_type(type0, type1);
- node->output()->setType(promoted_type);
- }
- break;
- }
default:
TORCH_CHECK(
false,
- "type inference failed, unrecognized operation encountered:",
- node->kind().toDisplayString());
+ "type inference failed, unrecognized operation encountered.");
// TODO: generate a proper error log, as this probably means something
// went unexpected.
break;
const TensorTypePtr& op,
const std::vector<int64_t>& dims,
bool keepdim) {
- TORCH_CHECK(
- hasTypeAndDevice(op),
- "Type and device propagation has failed, or was not provided enough information.");
+ TORCH_CHECK(hasTypeAndDim(op), "requires complete shape on input");
+ auto input_size = op->sizes();
+ int64_t ndims = keepdim ? input_size.size().value() : 0;
+ if (!keepdim) {
+ for (size_t i = 0; i < input_size.size(); i++) {
+ if (std::find(dims.begin(), dims.end(), i) == dims.end()) {
+ ndims++;
+ }
+ }
+ }
return TensorType::create(
- *op->scalarType(), *op->device(), c10::nullopt, c10::nullopt);
+ *op->scalarType(), *op->device(), ndims, c10::nullopt);
}
// TODO: we should comply to codegen type promotion.
if (op0 != nullptr && op1 != nullptr) {
TORCH_CHECK(
- hasTypeAndDevice(op0) && hasTypeAndDevice(op1),
- "Type and device propagation has failed, or was not provided enough information.");
+ op0->sizes().size().has_value() && op1->sizes().size().has_value(),
+ "Cannot process input tensor without concrete number of dimensions.");
+ int64_t ndims = *op0->sizes().size() > *op1->sizes().size()
+ ? *op0->sizes().size()
+ : *op1->sizes().size();
+
auto promoted_scalar_type = scalar_type.has_value()
? *scalar_type
: c10::promoteTypes(*op0->scalarType(), *op1->scalarType());
return TensorType::create(
- promoted_scalar_type, *op0->device(), c10::nullopt, c10::nullopt);
+ promoted_scalar_type, *op0->device(), ndims, c10::nullopt);
} else {
auto ptr = (op0 != nullptr) ? op0 : op1;
TORCH_CHECK(
- hasTypeAndDevice(ptr),
- "Type and device propagation has failed, or was not provided enough information.");
+ hasTypeAndDim(ptr),
+ "Type, device, and dimensionality propagation has failed, or was not provided enough information.");
return TensorType::create(
scalar_type.has_value() ? *scalar_type : *ptr->scalarType(),
*ptr->device(),
- c10::nullopt,
+ *ptr->sizes().size(),
c10::nullopt);
}
}
}
} // namespace
-TensorView::TensorView(TensorDomain* domain, DataType dtype, MemoryType mtype)
- : Val(ValType::TensorView, dtype), domain_(domain), memory_type_(mtype) {
- // Don't do this after transforms
- if (domain_->domain() == domain_->getRootDomain()) {
- // Mark the size-1 axes as broadcast to support implicit broadcast semantic
- for (auto* id : domain_->domain()) {
- if (!id->isBroadcast() && !id->isReduction() && !id->isGather() &&
- id->extent()->isOneInt()) {
- id->convertToBroadcast();
- }
- }
- }
-}
+TensorView::TensorView(TensorDomain* _domain, DataType dtype, MemoryType mtype)
+ : Val(ValType::TensorView, dtype), domain_(_domain), memory_type_(mtype) {}
TensorView::TensorView(const std::shared_ptr<c10::TensorType>& tensor_type)
: Val(ValType::TensorView,
TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner)
: Val(src, ir_cloner),
domain_(ir_cloner->clone(src->domain_)),
- compute_at_pos_(src->compute_at_pos_),
- max_producer_pos_(src->max_producer_pos_),
- memory_type_(src->memory_type_),
- swizzle_type_(src->swizzle_type_) {
- for (const auto id : src->axesToSwizzle()) {
- axes_to_swizzle_.push_back(ir_cloner->clone(id));
- }
-}
-
-bool TensorView::hasAnyReduction() const {
- return domain()->noReductions().size() != domain()->domain().size();
-}
+ compute_at_view_(ir_cloner->clone(src->compute_at_view_)),
+ relative_compute_at_axis_(src->relative_compute_at_axis_),
+ this_compute_at_axis_(src->this_compute_at_axis_),
+ memory_type_(src->memory_type_) {}
bool TensorView::hasReduction() const {
return domain()->hasReduction();
return domain()->hasGridReduction();
}
+bool TensorView::hasBlockBroadcast() const {
+ return domain()->hasBlockBroadcast();
+}
+
bool TensorView::hasBroadcast() const {
return domain()->hasBroadcast();
}
return domain()->axis(pos);
}
-void TensorView::setComputeAt(unsigned int pos, bool decrease) {
- if (pos <= compute_at_pos_ && !decrease) {
- return;
- }
+TensorView* TensorView::unsafeClone() const {
+ TensorView* new_view = new TensorView(domain_, getDataType().value());
+ new_view->compute_at_view_ = compute_at_view_;
+ new_view->relative_compute_at_axis_ = relative_compute_at_axis_;
+ new_view->this_compute_at_axis_ = this_compute_at_axis_;
+ new_view->memory_type_ = memory_type_;
+ new_view->name_ = name();
+ return new_view;
+}
+
+void TensorView::setComputeAt(TensorView* computeAtView, int axis) {
+ compute_at_view_ = computeAtView;
+ relative_compute_at_axis_ = axis;
+ setThisComputeAtAxis();
TORCH_INTERNAL_ASSERT(
- (unsigned)pos <= nDims(),
- "Invalid this computeAt position for T",
- name(),
- ": ",
- pos);
+ getThisComputeAtAxis() >= 0 &&
+ (unsigned int)getThisComputeAtAxis() <= nDims(),
+ "Invalid computeAt on ",
+ this,
+ " tried to set to local axis ",
+ getThisComputeAtAxis());
- compute_at_pos_ = pos;
+ TORCH_INTERNAL_ASSERT(
+ std::none_of(
+ domain()->domain().begin(),
+ domain()->domain().begin() + getThisComputeAtAxis(),
+ [](IterDomain* id) { return id->isReduction(); }),
+ "Invalid computeAt, reduction domain inside computeAt axis.");
+}
+
+void TensorView::setComputeAt(
+ TensorView* computeAtView,
+ int thisPos,
+ int relPos) {
+ compute_at_view_ = computeAtView;
+ relative_compute_at_axis_ = relPos;
+ this_compute_at_axis_ = thisPos;
+ TORCH_INTERNAL_ASSERT(
+ this_compute_at_axis_ <= nDims(), "Manually set an invalid computeAt.");
}
-void TensorView::setMaxProducer(unsigned int pos, bool decrease) {
- if (pos <= max_producer_pos_ && !decrease) {
- return;
+// Where in compute_at_view does this->axis(pos) match up?
+// TODO: This doesn't seem like the safest function as a fusion output can ref
+// another fusion output, we may want to check that there is a direct
+// consumer/producer relationship between this and compute_at view before using
+// this function, and creating another pass to handle relative outputs.
+int TensorView::getComputeAtRelPos(int pos) {
+ if (!hasComputeAt()) {
+ return pos;
}
- TORCH_INTERNAL_ASSERT(
- (unsigned)pos <= nDims(),
- "Invalid max producer position for T",
- name(),
- ": ",
- pos);
+ if (!compute_at_view_->hasBroadcast()) {
+ return pos;
+ }
- max_producer_pos_ = pos;
-}
+ size_t pos_cav = 0, pos_this = 0;
-TensorView* TensorView::computeAt(
- TensorView* consumer,
- int position,
- ComputeAtMode mode) {
- // Make sure this and consumer are not the same tensor, that's illegal
- TORCH_CHECK(!sameAs(consumer), "Cannot call this->computeAt(this, ...)");
+ // We could be in an instance where pos == 0, but consumer[0] is bcast and
+ // this[0] is not
- // We support negative axes, so increment it by consumer->nDims() + 1 and make
- // sure the result is within consumer->nDims() + 1. being at consumer->nDims()
- // means producer will be computed inline with consumer, hence the +1.
- if (position < 0)
- position += int(consumer->nDims()) + 1;
+ while (compute_at_view_->axis(pos_cav)->isBroadcast() &&
+ !(axis(pos_this)->isBroadcast())) {
+ pos_cav++;
+ }
- TORCH_CHECK(
- (position >= 0 && (unsigned int)position < consumer->nDims() + 1) ||
- mode == ComputeAtMode::BestEffort,
- "Compute at called on an position outside valid range.");
+ while ((int)pos_this < pos) {
+ TORCH_INTERNAL_ASSERT(
+ pos_cav < compute_at_view_->nDims(),
+ "Error computing relative position in computeAt.");
- if (mode == ComputeAtMode::BestEffort) {
- position = std::max(-1, position);
- position = std::min((int)consumer->nDims(), position);
+ if (compute_at_view_->axis(pos_cav)->isBroadcast() &&
+ !(axis(pos_this)->isBroadcast())) {
+ pos_cav++;
+ } else {
+ pos_cav++;
+ pos_this++;
+ }
}
- ComputeAt::runAt(this, consumer, (unsigned int)position, mode);
+ return pos_cav;
+}
- return this;
+void TensorView::setThisComputeAtAxis() {
+ if (compute_at_view_ == nullptr) {
+ relative_compute_at_axis_ = 0;
+ this_compute_at_axis_ = 0;
+ return;
+ }
+
+ // this[is{i1}, is{i2},] -> compute at compute_at_view[bS{i0}, iS{i1}, iS{i2}]
+ // axis = 2 this compute at axis = 1
+
+ // pos in compute at view
+ size_t pos_cav = 0, pos_this = 0;
+ while (pos_cav < relative_compute_at_axis_ && pos_this < nDims()) {
+ if (compute_at_view_->axis(pos_cav)->isBroadcast() &&
+ !(axis(pos_this)->isBroadcast())) {
+ pos_cav++;
+ } else {
+ pos_cav++;
+ pos_this++;
+ }
+ }
+
+ TORCH_INTERNAL_ASSERT(
+ pos_cav == relative_compute_at_axis_ ||
+ (pos_cav < compute_at_view_->nDims() &&
+ compute_at_view_->axis(pos_cav)->isBroadcast()),
+ "Error seting up relative position between this and what we view into.");
+
+ this_compute_at_axis_ = pos_this;
}
-TensorView* TensorView::computeWith(
- TensorView* consumer,
- int position,
- ComputeAtMode mode) {
+TensorView* TensorView::computeAt(TensorView* consumer, int axis) {
// Make sure this and consumer are not the same tensor, that's illegal
TORCH_CHECK(!sameAs(consumer), "Cannot call this->computeAt(this, ...)");
- // We support negative axes, so increment it by this->nDims() + 1 and make
- // sure the result is within this->nDims() + 1. being at this->nDims()
- // means producer will be computed inline with this, hence the +1.
- if (position < 0)
- position += int(this->nDims()) + 1;
+ // We support negative axes, so increment it by consumer->nDims() + 1 and make
+ // sure the result is within consumer->nDims() + 1. being at consumer->nDims()
+ // means producer will be computed inline with consumer, hence the +1.
+ if (axis < 0)
+ axis += int(consumer->nDims()) + 1;
TORCH_CHECK(
- position >= 0 && (unsigned int)position < this->nDims() + 1,
- "Compute at called on an position outside valid range.");
+ axis >= 0 && (unsigned int)axis < consumer->nDims() + 1,
+ "Compute at called on an axis outside valid range.");
- ComputeAt::runWith(this, consumer, (unsigned int)position, mode);
+ ComputeAt::run(this, consumer, (unsigned int)axis);
return this;
}
-TensorView* TensorView::split(int axis_, Val* factor, bool inner_split) {
+TensorView* TensorView::split(int axis, Val* factor) {
// Only check things associated with axis, factor will be validated in
// IterDomain
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim TensorView");
- if (axis_ < 0)
- axis_ += domain()->nDims();
-
- TORCH_INTERNAL_ASSERT(
- axis_ >= 0,
- "Split axis is less than 0 even after adjusting for nDims: ",
- axis_);
+ if (axis < 0)
+ axis += domain()->nDims();
- TORCH_CHECK(
- axis_ >= (int)getComputeAtPosition(),
- "Cannot split axis within compute at position. Axis = ",
- axis_,
- " computeAtPosition = ",
- getComputeAtPosition());
-
- TORCH_CHECK(
- axis_ >= (int)getMaxProducerPosition(),
- "Cannot split axis within max producer position. Axis = ",
- axis_,
- " maxProducerPosition = ",
- getMaxProducerPosition());
-
- TORCH_CHECK(
- axis(axis_)->getParallelType() == ParallelType::Serial,
- "Splitting an axis of non-Serial parallel type is not supported at this time."
- " Parallelization strategy must be set after calling split.");
+ if (getComputeAtView() != nullptr)
+ if (axis < (int)getThisComputeAtAxis())
+ TORCH_CHECK(
+ false,
+ "Cannot split axis within compute at range. Axis = ",
+ axis,
+ " thisComputeAtAxis = ",
+ getThisComputeAtAxis());
- domain()->split(axis_, factor, inner_split);
+ domain()->split(axis, factor);
return this;
}
-TensorView* TensorView::split(int axis, unsigned int factor, bool inner_split) {
- split(axis, new Int(factor), inner_split);
+TensorView* TensorView::split(int axis, unsigned int factor) {
+ domain()->split(axis, new Int(factor));
return this;
}
// Merge "axis" and "axis+1" into 1 dimension
TensorView* TensorView::merge(int axis_o, int axis_i) {
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim TensorView");
-
if (axis_o < 0)
axis_o += domain()->nDims();
if (axis_i < 0)
axis_i += domain()->nDims();
- TORCH_CHECK(
- axis_o >= (int)getComputeAtPosition() &&
- axis_i >= (int)getComputeAtPosition(),
- false,
- "Cannot merge axes within compute at position. Either axis ",
- axis_o,
- " or ",
- axis_i,
- " are within computeAtPosition = ",
- getComputeAtPosition());
-
- TORCH_CHECK(
- axis_o >= (int)getMaxProducerPosition() &&
- axis_i >= (int)getMaxProducerPosition(),
- "Cannot merge axes within max producer position. Either axis ",
- axis_o,
- " or ",
- axis_i,
- " are within maxProducerPosition = ",
- getMaxProducerPosition());
-
- TORCH_CHECK(
- axis(axis_o)->getParallelType() == ParallelType::Serial ||
- axis(axis_i)->getParallelType() == ParallelType::Serial,
- "Merging axes of non-Serial parallel type is not supported at this time."
- " Parallelization strategy must be set after calling split.");
+ if (getComputeAtView() != nullptr)
+ if (axis_o + 1 < (int)getThisComputeAtAxis() ||
+ axis_i + 1 < (int)getThisComputeAtAxis())
+ TORCH_CHECK(
+ false,
+ "Cannot merge axis within compute at range. Either axis ",
+ axis_o,
+ " or ",
+ axis_i,
+ " are within thisComputeAtAxis = ",
+ getThisComputeAtAxis());
domain()->merge(axis_o, axis_i);
return this;
TORCH_INTERNAL_ASSERT(
!(nDims() == 0 && old2new_.size() > 0),
"Tried to reorder a 0-dim TensorView");
-
- for (auto entry : old2new_) {
- auto old_pos = entry.first < 0 ? entry.first + (int)nDims() : entry.first;
- auto new_pos =
- entry.second < 0 ? entry.second + (int)nDims() : entry.second;
- if (old_pos == new_pos) {
- continue;
- }
- TORCH_INTERNAL_ASSERT(
- old_pos >= 0,
- "Found \"old\" position that's less than 0 even though already adjusted by nDims: ",
- old_pos);
- TORCH_INTERNAL_ASSERT(
- new_pos >= 0,
- "Found \"new\" position that's less than 0 even though already adjusted by nDims: ",
- new_pos);
- TORCH_CHECK(
- old_pos >= (int)getComputeAtPosition() &&
- new_pos >= (int)getComputeAtPosition(),
- "Cannot reorder axes within compute at position. Either axis ",
- old_pos,
- " or ",
- new_pos,
- " are within computeAtPosition = ",
- getComputeAtPosition());
-
- TORCH_CHECK(
- old_pos >= (int)getMaxProducerPosition() &&
- new_pos >= (int)getMaxProducerPosition(),
- "Cannot reorder axes within max producer position. Either axis ",
- old_pos,
- " or ",
- new_pos,
- " are within maxProducerPosition = ",
- getMaxProducerPosition());
- }
-
domain()->reorder(old2new_);
return this;
}
-TensorView* TensorView::swizzle(
- SwizzleType type,
- const std::vector<int>& axes) {
- swizzle_type_ = type;
-
- // Clear previously set swizzle axes if any
- if (axes_to_swizzle_.size()) {
- axes_to_swizzle_.clear();
- }
-
- if (swizzle_type_ == SwizzleType::Transpose) {
- TORCH_CHECK(
- axes.size() == 2,
- "Invalid axis list: ",
- axes,
- ". Number of axes must be two.");
- TORCH_CHECK(
- axes[0] != axes[1],
- "Invalid axis list: ",
- axes,
- ". Two distinctive axes must be given.");
- TORCH_CHECK(
- getMemoryType() == MemoryType::Shared,
- "Transpose swizzle is meant for tensors on shared memory.");
- for (auto pos : axes) {
- if (pos < 0) {
- pos += nDims();
- }
- TORCH_CHECK(pos >= 0 && pos < (int)nDims(), "Invalid axis: ", pos);
- TORCH_CHECK(
- pos >= (int)getComputeAtPosition(),
- "Invalid axis: ",
- pos,
- ". Axis outside computeAt position is not allocated.");
- TORCH_CHECK(
- !axis(pos)->isReduction(),
- "Invalid axis: ",
- pos,
- ". Swizzling a reduction axis is not supported");
- TORCH_CHECK(
- !axis(pos)->isBroadcast(),
- "Invalid axis: ",
- pos,
- ". Swizzling a broadcast axis is not supported");
- axes_to_swizzle_.push_back(axis(pos));
- }
- }
-
- return this;
-}
-
TensorView* TensorView::rFactor(const std::vector<int>& axes) {
- // TODO: I think we should do this but
- // NVFuserTest.FusionSmemBlockGemmCache_CUDA prevents it from going in at the
- // moment.
-
- // TORCH_INTERNAL_ASSERT(
- // !hasComputeAt(), "Cannot rfactor tensors after compute at has been
- // set.");
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim TensorView");
- TORCH_INTERNAL_ASSERT(definition()->isA<ReductionOp>());
FusionGuard fg(fusion());
+ Expr* origin_expr = fusion()->origin(this);
TORCH_CHECK(
- definition() != nullptr &&
- definition()->getExprType() == ExprType::ReductionOp,
+ origin_expr != nullptr &&
+ origin_expr->getExprType() == ExprType::ReductionOp,
"Error rfactoring ",
this,
- " its definition is either a nullptr or not a reduction.");
+ " its origin is either a nullptr or not a reduction.");
TORCH_CHECK(
!domain()->hasRFactor(), "Cannot call rfactor on the same view twice.");
- ReductionOp* this_definition = definition()->as<ReductionOp>();
+ ReductionOp* this_origin = origin_expr->as<ReductionOp>();
// Split tensor view into 2 parts
auto domain_pair = domain()->rFactor(axes);
TensorView* consumer = this;
// Setup dependency chain, inserting producer before this op.
- // Expr* producer_definition =
+ // Expr* producer_origin =
new ReductionOp(
- this_definition->getReductionOpType(),
- this_definition->init(),
+ this_origin->getReductionOpType(),
+ this_origin->init(),
producer,
- this_definition->in());
+ this_origin->in());
- // Expr* consumer_definition =
+ // Expr* consumer_origin =
new ReductionOp(
- this_definition->getReductionOpType(),
- this_definition->init(),
+ this_origin->getReductionOpType(),
+ this_origin->init(),
consumer,
producer);
return producer;
}
-TensorView* TensorView::welfordRfactorHelper(
- TensorView* tv,
- const std::vector<int>& axes) {
- // Hack:
- // Semantically we should always keep the outputs of welfordOp scheduled
- // the same but the user end cannot guarantee that.
- // In order to guarantee that the rFactor is defined meaningfully the
- // scheduling of the output TV that got the rfactor call is force replayed
- // towards the other two
-
- if (!sameAs(tv)) {
- auto root = tv->getRootDomain();
- auto this_root = getRootDomain();
-
- // construct a trivial root domain map
- std::unordered_map<IterDomain*, IterDomain*> id_map;
- for (size_t i = 0; i < root.size(); i++) {
- id_map[this_root[i]] = root[i];
- }
-
- // replay on the target tv
- ReplayTransformations replay(domain()->domain(), id_map);
-
- // construct the new tensor domain
- std::vector<IterDomain*> new_id;
- for (auto id : domain()->domain()) {
- TORCH_INTERNAL_ASSERT(
- replay.getReplay().count(id), "Welford Replay Failed");
- new_id.push_back(replay.getReplay().at(id));
- }
-
- std::vector<bool> new_contig(
- tv->domain()->contiguity().begin(), tv->domain()->contiguity().end());
- // replace tensor domain of target tv
- tv->setDomain(new TensorDomain(tv->getRootDomain(), new_id, new_contig));
- }
-
- // Split tensor view into 2 parts
- auto domain_pair = tv->domain()->rFactor(axes);
- // Producer in the pair
- auto producer_domain = domain_pair.first;
- // Consumer in the pair
- auto consumer_domain = domain_pair.second;
-
- // This domain will be the consumer, so create the producer
- TensorView* producer =
- new TensorView(producer_domain, tv->getDataType().value());
-
- // Set domain of consumer
- tv->setDomain(consumer_domain);
-
- return producer;
-}
-
-WelfordResult TensorView::rFactor(
- const std::vector<int>& axes,
- TensorView* avg,
- TensorView* var,
- TensorView* n) {
- TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim TensorView");
- FusionGuard fg(fusion());
- TORCH_CHECK(
- definition() != nullptr &&
- definition()->getExprType() == ExprType::WelfordOp,
- "Error rfactoring welford ",
- this,
- " its definition is either a nullptr or not a welford.");
- TORCH_CHECK(
- !domain()->hasRFactor(), "Cannot call rfactor on the same view twice.");
-
- WelfordOp* wop = definition()->as<WelfordOp>();
-
- TORCH_INTERNAL_ASSERT(
- avg->sameAs(wop->outAvg()), "Welford rfactor not used correctly");
- TORCH_INTERNAL_ASSERT(
- var->sameAs(wop->outVar()), "Welford rfactor not used correctly");
- TORCH_INTERNAL_ASSERT(
- n->sameAs(wop->outN()), "Welford rfactor not used correctly");
-
- std::vector<std::pair<TensorView*, TensorView*>> tv2rf{
- {avg, nullptr}, {var, nullptr}, {n, nullptr}};
-
- // Make sure this gets rfactored last so everybody gets
- // replayed correctly
- for (auto& it : tv2rf) {
- if (!sameAs(it.first)) {
- it.second = welfordRfactorHelper(it.first, axes);
- }
- }
-
- for (auto& it : tv2rf) {
- if (sameAs(it.first)) {
- it.second = welfordRfactorHelper(it.first, axes);
- }
- }
-
- TensorView* producer_avg = tv2rf[0].second;
- TensorView* producer_var = tv2rf[1].second;
- TensorView* producer_n = tv2rf[2].second;
-
- // Setup dependency chain, inserting producer before this op.
- // Expr* producer_definition =
- new WelfordOp(
- producer_avg,
- producer_var,
- producer_n, /*out var/avg/count */
- wop->initAvg(),
- wop->initVar(),
- wop->initN(), /*init var/avg/count */
- wop->inAvg(),
- wop->inVar(),
- wop->inN());
-
- // Expr* consumer_definition =
- new WelfordOp(
- avg,
- var,
- n,
- wop->initAvg(),
- wop->initVar(),
- wop->initN(),
- producer_avg,
- producer_var,
- producer_n);
-
- return WelfordResult(producer_avg, producer_var, producer_n);
-}
-
TensorView* TensorView::cache_before() {
FusionGuard fg(fusion());
+ Expr* origin_expr = fusion()->origin(this);
TORCH_CHECK(
- definition() != nullptr && !isFusionInput(),
+ origin_expr != nullptr && !fusion()->hasInput(this),
"Error adding cache_before ",
this,
- " its definition is a nullptr and we restrict using cache_before on an input.");
+ " its origin is a nullptr and we restrict using cache_before on an input.");
TORCH_CHECK(
- isFusionOutput() ||
- definition()->getExprType() != ExprType::ReductionOp ||
- definition()->getExprType() != ExprType::WelfordOp,
+ fusion()->hasOutput(this) ||
+ origin_expr->getExprType() != ExprType::ReductionOp,
"Error adding cache_before ",
this,
- " its definition is a reduction and it is not an output, instead please use cache_after.");
-
- // Previously, caching computed-at tensors was allowed but was never
- // really robust. Make it an error unless it is really needed.
- TORCH_CHECK(
- !hasComputeAt(),
- "Caching computed-at tensors is not allowed. Apply caching before computeAt");
-
- // It also did additional transformation when a producer tensor has computeAt.
- // Make sure we no longer rely on that behavior.
- if (definition() != nullptr) {
- for (TensorView* producer_of_producer :
- ir_utils::filterByType<TensorView>(definition()->inputs())) {
- TORCH_CHECK(
- !producer_of_producer->hasComputeAt(),
- "Potentially invalid computeAt and caching detected. Apply caching before computeAt.");
- }
- }
+ " its origin is a reduction and it is not an output, instead please use cache_after.");
// Create Producer Domain
- // This domain will be the consumer which needs a new domain, so replace the
- // producers domain with this domain.
+ // This domain will be the consumer, so create the producer
auto root_domain = getRootDomain();
-
TensorView* producer = new TensorView(
new TensorDomain(
- domain()->getRootDomain(),
- domain()->domain(),
- domain()->contiguity()),
+ root_domain, std::vector<bool>(root_domain.size(), true)),
getDataType().value());
// Set domain of consumer
TensorView* consumer = this;
- size_t i = 0;
- auto no_reduction_root_domain = TensorDomain::noReductions(getRootDomain());
- std::vector<IterDomain*> new_root_domain(no_reduction_root_domain.size());
- for (const auto& dom : no_reduction_root_domain) {
- new_root_domain[i++] = dom->clone();
+ // this TV is an output and its origin is a reduction
+ // remove reduction axis from this tv
+ if (origin_expr->getExprType() == ExprType::ReductionOp) {
+ size_t i = 0;
+ auto no_reduction_root_domain = TensorDomain::noReductions(getRootDomain());
+ std::vector<IterDomain*> new_root_domain(no_reduction_root_domain.size());
+ for (const auto& dom : no_reduction_root_domain) {
+ new_root_domain[i++] = dom->clone();
+ }
+ consumer->setDomain(new TensorDomain(
+ new_root_domain, std::vector<bool>(new_root_domain.size(), true)));
}
- consumer->setDomain(new TensorDomain(
- new_root_domain, std::vector<bool>(new_root_domain.size(), true)));
-
// Insert producer - Cache_Before (CB) - before this TV.
- // Before: Prev TV -> [Definition Op] -> This TV
- // After: Prev TV -> [Definition Op] -> New CB TV -> [Set Op] -> This TV
+ // Before: Prev TV -> [Origin Op] -> This TV
+ // After: Prev TV -> [Origin Op] -> New CB TV -> [Set Op] -> This TV
// Get inputs for origin expression
- auto expr_inputs = definition()->inputs();
- // Expr* producer_definition =
- ir_utils::replaceValInExpr(definition(), this, producer);
+ auto expr_inputs = origin_expr->inputs();
+
+ // Expr* producer_origin =
+ createExprConsumer(origin_expr, producer);
// Expr* producer_uses =
new UnaryOp(UnaryOpType::Set, consumer, producer);
- // definition_ is no longer valid
- // setDefinition(nullptr);
-
- auto replayed_consumer_pair =
- TransformReplay::replayCasP(consumer, producer, -1);
- consumer->setDomain(replayed_consumer_pair.first);
+ // Before: This TV -> Next TV
+ // After: New TV (CB) -> This TV -> Next TV
+ if (hasComputeAt()) {
+ TransformReplay::replayPasC(producer, consumer, -1);
+ auto this_ca_pos = getThisComputeAtAxis();
+ producer->computeAt(consumer, this_ca_pos);
+ } else {
+ // Before: Prev TV -> This TV
+ // After: Prev TV -> New TV (CB) -> This TV
+ // Iterate over origin expression inputs for cache_before on outputs
+ for (TensorView* origin_input :
+ ir_utils::filterByType<TensorView>(expr_inputs)) {
+ if (origin_input->hasComputeAt() &&
+ origin_input->getComputeAtView() == this) {
+ TransformReplay::replayPasC(producer, consumer, -1);
+
+ auto origin_ca_pos = origin_input->getThisComputeAtAxis();
+ auto origin_rel_ca_pos = origin_input->getRelativeComputeAtAxis();
+ origin_input->computeAt(producer, origin_ca_pos);
+ producer->setComputeAt(consumer, origin_rel_ca_pos);
+ }
+ }
+ }
return producer;
}
-TensorView* TensorView::cache_fork() {
- FusionGuard fg(fusion());
-
- // Before: [Expr] -> This TV (Global Output) -> [Usage Expr]
- // After: [Expr] -> This TV (Local) -> [Usage Expr] > Next TV
- // (Fork) -> [Set Expr] -> New TV (Global Output)
-
- TORCH_CHECK(
- fusion()->hasOutput(this) && !this->uses().empty(),
- "Error adding cache_fork ",
- this,
- " this TensorView must be an output with subsequent uses");
-
- // Previously, caching computed-at tensors was allowed but was never
- // really robust. Make it an error unless it is really needed.
- TORCH_CHECK(
- !hasComputeAt(),
- "Caching computed-at tensors is not allowed. Apply caching before computeAt");
-
- // This domain will be the producer, so create the consumer
- auto root_domain = TensorDomain::noReductions(getRootDomain());
- TensorView* new_output = new TensorView(
- new TensorDomain(
- IterDomain::clone(root_domain),
- std::vector<bool>(root_domain.size(), true)),
- getDataType().value());
-
- // Create write operation from this TV to new output
- new UnaryOp(UnaryOpType::Set, new_output, this);
-
- // The new TV becomes an output.
- // New TV has global memory type.
- // This TV has local memory type.
- fusion()->replaceOutput(this, new_output);
-
- // Transform new output according to this TV
- auto replayed_output_pair = TransformReplay::replayCasP(new_output, this, -1);
- new_output->setDomain(replayed_output_pair.first);
-
- return new_output;
-}
-
TensorView* TensorView::cache_after() {
FusionGuard fg(fusion());
- const bool kIsFusionInput = fusion()->hasInput(this);
-
// Get all the uses for this Tensorview
TORCH_CHECK(
!fusion()->hasOutput(this),
this,
" we restrict using cache_after on an output.");
- // Previously, caching computed-at tensors was allowed but was never
- // really robust. Make it an error unless it is really needed.
- TORCH_CHECK(
- !hasComputeAt(),
- "Caching computed-at tensors is not allowed. Apply caching before computeAt.");
-
- // It also did additional transformation when this tensor is an
- // input and the outputs of its consumers have computeAt. Make sure
- // we no longer rely on that behavior.
- if (kIsFusionInput) {
- for (const auto& expr : uses()) {
- for (TensorView* output :
- ir_utils::filterByType<TensorView>(expr->outputs())) {
- TORCH_CHECK(
- !output->hasComputeAt(),
- "Potentially invalid computeAt and caching detected. Apply caching before computeAt.");
- }
- }
- }
-
// Create Consumer Domain
// Keep Broadcast Axis (Permanent)
// Remove Reduction Axis
// Expr* consumer_uses =
for (auto expr : fusion()->unordered_uses(this)) {
- ir_utils::replaceValInExpr(expr, this, consumer);
+ createExprProducer(expr, this, consumer);
}
- // Expr* consumer_definition =
+ // Expr* consumer_origin =
new UnaryOp(UnaryOpType::Set, consumer, producer);
+ // Before: This TV -> Next TV
+ // After: This TV -> New TV (After) -> Next TV
+ if (hasComputeAt()) {
+ TransformReplay::replayCasP(consumer, producer, -1);
+
+ auto rel_ca_pos = getRelativeComputeAtAxis();
+ auto this_ca_pos = getThisComputeAtAxis();
+ auto this_ca_view = getComputeAtView();
+
+ computeAt(consumer, this_ca_pos);
+ consumer->setComputeAt(this_ca_view, rel_ca_pos);
+ } else {
+ // Check users of this TV for computeAt for cache_after on inputs
+ for (const auto& expr : fusion()->unordered_uses(consumer)) {
+ for (TensorView* output :
+ ir_utils::filterByType<TensorView>(expr->outputs())) {
+ if (output->hasComputeAt()) {
+ TransformReplay::replayPasC(consumer, output, -1);
+ auto output_ca_pos = output->getThisComputeAtAxis();
+ consumer->setComputeAt(output, output_ca_pos);
+ }
+ }
+ }
+ }
+
return consumer;
}
}
}
-void TensorView::clearReductionIterDomains() {
- TORCH_INTERNAL_ASSERT(
- !domain()->hasRFactor(),
- "should not call clearReductionIterDomains on rfactor tv");
+namespace {
- TORCH_INTERNAL_ASSERT(
- domain()->domain() == getRootDomain(),
- "should not call clearReductionIterDomains on already transformed TensorDomains");
-
- std::vector<IterDomain*> new_root;
- std::vector<bool> new_contig;
- for (size_t i = 0; i < getRootDomain().size(); i++) {
- if (!getRootDomain()[i]->isReduction()) {
- new_root.push_back(getRootDomain()[i]);
- new_contig.push_back(domain()->contiguity()[i]);
- }
+// Create New Expr given consumer - [output of the expression]
+struct CreateExprConsumer : public OptInDispatch {
+ public:
+ static void create(Expr* expr, TensorView* consumer) {
+ CreateExprConsumer cec(consumer);
+ cec.handle(expr);
}
- setDomain(new TensorDomain(new_root, new_contig));
-}
+ private:
+ explicit CreateExprConsumer(TensorView* consumer) : consumer_(consumer) {}
-TensorViewBuilder& TensorViewBuilder::ndims(size_t ndims) {
- TORCH_CHECK(shape_.empty() || shape_.size() == ndims);
- TORCH_CHECK(contiguity_.empty() || contiguity_.size() == ndims);
- ndims_ = ndims;
- return *this;
-}
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
-TensorViewBuilder& TensorViewBuilder::dtype(DataType dtype) {
- dtype_ = dtype;
- return *this;
-}
+ void handle(UnaryOp* unary_expr) final {
+ new UnaryOp(unary_expr->getUnaryOpType(), consumer_, unary_expr->in());
+ }
-TensorViewBuilder& TensorViewBuilder::contiguity(std::vector<bool> contiguity) {
- TORCH_CHECK(contiguity_.empty(), "Attempting to reset contiguity");
- if (!contiguity.empty()) {
- TORCH_CHECK(ndims_ == 0 || ndims_ == contiguity.size());
- ndims_ = contiguity.size();
+ void handle(BinaryOp* binary_expr) final {
+ new BinaryOp(
+ binary_expr->getBinaryOpType(),
+ consumer_,
+ binary_expr->lhs(),
+ binary_expr->rhs());
}
- contiguity_ = std::move(contiguity);
- return *this;
-}
-TensorViewBuilder& TensorViewBuilder::shape(std::vector<int64_t> shape) {
- TORCH_CHECK(shape_.empty(), "Attempting to reset shape");
- if (!shape.empty()) {
- TORCH_CHECK(ndims_ == 0 || ndims_ == shape.size());
- ndims_ = shape.size();
+ void handle(TernaryOp* ternary_expr) final {
+ new TernaryOp(
+ ternary_expr->getTernaryOpType(),
+ consumer_,
+ ternary_expr->in1(),
+ ternary_expr->in2(),
+ ternary_expr->in3());
}
- shape_ = std::move(shape);
- return *this;
-}
-TensorView* TensorViewBuilder::build() const {
- // Build the domain
- std::vector<IterDomain*> domain(ndims_, nullptr);
- for (size_t i = 0; i < ndims_; i++) {
- if (shape_.empty() || shape_[i] == -1) {
- domain[i] = new IterDomain(new Int(0), new Int());
+ void handle(ReductionOp* reduction_expr) final {
+ new ReductionOp(
+ reduction_expr->getReductionOpType(),
+ reduction_expr->init(),
+ consumer_,
+ reduction_expr->in());
+ }
+
+ void handle(BroadcastOp* broadcast_expr) final {
+ new BroadcastOp(consumer_, broadcast_expr->in());
+ }
+
+ private:
+ TensorView* consumer_ = nullptr;
+};
+
+// Create New Expr given producer - [an input for the expression]
+struct CreateExprProducer : public OptInDispatch {
+ public:
+ static void create(Expr* expr, TensorView* current, TensorView* producer) {
+ CreateExprProducer cep(current, producer);
+ cep.handle(expr);
+ }
+
+ private:
+ explicit CreateExprProducer(TensorView* current, TensorView* producer)
+ : current_(current), producer_(producer) {}
+
+ void handle(Expr* expr) final {
+ OptInDispatch::handle(expr);
+ }
+
+ void handle(UnaryOp* unary_expr) final {
+ new UnaryOp(unary_expr->getUnaryOpType(), unary_expr->out(), producer_);
+ }
+
+ void handle(BinaryOp* binary_expr) final {
+ if (binary_expr->lhs()->sameAs(current_)) {
+ new BinaryOp(
+ binary_expr->getBinaryOpType(),
+ binary_expr->out(),
+ producer_,
+ binary_expr->rhs());
} else {
- TORCH_CHECK(
- shape_[i] >= 0,
- "Invalid extent value. ",
- "For a tensor representing a single scalar use ndims = 0 with no sizes set.");
- domain[i] = new IterDomain(new Int(0), new Int(shape_[i]));
+ new BinaryOp(
+ binary_expr->getBinaryOpType(),
+ binary_expr->out(),
+ binary_expr->lhs(),
+ producer_);
}
}
- // Create the final TensorView
- return new TensorView(new TensorDomain(domain, contiguity_), dtype_);
+ void handle(TernaryOp* ternary_expr) final {
+ if (ternary_expr->in1()->sameAs(current_)) {
+ new TernaryOp(
+ ternary_expr->getTernaryOpType(),
+ ternary_expr->out(),
+ producer_,
+ ternary_expr->in2(),
+ ternary_expr->in3());
+ } else if (ternary_expr->in2()->sameAs(current_)) {
+ new TernaryOp(
+ ternary_expr->getTernaryOpType(),
+ ternary_expr->out(),
+ ternary_expr->in1(),
+ producer_,
+ ternary_expr->in3());
+ } else {
+ new TernaryOp(
+ ternary_expr->getTernaryOpType(),
+ ternary_expr->out(),
+ ternary_expr->in1(),
+ ternary_expr->in2(),
+ producer_);
+ }
+ }
+
+ void handle(ReductionOp* reduction_expr) final {
+ new ReductionOp(
+ reduction_expr->getReductionOpType(),
+ reduction_expr->init(),
+ reduction_expr->out(),
+ producer_);
+ }
+
+ void handle(BroadcastOp* broadcast_expr) final {
+ new BroadcastOp(broadcast_expr->out(), producer_);
+ }
+
+ private:
+ TensorView* current_ = nullptr;
+ TensorView* producer_ = nullptr;
+};
+
+} // namespace
+
+// In Cache Before, for the origin expr of the original tensor,
+// we create a new operation where the original tensor is replaced
+// with the new cache tensor. This function creates a new expr
+// given the consumer, the output of the expression.
+void TensorView::createExprConsumer(Expr* expr, TensorView* consumer) {
+ CreateExprConsumer::create(expr, consumer);
+}
+
+// In Cache After, for all the uses of the original tensor, we create
+// a new operation where the original tensor is replaced with the new
+// cache tensor. This function creates a new expr given a producer,
+// an input for the expression.
+void TensorView::createExprProducer(
+ Expr* expr,
+ TensorView* current,
+ TensorView* producer) {
+ CreateExprProducer::create(expr, current, producer);
}
} // namespace cuda
"Transform traversal failed, modified a node but it was not a leaf node.");
// Replay the split onto mapped
- auto outs = IterDomain::split(mapped, s->factor(), s->innerSplit());
+ auto outs = IterDomain::split(mapped, s->factor());
// Remove mapped from the leaf IDs
leaf_ids_.erase(mapped);
BestEffortReplay::BestEffortReplay(
const std::vector<IterDomain*>& replay_domain,
const std::vector<IterDomain*>& target_domain,
- std::unordered_map<IterDomain*, IterDomain*> target2replay_map,
- std::unordered_map<IterDomain*, IterDomain*> forward_id_map)
- : target2replay_id_map_(std::move(target2replay_map)),
- forward_id_map_(std::move(forward_id_map)) {
- for (auto entry : target2replay_id_map_) {
+ std::unordered_map<IterDomain*, IterDomain*> replay_map,
+ bool forward_bcast_mismatch)
+ : id_map_(std::move(replay_map)) {
+ for (auto entry : id_map_)
leaf_ids_[entry.second] = counter++;
- }
// Grab expr history of iter domains in target_domain
- std::vector<Expr*> target_exprs = ExprSort::getExprs(
+ std::vector<Expr*> t_exprs = ExprSort::getExprs(
FusionGuard::getCurFusion(),
std::vector<Val*>(target_domain.begin(), target_domain.end()));
// replay_domain domain. This will be used to propagate the target_domain to
// replay_domain map.
- // Map replay domain's IterDomains to the Exprs they're used in
- std::vector<Expr*> replay_exprs = ExprSort::getExprs(
+ // Maps replay domain's IterDomains to the Exprs they're used in
+ std::vector<Expr*> r_exprs = ExprSort::getExprs(
FusionGuard::getCurFusion(),
std::vector<Val*>(replay_domain.begin(), replay_domain.end()));
-
- std::unordered_map<IterDomain*, Expr*> replay_id2expr_map;
- for (auto replay_expr : replay_exprs) {
- for (auto id : ir_utils::filterByType<IterDomain>(replay_expr->inputs())) {
+ std::unordered_map<IterDomain*, Expr*> replay_expr_map;
+ for (auto r_expr : r_exprs) {
+ for (auto id : ir_utils::filterByType<IterDomain>(r_expr->inputs())) {
TORCH_INTERNAL_ASSERT(
- replay_id2expr_map.find(id) == replay_id2expr_map.end(),
- "Error trying to map rfactor root domain during replay.",
- " An IterDomain was found to be used in more than one expression.");
+ replay_expr_map.find(id) == replay_expr_map.end(),
+ "Error trying to map rfactor root domain during replay. IterDomain's shouldn't have more than one use.");
// Only want to forward rfactor in map
- replay_id2expr_map[id] = replay_expr;
+ replay_expr_map[id] = r_expr;
}
}
std::string err_str(
- "Error during replay, a transformation was called that conflicts with an rfactor call.");
+ "Error during replay, a computeAt was called that conflicts with an rfactor call.");
// Iterate through target IterDomains' history and compare with what we
// recorded from replay_domain
- for (auto target_expr : target_exprs) {
- auto target_inps_filtered =
- ir_utils::filterByType<IterDomain>(target_expr->inputs());
-
- // If any input argument in target expression is in the forward map then
- // forward the mapped IterDomains in replay and continue to the next
- // expression as target_expr cannot match a replay_expr
- if (std::any_of(
- target_inps_filtered.begin(),
- target_inps_filtered.end(),
- [&](IterDomain* target_inp) {
- return this->inForwardMap(target_inp);
- })) {
- for (auto target_inp : target_inps_filtered) {
- if (inForwardMap(target_inp)) {
- auto target2replay_it = target2replay_id_map_.find(target_inp);
- if (target2replay_it != target2replay_id_map_.end()) {
- // Replace target_inp entry in target2replay_id_map_ with forwarded
- // id
- target2replay_id_map_[getForwardedId(target_inp)] =
- target2replay_it->second;
- target2replay_id_map_.erase(target_inp);
- }
- }
- }
- // Continue to next target_expr
- continue;
- }
-
- std::vector<IterDomain*> target_id_inps(
- target_inps_filtered.begin(), target_inps_filtered.end());
-
- std::vector<IterDomain*> replay_inps =
- std::vector<IterDomain*>(target_id_inps.size(), nullptr);
-
- bool missing_replay_input = false;
-
- // Map target_expr inputs to replay domain directly
- for (const auto t_i : c10::irange(target_id_inps.size())) {
- // There might not be a mapping, that could be okay (depends on rfactor
- // checking).
- auto it = target2replay_id_map_.find(target_id_inps[t_i]);
- if (it != target2replay_id_map_.end()) {
- replay_inps[t_i] = getForwardedId(it->second);
- } else {
- missing_replay_input = true;
- }
+ for (auto t_expr : t_exprs) {
+ auto t_inps_filtered = ir_utils::filterByType<IterDomain>(t_expr->inputs());
+ std::vector<IterDomain*> t_inps(
+ t_inps_filtered.begin(), t_inps_filtered.end());
+
+ std::vector<IterDomain*> r_inps =
+ std::vector<IterDomain*>(t_inps.size(), nullptr);
+
+ // Map t_expr inputs to replay domain directly
+ for (const auto t_i : c10::irange(t_inps.size())) {
+ // There might not be a mapping, that could be okay.
+ auto it = id_map_.find(t_inps[t_i]);
+ if (it != id_map_.end())
+ r_inps[t_i] = it->second;
}
- // Check if any of the associated replay id's are part of an rfactor domain
- bool replay_has_rfactor_inp =
- std::any_of(replay_inps.begin(), replay_inps.end(), [](IterDomain* id) {
+ bool has_rfactor =
+ std::any_of(r_inps.begin(), r_inps.end(), [](IterDomain* id) {
return id == nullptr ? false : id->isRFactorProduct();
});
- // If some replay id inputs are part of rfactor, make sure all target
- // expression inputs map to a replay input
- if (replay_has_rfactor_inp) {
+ if (has_rfactor) {
bool no_missing_exprs = std::none_of(
- replay_inps.begin(),
- replay_inps.end(),
- [&replay_id2expr_map](IterDomain* id) {
+ r_inps.begin(), r_inps.end(), [&replay_expr_map](IterDomain* id) {
if (id == nullptr) {
return true;
} else {
- return replay_id2expr_map.find(id) == replay_id2expr_map.end();
+ return replay_expr_map.find(id) == replay_expr_map.end();
}
});
TORCH_INTERNAL_ASSERT(no_missing_exprs, err_str);
}
- // If any inputs are missing, continue as this expr doesn't match.
- if (missing_replay_input) {
- TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
- continue;
+ // I would like to have this more generic or have this whole function go
+ // through dispatch, but trying to make quick forward progress on
+ // https://github.com/csarofeen/pytorch/issues/286 This mapping reflects
+ // more closely what is done in ReplayTransform with mismatched
+ // broadcast/merge
+ if (forward_bcast_mismatch && !has_rfactor &&
+ t_expr->getExprType().value() == ExprType::Merge) {
+ auto t_merge = t_expr->as<Merge>();
+ auto t_outer = t_merge->outer();
+ auto t_inner = t_merge->inner();
+ IterDomain* r_outer = id_map_.find(t_outer) != id_map_.end()
+ ? id_map_.at(t_outer)
+ : nullptr;
+ IterDomain* r_inner = id_map_.find(t_inner) != id_map_.end()
+ ? id_map_.at(t_inner)
+ : nullptr;
+ if (r_outer != nullptr && r_inner == nullptr && t_inner->isBroadcast()) {
+ id_map_[t_merge->out()] = r_outer;
+ } else if (
+ r_inner != nullptr && r_outer == nullptr && t_outer->isBroadcast()) {
+ id_map_[t_merge->out()] = r_inner;
+ }
}
- // Find which replay_expr maps to the target_expr
- Expr* replay_expr = nullptr;
- // Check if all inputs have the same expression
- bool mismatched_replay_exprs = false;
- for (auto replay_inp : replay_inps) {
- auto it = replay_id2expr_map.find(replay_inp);
- if (it != replay_id2expr_map.end()) {
- if (replay_expr == nullptr) {
- replay_expr = it->second;
- } else {
- mismatched_replay_exprs =
- mismatched_replay_exprs || replay_expr != it->second;
+ Expr* r_expr = nullptr;
+ for (auto r_inp : r_inps) {
+ if (r_inp != nullptr) {
+ auto it = replay_expr_map.find(r_inp);
+ if (it != replay_expr_map.end()) {
+ r_expr = it->second;
+ break;
}
- } else {
- // If no expr is mapped then set mismatched epxrs to go to continue to
- // the next target expr
- mismatched_replay_exprs = true;
}
}
- // If expressions of mapped inputs don't match, then continue to next target
- // expr
- if (mismatched_replay_exprs || replay_expr == nullptr) {
- TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
+ if (r_expr == nullptr) {
+ TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
continue;
}
- bool mismatched_inputs = replay_inps.size() != replay_expr->inputs().size();
- for (size_t i = 0; i < replay_inps.size() && !mismatched_inputs; i++) {
- mismatched_inputs =
- mismatched_inputs || replay_expr->inputs()[i] != replay_inps[i];
+ bool mismatched_inputs = r_inps.size() != r_expr->inputs().size();
+ for (size_t i = 0; i < r_inps.size() && !mismatched_inputs; i++) {
+ if (r_inps[i] == nullptr) {
+ mismatched_inputs = true;
+ } else {
+ mismatched_inputs =
+ mismatched_inputs || r_expr->inputs()[i] != r_inps[i];
+ }
}
- // If there isn't an rfactor id in the replay's inputs and there's a
- // mismatched input, continue
if (mismatched_inputs) {
- TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
+ TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
continue;
}
- // If there isn't an rfactor id in the replay's inputs and there's a
- // mismatch in replay_expr's and target_expr's outputs, continue
- if (target_expr->outputs().size() != replay_expr->outputs().size()) {
- TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
+ if (t_expr->outputs().size() != r_expr->outputs().size()) {
+ TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
continue;
}
- // If there isn't an rfactor id in the replay's inputs and there's a
- // mismatch in replay_expr's and target_expr's expression type, continue
- if (replay_expr->getExprType().value() !=
- target_expr->getExprType().value()) {
- TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
+ if (r_expr->getExprType().value() != t_expr->getExprType().value()) {
+ TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
continue;
}
- // If there isn't an rfactor id in the replay's inputs and there's a
- // mismatch in replay_expr's and target_expr's split factor (if a split
- // expr), continue
- if (replay_expr->getExprType().value() == ExprType::Split) {
- auto r_split = replay_expr->as<Split>();
- auto t_split = target_expr->as<Split>();
- if (!r_split->factor()->sameAs(t_split->factor()) ||
- r_split->innerSplit() != t_split->innerSplit()) {
- TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str);
+ // If the expression is a split, make sure it's split by the same ammount.
+ if (r_expr->getExprType().value() == ExprType::Split) {
+ if (!r_expr->as<Split>()->factor()->sameAs(
+ r_expr->as<Split>()->factor())) {
+ TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
continue;
}
}
- // Take replay expr inputs out of map:
- for (size_t t_i = 0; t_i < target_id_inps.size(); t_i++) {
- auto t_inp = target_id_inps[t_i];
- auto r_orig_inp = target2replay_id_map_.at(t_inp);
- auto r_maybe_forwarded_inp = replay_inps[t_i];
-
- // Remove original target2replay_it->second if it's in leaf_ids
- if (leaf_ids_.find(r_orig_inp) != leaf_ids_.end()) {
- leaf_ids_.erase(r_orig_inp);
- }
+ bool missing_input = std::any_of(
+ t_expr->inputs().begin(), t_expr->inputs().end(), [this](Val* inp) {
+ if (inp->getValType() == ValType::IterDomain) {
+ return id_map_.find(inp->as<IterDomain>()) == id_map_.end();
+ }
+ return false;
+ });
- // Check if we used a forwarded id, if so add forwarded id's to tracking.
- if (r_orig_inp != r_maybe_forwarded_inp) {
- forwarded_ids_.emplace_back(r_orig_inp);
+ if (missing_input) {
+ TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
+ continue;
+ }
+ // Take target_domain inputs out of map:
+ for (auto t_inp : ir_utils::filterByType<IterDomain>(t_expr->inputs())) {
+ auto it = id_map_.find(t_inp);
+ if (leaf_ids_.find(it->second) != leaf_ids_.end()) {
+ leaf_ids_.erase(it->second);
}
}
// Add outputs to map.
- for (const auto i : c10::irange(target_expr->outputs().size())) {
- auto t_out = target_expr->output(i);
- auto r_out = replay_expr->output(i);
+ for (const auto i : c10::irange(t_expr->outputs().size())) {
+ auto t_out = t_expr->output(i);
+ auto r_out = r_expr->output(i);
if (t_out->getValType() == ValType::IterDomain &&
r_out->getValType() == ValType::IterDomain) {
- target2replay_id_map_[t_out->as<IterDomain>()] =
- r_out->as<IterDomain>();
+ id_map_[t_out->as<IterDomain>()] = r_out->as<IterDomain>();
leaf_ids_[r_out->as<IterDomain>()] = counter++;
}
}
}
BestEffortReplay ber(td2->domain(), td1->domain(), id_map);
- for (const auto i :
- c10::irange(std::max(td1->domain().size(), td2->domain().size()))) {
+
+ for (const auto i : c10::irange(td1->domain().size())) {
if (ber.getReplay().find(td1->axis(i)) == ber.getReplay().end()) {
return i;
}
return i;
}
}
- return std::min(td1->nDims(), td2->nDims());
-}
-
-namespace {
-
-// Maps that track information relevant to best effort replay about broadcast
-// axes in consumer that are not in producer
-//
-// For example if we have consumer: T0[i0, b1, b2, i3] and producer:
-// T1[i0, i3]
-//
-// If consumer transformations are:
-// -> T[i0, b1o, b1i, b2o, b2i, i3]
-// -> T[i0*b1i, b1o, b2o, b2i, i3]
-// -> T[i0*b1i*b2o, b1o, b2i, i3]
-// -> T[i0*b1i*b2o*i3, b1o, b2i]
-//
-// forwarding_map would forward i0->i0*b1i and i0*b1i->i0*b1i*b2o
-// compliment_map would have the entry i0->b1i and i0*b1i->b2o
-//
-// The first is to fast forward transformations in consumer involving broadcast
-// axes not in producer. The compliment map is to use later to compute what leaf
-// nodes we may have after the forwarding process is finished. Leaf nodes are
-// only important for replayCasP, so look there to see how this is done. Forward
-// map is used for replayCasP and replayPasC.
-struct ConsumerForwardingInfo {
- public:
- // Map IterDomain* axes that can safely be forwarded to their output.
- std::unordered_map<IterDomain*, IterDomain*> forwarding_map;
-
- // Given a forward id map id_input -> id_forwarded
- // Track the other inputs in the expr that id_input is an input to. These will
- // be used to adjust the replay's leaf tracking. Don't need to track one to
- // many as currently transformations on IterDomains can only have maximum 2
- // inputs, but maybe in the future we'll have more.
- std::unordered_map<IterDomain*, std::vector<IterDomain*>> compliment_map;
-
- ConsumerForwardingInfo(
- const TensorView* producer,
- const TensorView* consumer) {
- // Collect which root axes are in consumer that are not in producer because
- // of broadcasting
- std::unordered_set<IterDomain*> consumer_bcast_roots_not_in_producer;
-
- const auto c2p_root_map =
- PairwiseRootDomainMap(producer, consumer)
- .mapConsumerToProducer(consumer->domain(), producer->domain());
-
- for (auto consumer_root_id : consumer->getRootDomain()) {
- if (consumer_root_id->isBroadcast()) {
- if (c2p_root_map.find(consumer_root_id) == c2p_root_map.end()) {
- consumer_bcast_roots_not_in_producer.emplace(consumer_root_id);
- }
- }
- }
-
- // We have root axes in consumer that don't exist in producer, now forward
- // those to include all id's in consumer comprised of only axes not in
- // producer.
- auto consumer_bcast_ids_not_in_producer =
- consumer_bcast_roots_not_in_producer;
-
- std::vector<Expr*> consumer_history = ExprSort::getExprs(
- FusionGuard::getCurFusion(),
- std::vector<Val*>(
- consumer->domain()->domain().begin(),
- consumer->domain()->domain().end()));
-
- auto isIdOnlyInConsumer =
- [&consumer_bcast_ids_not_in_producer](IterDomain* input_id) {
- return consumer_bcast_ids_not_in_producer.find(input_id) !=
- consumer_bcast_ids_not_in_producer.end();
- };
-
- for (auto expr : consumer_history) {
- auto input_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
- // If expr inputs are all in consumer_bcast_ids_not_in_producer, than so
- // are all outputs
- if (std::all_of(input_ids.begin(), input_ids.end(), isIdOnlyInConsumer)) {
- // add all outputs to not being in producer
- for (auto output_ids :
- ir_utils::filterByType<IterDomain>(expr->outputs())) {
- consumer_bcast_ids_not_in_producer.emplace(output_ids);
- }
- } else if (
- expr->isA<Merge>() &&
- std::any_of(input_ids.begin(), input_ids.end(), isIdOnlyInConsumer)) {
- auto merge_expr = expr->as<Merge>();
- // If
- // - one of the inputs is made of id's in consumer that don't map to
- // producer (bcast axes),
- // - && the other input maps to an id in both consumer and producer
- // - && this is a merge
- // for the sake of BestEffortReplay we can forward the input mapping
- // to both consumer and producer to the output of the expression
- std::vector<IterDomain*> forwarded_ids;
- std::vector<IterDomain*> compliment_ids;
-
- for (auto input_id : input_ids) {
- if (!isIdOnlyInConsumer(input_id)) {
- forwarded_ids.emplace_back(input_id);
- forwarding_map.emplace(std::make_pair(input_id, merge_expr->out()));
- } else {
- compliment_ids.push_back(input_id);
- }
- }
-
- // Set up compliment map
- for (auto forwarded_id : forwarded_ids) {
- compliment_map.emplace(std::make_pair(forwarded_id, compliment_ids));
- }
- }
- }
- }
-};
-
-} // namespace
-
-BestEffortReplay BestEffortReplay::replayCasP(
- const TensorView* consumer,
- const TensorView* producer,
- int producer_compute_at_axis,
- const RootDomainMap& root_map) {
- if (producer_compute_at_axis < 0)
- producer_compute_at_axis += (int)producer->nDims() + 1;
-
- TORCH_INTERNAL_ASSERT(
- producer_compute_at_axis >= 0 &&
- (unsigned int)producer_compute_at_axis <= producer->nDims(),
- "Invalid axis provided to BestEffortReplay::replayCasP.");
-
- // producer ids we need to match in consumer
- std::vector<IterDomain*> producer_CA_ids(
- producer->domain()->domain().begin(),
- producer->domain()->domain().begin() + producer_compute_at_axis);
- producer_CA_ids = TensorDomain::noReductions(producer_CA_ids);
-
- // If producer has an rfactor root, that's what will match the consumer
- std::vector<IterDomain*> producer_root = producer->getMaybeRFactorDomain();
-
- // Figure out all inputs required to generate the compute_at dimensions. We
- // need all deps because inputs on producer may be in getRootDomain, but we
- // may need in rFactorDomain
- auto all_CA_id_deps = DependencyCheck::getAllValsBetween(
- {producer_root.begin(), producer_root.end()},
- {producer_CA_ids.begin(), producer_CA_ids.end()});
-
- // Figure out minimal set of root IDs needed to produce producer_CA_ids:
- std::unordered_set<IterDomain*> producer_CA_root_ids;
- for (IterDomain* id : producer_root) {
- if (std::find(all_CA_id_deps.begin(), all_CA_id_deps.end(), id) !=
- all_CA_id_deps.end()) {
- producer_CA_root_ids.emplace(id);
- }
- }
-
- const auto p2c_root_map = root_map.mapProducerToConsumer(
- producer->domain(), consumer->domain(), producer_CA_root_ids);
-
- // See FusionAdvancedComputeAt7 for an example of the forwarding logic
- ConsumerForwardingInfo consumer_forwarding_info(producer, consumer);
-
- auto consumer_replay = BestEffortReplay(
- consumer->domain()->domain(),
- producer_CA_ids,
- p2c_root_map,
- consumer_forwarding_info.forwarding_map);
-
- // Need to adjust leaf map based on forwarding before returning.
-
- // ID's could go through more than one forward iteration in the map before it
- // terminates. Grab every id between the forwarded id, and what it was
- // forwarded to
- std::function<void(IterDomain*, std::vector<IterDomain*>&)>
- collectForwardedIds =
- [&consumer_forwarding_info, &collectForwardedIds](
- IterDomain* forward_id,
- std::vector<IterDomain*>& forwarded_ids) -> void {
- if (consumer_forwarding_info.forwarding_map.find(forward_id) !=
- consumer_forwarding_info.forwarding_map.end()) {
- forwarded_ids.emplace_back(forward_id);
- collectForwardedIds(
- consumer_forwarding_info.forwarding_map.at(forward_id),
- forwarded_ids);
- }
- };
-
- std::vector<IterDomain*> expanded_forwarded_ids;
- for (auto forwarded_id : consumer_replay.forwarded_ids_) {
- collectForwardedIds(forwarded_id, expanded_forwarded_ids);
- }
-
- // Grab all compliments of forwarded ids.
- std::vector<IterDomain*> compliments;
- for (auto forwarded_id : expanded_forwarded_ids) {
- auto compliment_map_it =
- consumer_forwarding_info.compliment_map.find(forwarded_id);
- TORCH_INTERNAL_ASSERT(
- compliment_map_it != consumer_forwarding_info.compliment_map.end(),
- "Issue tracking forwarded broadcast merges in best effort replay consumer as producer.");
- compliments.insert(
- compliments.end(),
- compliment_map_it->second.begin(),
- compliment_map_it->second.end());
- }
-
- // Grab all exprs used to make the forwarded compliments
- auto compliment_exprs = ExprSort::getExprs(
- FusionGuard::getCurFusion(), {compliments.begin(), compliments.end()});
-
- // Figure out if there are any leaves in compliment_exprs that aren't
- // the forwarded id
- std::unordered_map<IterDomain*, size_t> leaf_ids;
-
- for (auto expr : compliment_exprs) {
- for (auto inp : ir_utils::filterByType<IterDomain>(expr->inputs())) {
- leaf_ids.erase(inp);
- }
- for (auto out : ir_utils::filterByType<IterDomain>(expr->outputs())) {
- // If we used the comliment for forwarded don't add to leaf nodes.
- if (std::find(compliments.begin(), compliments.end(), out) ==
- compliments.end()) {
- leaf_ids.emplace(std::make_pair(out, consumer_replay.counter++));
- }
- }
- }
-
- consumer_replay.leaf_ids_.insert(leaf_ids.begin(), leaf_ids.end());
-
- return consumer_replay;
-}
-
-// Runs a best effort replay that ignores broadcast axes that appear in
-// consumer that are not mapped to producer in root_map.
-BestEffortReplay BestEffortReplay::replayPasC(
- const TensorView* producer,
- const TensorView* consumer,
- int consumer_compute_at_axis,
- const RootDomainMap& root_map) {
- if (consumer_compute_at_axis < 0)
- consumer_compute_at_axis += (int)consumer->nDims() + 1;
- TORCH_INTERNAL_ASSERT(
- consumer_compute_at_axis >= 0 &&
- (unsigned int)consumer_compute_at_axis <= consumer->nDims(),
- "Invalid axis provided to BestEffortReplay::replayPasC.");
-
- // consumer ids we need to match in producer
- std::vector<IterDomain*> consumer_CA_ids(
- consumer->domain()->domain().begin(),
- consumer->domain()->domain().begin() + consumer_compute_at_axis);
-
- // Figure out all inputs required to generate the compute_at dimensions
- auto consumer_CA_root_vals = IterVisitor::getInputsTo(
- std::vector<Val*>(consumer_CA_ids.begin(), consumer_CA_ids.end()));
-
- std::unordered_set<IterDomain*> consumer_CA_root_ids;
- for (auto val : consumer_CA_root_vals) {
- if (val->getValType().value() == ValType::IterDomain) {
- consumer_CA_root_ids.emplace(val->as<IterDomain>());
- }
- }
-
- const auto c2p_root_map = root_map.mapConsumerToProducer(
- consumer->domain(), producer->domain(), consumer_CA_root_ids);
-
- ConsumerForwardingInfo consumer_forwarding_info(producer, consumer);
-
- // Instead of replaying from the root, lets try to play forward the history
- // of producer if they match ops on consumer. Enforce if we modify an
- // rfactor axis that those ops must match.
- return BestEffortReplay(
- producer->domain()->domain(),
- consumer_CA_ids,
- c2p_root_map,
- consumer_forwarding_info.forwarding_map);
+ return td1->nDims();
}
} // namespace cuda
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <unordered_map>
#include <vector>
class TORCH_CUDA_CU_API BestEffortReplay {
private:
- std::unordered_map<IterDomain*, IterDomain*> target2replay_id_map_;
- std::unordered_map<IterDomain*, IterDomain*> forward_id_map_;
+ std::unordered_map<IterDomain*, IterDomain*> id_map_;
std::unordered_map<IterDomain*, size_t> leaf_ids_;
- std::vector<IterDomain*> forwarded_ids_;
-
- // Need to track which id's have been forwarded. Later need to make sure leaf
- // nodes to produce compliment axes are properly tracked. i.e.
- // T[i0, b1, b2, i3]
- // -> T[i0, b1o, b1i, b2o, b2i, i3]
- // -> T[i0*b1i*b2o, b1o, b2i, i3]
- // -> T[i0*b1i*b2o*i3, b1o, b2i]
- // If we forwarded i0 -> i0*b1i*b2o*i3, we need to know that b1o and b2i
- // are leaf nodes even though their split wasn't part of targets replay.
-
- // Counter to make sure best effort replay leaf_ids can be grabbed
- // deterministicly
size_t counter = 0;
- bool inForwardMap(IterDomain* id) const {
- return forward_id_map_.find(id) != forward_id_map_.end();
- }
-
- IterDomain* getForwardedId(IterDomain* id) const {
- auto forwarded_id_it = forward_id_map_.find(id);
- if (forwarded_id_it == forward_id_map_.end()) {
- return id;
- } else {
- return getForwardedId(forwarded_id_it->second);
- }
- }
-
public:
- // Highly duplicated from the constructor above.
- // TODO: Remove other constructor
+ // replay_map: mapping of target root domains to corresponding
+ // replay root domains
BestEffortReplay(
const std::vector<IterDomain*>& replay_domain,
const std::vector<IterDomain*>& target_domain,
- std::unordered_map<IterDomain*, IterDomain*> target2replay_map,
- std::unordered_map<IterDomain*, IterDomain*> forward_id_map = {});
+ std::unordered_map<IterDomain*, IterDomain*> replay_map,
+ bool forward_bcast_mismatch = false);
// Return iter domain map from target_domain IDs to their "replayed"
// replay_domain IDs. If not in map, was not replayed.
const std::unordered_map<IterDomain*, IterDomain*>& getReplay() const {
- return target2replay_id_map_;
+ return id_map_;
}
// ids in replay that did not have matching transforms in target_domain
return leaf_vec_;
}
- // Runs a best effort replay that ignores broadcast axes that appear in
- // consumer that are not mapped to producer in root_map.
- static BestEffortReplay replayCasP(
- const TensorView* consumer,
- const TensorView* producer,
- int producer_compute_at_axis,
- const RootDomainMap& root_map);
-
- // Runs a best effort replay that ignores broadcast axes that appear in
- // consumer that are not mapped to producer in root_map.
- static BestEffortReplay replayPasC(
- const TensorView* producer,
- const TensorView* consumer,
- int consumer_compute_at_axis,
- const RootDomainMap& root_map);
-
// Find the first position i where td1[i] is not the same as td2[i]. "Same"
// means the DAG and input IDs to generate td1[i] and td2[i] are the same.
- // td1 and td2 are assumed to have some matching iter domains, as this is a
- // strict same-ness check.
static int findFirstMismatchedID(
const TensorDomain* td1,
const TensorDomain* td2);
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
-#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
-#include <deque>
+#include <vector>
namespace torch {
namespace jit {
"Transform traversal failed, modified a node but it was not a leaf node.");
// outer loop size
- Val* remainder = ceilDiv(mapped->extent(), s->factor());
+ Val* oe = ceilDiv(mapped->extent(), s->factor());
// Manually replay the split, following the output of the operations.
// This is so rfactor ops are replayed correctly.
IterDomain* ido = new IterDomain(
new Int(0),
- s->innerSplit() ? remainder->as<Int>() : s->factor(),
+ oe->as<Int>(),
s->outer()->getParallelType(),
s->outer()->getIterType(),
s->outer()->isRFactorProduct());
// inner IterDomain
IterDomain* idi = new IterDomain(
new Int(0),
- s->innerSplit() ? s->factor() : remainder->as<Int>(),
+ s->factor(),
s->inner()->getParallelType(),
- s->inner()->getIterType(),
+ s->outer()->getIterType(),
s->inner()->isRFactorProduct());
// Generate the split node
- new Split(ido, idi, mapped, s->factor(), s->innerSplit());
+ new Split(ido, idi, mapped, s->factor());
// Remove mapped id from leaf IDs
leaf_ids_.erase(mapped);
TensorDomain* TransformReplay::fullSelfReplay(
const TensorDomain* new_self_root,
const TensorDomain* self) {
- FUSER_PERF_SCOPE("TransformReplay::fullSelfReplay");
+ FUSER_PERF_SCOPE("fullSelfReplay");
TORCH_INTERNAL_ASSERT(
- new_self_root->getRootDomain().size() == self->getRootDomain().size(),
+ new_self_root->nDims() == self->getRootDomain().size(),
"Invalid number of IterDomains provided.");
// Map for replay, should be pretty simple.
size_t i = 0;
for (auto id : self->getRootDomain()) {
TORCH_INTERNAL_ASSERT(
- new_self_root->getRootDomain()[i]->start()->isZeroInt() &&
- id->start()->isZeroInt(),
- "Replay does not support IterDomains that do not start at 0, received: ",
- new_self_root->getRootDomain()[i]->start(),
- " and ",
- id->start()->isZeroInt());
+ new_self_root->axis(i)->start() == id->start(),
+ "Replay does not support IterDomains that do not start at 0.");
TORCH_INTERNAL_ASSERT(
- new_self_root->getRootDomain()[i]->getParallelType() ==
- id->getParallelType() &&
- new_self_root->getRootDomain()[i]->isReduction() ==
- id->isReduction() &&
- new_self_root->getRootDomain()[i]->isRFactorProduct() ==
+ new_self_root->axis(i)->getParallelType() == id->getParallelType() &&
+ new_self_root->axis(i)->isReduction() == id->isReduction() &&
+ new_self_root->axis(i)->isRFactorProduct() ==
id->isRFactorProduct() &&
- new_self_root->getRootDomain()[i]->isBroadcast() ==
- id->isBroadcast(),
- "Axes ",
- id,
- " and ",
- new_self_root->getRootDomain()[i],
- " do not match for self replay.");
- axis_map[id] = new_self_root->getRootDomain()[i];
+ new_self_root->axis(i)->isBroadcast() == id->isBroadcast(),
+ "Axes do not match for self replay.");
+ axis_map[id] = new_self_root->axis(i);
i++;
}
}
"Error during replay, didn't replay an axis.");
new_domain[i++] = it->second;
}
-
- if (self->hasRFactor()) {
- std::vector<IterDomain*> new_rfactor_domain(
- self->getMaybeRFactorDomain().size(), nullptr);
- size_t i = 0;
- for (auto id : self->getMaybeRFactorDomain()) {
- auto it = replay.getReplay().find(id);
- TORCH_INTERNAL_ASSERT(
- it != replay.getReplay().end(),
- "Error during replay, didn't replay an axis.");
- new_rfactor_domain[i++] = it->second;
- }
- return new TensorDomain(
- new_self_root->getRootDomain(),
- new_rfactor_domain,
- new_domain,
- new_self_root->contiguity());
- }
}
return new TensorDomain(
- new_self_root->getRootDomain(), new_domain, new_self_root->contiguity());
+ new_self_root->domain(), new_domain, self->contiguity());
}
// Producer could have rfactor axes which consumer may want replayed. We can
// mapped to in the consumer the operations would all be the same. then we want
// to start the replay of the producer from the rfactor root axes, not the root.
std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
- const TensorView* producer,
- const TensorView* consumer,
- int consumer_compute_at_axis,
- const RootDomainMap& root_map) {
- FUSER_PERF_SCOPE("TransformReplay::replayPasC");
-
- // If this is a reduction operation, we may call transform_replay on the
- // tensor view. When this happens, just return thet target view.
- if (producer == consumer)
- return {producer->domain(), producer->nDims()};
+ const TensorDomain* producer,
+ const TensorDomain* consumer,
+ int consumer_compute_at_axis) {
+ FUSER_PERF_SCOPE("replayPasC");
if (consumer_compute_at_axis < 0)
consumer_compute_at_axis += (int)consumer->nDims() + 1;
// consumer ids we need to match in producer
std::vector<IterDomain*> consumer_CA_ids(
- consumer->domain()->domain().begin(),
- consumer->domain()->domain().begin() + consumer_compute_at_axis);
+ consumer->domain().begin(),
+ consumer->domain().begin() + consumer_compute_at_axis);
+
+ // Figure out all inputs required to generate the compute_at dimensions
+ std::unordered_set<Val*> consumer_CA_root_vals = IterVisitor::getInputsTo(
+ std::vector<Val*>(consumer_CA_ids.begin(), consumer_CA_ids.end()));
+
+ std::unordered_set<IterDomain*> consumer_CA_root_ids;
+ for (auto val : consumer_CA_root_vals) {
+ if (val->getValType().value() == ValType::IterDomain) {
+ consumer_CA_root_ids.emplace(val->as<IterDomain>());
+ }
+ }
+
+ // Map of consumer_CA_root_ids to related producer_CA_ids
+ auto replay_root_map =
+ TensorDomain::mapRootCtoP(consumer, producer, consumer_CA_root_ids);
+
+ // Track which root axes in producer we will send to replay
+ std::unordered_set<IterDomain*> producer_roots4replay;
+ for (auto entry : replay_root_map) {
+ producer_roots4replay.emplace(entry.second);
+ }
// Instead of replaying from the root, lets try to play forward the history of
// producer if they match ops on consumer. Enforce if we modify an rfactor
// axis that those ops must match.
- auto forward_replay = BestEffortReplay::replayPasC(
- producer, consumer, consumer_compute_at_axis, root_map);
+ BestEffortReplay forward_replay(
+ producer->domain(), consumer_CA_ids, replay_root_map);
// Make a new map based on all the leaves resulting from best effort replay
id_map forwarded_replay_map;
auto leaf_ids(replay_PasC.getUnorderedLeafIDs());
// Remove all ids that map to the compute at axis, we're going to replay the
- // rest, track all dims needed to match consumer CA dims
- std::vector<IterDomain*> needed_dims;
+ // rest
for (auto c_id : consumer_CA_ids) {
auto it = replay_PasC.getReplay().find(c_id);
if (it == replay_PasC.getReplay().end()) {
TORCH_INTERNAL_ASSERT(
- c_id->isBroadcast() || c_id->isGather(),
+ c_id->isBroadcast(),
"Could not find axis, ",
c_id,
", requested in replay.");
continue;
}
- TORCH_INTERNAL_ASSERT(
- leaf_ids.find(it->second) != leaf_ids.end(),
- "Replayed id to match consumer id ",
- c_id,
- " should be a leaf in replay map.");
- leaf_ids.erase(it->second);
- needed_dims.push_back(it->second);
+ if (leaf_ids.find(it->second) != leaf_ids.end())
+ leaf_ids.erase(it->second);
}
// leaf_ids now contains all producer ID products that are not used to satisfy
// the computeAt Turn into a map so we can play forward these IDs in producer
// (if possible):
id_map producer_self_replay_map;
- for (auto entry : leaf_ids) {
+ for (auto entry : leaf_ids)
producer_self_replay_map[entry.first] = entry.first;
- }
-
- // Check which root domains were used to produce the leaf_ids. We may have
- // picked up extra roots in consumer because of broadcast forwarding.
- std::vector<Val*> unordered_non_root_leaf_vals;
- for (auto leaf_id : replay_PasC.getUnorderedLeafIDs()) {
- if (leaf_id.first->definition() == nullptr) {
- continue;
- } else {
- unordered_non_root_leaf_vals.emplace_back(leaf_id.first);
- }
- }
-
- auto processed_roots = IterVisitor::getInputsTo(unordered_non_root_leaf_vals);
auto producer_root = producer->getMaybeRFactorDomain();
// Any root domain that was not used to generate computeIDs we can also put in
// the map to forward their transformations.
- for (auto producer_root_id : producer_root) {
- if (std::find(
- processed_roots.begin(), processed_roots.end(), producer_root_id) ==
- processed_roots.end() &&
- std::find(needed_dims.begin(), needed_dims.end(), producer_root_id) ==
- needed_dims.end()) {
+ for (auto producer_root_id : producer_root)
+ if (producer_roots4replay.find(producer_root_id) ==
+ producer_roots4replay.end()) {
producer_self_replay_map[producer_root_id] = producer_root_id;
}
- }
// Play forward transformations all producer IDs we can
auto producer_replayed_leaves = BestEffortReplay(
- producer->domain()->domain(),
- producer->domain()->domain(),
- producer_self_replay_map);
+ producer->domain(), producer->domain(), producer_self_replay_map);
/*
* Accumulate axes in to the new domain in the following order, making sure to
auto it = replay_PasC.getReplay().find(c_id);
if (it == replay_PasC.getReplay().end()) {
TORCH_INTERNAL_ASSERT(
- c_id->isBroadcast() || c_id->isGather(),
+ c_id->isBroadcast(),
"Could not find axis, ",
c_id,
", requested in replay.");
}
unsigned int producer_compute_at_axis = new_IDs.size();
-
// Add axes in (2)
- for (auto c_id : consumer->domain()->domain()) {
+ for (auto c_id : consumer->domain()) {
auto it = replay_PasC.getReplay().find(c_id);
if (it != replay_PasC.getReplay().end()) {
auto id = it->second;
}
// Add axes in (3)
- for (auto id : producer->domain()->domain()) {
+ for (auto id : producer->domain()) {
if (producer_replayed_leaves.getUnorderedLeafIDs().find(id) !=
producer_replayed_leaves.getUnorderedLeafIDs().end()) {
if (used_IDs.find(id) == used_IDs.end()) {
producer->getRootDomain(),
producer->getRFactorDomain(),
new_IDs,
- producer->domain()->contiguity());
-
+ producer->contiguity());
return {replayed, producer_compute_at_axis};
}
std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
- const TensorView* consumer,
- const TensorView* producer,
- int producer_compute_at_axis,
- const RootDomainMap& root_map) {
- FUSER_PERF_SCOPE("TransformReplay::replayCasP");
-
- // If this is a reduction operation, we may call transform_replay on the same
- // tensor view. When this happens, just return thet target view.
- if (consumer == producer)
- return {consumer->domain(), consumer->nDims()};
+ const TensorDomain* consumer,
+ const TensorDomain* producer,
+ int producer_compute_at_axis) {
+ FUSER_PERF_SCOPE("replayCasP");
if (producer_compute_at_axis < 0)
producer_compute_at_axis += (int)producer->nDims() + 1;
// producer ids we need to match in consumer
std::vector<IterDomain*> producer_CA_ids(
- producer->domain()->domain().begin(),
- producer->domain()->domain().begin() + producer_compute_at_axis);
+ producer->domain().begin(),
+ producer->domain().begin() + producer_compute_at_axis);
producer_CA_ids = TensorDomain::noReductions(producer_CA_ids);
+ // Grab root domains of producer and consumer
+ std::vector<IterDomain*> consumer_root = consumer->getRootDomain();
+
+ // If producer has an rfactor root, that's what will match the consumer
+ std::vector<IterDomain*> producer_root = producer->getMaybeRFactorDomain();
+
+ // Figure out all inputs required to generate the compute_at dimensions. We
+ // need all deps because inputs on producer may be in getRootDomain, but we
+ // may need in rFactorDomain
+ std::unordered_set<Val*> all_CA_id_deps = DependencyCheck::getAllValsBetween(
+ {producer_root.begin(), producer_root.end()},
+ {producer_CA_ids.begin(), producer_CA_ids.end()});
+
+ // Figure out which root IDs we need:
+ std::unordered_set<IterDomain*> producer_CA_root_ids;
+ for (IterDomain* id : producer_root) {
+ if (all_CA_id_deps.find(id) != all_CA_id_deps.end())
+ producer_CA_root_ids.emplace(id);
+ }
+
+ auto replay_root_map =
+ TensorDomain::mapRootPtoC(producer, consumer, producer_CA_root_ids);
+
+ // Track which root axes in producer we will send to replay
+ std::unordered_set<IterDomain*> consumer_roots4replay;
+ for (auto entry : replay_root_map) {
+ consumer_roots4replay.emplace(entry.second);
+ }
+
// Instead of replaying from the root, lets try to forward the history of
// consumer if they match ops on producer. Enforce if we modify an rfactor
// axis that those ops match.
- BestEffortReplay forward_replay = BestEffortReplay::replayCasP(
- consumer, producer, producer_compute_at_axis, root_map);
+ BestEffortReplay forward_replay(
+ consumer->domain(), producer_CA_ids, replay_root_map);
- // Track dangling leaves which can be produced in
- // BestEffortReplay::replayCasP these don't have any equivalent in producer
- // so they're not in the map. We will simply map them to themselves so we
- // don't lose them.
id_map forwarded_replay_map;
- auto forward_dangling_leaves = forward_replay.getUnorderedLeafIDs();
for (auto entry : forward_replay.getReplay()) {
- if (forward_dangling_leaves.find(entry.second) !=
- forward_dangling_leaves.end()) {
+ if (forward_replay.getUnorderedLeafIDs().find(entry.second) !=
+ forward_replay.getUnorderedLeafIDs().end())
forwarded_replay_map[entry.first] = entry.second;
- forward_dangling_leaves.erase(entry.second);
- }
}
// Replay producer dimensions.
auto leaf_ids(replay_CasP.getUnorderedLeafIDs());
// Remove all ids that map to the compute at axis, we're going to replay the
- // rest, track all dims that are needed to match producer CA dims
- std::vector<IterDomain*> needed_dims;
+ // rest
for (auto p_id : producer_CA_ids) {
auto it = replay_CasP.getReplay().find(p_id);
TORCH_INTERNAL_ASSERT(
"Could not find axis, ",
p_id,
", requested in replay.");
- TORCH_INTERNAL_ASSERT(
- leaf_ids.find(it->second) != leaf_ids.end(),
- "Replayed id to match producer id ",
- p_id,
- " should be a leaf in replay map.");
- leaf_ids.erase(it->second);
- needed_dims.push_back(it->second);
+ if (leaf_ids.find(it->second) != leaf_ids.end())
+ leaf_ids.erase(it->second);
}
// leaf_ids now contains all consumer ID products that are not used to satisfy
- // the computeAt. Turn into a map so we can play forward these IDs in
- // consumer (if possible):
+ // the computeAt Turn into a map so we can play forward these IDs in consumer
+ // (if possible):
id_map consumer_self_replay_map;
- for (auto entry : leaf_ids) {
- consumer_self_replay_map[entry.first] = entry.first;
- }
-
- for (auto entry : forward_dangling_leaves) {
+ for (auto entry : leaf_ids)
consumer_self_replay_map[entry.first] = entry.first;
- }
-
- // Check which root domains were used to produce the leaf_ids. We may have
- // picked up extra roots in consumer because of broadcast forwarding.
- std::vector<Val*> unordered_non_root_leaf_vals;
- for (auto leaf_id : replay_CasP.getUnorderedLeafIDs()) {
- if (leaf_id.first->definition() == nullptr) {
- continue;
- } else {
- unordered_non_root_leaf_vals.emplace_back(leaf_id.first);
- }
- }
-
- auto processed_roots = IterVisitor::getInputsTo(unordered_non_root_leaf_vals);
-
- std::vector<IterDomain*> consumer_root = consumer->getRootDomain();
// Any root domain that was not used to generate computeIDs we can also put in
// the map to forward their transformations.
- for (auto consumer_root_id : consumer_root) {
- if (std::find(
- processed_roots.begin(), processed_roots.end(), consumer_root_id) ==
- processed_roots.end() &&
- // Don't re-add roots that may have directly mapped in the replay
- std::find(needed_dims.begin(), needed_dims.end(), consumer_root_id) ==
- needed_dims.end()) {
+ for (auto consumer_root_id : consumer_root)
+ if (consumer_roots4replay.find(consumer_root_id) ==
+ consumer_roots4replay.end())
consumer_self_replay_map[consumer_root_id] = consumer_root_id;
- }
- }
// Play forward transformations all consumer IDs we can
auto consumer_replayed_leaves = BestEffortReplay(
- consumer->domain()->domain(),
- consumer->domain()->domain(),
- consumer_self_replay_map);
+ consumer->domain(), consumer->domain(), consumer_self_replay_map);
/*
* Accumulate axes in to the new domain in the following order, making sure to
}
// Add axes in (2)
- for (auto p_id : producer->domain()->domain()) {
+ for (auto p_id : producer->domain()) {
auto it = replay_CasP.getReplay().find(p_id);
if (it != replay_CasP.getReplay().end()) {
auto id = it->second;
}
// Add axes in (3)
- for (auto id : consumer->domain()->domain()) {
+ for (auto id : consumer->domain()) {
if (consumer_replayed_leaves.getUnorderedLeafIDs().find(id) !=
consumer_replayed_leaves.getUnorderedLeafIDs().end()) {
if (used_IDs.find(id) == used_IDs.end()) {
consumer->getRootDomain(),
consumer->getRFactorDomain(),
new_IDs,
- consumer->domain()->contiguity());
+ consumer->contiguity());
return {replayed, producer_CA_ids.size()};
}
// replay Producer as Consumer
-std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
- const TensorView* producer,
- const TensorView* consumer,
+std::pair<TensorView*, unsigned int> TransformReplay::replayPasC(
+ TensorView* producer,
+ TensorView* consumer,
int compute_at_axis) {
- // Use the pairwise root map as a default mapper
- PairwiseRootDomainMap root_map(producer, consumer);
- return replayPasC(producer, consumer, compute_at_axis, root_map);
-}
-
-std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
- const TensorView* consumer,
- const TensorView* producer,
- int compute_at_axis) {
- // Use the pairwise root map as a default mapper
- PairwiseRootDomainMap root_map(producer, consumer);
- return replayCasP(consumer, producer, compute_at_axis, root_map);
-}
-
-namespace {
-
-std::deque<TensorView*> deduplicate(const std::deque<TensorView*>& tv_deuqe) {
- std::deque<TensorView*> deduplicated;
- std::unordered_set<TensorView*> inserted;
- for (auto tv_entry : tv_deuqe) {
- if (inserted.find(tv_entry) == inserted.end()) {
- deduplicated.emplace_back(tv_entry);
- inserted.emplace(tv_entry);
- }
- }
- return deduplicated;
-}
-
-std::deque<TensorView*> tvInputs(Expr* expr) {
- auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs());
- return std::deque<TensorView*>(tv_inputs.begin(), tv_inputs.end());
-}
-
-std::deque<TensorView*> tvOutputs(Expr* expr) {
- auto tv_outputs = ir_utils::filterByType<TensorView>(expr->outputs());
- return std::deque<TensorView*>(tv_outputs.begin(), tv_outputs.end());
-}
-
-std::deque<TensorView*> consumersOf(TensorView* tv) {
- std::deque<TensorView*> consumer_tvs;
- for (auto def : tv->uses()) {
- auto outs = tvOutputs(def);
- consumer_tvs.insert(consumer_tvs.end(), outs.begin(), outs.end());
- }
- return deduplicate(consumer_tvs);
-}
-
-std::deque<TensorView*> producersFor(TensorView* tv) {
- auto def = tv->definition();
- if (def == nullptr) {
- return {};
- }
-
- return deduplicate(tvInputs(def));
-}
-
-}; // namespace
-
-bool TransformPropagator::replayPasC(
- TensorView* producer_tv,
- TensorView* consumer_tv) {
- if (producer_tv == starting_tv) {
- return false;
- }
-
- auto consumer_pos_it = replayed_pos.find(consumer_tv);
- if (consumer_pos_it == replayed_pos.end()) {
- return false;
- }
-
- auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv);
- auto producerAsC = TransformReplay::replayPasC(
- producer_tv, consumer_tv, consumer_pos_it->second, pairwiseMap);
-
- if (replayed_pos.find(producer_tv) != replayed_pos.end()) {
- if (producerAsC.second <= replayed_pos.at(producer_tv)) {
- return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks)
- }
- }
-
- producer_tv->setDomain(producerAsC.first);
- replayed_pos[producer_tv] = producerAsC.second;
-
- return true;
-}
-
-bool TransformPropagator::replayCasP(
- TensorView* consumer_tv,
- TensorView* producer_tv) {
- if (consumer_tv == starting_tv) {
- return false;
- }
-
- auto producer_pos_it = replayed_pos.find(producer_tv);
- if (producer_pos_it == replayed_pos.end()) {
- return false;
- }
-
- auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv);
- auto consumerAsP = TransformReplay::replayCasP(
- consumer_tv, producer_tv, producer_pos_it->second, pairwiseMap);
-
- if (replayed_pos.find(consumer_tv) != replayed_pos.end()) {
- if (consumerAsP.second <= replayed_pos.at(consumer_tv)) {
- return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks)
- }
- }
-
- consumer_tv->setDomain(consumerAsP.first);
- replayed_pos[consumer_tv] = consumerAsP.second;
-
- return true;
-}
+ // If this is a reduction operation, we may call transform_replay on the
-TransformPropagator::TransformPropagator(TensorView* from) : starting_tv(from) {
- // Tensors we should try to propagate in the consumer direction
- std::deque<TensorView*> consumer_propagation{starting_tv};
-
- // Tensors we should try to propagate in the producer direction
- std::deque<TensorView*> producer_propagation{starting_tv};
-
- // Seed position with local tv
- replayed_pos[from] = from->nDims();
-
- // While tensor views are being replayed, if they're modified, make sure we
- // propagate back to all producers as well as consumers. This is definitely
- // not the most efficient implementation as what we do is any time a tv is
- // changed we propagate both forward and backward. If a forward pass touches
- // every node, the backward pass will try to replay every node, potentially
- // multiple times.
- while (!consumer_propagation.empty() || !producer_propagation.empty()) {
- while (!consumer_propagation.empty()) {
- // Tensor view we will replay onto consumers
- auto tv = consumer_propagation.front();
- consumer_propagation.pop_front();
-
- // Replay tv forward to its consumers.
- for (auto consumer_tv : consumersOf(tv)) {
- auto replayed = replayCasP(consumer_tv, tv);
- // If consumer has changed, mark we should propagate its consumers
-
- if (replayed) {
- consumer_propagation.emplace_back(consumer_tv);
- producer_propagation.emplace_back(consumer_tv);
- }
- }
- }
+ // tensor view. When this happens, just return thet target view.
+ if (producer == consumer)
+ return {producer, 0};
- while (!producer_propagation.empty()) {
- // Tensor view we will replay onto producers
- auto tv = producer_propagation.front();
- producer_propagation.pop_front();
- // Replay tv backward to its producers
- for (auto producer_tv : producersFor(tv)) {
- auto replayed = replayPasC(producer_tv, tv);
- if (replayed) {
- producer_propagation.emplace_back(producer_tv);
- consumer_propagation.emplace_back(producer_tv);
- }
- }
- }
- }
+ std::pair<TensorDomain*, unsigned int> replay =
+ replayPasC(producer->domain(), consumer->domain(), compute_at_axis);
+ producer->setDomain(replay.first);
+ return {producer, replay.second};
}
-void TransformPropagator::from(TensorView* tv) {
- TransformPropagator propagate(tv);
+std::pair<TensorView*, unsigned int> TransformReplay::replayCasP(
+ TensorView* consumer,
+ TensorView* producer,
+ int compute_at_axis) {
+ // If this is a reduction operation, we may call transform_replay on the same
+ // tensor view. When this happens, just return thet target view.
+ if (consumer == producer)
+ return {consumer, 0};
+ std::pair<TensorDomain*, unsigned int> replay =
+ replayCasP(consumer->domain(), producer->domain(), compute_at_axis);
+ consumer->setDomain(replay.first);
+ return {consumer, replay.second};
}
} // namespace cuda
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <algorithm>
-#include <unordered_map>
#include <vector>
namespace torch {
class TensorDomain;
class TensorView;
-class RootDomainMap;
class TORCH_CUDA_CU_API TransformReplay {
public:
// Replay producer as consumer, returns {producer, producer_compute_at_axis}.
static std::pair<TensorDomain*, unsigned int> replayPasC(
- const TensorView* producer,
- const TensorView* consumer,
+ const TensorDomain* producer,
+ const TensorDomain* consumer,
+ int consumer_compute_at_axis);
+
+ // Replay producer as consumer, returns {producer, producer_compute_at_axis}.
+ static std::pair<TensorView*, unsigned int> replayPasC(
+ TensorView* producer,
+ TensorView* consumer,
int consumer_compute_at_axis);
- static std::pair<TensorDomain*, unsigned int> replayPasC(
- const TensorView* producer,
- const TensorView* consumer,
- int consumer_compute_at_axis,
- const RootDomainMap& root_map);
- // Replay producer as consumer, returns {replayed_consumer_domain,
- // consumer_compute_at_axis}.
+ // Replay producer as consumer, returns {consumer, consumer_compute_at_axis}.
static std::pair<TensorDomain*, unsigned int> replayCasP(
- const TensorView* consumer,
- const TensorView* producer,
+ const TensorDomain* consumer,
+ const TensorDomain* producer,
+ int producer_compute_at_axis);
+
+ // Replay producer as consumer, returns {consumer, consumer_compute_at_axis}.
+ static std::pair<TensorView*, unsigned int> replayCasP(
+ TensorView* consumer,
+ TensorView* producer,
int producer_compute_at_axis);
- static std::pair<TensorDomain*, unsigned int> replayCasP(
- const TensorView* consumer,
- const TensorView* producer,
- int producer_compute_at_axis,
- const RootDomainMap& root_map);
// Self replay.
static TensorDomain* fullSelfReplay(
const TensorDomain* self);
};
-class TORCH_CUDA_CU_API TransformPropagator {
- private:
- bool replayPasC(TensorView* producer_tv, TensorView* consumer_tv = nullptr);
- bool replayCasP(TensorView* consumer_tv, TensorView* producer_tv = nullptr);
-
- TransformPropagator(TensorView* from);
-
- private:
- std::unordered_map<TensorView*, unsigned int> replayed_pos;
- TensorView* starting_tv = nullptr;
-
- public:
- static void from(TensorView* tv);
-};
-
} // namespace cuda
} // namespace fuser
} // namespace jit
return ReplayTransformations::handle(s);
// outer loop size
- Val* remainder = ceilDiv(mapped->extent(), s->factor());
+ Val* oe = ceilDiv(mapped->extent(), s->factor());
// Manually replay the split, making reduction = false and rfactor = true
// outer IterDomain
IterDomain* ido = new IterDomain(
new Int(0),
- s->innerSplit() ? remainder->as<Int>() : s->factor(),
+ oe->as<Int>(),
mapped->getParallelType(),
rfactor_outer ? IterType::Reduction : IterType::Iteration,
true); // broadcast
// inner IterDomain
IterDomain* idi = new IterDomain(
new Int(0),
- s->innerSplit() ? s->factor() : remainder->as<Int>(),
+ s->factor(),
mapped->getParallelType(),
rfactor_inner ? IterType::Reduction : IterType::Iteration,
true);
// Generate the split node
- new Split(ido, idi, mapped, s->factor(), s->innerSplit());
+ new Split(ido, idi, mapped, s->factor());
// Remove mapped id from leaf IDs
leaf_ids_.erase(mapped);
TensorDomain* TransformRFactor::runReplay(
TensorDomain* orig_td,
std::vector<int> axes) {
- FUSER_PERF_SCOPE("TransformRFactor::runReplay");
+ FUSER_PERF_SCOPE("runReplay");
TORCH_CHECK(!axes.empty(), "No axes provided to rfactor replay.");
TensorDomain* TransformRFactor::runReplay2(
TensorDomain* orig_td,
std::vector<int> axes) {
- FUSER_PERF_SCOPE("TransformRFactor::runReplay2");
+ FUSER_PERF_SCOPE("runReplay2");
int ndims = (int)orig_td->nDims();
namespace fuser {
namespace cuda {
-bool isFloatingPointType(DataType dtype) {
- switch (dtype) {
- case DataType::Bool:
- return false;
- case DataType::Double:
- case DataType::Float:
- case DataType::Half:
- return true;
- case DataType::Int:
- case DataType::Int32:
- return false;
- case DataType::Null:
- TORCH_CHECK(
- false, "Null type is not a valid argument to isFloatingPoint");
- default:
- TORCH_CHECK(false, "Type not supported in isFloatingPoint");
- }
-}
-
-bool isIntegralType(DataType dtype) {
- switch (dtype) {
- case DataType::Bool:
- case DataType::Double:
- case DataType::Float:
- case DataType::Half:
- return false;
- case DataType::Int:
- case DataType::Int32:
- return true;
- case DataType::Null:
- TORCH_CHECK(
- false, "Null type is not a valid argument to isFloatingPoint");
- default:
- TORCH_CHECK(false, "Type not supported in isFloatingPoint");
- }
-}
-
-bool isIntegerOp(const BinaryOpType bopt) {
- return bopt >= BinaryOpType::Mod && bopt <= BinaryOpType::Rshift;
-}
-
-bool isLogicalOp(const BinaryOpType bopt) {
- return bopt >= BinaryOpType::Eq && bopt <= BinaryOpType::NE;
-}
-
-bool alsoBooleanOperator(const BinaryOpType bopt) {
- return bopt >= BinaryOpType::And && bopt <= BinaryOpType::Xor;
-}
-
-bool alsoBooleanOperator(const UnaryOpType uopt) {
- return uopt >= UnaryOpType::Not && uopt <= UnaryOpType::Not;
-}
-
-bool noFullIntegerSupport(const BinaryOpType bopt) {
- return bopt == BinaryOpType::Div || bopt == BinaryOpType::Pow ||
- bopt == BinaryOpType::Fmod;
-}
-
// Return highest on list (smallest enum val)
DataType promote_type(const DataType& t1, const DataType& t2) {
TORCH_CHECK(
// Return highest on list (smallest enum val)
ValType promote_type(const ValType& t1, const ValType& t2) {
- if (t1 == ValType::TensorView || t2 == ValType::TensorView) {
- return ValType::TensorView;
- }
- if (t1 == ValType::Scalar &&
- (t2 == ValType::Scalar || t2 == ValType::NamedScalar)) {
- return ValType::Scalar;
- }
- if (t2 == ValType::Scalar &&
- (t1 == ValType::Scalar || t1 == ValType::NamedScalar)) {
- return ValType::Scalar;
- }
- TORCH_CHECK(false, "Expected promotable ValTypes but got: ", t1, " and ", t2);
+ TORCH_CHECK(
+ t1 >= ValType::TensorView && t2 >= ValType::TensorView,
+ "Expected promotable ValTypes but got: ",
+ t1,
+ " and ",
+ t2);
+ // Check that it's a promotable type (with dtype)
+ // static_assert??
+ return t1 < t2 ? t1 : t2;
}
static const char* data_type2string(DataType t) {
switch (t) {
case DataType::Bool:
return "bool";
- case DataType::Double:
- return "double";
case DataType::Float:
return "float";
case DataType::Half:
return "__half";
case DataType::Int:
return "int64_t";
- case DataType::Int32:
- return "int";
case DataType::Null:
return "nullptr";
default:
static const char* val_type2string(ValType t) {
switch (t) {
+ case ValType::TensorIndex:
+ return "TensorIndex";
case ValType::TensorView:
return "TensorView";
case ValType::TensorDomain:
return "Scalar";
case ValType::NamedScalar:
return "NamedScalar";
+ case ValType::KirIterDomain:
+ return "KirIterDomain";
+ case ValType::KirNamedScalar:
+ return "KirNamedScalar";
+ case ValType::KirScalar:
+ return "KirScalar";
+ case ValType::KirTensorDomain:
+ return "KirTensorDomain";
+ case ValType::KirTensorView:
+ return "KirTensorView";
default:
- TORCH_INTERNAL_ASSERT(false, "No string found for val type.");
+ break;
}
+ TORCH_INTERNAL_ASSERT(false, "No string found for val type.");
+ return nullptr;
}
static const char* expr_type2string(ExprType t) {
return "TernaryOp";
case ExprType::ReductionOp:
return "ReductionOp";
+ case ExprType::GridReduction:
+ return "GridReduction";
case ExprType::BroadcastOp:
return "BroadcastOp";
- case ExprType::ShiftOp:
- return "ShiftOp";
+ case ExprType::ForLoop:
+ return "ForLoop";
+ case ExprType::IfThenElse:
+ return "IfThenElse";
+ case ExprType::Allocate:
+ return "Allocate";
+ case ExprType::Sync:
+ return "SyncThreads";
case ExprType::Split:
return "Split";
case ExprType::Merge:
return "Merge";
+ case ExprType::KirUnaryOp:
+ return "KirUnaryOp";
+ case ExprType::KirBinaryOp:
+ return "KirBinaryOp";
+ case ExprType::KirTernaryOp:
+ return "KirTernaryOp";
+ case ExprType::KirReductionOp:
+ return "KirReductionOp";
+ case ExprType::KirBroadcastOp:
+ return "KirBroadcastOp";
default:
- TORCH_INTERNAL_ASSERT(false, "No string found for expr type.");
- }
-}
-
-bool needFloatSuffix(UnaryOpType t) {
- switch (t) {
- case UnaryOpType::Abs:
- case UnaryOpType::Cast:
- case UnaryOpType::Frac:
- case UnaryOpType::Gelu:
- case UnaryOpType::Silu:
- case UnaryOpType::Neg:
- case UnaryOpType::Relu:
- case UnaryOpType::Reciprocal:
- case UnaryOpType::Set:
- case UnaryOpType::Sigmoid:
- return false;
- default:
- return true;
+ break;
}
+ TORCH_INTERNAL_ASSERT(false, "No string found for expr type.");
+ return nullptr;
}
static const char* unary_op_type2string(UnaryOpType t) {
switch (t) {
case UnaryOpType::Abs:
- return "abs";
+ return "fabs";
case UnaryOpType::Acos:
- return "acos";
+ return "acosf";
case UnaryOpType::Asin:
- return "asin";
+ return "asinf";
case UnaryOpType::Atan:
- return "atan";
+ return "atanf";
case UnaryOpType::Atanh:
- return "atanh";
+ return "atanhf";
case UnaryOpType::Cast:
return "cast";
case UnaryOpType::Ceil:
- return "ceil";
+ return "ceilf";
case UnaryOpType::Cos:
- return "cos";
+ return "cosf";
case UnaryOpType::Cosh:
- return "cosh";
+ return "coshf";
case UnaryOpType::Exp:
- return "exp";
+ return "expf";
case UnaryOpType::Expm1:
- return "expm1";
+ return "expm1f";
case UnaryOpType::Erf:
- return "erf";
+ return "erff";
case UnaryOpType::Erfc:
- return "erfc";
+ return "erfcf";
case UnaryOpType::Floor:
- return "floor";
+ return "floorf";
case UnaryOpType::Frac:
return "frac";
case UnaryOpType::Gelu:
return "gelu";
- case UnaryOpType::Silu:
- return "silu";
case UnaryOpType::Lgamma:
- return "lgamma";
+ return "lgammaf";
case UnaryOpType::Log:
- return "log";
+ return "logf";
case UnaryOpType::Log10:
- return "log10";
+ return "log10f";
case UnaryOpType::Log1p:
- return "log1p";
+ return "log1pf";
case UnaryOpType::Log2:
- return "log2";
+ return "log2f";
case UnaryOpType::Neg:
return "neg";
- case UnaryOpType::Not:
- return "not";
case UnaryOpType::RandLike:
return "randLike";
case UnaryOpType::Reciprocal:
case UnaryOpType::Relu:
return "relu";
case UnaryOpType::Rsqrt:
- return "rsqrt";
+ return "rsqrtf";
case UnaryOpType::Round:
- return "nearbyint";
+ return "roundf";
case UnaryOpType::Set:
return "set";
case UnaryOpType::Sigmoid:
return "sigmoid";
case UnaryOpType::Sin:
- return "sin";
+ return "sinf";
case UnaryOpType::Sinh:
- return "sinh";
+ return "sinhf";
case UnaryOpType::Sqrt:
- return "sqrt";
+ return "sqrtf";
case UnaryOpType::Tan:
- return "tan";
+ return "tanf";
case UnaryOpType::Tanh:
- return "tanh";
+ return "tanhf";
case UnaryOpType::Trunc:
- return "trunc";
+ return "truncf";
default:
- TORCH_INTERNAL_ASSERT(false, "No string found for unary op type.");
+ break;
}
-}
-
-std::string stringifyBooleanOp(const UnaryOpType uopt) {
- TORCH_INTERNAL_ASSERT(
- uopt == UnaryOpType::Not, uopt, " is not a boolean operator.");
- return "!";
+ TORCH_INTERNAL_ASSERT(false, "No string found for unary op type.");
+ return nullptr;
}
static const char* unary_op_type_inline_op2string(UnaryOpType t) {
switch (t) {
case UnaryOpType::Neg:
return "-";
- case UnaryOpType::Not:
- return "~";
case UnaryOpType::Set:
return "";
- case UnaryOpType::Address:
- return "(int64_t) &";
default:
break;
}
return nullptr;
}
-bool needFloatSuffix(BinaryOpType t) {
- switch (t) {
- case BinaryOpType::Atan2:
- case BinaryOpType::Div:
- case BinaryOpType::Fmod:
- case BinaryOpType::Max:
- case BinaryOpType::Min:
- case BinaryOpType::Pow:
- return true;
- default:
- return false;
- }
-}
-
static const char* binary_op_type2string(BinaryOpType t) {
switch (t) {
case BinaryOpType::Add:
return "add";
case BinaryOpType::Atan2:
- return "atan2";
+ return "atan2f";
case BinaryOpType::Div:
return "div";
case BinaryOpType::Fmod:
- return "fmod";
+ return "fmodf";
case BinaryOpType::Max:
- return "fmax";
+ return "fmaxf";
case BinaryOpType::Min:
- return "fmin";
+ return "fminf";
case BinaryOpType::Mul:
return "mul";
case BinaryOpType::Pow:
- return "pow";
+ return "powf";
case BinaryOpType::Remainder:
return "remainder";
case BinaryOpType::Sub:
case BinaryOpType::NE:
return "notEqual";
default:
- TORCH_INTERNAL_ASSERT(false, "No string found for binary op type.");
- }
-}
-
-static const char* binary_op_integer_op2string(BinaryOpType t) {
- switch (t) {
- case BinaryOpType::Max:
- return "max";
- case BinaryOpType::Min:
- return "min";
- default:
break;
}
+ TORCH_INTERNAL_ASSERT(false, "No string found for binary op type.");
return nullptr;
}
return "+";
case BinaryOpType::Div:
return "/";
+ case BinaryOpType::Mod:
+ return "%";
case BinaryOpType::Mul:
return "*";
case BinaryOpType::Sub:
return "-";
- // Integer ops
- case BinaryOpType::Mod:
- return "%";
- case BinaryOpType::Lshift:
- return "<<";
- case BinaryOpType::Rshift:
- return ">>";
// Logical Ops
+ case BinaryOpType::And:
+ return "&&";
case BinaryOpType::Eq:
return "==";
case BinaryOpType::GE:
return "<";
case BinaryOpType::NE:
return "!=";
- // Assume bitwise, otherwise use stringifyBooleanOp
- case BinaryOpType::And:
- return "&";
- case BinaryOpType::Or:
- return "|";
- case BinaryOpType::Xor:
- return "^";
default:
break;
}
return nullptr;
}
-std::string stringifyBooleanOp(const BinaryOpType bopt) {
- switch (bopt) {
- case BinaryOpType::And:
- return "&&";
- case BinaryOpType::Or:
- return "||";
- case BinaryOpType::Xor:
- return "!=";
- default:
- TORCH_INTERNAL_ASSERT(false, bopt, " is not a boolean operator.")
- }
-}
-
static const char* ternary_op_type2string(TernaryOpType t) {
switch (t) {
case TernaryOpType::Clamp:
case TernaryOpType::Where:
return "where";
default:
- TORCH_INTERNAL_ASSERT(false, "Unexpected TernaryOpType", t);
+ break;
}
+ TORCH_INTERNAL_ASSERT(false, "No string found for ternary op type.");
+ return nullptr;
}
static const char* parallel_type2string(ParallelType t) {
return "threadIdx.x";
case ParallelType::Vectorize:
return "V";
- case ParallelType::MisalignedVectorize:
- return "MV";
case ParallelType::Unroll:
- return "UR";
- case ParallelType::Unswitch:
- return "US";
+ return "U";
case ParallelType::Serial:
return "S";
default:
- TORCH_INTERNAL_ASSERT(false, "Unexpected ParallelType", t);
+ break;
}
+ TORCH_INTERNAL_ASSERT(false, "No string found for parallel type.");
+ return nullptr;
}
static const char* memory_type2string(MemoryType t) {
case MemoryType::Global:
return "global";
default:
- TORCH_INTERNAL_ASSERT(false, "Unexpected MemoryType", t);
+ break;
}
+ TORCH_INTERNAL_ASSERT(false, "No string found for memory type.");
+ return nullptr;
}
static const char* iter_type2string(IterType t) {
return "sb";
case IterType::BroadcastWithoutStride:
return "b";
- case IterType::Gather:
- return "g";
default:
- // Don't try to print t as it would recursively call this function
- TORCH_INTERNAL_ASSERT(false, "Unexpected IterType");
+ TORCH_INTERNAL_ASSERT(false, "No string found for IterDomain type.");
+ return nullptr;
}
}
case ParallelType::TIDx:
return "blockDim.x";
default:
- TORCH_INTERNAL_ASSERT(false, "Unexpected parallel type", t);
+ break;
}
+ TORCH_INTERNAL_ASSERT(false, "Could not find size of the thread type ", t);
+ return nullptr;
}
const unsigned int _WORD_SHIFT = 16;
static const char* supported_casts2string(
const std::pair<DataType, DataType>& t) {
switch (supported_switch_pair(std::get<0>(t), std::get<1>(t))) {
- case supported_switch_pair(DataType::Double, DataType::Float):
- return "(float)";
- case supported_switch_pair(DataType::Float, DataType::Double):
- return "(double)";
- case supported_switch_pair(DataType::Int32, DataType::Float):
- return "(float)";
- case supported_switch_pair(DataType::Int, DataType::Float):
- return "(double)";
- case supported_switch_pair(DataType::Int32, DataType::Int):
- return "(int64_t)";
case supported_switch_pair(DataType::Float, DataType::Half):
return "__float2half";
case supported_switch_pair(DataType::Half, DataType::Float):
return "__half2float";
- case supported_switch_pair(DataType::Bool, DataType::Float):
- return "float";
default:
- return nullptr;
+ break;
+ }
+ return nullptr;
+}
+
+bool is_logical_op(const BinaryOpType& bot) {
+ switch (bot) {
+ case BinaryOpType::And:
+ case BinaryOpType::Eq:
+ case BinaryOpType::GE:
+ case BinaryOpType::GT:
+ case BinaryOpType::LE:
+ case BinaryOpType::LT:
+ case BinaryOpType::NE:
+ return true;
+ default:
+ return false;
}
}
switch (scalar_type) {
case at::ScalarType::Bool:
return DataType::Bool;
- case at::ScalarType::Double:
- return DataType::Double;
case at::ScalarType::Float:
return DataType::Float;
case at::ScalarType::Half:
return DataType::Half;
case at::ScalarType::Long:
return DataType::Int;
- case at::ScalarType::Int:
- return DataType::Int32;
default:
+ TORCH_INTERNAL_ASSERT(false, "No data type found for scalar type.");
return DataType::Null;
}
}
switch (data_type) {
case DataType::Bool:
return at::ScalarType::Bool;
- case DataType::Double:
- return at::ScalarType::Double;
case DataType::Float:
return at::ScalarType::Float;
case DataType::Half:
return at::ScalarType::Half;
case DataType::Int:
return at::ScalarType::Long;
- case DataType::Int32:
- return at::ScalarType::Int;
default:
TORCH_INTERNAL_ASSERT(false, "No data type found for scalar type.");
+ return at::ScalarType::Undefined;
}
}
: c10::nullopt;
}
-c10::optional<std::string> integer_op_str(const BinaryOpType botype) {
- const char* str = binary_op_integer_op2string(botype);
- return str != nullptr ? c10::optional<std::string>(std::string(str))
- : c10::nullopt;
-}
-
std::string stringifyThreadSize(const ParallelType ptype) {
return thread_size2string(ptype);
}
return parallel_type2string(ptype);
}
-std::string typePrefix(const DataType data_type) {
- switch (data_type) {
- case DataType::Bool:
- return "b";
- case DataType::Double:
- return "d";
- case DataType::Float:
- case DataType::Half:
- return "f";
- case DataType::Int:
- case DataType::Int32:
- return "i";
- default:
- TORCH_INTERNAL_ASSERT(false, "No data type found for scalar type.");
- }
-}
-
-bool isParallelTypeThreadDim(ParallelType ptype) {
- return ptype == ParallelType::TIDx || ptype == ParallelType::TIDy ||
- ptype == ParallelType::TIDz;
-}
-
-bool isParallelTypeBlockDim(ParallelType ptype) {
- return ptype == ParallelType::BIDx || ptype == ParallelType::BIDy ||
- ptype == ParallelType::BIDz;
-}
-
-bool isParallelTypeThread(ParallelType ptype) {
- return isParallelTypeBlockDim(ptype) || isParallelTypeThreadDim(ptype);
-}
-
-bool isParallelTypeVectorize(ParallelType ptype) {
- return ptype == ParallelType::Vectorize ||
- ptype == ParallelType::MisalignedVectorize;
-}
-
c10::optional<std::string> cast_func_str(
const std::pair<DataType, DataType>& cast) {
const char* str = supported_casts2string(cast);
switch (type) {
case DataType::Bool:
return sizeof(bool);
- case DataType::Double:
- return sizeof(double);
case DataType::Float:
- return sizeof(float);
+ return 4;
case DataType::Half:
- return sizeof(at::Half);
+ return 2;
case DataType::Int:
- return sizeof(uint64_t);
- case DataType::Int32:
- return sizeof(uint32_t);
+ return 4;
default:
TORCH_INTERNAL_ASSERT(false, "Size undefined for data type, ", type);
}
namespace fuser {
namespace cuda {
-enum class KernelIndexMode { INT32, INT64 };
-
// https://stackoverflow.com/questions/18837857/cant-use-enum-class-as-unordered-map-key
struct TypeHash {
template <typename T>
TensorView,
Scalar,
NamedScalar,
-};
-// Manual - The user provides the Bool value. Predicate generation is bypassed.
-// Inline corresponds with PredicateCompute::getInlinePredicate
-// Unswitch corresponds with UnswitchPredicate::get
-// Misaligned - PredicateCompute::getInlinePredicate + Misaligned flag
-// Shift - ShiftPredicateInserter::getShiftPredicate
-// Padding - ShiftPredicateInserter::getPaddingPredicate
-// ReductionWrite - Same as Inline but without reduction axes
-enum class PredicateType {
- Manual,
- Inline,
- Unswitch,
- Vectorize,
- Misaligned,
- Shift,
- Padding,
- ReductionWrite
+ // Temporary: Kernel IR nodes
+ TensorIndex,
+ KirNamedScalar,
+ KirScalar,
+ KirTensorDomain,
+ KirIterDomain,
+ KirTensorView,
};
-enum class DataType { Double, Float, Half, Int, Int32, Bool, Null };
-
-// Returns if the datatype is a floating point type
-bool isFloatingPointType(DataType dtype);
-// Returns if the datatype is an integer type
-bool isIntegralType(DataType dtype);
+enum class DataType { Bool, Float, Half, Int, Null };
enum class ExprType {
Invalid,
TernaryOp,
ReductionOp,
BroadcastOp,
- WelfordOp,
- TransposeOp,
- ShiftOp,
- GatherOp,
Split,
Merge,
+
+ // Temporary: Kernel IR nodes
+ GridReduction,
+ ForLoop,
+ IfThenElse,
+ Allocate,
+ Sync,
+ KirUnaryOp,
+ KirBinaryOp,
+ KirTernaryOp,
+ KirReductionOp,
+ KirBroadcastOp,
};
enum class UnaryOpType {
Abs,
Acos,
- Address,
Asin,
Atan,
Atanh,
Floor,
Frac,
Gelu,
- Silu,
Lgamma,
Log,
Log10,
Sqrt,
Tan,
Tanh,
- Trunc,
-
- // Might be a bitwise operator or boolean operator.
- Not
+ Trunc
};
-// Primarily for Not, which could be Not a boolean, or a bitwise not.
-bool alsoBooleanOperator(const UnaryOpType uopt);
-
// TODO: Order of this list is important as it affects type promotion. it's not
// in the right order now.
enum class BinaryOpType {
Sub,
// TypeAs,
- // Integer output ops. If changing modify isIntegerOp
+ // Logical Ops
+ // Int operations, leave position of Mod we depend on its location of first
Mod,
CeilDiv,
- Lshift,
- Rshift,
-
- // Logical Ops
- // Int operations, leave position of Mod as first logical op see
- // isLogicalOp(BinaryOpType bopt)
+ And,
Eq,
GE,
GT,
LE,
LT,
- NE,
-
- // Maybe bitwise or boolean op, leave position of and as first bool/int
- // op. These are ops that have different operators based on output type. See
- // is boolean op. These ops also don't work on floating point inputs.
- And,
- Or,
- Xor
+ NE
};
-// Return if output of operator should be a boolean
-bool isIntegerOp(const BinaryOpType bopt);
-
-// Return if output of operator should be a boolean
-bool isLogicalOp(const BinaryOpType bopt);
-
-// Operations that could be a bitwise operation or a boolean operation depending
-// on input, for example bitwise_and is also used for boolean and in the jit
-bool alsoBooleanOperator(const BinaryOpType bopt);
-
-//! Operations that have tricky behaviors with all integer inputs
-bool noFullIntegerSupport(const BinaryOpType bopt);
-
enum class TernaryOpType { Clamp, Threshold, Where };
enum class ParallelType {
TIDy,
TIDx,
Vectorize,
- MisalignedVectorize,
Unroll,
- Unswitch,
Serial
};
Iteration,
Reduction,
BroadcastWithStride,
- BroadcastWithoutStride,
- Gather
+ BroadcastWithoutStride
};
-enum class SwizzleType { NoSwizzle, Transpose };
-
-// Returns if function needs an f suffix on the operator when operating on a
-// float value i.e. sin->sinf
-bool needFloatSuffix(UnaryOpType t);
-bool needFloatSuffix(BinaryOpType t);
-
ValType promote_type(const ValType& t1, const ValType& t2);
DataType promote_type(const DataType& t1, const DataType& t2);
+bool is_logical_op(const BinaryOpType& bot);
-// If type cannot be found (i.e. codegen does not support provided type) returns
-// DataType::Null
-TORCH_CUDA_CU_API DataType aten_to_data_type(const at::ScalarType& scalar_type);
-TORCH_CUDA_CU_API at::ScalarType data_type_to_aten(const DataType& data_type);
+DataType aten_to_data_type(const at::ScalarType& scalar_type);
+at::ScalarType data_type_to_aten(const DataType& data_type);
TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const ValType);
TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const DataType);
TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const MemoryType);
TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const IterType);
-std::string stringifyBooleanOp(const UnaryOpType);
-std::string stringifyBooleanOp(const BinaryOpType);
-
std::string stringifyThreadSize(const ParallelType);
std::string stringifyThread(const ParallelType);
-std::string typePrefix(const DataType);
-
-// TODO: ThreadDim should be BlockDim and BlockDim should be GridDim
-TORCH_CUDA_CU_API bool isParallelTypeThreadDim(ParallelType);
-TORCH_CUDA_CU_API bool isParallelTypeBlockDim(ParallelType);
-TORCH_CUDA_CU_API bool isParallelTypeThread(ParallelType);
-
-TORCH_CUDA_CU_API bool isParallelTypeVectorize(ParallelType);
TORCH_CUDA_CU_API c10::optional<std::string> inline_op_str(const UnaryOpType);
TORCH_CUDA_CU_API c10::optional<std::string> inline_op_str(const BinaryOpType);
-TORCH_CUDA_CU_API c10::optional<std::string> integer_op_str(const BinaryOpType);
TORCH_CUDA_CU_API c10::optional<std::string> cast_func_str(
const std::pair<DataType, DataType>&);
-TORCH_CUDA_CU_API size_t dataTypeSize(DataType type);
+size_t dataTypeSize(DataType type);
enum class LaunchConfigType {
Compatible,
+++ /dev/null
-
-#include <torch/csrc/jit/codegen/cuda/utils.h>
-
-#include <c10/util/string_view.h>
-
-#include <cstdlib>
-#include <unordered_map>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-namespace cuda {
-
-namespace {
-
-auto parseDebugDumpOptions() {
- std::unordered_map<DebugDumpOption, bool> options_map = {
- {DebugDumpOption::FusionIr, false},
- {DebugDumpOption::FusionIrMath, false},
- {DebugDumpOption::KernelIr, false},
- {DebugDumpOption::CudaKernel, false},
- {DebugDumpOption::CudaFull, false},
- {DebugDumpOption::CudaToFile, false},
- {DebugDumpOption::LaunchParam, false},
- {DebugDumpOption::FusionSegments, false},
- {DebugDumpOption::PrintRuntimeArgs, false},
- {DebugDumpOption::EffectiveBandwidth, false},
- {DebugDumpOption::FusionSegmentsDrawing, false},
- {DebugDumpOption::PrintPtxasLog, false},
- {DebugDumpOption::SchedulerDebug, false},
- {DebugDumpOption::ParallelDimensions, false}};
-
- if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_DUMP")) {
- c10::string_view options_view(dump_options);
- while (!options_view.empty()) {
- const auto end_pos = options_view.find_first_of(',');
- const auto token = options_view.substr(0, end_pos);
- if (token == "fusion_ir") {
- options_map[DebugDumpOption::FusionIr] = true;
- } else if (token == "fusion_ir_math") {
- options_map[DebugDumpOption::FusionIrMath] = true;
- } else if (token == "kernel_ir") {
- options_map[DebugDumpOption::KernelIr] = true;
- } else if (token == "cuda_kernel") {
- options_map[DebugDumpOption::CudaKernel] = true;
- } else if (token == "cuda_full") {
- options_map[DebugDumpOption::CudaFull] = true;
- } else if (token == "cuda_to_file") {
- options_map[DebugDumpOption::CudaToFile] = true;
- } else if (token == "launch_param") {
- options_map[DebugDumpOption::LaunchParam] = true;
- } else if (token == "segmented_fusion") {
- options_map[DebugDumpOption::FusionSegments] = true;
- } else if (token == "print_args") {
- options_map[DebugDumpOption::PrintRuntimeArgs] = true;
- } else if (token == "dump_eff_bandwidth") {
- options_map[DebugDumpOption::EffectiveBandwidth] = true;
- } else if (token == "draw_segmented_fusion") {
- options_map[DebugDumpOption::FusionSegmentsDrawing] = true;
- } else if (token == "ptxas_verbose") {
- options_map[DebugDumpOption::PrintPtxasLog] = true;
- } else if (token == "scheduler_params") {
- options_map[DebugDumpOption::SchedulerDebug] = true;
- } else if (token == "parallel_dimensions") {
- options_map[DebugDumpOption::ParallelDimensions] = true;
- } else {
- TORCH_CHECK(
- false,
- "Invalid debug dump option: '",
- token,
- "'\nAvailable options:\n",
- "\tfusion_ir, fusion_ir_math, kernel_ir, cuda_kernel, cuda_full,\n",
- "\tcuda_to_file, launch_param, segmented_fusion, print_args,\n",
- "\tdump_eff_bandwidth, draw_segmented_fusion, scheduler_params\n",
- "\tparallel_dimensions,\n");
- }
- options_view = (end_pos != c10::string_view::npos)
- ? options_view.substr(end_pos + 1)
- : "";
- }
- }
-
- return options_map;
-}
-
-} // namespace
-
-bool isDebugDumpEnabled(DebugDumpOption option) {
- const static auto dump_options = parseDebugDumpOptions();
- return dump_options.at(option);
-}
-
-bool useFallback() {
- const char* disable_fb_env = getenv("PYTORCH_NVFUSER_DISABLE_FALLBACK");
- return !(disable_fb_env ? atoi(disable_fb_env) : 0);
-}
-
-bool disableRNGUnrolling() {
- const char* disable_rng_unroll = getenv("PYTORCH_NVFUSER_DISABLE_RNG_UNROLL");
- return disable_rng_unroll ? atoi(disable_rng_unroll) : 0;
-}
-
-} // namespace cuda
-} // namespace fuser
-} // namespace jit
-} // namespace torch
namespace fuser {
namespace cuda {
-//! Types of debug print-outs
-//!
-//! These can be set through the `PYTORCH_NVFUSER_DUMP` environment variable
-//!
-enum class DebugDumpOption {
- FusionIr, //!< Dump the Fusion IR before lowering
- FusionIrMath, //!< Dump just the compute (math) part of the Fusion IR
- KernelIr, //!< Dump the compiler Kernel IR
- CudaKernel, //!< Dump the generated CUDA C++ kernel code
- CudaFull, //!< Dump the complete CUDA C++ code
- CudaToFile, //!< Dump CUDA Strings to File
- LaunchParam, //!< Dump the Launch parameters of kernel
- FusionSegments, //!< Dump Segmented Fusion Graph
- PrintRuntimeArgs, //!< Print the runtime arguments when launching kernels
- EffectiveBandwidth, //! Measure kernel performance and print effective
- //! bandwidth
- FusionSegmentsDrawing, //!< Dump Segmented Fusion Graph
- PrintPtxasLog, //!< Print the ptxas verbose log including register usage
- SchedulerDebug, //! Dump scheduler heuristic parameters
- ParallelDimensions //!< Dump known parallel dimensions
-};
-
-bool isDebugDumpEnabled(DebugDumpOption option);
-
-// Check if fallback path should be used which will dispatch to eagermode if any
-// errors are encountered. Helpful for debugging.
-bool useFallback();
-
-// Returns if unrolling should not be used for kernels with RNG in them.
-bool disableRNGUnrolling();
-
-//! Ceil integer division
+// Common Functions
constexpr int64_t ceilDiv(int64_t a, int64_t b) {
return (a + b - 1) / b;
}
-//! Simple mixin for suppressing copy & move operations, ex:
-//!
-//! class Foo : public NonCopyable {
-//! ...
-//! };
-//!
+// Simple mixin for suppressing copy & move operations, ex:
+//
+// class Foo : public NonCopyable {
+// ...
+// };
+//
class NonCopyable {
public:
NonCopyable() = default;
NonCopyable& operator=(const NonCopyable&) = delete;
};
-//! A generic root for a hierarchy of polymorphic classes:
-//! - It ensures virtual destructors
-//! - Provides the base->as<Derived>() and node->isA<T>() notation
+// A generic root for a hierarchy of polymorphic classes:
+// - It ensures virtual destructors
+// - Provides the base->as<Derived>() and node->isA<T>() notation
class PolymorphicBase {
public:
virtual ~PolymorphicBase() = default;
return downcast_ptr;
}
- //! Check if the runtime time is T (or derived from T)
- //!
- //! \note Don't use this for conditional casts. Instead, use:
- //!
- //! if (auto t = dynamic_cast<T>(p)) { ... }
- //!
- //! instead of:
- //!
- //! if (p->isA<T>()) { auto t = p->as<T>(); ... }
- //!
+ // Check if the runtime time is T (or derived from T)
+ //
+ // NOTE: Don't use this for conditional casts. Use:
+ //
+ // if (auto t = dynamic_cast<T>(p)) { ... }
+ //
+ // instead of:
+ //
+ // if (p->isA<T>()) { auto t = p->as<T>(); ... }
+ //
template <class T>
bool isA() const {
return dynamic_cast<const T*>(this) != nullptr;
if (n->kind() == prim::Constant || n->kind() == prim::AutogradZero ||
n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk ||
- n->kind() == prim::profile || n->kind() == prim::profile_ivalue)
+ n->kind() == prim::profile)
return true;
if (n->isMemberOf(differentiable_ops))
// before any other pass that could insert `prim::iprofile_value` node on
// `aten::_grad_sum_to_size` input.
InsertProfileNodesForSpecializeAutogradZero(pr_.get());
- // `InsertProfileNodesForCUDAFuser` inserts profile node for non-tensor
- // value
-#ifndef C10_MOBILE
- if (RegisterCudaFuseGraph::isRegistered()) {
- torch::jit::fuser::cuda::InsertProfileNodesForCUDAFuser(pr_.get());
- }
-#endif
GRAPH_DUMP("Profiled Graph: ", pr_->graph());
profiling_plan_ = ExecutionPlan(pr_->graph(), function_name_);
// fall-through
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/clear_profiling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
-#include <torch/csrc/jit/passes/cuda_graph_fuser.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/interpreter.h>
-#include <torch/csrc/jit/codegen/cuda/interface.h>
-#include <torch/csrc/jit/ir/ir.h>
-
namespace torch {
namespace jit {
}
bool needsProfiledInputs(Node* n) {
- if (tensorexpr::isSupported(n) ||
-#ifndef C10_MOBILE
- (RegisterCudaFuseGraph::isRegistered() && fuser::cuda::canFuseNode(n))
-#else
- false
-#endif
- ) {
+ if (tensorexpr::isSupported(n)) {
return true;
}
}
bool needsProfiledOutput(Node* n) {
- if (tensorexpr::isSupported(n) ||
-#ifndef C10_MOBILE
- (RegisterCudaFuseGraph::isRegistered() && fuser::cuda::canFuseNode(n))
-#else
- false
-#endif
- ) {
+ if (tensorexpr::isSupported(n)) {
return true;
}
result.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
return result
- def assertGraphContains(self, graph, kind, consider_subgraphs=False):
-
- if consider_subgraphs:
- strgraph = str(graph)
- count = strgraph.count(kind) - strgraph.count('with {}'.format(kind))
- self.assertTrue(count > 0)
- return
-
- def nodes(block):
- out = []
- for node in block.nodes():
- if node.kind() == kind:
- out.append(node)
- for block in node.blocks():
- out += nodes(block)
- return out
-
- out_nodes = nodes(graph)
- self.assertTrue(len(out_nodes) > 0)
+ def assertGraphContains(self, graph, kind):
+ self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
def perform_assert(graph, kind, actual, expected, consider_subgraphs):