return format(str, env);
}
-// If there is a single user of a node and it's a chunk operation, returns
-// that user. Returns nullptr otherwise.
-static Node* usedInFusedChunk(const Value* input) {
- auto uses = input->uses();
- if (uses.size() == 1) {
- Node* user = uses[0].user;
- if (user->kind() == prim::ConstantChunk) {
- return user;
- }
- }
- return nullptr;
-}
-
static void emitIndexingFor(
std::ostream& out,
const std::string& tensor,
}
// TODO: handle cases where we need to generate > 2^32 element tensors
-std::tuple<
- std::string,
- std::vector<PartitionDesc>,
- std::vector<PartitionDesc>,
- bool>
-generateKernel(
+std::string generateKernel(
const std::string& name,
const Graph& graph,
- const std::vector<TensorDesc>& input_desc,
- const std::vector<TensorDesc>& output_desc,
+ const std::vector<std::pair<const Value*, const TensorDesc>>& inputs,
+ const std::vector<std::pair<const Value*, const TensorDesc>>& outputs,
const bool use_cuda) {
TemplateEnv env;
env.s("kernelName", name);
env));
};
- // Writes input parameters and creates flattened inputs
- std::vector<PartitionDesc> chunk_desc;
- std::vector<std::pair<const Value*, const TensorDesc&>> flat_inputs;
- {
- size_t input_index = 0;
- for (const auto& p : graph.inputs()) {
- if (const Node* chunk = usedInFusedChunk(p)) {
- int64_t dim = chunk->i(attr::dim);
- int64_t chunks = chunk->i(attr::chunks);
- chunk_desc.emplace_back(input_desc[input_index++], chunks, dim);
- for (const auto* o : chunk->outputs()) {
- flat_inputs.emplace_back(o, *chunk_desc.back().subTensorDesc());
- }
- } else {
- chunk_desc.emplace_back();
- flat_inputs.emplace_back(p, input_desc[input_index++]);
- }
- }
- for (const auto& input : flat_inputs) {
- emitFormal(input.first, input.second);
- }
+ // Writes input parameters
+ for (const auto& input : inputs) {
+ emitFormal(input.first, input.second);
}
+
- // Writes output parameters and creates flattened outputs
- std::vector<PartitionDesc> concat_desc;
- std::vector<std::pair<const Value*, TensorDesc>> flat_output_nodes;
- {
- size_t i = 0;
- for (const auto& o : graph.outputs()) {
- const auto& desc = output_desc[i++];
- if (o->node()->kind() != prim::FusedConcat) {
- emitFormal(o, desc);
- concat_desc.emplace_back();
- flat_output_nodes.emplace_back(o, desc);
- } else {
- const auto cat = o->node();
- concat_desc.emplace_back(desc, cat->inputs().size(), cat->i(attr::dim));
- for (const auto& c : cat->inputs()) {
- emitFormal(c, *concat_desc.back().subTensorDesc());
- flat_output_nodes.emplace_back(c, desc);
- }
- }
- }
+ // Writes output parameters
+ for (const auto& output : outputs) {
+ emitFormal(output.first, output.second);
}
// Acquires input values
bool has_half_tensor = false;
size_t formal_count = 0;
- for (const auto input : flat_inputs) {
+ for (const auto input : inputs) {
auto p = input.first;
env.s("node", valueName(p));
env.d("formal", formal_count++);
}
// Generates writes to output tensors
- for (const auto& output : flat_output_nodes) {
- const auto& o = output.first;
+ for (const auto& output : outputs) {
env.d("formal", formal_count++);
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
- env.s("node", valueName(o));
+ env.s("node", valueName(output.first));
// Acquires and converts (if needed) outputs
// Note: conversion to half is only supported for CUDA kernels.
if (debugFuser()) {
std::cerr << "fusion code:" << code_string << std::endl;
}
- return std::make_tuple(
- code_string, std::move(chunk_desc), std::move(concat_desc), has_random);
+ return code_string;
}
} // namespace fuser
namespace fuser {
// Creates a CPU or CUDA kernel for the given graph.
-// Returns a tuple consisting of the generated code (as a string),
-// two vectors of PartitionDescs, the chunk and concat descriptions,
-// respectively, and a bool indicating whether the generated code
-// generates random numbers.
-// TODO: the partition descriptions should be generated by the executor.
-TORCH_API std::tuple<
- std::string,
- std::vector<PartitionDesc>,
- std::vector<PartitionDesc>,
- bool>
-generateKernel(
+// Returns the C++ or CUDA string implementing the kernel.
+TORCH_API std::string generateKernel(
const std::string& name,
const Graph& graph,
- const std::vector<TensorDesc>& input_desc,
- const std::vector<TensorDesc>& output_desc,
+ const std::vector<std::pair<const Value*, const TensorDesc>>& inputs,
+ const std::vector<std::pair<const Value*, const TensorDesc>>& outputs,
const bool use_cuda);
} // namespace fuser
PropagateInputShapes(graph);
- // Creates output descriptions
+ // Creates chunk and flattened input descriptions
+ std::vector<PartitionDesc> chunk_desc;
+ std::vector<std::pair<const Value*, const TensorDesc>> flat_inputs;
+ {
+ size_t input_index = 0;
+ for (const auto& p : graph->inputs()) {
+ if (const Node* chunk = usedInFusedChunk(p)) {
+ int64_t dim = chunk->i(attr::dim);
+ int64_t chunks = chunk->i(attr::chunks);
+ chunk_desc.emplace_back(input_desc[input_index++], chunks, dim);
+ for (const auto* o : chunk->outputs()) {
+ flat_inputs.emplace_back(o, *chunk_desc.back().subTensorDesc());
+ }
+ } else {
+ chunk_desc.emplace_back();
+ flat_inputs.emplace_back(p, input_desc[input_index++]);
+ }
+ }
+ }
+
+ // Creates output, concat, and flattened output descriptions
std::vector<TensorDesc> output_desc;
- for (const Value* output : graph->outputs()) {
+ std::vector<PartitionDesc> concat_desc;
+ std::vector<std::pair<const Value*, const TensorDesc>> flat_outputs;
+ for (const Value* o : graph->outputs()) {
+ // Creates output description
std::vector<int64_t> sizes = map_size;
- if (output->node()->kind() == prim::FusedConcat) {
- sizes.at(output->node()->i(attr::dim)) *= output->node()->inputs().size();
+ if (o->node()->kind() == prim::FusedConcat) {
+ sizes.at(o->node()->i(attr::dim)) *= o->node()->inputs().size();
}
- auto scalar_type =
- output->type()->expect<c10::TensorType const>()->scalarType();
+ auto scalar_type = o->type()->expect<c10::TensorType const>()->scalarType();
auto type = CompleteTensorType::create(scalar_type, device, sizes);
- output_desc.emplace_back(std::move(type));
+ output_desc.emplace_back(type);
+ const auto& desc = output_desc.back();
+
+ // Creates concat and flattened output descriptions (relies on output desc)
+ if (o->node()->kind() != prim::FusedConcat) {
+ concat_desc.emplace_back();
+ flat_outputs.emplace_back(o, desc);
+ } else {
+ const auto cat = o->node();
+ concat_desc.emplace_back(desc, cat->inputs().size(), cat->i(attr::dim));
+ for (const auto& c : cat->inputs()) {
+ flat_outputs.emplace_back(c, *concat_desc.back().subTensorDesc());
+ }
+ }
}
const std::string name = "kernel_" + std::to_string(next_kernel_id++);
const bool use_cuda = device.is_cuda();
- std::string code;
- std::vector<PartitionDesc> chunk_desc;
- std::vector<PartitionDesc> concat_desc;
- bool has_random;
- std::tie(code, chunk_desc, concat_desc, has_random) =
- generateKernel(name, *graph, input_desc, output_desc, use_cuda);
-
+ std::string code = generateKernel(name, *graph, flat_inputs, flat_outputs, use_cuda);
std::shared_ptr<FusedKernel> fused_kernel;
if (use_cuda) {
#if USE_CUDA_FUSER
output_desc,
chunk_desc,
concat_desc,
- has_random);
+ spec.hasRandom());
#else
throw std::runtime_error("CUDA Fusion is not supported on this build.");
#endif // USE_CUDA_FUSER
output_desc,
chunk_desc,
concat_desc,
- has_random);
+ spec.hasRandom());
#else
throw std::runtime_error("CPU Fusion is not supported on this build.");
#endif // USE_CPU_FUSER
// TODO: allow abstract kernels to use multiple generated kernels
// TODO: allow abstract kernels to reuse generated kernels from common pool
struct TORCH_API KernelSpec {
- KernelSpec(const int64_t _key, const std::shared_ptr<Graph>& _graph)
- : key_{_key},
- graph_{_graph},
- code_{_graph},
- nInputs_{_graph->inputs().size()},
- inputBroadcastGroups_{},
- inputChunks_{},
- kernels_{} {}
+ // Note: assumes the spec is a single block
+ // Note: This is the appropriate place to generalize if you want to add other
+ // passes to upfront compilation that walk the graph.
+ KernelSpec(
+ const int64_t _key,
+ const std::shared_ptr<Graph>& _graph)
+ : key_{_key},
+ graph_{_graph},
+ code_{_graph},
+ nInputs_{_graph->inputs().size()},
+ inputBroadcastGroups_{},
+ inputChunks_{},
+ has_random_{false},
+ kernels_{} {
+
+ for (const auto& n : graph_->nodes()) {
+ if (n->kind() == aten::rand_like) {
+ has_random_ = true;
+ break;
+ }
+ }
+ }
// Getters
int64_t key() const {
return inputChunks_;
}
+ bool hasRandom() const {
+ return has_random_;
+ }
+
// Cache functions
c10::optional<std::shared_ptr<FusedKernel>> findKernel(
const ArgSpec& arg_spec) const {
uint64_t nInputs_;
std::vector<std::vector<int64_t>> inputBroadcastGroups_;
std::vector<PartitionInfo> inputChunks_;
+ bool has_random_;
mutable std::mutex mutex_;
mutable std::
unordered_map<ArgSpec, std::shared_ptr<FusedKernel>, torch::hash<ArgSpec>>